File tree Expand file tree Collapse file tree 1 file changed +0
-9
lines changed Expand file tree Collapse file tree 1 file changed +0
-9
lines changed Original file line number Diff line number Diff line change @@ -30,19 +30,10 @@ class _ActivationCheckpointWrapper(nn.Module):
30
30
"""Apply activation checkpointing to the wrapped module during training."""
31
31
def __init__ (self , module : nn .Module ) -> None :
32
32
super ().__init__ ()
33
- # Pre-detect BatchNorm presence for fast path
34
- self ._has_bn = any (isinstance (m , nn .modules .batchnorm ._BatchNorm ) for m in module .modules ())
35
33
self .module = module
36
34
37
35
def forward (self , x : torch .Tensor ) -> torch .Tensor :
38
36
if self .training and torch .is_grad_enabled () and x .requires_grad :
39
- if self ._has_bn :
40
- warnings .warn (
41
- "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
42
- "running statistics during recomputation." ,
43
- RuntimeWarning ,
44
- )
45
- return cast (torch .Tensor , self .module (x ))
46
37
try :
47
38
return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
48
39
except TypeError :
You can’t perform that action at this time.
0 commit comments