diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 15166aca0d..e7806f07ad 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -411,7 +411,7 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ - layout: Layout = PlainLayout() + aq_layout: Layout = PlainLayout() @classmethod def from_float(cls, weight): @@ -442,7 +442,7 @@ def get_weight_block_size(x): target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - _layout = cls.layout + _layout = cls.aq_layout block_size = get_weight_block_size(weight) weight = to_affine_quantized_intx( @@ -616,12 +616,13 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ group_size: int = 32 - layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) + # can't override the `layout` attribute + aq_layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) @classmethod def from_float(cls, weight): group_size = cls.group_size - _layout = cls.layout + _layout = cls.aq_layout if weight.shape[-1] % group_size != 0: return weight @@ -681,7 +682,7 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( AQInt4G32WeightOnlyQuantizedLinearWeight ): group_size: int = 128 - layout: Layout = MarlinSparseLayout() + aq_layout: Layout = MarlinSparseLayout() class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):