Skip to content

Commit fd9155c

Browse files
committed
Add loss inputs for PT functional loss
1 parent 1f2b582 commit fd9155c

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

smdebug/pytorch/hook.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,15 @@ def forward_pre_hook(self, module, inputs):
137137
self.export_collections()
138138
self.exported_collections = True
139139

140-
def record_tensor_value(self, tensor_name: str, tensor_value: torch.Tensor) -> None:
140+
def record_tensor_value(self, tensor_name: str, tensor_value: torch.Tensor, inputs=None) -> None:
141141
"""Used for registering functional directly, such as F.mse_loss()."""
142142
assert isinstance(
143143
tensor_value, torch.Tensor
144144
), f"tensor_value={tensor_value} must be torch.Tensor"
145145

146+
if inputs:
147+
self._write_inputs(tensor_name, inputs)
148+
146149
self._write_outputs(tensor_name, tensor_value)
147150

148151
# This hook is invoked by trainer after running the forward pass.

tests/pytorch/test_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create_net_and_train(out_dir, n_steps, use_loss_module=False, use_loss_funct
5858
loss = criterion(outputs, labels)
5959
if use_loss_functional:
6060
loss = F.cross_entropy(outputs, labels)
61-
hook.record_tensor_value("nll_loss", tensor_value=loss)
61+
hook.record_tensor_value("nll_loss", tensor_value=loss, inputs=[outputs, labels])
6262
loss.backward()
6363
optimizer.step()
6464

@@ -78,8 +78,8 @@ def test_register_loss_functional(out_dir):
7878
loss_tensor = trial.tensor("nll_loss_output_0")
7979

8080
# Capture ['nll_loss_output_0']
81-
assert len(trial.tensor_names()) == 1
82-
assert len(loss_coll.tensor_names) == 1
81+
assert len(trial.tensor_names()) == 3
82+
assert len(loss_coll.tensor_names) == 3
8383

8484
# Loss should be logged for all the steps since passed `available_steps = range(n_steps)`
8585
assert len(trial.steps()) == n_steps

0 commit comments

Comments
 (0)