Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion smdebug/core/hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard Library
import atexit
import os
import re as _re
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -480,7 +481,7 @@ def export_collections(self):
return
num_workers = 1 # Override
self.collection_manager.set_num_workers(num_workers)
collection_file_name = f"{self.worker}_collections.json"
collection_file_name = f"{self.worker}_{os.getpid()}_collections.json"
self.collection_manager.export(self.out_dir, collection_file_name)

def _get_reduction_tensor_name(self, tensor_name, reduction_name, abs):
Expand Down
6 changes: 3 additions & 3 deletions smdebug/core/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_step_num_str(self):

def get_filename(self):
step_num_str = self.get_step_num_str()
event_filename = f"{step_num_str}_{self.worker_name}.tfevents"
event_filename = f"{step_num_str}_{self.worker_name}_{os.getpid()}.tfevents"
return event_filename

@classmethod
Expand All @@ -48,7 +48,7 @@ def match_regex(cls, s):
@classmethod
def load_filename(cls, s, print_error=True):
event_file_name = os.path.basename(s)
m = re.search("(.*)_(.*).tfevents$", event_file_name)
m = re.search("(.*)_(.*)_(.*).tfevents$", event_file_name)
if m:
step_num = int(m.group(1))
worker_name = m.group(2)
Expand Down Expand Up @@ -127,7 +127,7 @@ def next_index_prefix_for_step(step_num):
def _get_index_key(trial_prefix, step_num, worker_name):
index_prefix_for_step_str = IndexFileLocationUtils.get_index_prefix_for_step(step_num)
step_num_str = format(step_num, "012")
index_filename = format(f"{step_num_str}_{worker_name}.json")
index_filename = format(f"{step_num_str}_{worker_name}_{os.getpid()}.json")
index_key = os.path.join(trial_prefix, "index", index_prefix_for_step_str, index_filename)
return index_key

Expand Down
4 changes: 2 additions & 2 deletions smdebug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_worker_name_from_collection_file(filename: str) -> str:
:param filename: str
:return: worker_name: str
"""
worker_name_regex = re.compile(".*/collections/.+/(.+)_collections.(json|ts)")
worker_name_regex = re.compile(".*/collections/.+/(.+)_(.+)_collections.(json|ts)")
worker_name = re.match(worker_name_regex, filename).group(1)
if worker_name[0] == "_":
worker_name = deserialize_tf_device(worker_name)
Expand All @@ -203,7 +203,7 @@ def parse_worker_name_from_file(filename: str) -> str:
:return: worker_name: str
"""
# worker_2 = /tmp/ts-logs/index/000000001/000000001230_worker_2.json
worker_name_regex = re.compile(".+\/\d+_(.+)\.(json|csv|tfevents)$")
worker_name_regex = re.compile(".+\/\d+_(.+)_(.+)\.(json|csv|tfevents)$")
worker_name_regex_match = re.match(worker_name_regex, filename)
if worker_name_regex_match is None:
raise IndexReaderException(f"Invalid File Found: {filename}")
Expand Down
4 changes: 2 additions & 2 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def export_collections(self):
if len(self.device_map):
for device, serialized_device in self.device_map.items():
if self.save_all_workers is True or device == self.chief_worker:
collection_file_name = f"{serialized_device}_collections.json"
collection_file_name = f"{serialized_device}_{os.getpid()}_collections.json"
self.collection_manager.export(self.out_dir, collection_file_name)
return

# below is used in these cases
# if mirrored and device_map is empty (CPU training)
# if horovod/param server and worker == chief worker
collection_file_name = f"{self.worker}_collections.json"
collection_file_name = f"{self.worker}_{os.getpid()}_collections.json"
self.collection_manager.export(self.out_dir, collection_file_name)

def _get_num_workers(self):
Expand Down