Skip to content

Commit 631fd75

Browse files
Enable file rotation for Horovod trace file (#33)
* Hvd file reader and rotation of files Co-authored-by: Anirudh <[email protected]>
1 parent 394ecc8 commit 631fd75

File tree

18 files changed

+836
-27
lines changed

18 files changed

+836
-27
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
This script is a simple MNIST training script which uses Horovod and Tensorflow 2.x Keras interface.
3+
It is designed to be used with SageMaker Debugger in an official SageMaker Framework container (i.e. AWS Deep Learning Container).
4+
You will notice that this script looks exactly like a normal TensorFlow training script.
5+
The hook needed by SageMaker Debugger to save tensors during training will be automatically added in those environments.
6+
The hook will load configuration from json configuration that SageMaker will put in the training container from the
7+
configuration provided using the SageMaker python SDK when creating a job.
8+
For more information, please refer to https://github.com/awslabs/sagemaker-debugger/blob/master/docs/sagemaker.md
9+
10+
This script has been adapted from an example in Horovod repository https://github.com/uber/horovod
11+
"""
12+
13+
# Standard Library
14+
import argparse
15+
from datetime import datetime
16+
17+
# Third Party
18+
import horovod.tensorflow.keras as hvd
19+
import tensorflow.compat.v2 as tf
20+
21+
22+
def str2bool(v):
23+
if isinstance(v, bool):
24+
return v
25+
if v.lower() in ("yes", "true", "t", "y", "1"):
26+
return True
27+
elif v.lower() in ("no", "false", "f", "n", "0"):
28+
return False
29+
else:
30+
raise argparse.ArgumentTypeError("Boolean value expected.")
31+
32+
33+
def get_data(batch_size):
34+
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(
35+
path="mnist-%d.npz" % hvd.rank()
36+
)
37+
38+
dataset = tf.data.Dataset.from_tensor_slices(
39+
(
40+
tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
41+
tf.cast(mnist_labels, tf.int64),
42+
)
43+
)
44+
dataset = dataset.repeat().shuffle(10000).batch(batch_size)
45+
return dataset
46+
47+
48+
def get_model():
49+
mnist_model = tf.keras.Sequential(
50+
[
51+
tf.keras.layers.Conv2D(32, [3, 3], activation="relu", input_shape=(28, 28, 1)),
52+
tf.keras.layers.Conv2D(64, [3, 3], activation="relu"),
53+
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
54+
tf.keras.layers.Dropout(0.25),
55+
tf.keras.layers.Flatten(),
56+
tf.keras.layers.Dense(128, activation="relu"),
57+
tf.keras.layers.Dropout(0.5),
58+
tf.keras.layers.Dense(10, activation="softmax"),
59+
]
60+
)
61+
return mnist_model
62+
63+
64+
def train(model, dataset, epoch):
65+
# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
66+
# uses hvd.DistributedOptimizer() to compute gradients.
67+
model.compile(
68+
loss=tf.losses.SparseCategoricalCrossentropy(),
69+
optimizer=opt,
70+
metrics=["accuracy"],
71+
experimental_run_tf_function=False,
72+
)
73+
74+
# Create a TensorBoard callback
75+
logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
76+
77+
tboard_callback = tf.keras.callbacks.TensorBoard(
78+
log_dir=logs, histogram_freq=1, profile_batch=2
79+
)
80+
callbacks = [
81+
# Horovod: broadcast initial variable states from rank 0 to all other processes.
82+
# This is necessary to ensure consistent initialization of all workers when
83+
# training is started with random weights or restored from a checkpoint.
84+
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
85+
# Horovod: average metrics among workers at the end of every epoch.
86+
#
87+
# Note: This callback must be in the list before the ReduceLROnPlateau,
88+
# TensorBoard or other metrics-based callbacks.
89+
hvd.callbacks.MetricAverageCallback(),
90+
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
91+
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
92+
# the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
93+
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
94+
tboard_callback,
95+
]
96+
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
97+
if hvd.rank() == 0:
98+
callbacks.append(tf.keras.callbacks.ModelCheckpoint("checkpoint-{epoch}.h5"))
99+
100+
# Horovod: write logs on worker 0.
101+
verbose = 1 if hvd.rank() == 0 else 0
102+
103+
# Train the model.
104+
# Horovod: adjust number of steps based on number of GPUs.
105+
model.fit(
106+
dataset,
107+
steps_per_epoch=500 // hvd.size(),
108+
callbacks=callbacks,
109+
epochs=epoch,
110+
verbose=verbose,
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
# Training settings
116+
parser = argparse.ArgumentParser(description="Tensorflow2 MNIST Example")
117+
parser.add_argument("--use_only_cpu", type=str2bool, default=False)
118+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for")
119+
parser.add_argument("--model_dir", type=str, default="/tmp/mnist_model")
120+
121+
args = parser.parse_args()
122+
123+
# constants
124+
lr = 0.001
125+
batch_size = 64
126+
127+
# Horovod: initialize library.
128+
hvd.init()
129+
130+
# Horovod: pin GPU to be used to process local rank (one GPU per process)
131+
gpus = tf.config.experimental.list_physical_devices("GPU")
132+
for gpu in gpus:
133+
tf.config.experimental.set_memory_growth(gpu, True)
134+
if gpus:
135+
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
136+
137+
# Horovod: adjust learning rate based on number of GPUs.
138+
opt = tf.optimizers.Adam(lr * hvd.size())
139+
140+
# Horovod: add Horovod DistributedOptimizer.
141+
opt = hvd.DistributedOptimizer(opt)
142+
143+
dataset = get_data(batch_size)
144+
mnist_model = get_model()
145+
146+
train(model=mnist_model, dataset=dataset, epoch=args.num_epochs)

smdebug/core/hook.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __init__(
228228
self.timeline_writer = TimelineFileWriter(
229229
profiler_config_parser=self.profiler_config_parser
230230
)
231+
self.hvd_reader = None
231232

232233
if is_sagemaker_job() and SageMakerFileMetricsWriter is not None:
233234
self.metrics_writer = SageMakerFileMetricsWriter()
@@ -515,6 +516,10 @@ def _cleanup(self):
515516

516517
self.timeline_writer.close()
517518

519+
# close the Horovod file reader thread if it has been enabled
520+
if self.hvd_reader and self.hvd_reader.enabled:
521+
self.hvd_reader.close()
522+
518523
training_has_ended(self.out_dir)
519524
if self.first_process is True:
520525
remove_claim_file(self.out_dir)

smdebug/core/locations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class TraceFileLocation:
120120
# $ENV_BASE_FOLDER/framework/pevents/$START_TIME_YYYYMMDDHR/
121121
# $FILEEVENTENDTIMEUTCINEPOCH_{$ENV_NODE_ID}_model_timeline.json
122122
@staticmethod
123-
def get_file_location(timestamp, base_dir):
123+
def get_file_location(timestamp, base_dir, suffix=PYTHONTIMELINE_SUFFIX):
124124
env_base_location = base_dir
125125
date_hour = time.strftime(
126126
TRACE_DIRECTORY_FORMAT, time.gmtime(timestamp / CONVERT_TO_MICROSECS)
@@ -137,7 +137,7 @@ def get_file_location(timestamp, base_dir):
137137
+ "_"
138138
+ worker_id
139139
+ "_"
140-
+ PYTHONTIMELINE_SUFFIX,
140+
+ suffix,
141141
)
142142
return file_path
143143

smdebug/core/tfevent/timeline_file_writer.py

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from smdebug.core.locations import TraceFileLocation
3434
from smdebug.core.logger import get_logger
3535
from smdebug.core.utils import ensure_dir, get_node_id
36-
from smdebug.profiler.profiler_constants import CONVERT_TO_MICROSECS
36+
from smdebug.profiler.profiler_constants import CONVERT_TO_MICROSECS, PYTHONTIMELINE_SUFFIX
3737

3838
logger = get_logger()
3939

@@ -88,7 +88,18 @@ def to_json(self):
8888
"ph": self.phase,
8989
"ts": self.rel_ts_micros,
9090
}
91-
if self.phase == "X":
91+
92+
# handle Instant event
93+
if self.phase == "i":
94+
if self.args:
95+
# Instant events have a field unique to them called scope.
96+
# scope can be "g" - global, "p" - process, "t" - thread.
97+
# parsing this value that is being passed as args.
98+
s = self.args["s"] if "s" in self.args else "t"
99+
json_dict.update({"s": s})
100+
if "s" in self.args:
101+
self.args.pop("s")
102+
elif self.phase == "X":
92103
json_dict.update({"dur": self.duration})
93104

94105
if self.args:
@@ -105,7 +116,7 @@ class TimelineFileWriter:
105116
and asynchronously writes TimelineRecord to the file.
106117
"""
107118

108-
def __init__(self, profiler_config_parser, max_queue=100):
119+
def __init__(self, profiler_config_parser, max_queue=100, suffix=PYTHONTIMELINE_SUFFIX):
109120
"""Creates a `TimelineFileWriter` and a trace event file to write to.
110121
This event file will contain TimelineRecord as JSON strings, which are written to
111122
disk via the write_record method.
@@ -120,14 +131,34 @@ def __init__(self, profiler_config_parser, max_queue=100):
120131
self._worker = _TimelineLoggerThread(
121132
queue=self._event_queue,
122133
sentinel_event=self._sentinel_event,
123-
base_start_time=self.start_time_since_epoch_in_micros,
134+
base_start_time_in_us=self.start_time_since_epoch_in_micros,
124135
profiler_config_parser=self._profiler_config_parser,
136+
suffix=suffix,
125137
)
126138
self._worker.start()
127139

140+
def _update_base_start_time(self, base_start_time_in_us):
141+
"""
142+
Some trace files such as the Horovod trace file may start before this timeline
143+
writer is initialized. In such case, use this function to update the start time
144+
since epoch in micros.
145+
"""
146+
if base_start_time_in_us != self.start_time_since_epoch_in_micros:
147+
self.start_time_since_epoch_in_micros = base_start_time_in_us
148+
self._worker._update_base_start_time(base_start_time_in_us)
149+
128150
def write_trace_events(
129151
self, timestamp, training_phase="", op_name="", phase="X", duration=0, **kwargs
130152
):
153+
"""
154+
Creates TimelineRecord from the details passed as parameters, and enqueues an event for write.
155+
:param timestamp:start_time for the event (in seconds)
156+
:param training_phase: strings like, data_iteration, forward, backward, operations etc
157+
:param op_name: more details about phase like whether dataset or iterator
158+
:param phase: phase of trace event. default is 'X'
159+
:param duration: any duration manually computed (in seconds)
160+
:param kwargs: other params. can be process id and thread id
161+
"""
131162
if not self._worker._healthy or not self._profiler_config_parser.profiling_enabled:
132163
return
133164
duration_in_us = int(duration * CONVERT_TO_MICROSECS) # convert to micro seconds
@@ -167,7 +198,13 @@ class _TimelineLoggerThread(threading.Thread):
167198
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/writer/event_file_writer.py#L133"""
168199

169200
def __init__(
170-
self, queue, sentinel_event, base_start_time, profiler_config_parser, verbose=False
201+
self,
202+
queue,
203+
sentinel_event,
204+
base_start_time_in_us,
205+
profiler_config_parser,
206+
verbose=False,
207+
suffix=PYTHONTIMELINE_SUFFIX,
171208
):
172209
"""Creates a _TimelineLoggerThread."""
173210
threading.Thread.__init__(self)
@@ -180,14 +217,23 @@ def __init__(
180217
self.tensor_table = collections.defaultdict(int)
181218
self.continuous_fail_count = 0
182219
self.is_first = True
183-
self.last_event_end_time_in_us = int(round(base_start_time))
220+
self._update_base_start_time(base_start_time_in_us)
221+
self._healthy = True
222+
self._profiler_config_parser = profiler_config_parser
223+
self.node_id = get_node_id()
224+
self.suffix = suffix
225+
226+
def _update_base_start_time(self, base_start_time_in_us):
227+
"""
228+
Some trace files such as the Horovod trace file may start before this timeline
229+
writer is initialized. In such case, use this function to update the start time
230+
since epoch in micros.
231+
"""
232+
self.last_event_end_time_in_us = int(round(base_start_time_in_us))
184233
self.last_file_close_time_in_us = self.last_event_end_time_in_us
185234
self.cur_hour = datetime.utcfromtimestamp(
186235
self.last_file_close_time_in_us / CONVERT_TO_MICROSECS
187236
).hour
188-
self._healthy = True
189-
self._profiler_config_parser = profiler_config_parser
190-
self.node_id = get_node_id()
191237

192238
def run(self):
193239
while True:
@@ -315,13 +361,20 @@ def write_event(self, record):
315361
json_dict = {"name": "process_sort_index", "ph": "M", "pid": 0, "args": args}
316362
self._writer.write(json.dumps(json_dict) + ",\n")
317363

318-
args = {"name": record.training_phase}
319-
json_dict = {"name": "process_name", "ph": "M", "pid": tensor_idx, "args": args}
320-
self._writer.write(json.dumps(json_dict) + ",\n")
364+
# Instant events don't have a training phase
365+
if record.phase != "i":
366+
args = {"name": record.training_phase}
367+
json_dict = {"name": "process_name", "ph": "M", "pid": tensor_idx, "args": args}
368+
self._writer.write(json.dumps(json_dict) + ",\n")
321369

322-
args = {"sort_index": tensor_idx}
323-
json_dict = {"name": "process_sort_index", "ph": "M", "pid": tensor_idx, "args": args}
324-
self._writer.write(json.dumps(json_dict) + ",\n")
370+
args = {"sort_index": tensor_idx}
371+
json_dict = {
372+
"name": "process_sort_index",
373+
"ph": "M",
374+
"pid": tensor_idx,
375+
"args": args,
376+
}
377+
self._writer.write(json.dumps(json_dict) + ",\n")
325378

326379
self.is_first = False
327380

@@ -366,6 +419,7 @@ def close(self):
366419
new_file_name = TraceFileLocation().get_file_location(
367420
base_dir=self._profiler_config_parser.config.local_path,
368421
timestamp=self.last_event_end_time_in_us,
422+
suffix=self.suffix,
369423
)
370424
ensure_dir(new_file_name)
371425
os.rename(self.name(), new_file_name)
@@ -378,6 +432,8 @@ def name(self):
378432
self._profiler_config_parser.config.local_path
379433
+ "/framework/"
380434
+ self.node_id
435+
+ "_"
436+
+ self.suffix
381437
+ SMDEBUG_TEMP_PATH_SUFFIX
382438
)
383439

smdebug/profiler/algorithm_metrics_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from smdebug.profiler.profiler_constants import (
1111
DEFAULT_PREFIX,
1212
ENV_TIME_BUFFER,
13-
HOROVODTIMELINE_PREFIX,
13+
HOROVODTIMELINE_SUFFIX,
1414
MODELTIMELINE_SUFFIX,
1515
PYTHONTIMELINE_SUFFIX,
1616
TENSORBOARDTIMELINE_SUFFIX,
@@ -122,7 +122,7 @@ def _get_event_parser(self, filename):
122122
return self._SMEventsParser
123123
if TENSORBOARDTIMELINE_SUFFIX in filename:
124124
return self._TBEventsParser
125-
if HOROVODTIMELINE_PREFIX in filename:
125+
if HOROVODTIMELINE_SUFFIX in filename:
126126
return self._HorovordEventsParser
127127

128128
def _get_timestamp_from_filename(self, event_file):

0 commit comments

Comments
 (0)