From b72c5faf706e669549d7580a86d513f5294e49a8 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 25 Mar 2025 11:41:53 -0700 Subject: [PATCH] [AFQ] Optimize tensor_flatten for runtime [ghstack-poisoned] --- torchao/dtypes/affine_quantized_tensor.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 24ac56fc7f..758032e4b0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -165,14 +165,18 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor return dq def __tensor_flatten__(self): - return ["tensor_impl"], [ - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - self.dtype, - ] + # This is used in rumtime to unwrap AffineQuantizedTensor activations. + # AffineQuantizedTensor has __torch_function__ override: + # Each getattr will go through it, which is up to 10x slower than default attribute access. + with torch._C.DisableTorchFunctionSubclass(): + return ["tensor_impl"], [ + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + self.dtype, + ] @classmethod def __tensor_unflatten__(