File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -46,9 +46,19 @@ def checkpoint_wrapper(module: torch.nn.Module, ac_config):
4646 raise ValueError (
4747 f"Invalid AC mode: { ac_config .mode } . Valid modes: { valid_ac_modes } "
4848 )
49+
4950 if ac_config .mode == "full" :
5051 return ptd_checkpoint_wrapper (module , preserve_rng_state = False )
51- elif ac_config .mode == "selective" and ac_config .selective_ac_option == "op" :
52+
53+ assert ac_config .mode == "selective" , f"{ ac_config .mode } "
54+ use_op_sac = ac_config .selective_ac_option == "op"
55+ use_layer_sac = ac_config .selective_ac_option .isdigit ()
56+ if not use_op_sac and not use_layer_sac :
57+ raise ValueError (
58+ f"Invalid selective AC option: { ac_config .selective_ac_option } . "
59+ f"Valid options: 'op' or a positive int representing layer frequency"
60+ )
61+ if use_op_sac :
5262 from torch .utils .checkpoint import (
5363 CheckpointPolicy ,
5464 create_selective_checkpoint_contexts ,
@@ -81,7 +91,7 @@ def selective_checkpointing_context_fn():
8191 context_fn = selective_checkpointing_context_fn ,
8292 preserve_rng_state = False ,
8393 )
84- elif ac_config . mode == "selective" and ac_config . selective_ac_option . isdigit () :
94+ elif use_layer_sac :
8595 # Checkpoint every `ac_freq` of the modules passed to this function
8696 ac_freq = int (ac_config .selective_ac_option )
8797 if ac_freq <= 0 :
You can’t perform that action at this time.
0 commit comments