Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions smdebug/core/index_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def read_index_files(
self.logger.debug(f'Loaded Index Files: {",".join(index_files)}')
for index_file in index_files:
if self.index_file_cache.has_not_read(index_file):

step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
if (
range_steps is not None and step_in_range(range_steps, step)
Expand All @@ -319,9 +320,15 @@ def read_index_files(
object_requests.append(
ReadObjectRequest(format(f"s3://{self.bucket_name}/") + index_file)
)
self.index_file_cache.add(index_file, start_after_key)
self.logger.debug(f"Will read index_file: {index_file}")
self.index_file_cache.add(index_file, start_after_key)
else:
self.logger.debug(
f"index_file:{index_file} Indexcache contents:{self.index_file_cache.lookup_set}"
)

responses = S3Handler.get_objects(object_requests)
assert len(responses) == len(object_requests)
return responses, steps, start_after_key, workers

def list_index_files(self, start_after_key=None):
Expand Down Expand Up @@ -416,7 +423,11 @@ def read_index_files(
start_after_index = bisect_left(index_files, start_after_key)
else:
start_after_index = 0
self.logger.debug(f"Found index_files:{index_files}")
index_files = index_files[start_after_index:] # ignore files we have already read
self.logger.debug(
f"Curtailed Found index_files to :{index_files} start_after_index:{start_after_index} start_after_key:{start_after_key}"
)
for index_file in index_files:
if self.index_file_cache.has_not_read(index_file):
step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
Expand All @@ -428,9 +439,15 @@ def read_index_files(
self.logger.debug(
f"Sagemaker-Debugger: Read {os.path.getsize(index_file)} bytes from file {index_file}"
)
self.logger.debug(f"Will read index file:{index_file}")
with open(index_file) as f:
responses.append(f.read().encode())
self.index_file_cache.add(index_file, start_after_key)
self.index_file_cache.add(index_file, start_after_key)
else:
self.logger.debug(
f"IndexFile:{index_file} Indexcache contents:{self.index_file_cache.lookup_set}"
)

if len(index_files) > 0:
start_after_key = index_files[-1] # Last file that we have read
return responses, steps, start_after_key, workers
17 changes: 11 additions & 6 deletions smdebug/trials/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,14 @@ def _populate_workers_for_global_step(self, step, worker) -> None:
if step not in self.workers_for_global_step:
self.workers_for_global_step[step] = set()
self.workers_for_global_step[step].add(worker)
self.logger.debug(f"Populated workers for global step:{step} worker: {worker}")

if (
len(self.workers_for_global_step[step]) == self.num_workers
and step > self.last_complete_step
):
self.last_complete_step = step
self.logger.debug(f"Populating last completing step to: {step}")

def _populate_global_step_to_tensor_name_map(self, tensor: TensorLocation, step_num) -> None:
"""
Expand Down Expand Up @@ -514,13 +517,14 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
"""
all_steps = self.steps(mode=mode, show_incomplete_steps=True)
bisect_idx = bisect_left(all_steps, step)
g_step = self._global_step_currently(mode, step)

if bisect_idx < len(all_steps):
if all_steps[bisect_idx] > step:
if self.last_complete_step > step:
if self.last_complete_step > g_step:
return StepState.UNAVAILABLE
return StepState.NOT_YET_AVAILABLE
elif all_steps[bisect_idx] == step:
g_step = self.global_step(mode, step)
if len(self.workers_for_global_step[g_step]) == self.num_workers:
return StepState.AVAILABLE
elif self.loaded_all_steps is True:
Expand All @@ -531,9 +535,9 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
f"Step {step} of mode {mode} was marked complete because the job is complete"
)
return StepState.AVAILABLE
elif step <= self.last_complete_step:
elif g_step <= self.last_complete_step:
self.logger.info(
f"Step {step} of mode {mode} was written only by workers: {self.workers_for_global_step[step]}"
f"Step {step} of mode {mode} was written only by workers: {self.workers_for_global_step[g_step]}"
)
self.logger.info(
f"Step {step} of mode {mode} was marked complete because the last complete step is {self.last_complete_step}"
Expand All @@ -552,7 +556,7 @@ def _load_tensors(self):
def _update_last_index_token(self, new_index_token: str) -> None:
"""
This function updates the last_index_token in the following scenarios:
1. last_complete_step > last_index_token_step :
1. last_complete_step >= last_index_token_step :
this means that the token isn't pointing to the latest completed step
2. number of steps available ( complete or incomplete ) - (last_completed_step+1) > window_size_limit:
we maintain a window to stop querying for older steps that have not completed.
Expand All @@ -569,7 +573,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
)

# Case 1:
if self.last_complete_step > last_index_token_step:
if self.last_complete_step >= last_index_token_step:
prefix = IndexFileLocationUtils.get_prefix_from_index_file(new_index_token)
# sort lexicographically and select the last worker
last_worker = sorted(list(self.worker_set))[-1]
Expand All @@ -579,6 +583,7 @@ def _update_last_index_token(self, new_index_token: str) -> None:
self.last_index_token = IndexFileLocationUtils.get_index_key_for_step(
prefix, self.last_complete_step, last_worker_serialized
)
self.logger.debug(f"Updated last index token to:{self.last_index_token}")

# Case 2:
available_step = self._global_to_mode.keys()
Expand Down