diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 114f6b0ab4..41a26b62eb 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -758,7 +758,7 @@ def fold_bn_weights_into_conv_node( # since the node refers to a mutating op. Here we still need to call DCE first # to get rid of the unused getitem nodes that consume the BN node. m.graph.eliminate_dead_code() - if len(bn_node.users) == 0: + if not bn_node._erased and len(bn_node.users) == 0: m.graph.erase_node(bn_node)