diff --git a/torchao/utils.py b/torchao/utils.py index 2a5857460f..d75dfd22fc 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -525,6 +525,8 @@ def _(func, types, args, kwargs): ) def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: + if not (hasattr(self, "tensor_data_names") and hasattr(src, "tensor_data_names")): + return False _tensor_shape_match = all( getattr(self, t_name).shape == getattr(src, t_name).shape for t_name in self.tensor_data_names @@ -564,11 +566,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: def _(func, types, args, kwargs): self = args[0] src = args[1] - if _same_metadata(self, src): + has_self_meta = hasattr(self, "tensor_data_names") + has_src_meta = hasattr(src, "tensor_data_names") + if has_self_meta and has_src_meta and _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return + if not (has_self_meta and has_src_meta): + with torch._C._DisableTorchDispatch(): + return func(*args, **kwargs) raise ValueError( f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" )