Skip to content

Commit 942b4a8

Browse files
hongpeng-guomohitjain2504
authored andcommitted
[Tune][Fix]Remove the clear_checkpoint function during Trial restoration error handling. (ray-project#48532)
This PR removes the `clear_checkpoint` function, so that Tune doesn't try to "restart trials from scratch. `clear_checkpoint` solved for a legacy use case that doesn't apply anymore, and "restoration failures" are also now an edge case for function Trainables and Ray Train usage. --------- Signed-off-by: Hongpeng Guo <[email protected]> Signed-off-by: mohitjain2504 <[email protected]>
1 parent 36561dc commit 942b4a8

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

python/ray/tune/experiment/trial.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -793,11 +793,11 @@ def get_error(self) -> Optional[TuneError]:
793793
return None
794794

795795
def _handle_restore_error(self, exc: Exception):
796+
# For Restoration errors, we only increment the restore failure count
797+
# if the number of failures exceeds the restore retry limit.
796798
if self.temporary_state.num_restore_failures >= int(
797799
os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)
798800
):
799-
# Restore was unsuccessful, try again without checkpoint.
800-
self.clear_checkpoint()
801801
self.run_metadata.num_failures += 1
802802
else:
803803
self.temporary_state.num_restore_failures += 1
@@ -883,12 +883,6 @@ def should_checkpoint(self):
883883
def has_checkpoint(self) -> bool:
884884
return self.checkpoint is not None
885885

886-
def clear_checkpoint(self):
887-
if self.latest_checkpoint_result:
888-
self.latest_checkpoint_result.checkpoint = None
889-
self.temporary_state.restoring_from = None
890-
self.run_metadata.invalidate_cache()
891-
892886
def on_checkpoint(self, checkpoint_result: _TrainingResult):
893887
"""Hook for handling checkpoints taken by the Trainable.
894888

python/ray/tune/tests/test_tuner_restore.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,21 @@ def test_tuner_restore_latest_available_checkpoint(
537537

538538
@pytest.mark.parametrize("retry_num", [0, 2])
539539
def test_restore_retry(ray_start_2_cpus, tmpdir, retry_num):
540-
"""Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`."""
540+
"""
541+
Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`.
542+
543+
This unit test holds the following hyperparameters:
544+
- `retry_num`: Maximum number of retry attempts for restoring a trial.
545+
This value is assigned to the environment variable `TUNE_RESTORE_RETRY_NUM`.
546+
If the restoration fails after retry_num attempts, the trial increments its
547+
counter of total number of failures by 1.
548+
549+
- `retry_num_to_fail`: Number of restore attempts to fail. In this test,
550+
retry_num_to_fail is set to 2, causing the first two restore attempts to fail.
551+
552+
- `max_failures`: Maximum allowable failures during training. Here, max_failures is
553+
set to 2, meaning the training process will terminate after two total failures.
554+
"""
541555

542556
class MockTrainable(Trainable):
543557
"""A trainable that can generate one failure during training and
@@ -546,7 +560,7 @@ class MockTrainable(Trainable):
546560
def setup(self, config):
547561
self.idx = 0
548562
self.tag_file_path = config["tag_file_path"]
549-
self.retry_num_to_fail = config.get("retry_num_to_fail", 2)
563+
self.retry_num_to_fail = 2
550564
self._is_restored = False
551565

552566
def step(self):
@@ -592,7 +606,7 @@ def load_checkpoint(self, checkpoint_dir):
592606
name="tryout_restore",
593607
stop={"training_iteration": 5},
594608
storage_path=str(tmpdir),
595-
failure_config=FailureConfig(max_failures=1),
609+
failure_config=FailureConfig(max_failures=2),
596610
checkpoint_config=CheckpointConfig(checkpoint_frequency=1),
597611
),
598612
param_space={"tag_file_path": tag_file},

0 commit comments

Comments
 (0)