diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 3551214d7e..30c2a39c26 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -253,16 +253,24 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + w_vals_int8 = weight_tensor.tensor_impl.int_data scale = weight_tensor.tensor_impl.scale - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) + try: + y = torch.ops.aten._weight_int8pack_mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8, + scale.to(input_tensor.dtype), + ) + except Exception: + w_vals_int8_t = w_vals_int8.t() + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) if bias is not None: - y += bias.to(m.dtype) + y += bias.to(input_tensor.dtype) return y