Skip to content

Commit 69540ff

Browse files
author
Fabio Ferreira
committed
fix: revert batch norm changes
1 parent f673ca1 commit 69540ff

File tree

1 file changed

+0
-9
lines changed

1 file changed

+0
-9
lines changed

monai/networks/nets/unet.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,10 @@ class _ActivationCheckpointWrapper(nn.Module):
3030
"""Apply activation checkpointing to the wrapped module during training."""
3131
def __init__(self, module: nn.Module) -> None:
3232
super().__init__()
33-
# Pre-detect BatchNorm presence for fast path
34-
self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())
3533
self.module = module
3634

3735
def forward(self, x: torch.Tensor) -> torch.Tensor:
3836
if self.training and torch.is_grad_enabled() and x.requires_grad:
39-
if self._has_bn:
40-
warnings.warn(
41-
"Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
42-
"running statistics during recomputation.",
43-
RuntimeWarning,
44-
)
45-
return cast(torch.Tensor, self.module(x))
4637
try:
4738
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
4839
except TypeError:

0 commit comments

Comments
 (0)