Skip to content

Commit bc3e39e

Browse files
author
Andrew Gu
committed
Update on "Refactored activation checkpointing"
[ghstack-poisoned]
2 parents 87412e7 + 15b8221 commit bc3e39e

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,19 @@ def checkpoint_wrapper(module: torch.nn.Module, ac_config):
4646
raise ValueError(
4747
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
4848
)
49+
4950
if ac_config.mode == "full":
5051
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
51-
elif ac_config.mode == "selective" and ac_config.selective_ac_option == "op":
52+
53+
assert ac_config.mode == "selective", f"{ac_config.mode}"
54+
use_op_sac = ac_config.selective_ac_option == "op"
55+
use_layer_sac = ac_config.selective_ac_option.isdigit()
56+
if not use_op_sac and not use_layer_sac:
57+
raise ValueError(
58+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
59+
f"Valid options: 'op' or a positive int representing layer frequency"
60+
)
61+
if use_op_sac:
5262
from torch.utils.checkpoint import (
5363
CheckpointPolicy,
5464
create_selective_checkpoint_contexts,
@@ -81,7 +91,7 @@ def selective_checkpointing_context_fn():
8191
context_fn=selective_checkpointing_context_fn,
8292
preserve_rng_state=False,
8393
)
84-
elif ac_config.mode == "selective" and ac_config.selective_ac_option.isdigit():
94+
elif use_layer_sac:
8595
# Checkpoint every `ac_freq` of the modules passed to this function
8696
ac_freq = int(ac_config.selective_ac_option)
8797
if ac_freq <= 0:

0 commit comments

Comments
 (0)