From a260dc8dcdfb13cbd02d7285ad13bb8db0d26e05 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 13 Aug 2025 10:53:14 -0700 Subject: [PATCH 01/14] added marlin sparse to packing format, inital commit --- .../quantize_/common/packing_format.py | 5 + .../int4/int4_marlin_sparse_tensor.py | 533 ++++++++++++++++++ 2 files changed, 538 insertions(+) create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 77ed2790c5..96a29d2990 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -30,3 +30,8 @@ class PackingFormat(str, Enum): preshuffled is referring to the preshuffled format used by fbgemm kernels """ PRESHUFFLED = "preshuffled" + + """ + marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization + """ + MARLIN_SPARSE = "marlin_sparse" diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py new file mode 100644 index 0000000000..ac6225ac58 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -0,0 +1,533 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, +) + +from torchao.utils import fill_defaults, TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor + + +__all__ = [ + "Int4MarlinSparseTensor", +] + +aten = torch.ops.aten + + +try: + from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 +except: + int4_row_quantize_zp = None + pack_int4 = None + + +class Int4MarlinSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape", "meta", "num_bits"] + + def __new__(cls, qdata, scale, shape): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, scale, zero_point, meta, block_size, shape, num_bits): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.meta = meta + self.block_size = block_size + self.shape = shape + self.num_bits = num_bits + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @staticmethod + def pre_process(input: torch.Tensor) -> torch.Tensor: + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format + - 2º: tensor is injected with 2:4 sparsity + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + + Args: + input (torch.Tensor): the input tensor to preprocess + + Returns: + torch.Tensor: the preprocessed tensor + """ + from torchao.sparsity.marlin import inject_24 # avoid circular import + + input_t = input.t() + w_24, _ = inject_24(input_t, *input_t.shape) + return w_24.t() + + @classmethod + def from_plain( + cls, + qdata: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, + ): + from torchao.sparsity.marlin import const, pack_to_marlin_24 + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w_24 = int_data.t() + # addressing the case when scale has dimension 1, happens when + # weight_shape[-1] == group_size == 128 + if scale.ndim == 1: + scale = scale.reshape(scale.shape[0], -1) + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: Tuple[int], # quantize functions needs it as tuple not list + ): + preprocessed_w = cls.pre_process(w) + # assert ( + # len(block_size) == w.ndim + # ), f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + # if int4_row_quantize_zp is None: + # raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + # assert ( + # all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1 + # ), "Only groupwise quant is supported right now" + + # group_size = block_size[-1] + # original_shape = w.shape + + assert ( + block_size == 128 or block_size == w.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {block_size}" + + quant_min = 0 + quant_max = 15 + target_dtype = torch.int4 + + scale, zero_point = choose_qparams_affine( + input=preprocessed_w, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int4, # ??? i think its int4 because we wanna convert to int4 idk but in the old version i think its int32 + quant_min=quant_min, + quant_max=quant_max, + eps=1e-6, + # leaving scale dtype and zero point dtype as default for now idk + ) + + wq = quantize_affine( + input=preprocessed_w, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) + + +implements = Int4MarlinSparseTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert ( + weight_tensor.zero_point.is_contiguous() + ), "Expected zero_point to be contiguous" + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]) + res = torch.ops.fbgemm.bf16i4bf16_rowwise( + input_tensor, + weight_tensor.qdata, + weight_tensor.scale, + weight_tensor.zero_point, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + return res + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert ( + weight_tensor.zero_point.is_contiguous() + ), "Expected zero_point to be contiguous" + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( + input_tensor, + weight_tensor.qdata, + weight_tensor.scale, + weight_tensor.zero_point, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + return res + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + qdata has dimension: (N, K/2) + scale and zero_point has dimension: (K/groups, N) + + dim, start, end, step are args that's referring to the original tensor shape + which is (N, K), and we need to map that to the transformed weight shape of qdata, + scale and zero_point + + when dim == 0: we do a slice on qdata dim 0, and on dim 1 of scale and zero_point, + also adjust the start and end indexes based on the ratio between original shape and the shape + of qdata and scale/zero_point + + when dim == 1: we do a slice on qdata dim 1 and dim 0 of scale and zero_point and do the + same adjustment based on ratio + + Note that we need to call slice on the qdata, scale and zero_point directly because slice + is an operation that need to preserve aliasing, see `test_slice_preserves_aliasing` and + `test_slice_and_copy_similar_to_vllm` in `test_int4_tensor` for more details + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert ( + self.qdata.ndim == 2 + ), f"Expected packed weight to have dim 2, got {self.qdata.dim}" + N, K_by_2 = self.qdata.shape + sz_dim0, sz_dim1 = self.scale.shape + + data_len = self.shape[dim] + + if dim == 0: + pw_len = N + sz_len = sz_dim1 + else: + pw_len = K_by_2 + sz_len = sz_dim0 + + sz_dim = 1 - dim + if pw_len == 0 or sz_len == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + self.__class__( + self.qdata, + self.scale, + self.zero_point, + block_size=self.block_size, + shape=self.shape, + ), + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice.Tensor(self.qdata, dim, start_pw, end_pw, step) + scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) + zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) + packed_shape0, packed_shape1 = qdata.shape + new_shape = (packed_shape0, packed_shape1 * 2) + new = self.__class__( + qdata, scale, zero_point, block_size=self.block_size, shape=new_shape + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple Int4 quantized tensors + + For Int4Tensor, we need to concatenate qdata, scale, and zero_point tensors. + The concatenation behavior depends on the dimension and block_size configuration. + + If the concatenation dimension is not the same as the packed dimension, then we can just concatenate the + qdata, scale and zero_point directly, note that scale and zero_point has reversed dimension order in 2D + If the concatention dimension is the same as block_size, we'll check that scales from all + tensors are equal and use the first scale + """ + tensors, dim = fill_defaults(args, 2, [[], 0]) + if not tensors: + raise ValueError("Cannot concatenate empty list of tensors") + + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + # Validate that all tensors have compatible properties + for i in range(1, len(tensors)): + assert tensor_0.qdata.ndim == tensors[i].qdata.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.zero_point.ndim == tensors[i].zero_point.ndim + assert tensor_0.block_size == tensors[i].block_size + + qdatas = [t.qdata for t in tensors] + scales = [t.scale for t in tensors] + zero_points = [t.zero_point for t in tensors] + + # Concatenate the quantized data along the specified dimension + cat_qdata = aten.cat.default(qdatas, dim=dim) + + # if concatenation happens in the non-packed dimension, we need to concatenation + # scale and zero_point + if tensor_0.block_size[dim] == 1: + # For scale and zero_point, the concatenation dimension depends on the dimension + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, K/group_size, N) for 3D + if cat_qdata.ndim == 2: # 2D case + sz_dim = ( + 1 - dim + ) # If concatenating dim 0 (N), use dim 1 for scale; if dim 1 (K), use dim 0 + else: # 3D case + assert cat_qdata.ndim == 3 + if dim in [1, 2]: + sz_dim = 3 - dim + else: + sz_dim = dim + + cat_scale = aten.cat.default(scales, dim=sz_dim) + cat_zero_point = aten.cat.default(zero_points, dim=sz_dim) + + else: + # if concatenation happens in the packed dimension, we just need to verify + # that all scale and zero_points match + for i in range(1, len(tensors)): + assert torch.equal(tensor_0.scale, tensors[i].scale) + assert torch.equal(tensor_0.zero_point, tensors[i].zero_point) + cat_scale = scales[0] + cat_zero_point = zero_points[0] + + # Calculate new shape based on the concatenated qdata shape + new_shape = list(cat_qdata.shape) + new_shape[-1] *= 2 + new_shape = list(new_shape) + + new = Int4Tensor( + cat_qdata, + cat_scale, + cat_zero_point, + tensor_0.block_size, + new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + + # Transpose the quantized data + qdata = self.qdata.transpose(dim0, dim1).contiguous() + if self.scale.ndim == 3: + # since scale/zero_point dimension order is different + # (B, K/group_size, N), we'll need to remap the dim + remapped_dim0 = dim0 + if dim0 in [1, 2]: + remapped_dim0 = 3 - dim0 + + remapped_dim1 = dim1 + if dim1 in [1, 2]: + remapped_dim1 = 3 - dim1 + + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + else: + assert scale.ndim == 2, f"Only support ndim == 2 or 3, got: {scale.ndim}" + remapped_dim0 = 1 - dim0 + remapped_dim1 = 1 - dim1 + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + + # Update block_size by swapping the dimensions + block_size = self.block_size.copy() + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + + # Update shape by swapping the dimensions + new_shape = list(self.shape) + new_shape[dim0], new_shape[dim1] = new_shape[dim1], new_shape[dim0] + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, size = args + original_shape = self.shape + original_packing_dim = None + for i in range(len(original_shape)): + if original_shape[i] == (self.qdata.shape[i] * 2): + original_packing_dim = i + assert original_packing_dim is not None, "Didn't find a packing_dim" + + if len(original_shape) == 3 and len(size) == 2: + # only support combining the dim 0 and dim1 together + assert ( + original_shape[-1] == size[-1] + ), f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + # the dim that int4 packing happens + if original_packing_dim in [0, 1]: + packing_dim = 0 + else: + packing_dim = 1 + + block_size = self.block_size.copy() + block_size = [block_size[0] * block_size[1], block_size[2]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + # scale and zero_point have reversed dimensions + sz_shape[0], sz_shape[1] = sz_shape[1], sz_shape[0] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == 2 and len(size) == 3: + # only support extending the dim 0 to 2, `t.unflatten(0, (num_experts, -1))` + assert ( + original_shape[-1] == size[-1] + ), f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + if original_packing_dim == 0: + packing_dim = 1 + else: + # original_packing_dim is 1 + packing_dim = 2 + + block_size = self.block_size.copy() + block_size = [1, block_size[0], block_size[1]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + + # scale and zero_point have reversed dimensions + sz_shape[1], sz_shape[2] = sz_shape[2], sz_shape[1] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == len(size): + assert all( + x == y or y == -1 for x, y in zip(original_shape, size) + ), f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" + packing_dim = original_packing_dim + block_size = self.block_size + else: + assert ( + len(original_shape) == 2 and len(size) == 3 + ), f"Only support reshaping from 2D to 3D or from 3D to 2D or between sam ranges, requested: reshaping from {original_shape} to {size}" + + shape = list(qdata.shape) + for i in range(len(shape)): + if i == packing_dim: + shape[i] *= 2 + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.squeeze.dim) +def _(func, types, args, kwargs): + self, dim = args + + # Squeeze qdata + qdata = self.qdata.squeeze(dim=dim) + + # For scale and zero_point, we need to squeeze based on the tensor layout + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, N, K/group_size) for 3D + if self.qdata.ndim == 2: # 2D case + # qdata is (N, K/2), scale/zero_point is (K/group_size, N) + # When squeezing qdata dim, we need to squeeze scale/zero_point in reverse order + sz_dim = 1 - dim + else: # 3D case + # qdata is (B, N, K/2), scale/zero_point is (B, N, K/group_size) + sz_dim = dim + + scale = self.scale.squeeze(dim=sz_dim) + zero_point = self.zero_point.squeeze(dim=sz_dim) + + # Update block_size by removing the squeezed dimension + new_block_size = list(self.block_size) + if len(qdata.shape) < len(new_block_size): + new_block_size.pop(dim) + + # Update shape by removing the squeezed dimension + new_shape = list(self.shape) + if len(qdata.shape) < len(new_shape): + assert new_shape[dim] == 1 + new_shape.pop(dim) + + new = Int4Tensor( + qdata, + scale, + zero_point, + new_block_size, + new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +Int4Tensor.__module__ = "torchao.quantization" + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with Int4Tensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4Tensor]) From 505e21f4c52a119d25788d367d2d3539b5bda6b7 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 13 Aug 2025 10:57:03 -0700 Subject: [PATCH 02/14] deleting unnecessary functions --- .../int4/int4_marlin_sparse_tensor.py | 356 ------------------ 1 file changed, 356 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index ac6225ac58..0f5c7bed39 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -175,359 +175,3 @@ def _(func, types, args, kwargs): if bias is not None: res = res + bias return res - - -@implements(torch.bmm) -def _(func, types, args, kwargs): - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" - assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" - assert ( - weight_tensor.zero_point.is_contiguous() - ), "Expected zero_point to be contiguous" - - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( - input_tensor, - weight_tensor.qdata, - weight_tensor.scale, - weight_tensor.zero_point, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - return res - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - qdata has dimension: (N, K/2) - scale and zero_point has dimension: (K/groups, N) - - dim, start, end, step are args that's referring to the original tensor shape - which is (N, K), and we need to map that to the transformed weight shape of qdata, - scale and zero_point - - when dim == 0: we do a slice on qdata dim 0, and on dim 1 of scale and zero_point, - also adjust the start and end indexes based on the ratio between original shape and the shape - of qdata and scale/zero_point - - when dim == 1: we do a slice on qdata dim 1 and dim 0 of scale and zero_point and do the - same adjustment based on ratio - - Note that we need to call slice on the qdata, scale and zero_point directly because slice - is an operation that need to preserve aliasing, see `test_slice_preserves_aliasing` and - `test_slice_and_copy_similar_to_vllm` in `test_int4_tensor` for more details - """ - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - - assert ( - self.qdata.ndim == 2 - ), f"Expected packed weight to have dim 2, got {self.qdata.dim}" - N, K_by_2 = self.qdata.shape - sz_dim0, sz_dim1 = self.scale.shape - - data_len = self.shape[dim] - - if dim == 0: - pw_len = N - sz_len = sz_dim1 - else: - pw_len = K_by_2 - sz_len = sz_dim0 - - sz_dim = 1 - dim - if pw_len == 0 or sz_len == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - self.__class__( - self.qdata, - self.scale, - self.zero_point, - block_size=self.block_size, - shape=self.shape, - ), - ) - - pw_ratio = data_len / pw_len - start_pw = int(start / pw_ratio) - end_pw = int(end / pw_ratio) - - sz_ratio = data_len / sz_len - start_sz = int(start / sz_ratio) - end_sz = int(end / sz_ratio) - - qdata = aten.slice.Tensor(self.qdata, dim, start_pw, end_pw, step) - scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) - zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) - packed_shape0, packed_shape1 = qdata.shape - new_shape = (packed_shape0, packed_shape1 * 2) - new = self.__class__( - qdata, scale, zero_point, block_size=self.block_size, shape=new_shape - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.cat.default) -def _(func, types, args, kwargs): - """Concatenate multiple Int4 quantized tensors - - For Int4Tensor, we need to concatenate qdata, scale, and zero_point tensors. - The concatenation behavior depends on the dimension and block_size configuration. - - If the concatenation dimension is not the same as the packed dimension, then we can just concatenate the - qdata, scale and zero_point directly, note that scale and zero_point has reversed dimension order in 2D - If the concatention dimension is the same as block_size, we'll check that scales from all - tensors are equal and use the first scale - """ - tensors, dim = fill_defaults(args, 2, [[], 0]) - if not tensors: - raise ValueError("Cannot concatenate empty list of tensors") - - tensor_0 = tensors[0] - dim = dim % tensor_0.ndim - - # Validate that all tensors have compatible properties - for i in range(1, len(tensors)): - assert tensor_0.qdata.ndim == tensors[i].qdata.ndim - assert tensor_0.scale.ndim == tensors[i].scale.ndim - assert tensor_0.zero_point.ndim == tensors[i].zero_point.ndim - assert tensor_0.block_size == tensors[i].block_size - - qdatas = [t.qdata for t in tensors] - scales = [t.scale for t in tensors] - zero_points = [t.zero_point for t in tensors] - - # Concatenate the quantized data along the specified dimension - cat_qdata = aten.cat.default(qdatas, dim=dim) - - # if concatenation happens in the non-packed dimension, we need to concatenation - # scale and zero_point - if tensor_0.block_size[dim] == 1: - # For scale and zero_point, the concatenation dimension depends on the dimension - # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, K/group_size, N) for 3D - if cat_qdata.ndim == 2: # 2D case - sz_dim = ( - 1 - dim - ) # If concatenating dim 0 (N), use dim 1 for scale; if dim 1 (K), use dim 0 - else: # 3D case - assert cat_qdata.ndim == 3 - if dim in [1, 2]: - sz_dim = 3 - dim - else: - sz_dim = dim - - cat_scale = aten.cat.default(scales, dim=sz_dim) - cat_zero_point = aten.cat.default(zero_points, dim=sz_dim) - - else: - # if concatenation happens in the packed dimension, we just need to verify - # that all scale and zero_points match - for i in range(1, len(tensors)): - assert torch.equal(tensor_0.scale, tensors[i].scale) - assert torch.equal(tensor_0.zero_point, tensors[i].zero_point) - cat_scale = scales[0] - cat_zero_point = zero_points[0] - - # Calculate new shape based on the concatenated qdata shape - new_shape = list(cat_qdata.shape) - new_shape[-1] *= 2 - new_shape = list(new_shape) - - new = Int4Tensor( - cat_qdata, - cat_scale, - cat_zero_point, - tensor_0.block_size, - new_shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.transpose.int) -def _(func, types, args, kwargs): - self, dim0, dim1 = args - - # Transpose the quantized data - qdata = self.qdata.transpose(dim0, dim1).contiguous() - if self.scale.ndim == 3: - # since scale/zero_point dimension order is different - # (B, K/group_size, N), we'll need to remap the dim - remapped_dim0 = dim0 - if dim0 in [1, 2]: - remapped_dim0 = 3 - dim0 - - remapped_dim1 = dim1 - if dim1 in [1, 2]: - remapped_dim1 = 3 - dim1 - - scale = self.scale.transpose(remapped_dim0, remapped_dim1) - zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) - else: - assert scale.ndim == 2, f"Only support ndim == 2 or 3, got: {scale.ndim}" - remapped_dim0 = 1 - dim0 - remapped_dim1 = 1 - dim1 - scale = self.scale.transpose(remapped_dim0, remapped_dim1) - zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) - - # Update block_size by swapping the dimensions - block_size = self.block_size.copy() - block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] - - # Update shape by swapping the dimensions - new_shape = list(self.shape) - new_shape[dim0], new_shape[dim1] = new_shape[dim1], new_shape[dim0] - - new = Int4Tensor( - qdata, - scale, - zero_point, - block_size, - new_shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.view.default) -def _(func, types, args, kwargs): - self, size = args - original_shape = self.shape - original_packing_dim = None - for i in range(len(original_shape)): - if original_shape[i] == (self.qdata.shape[i] * 2): - original_packing_dim = i - assert original_packing_dim is not None, "Didn't find a packing_dim" - - if len(original_shape) == 3 and len(size) == 2: - # only support combining the dim 0 and dim1 together - assert ( - original_shape[-1] == size[-1] - ), f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" - # the dim that int4 packing happens - if original_packing_dim in [0, 1]: - packing_dim = 0 - else: - packing_dim = 1 - - block_size = self.block_size.copy() - block_size = [block_size[0] * block_size[1], block_size[2]] - - qdata_shape = size.copy() - qdata_shape[packing_dim] //= 2 - qdata = self.qdata.reshape(*qdata_shape) - sz_shape = [] - for i in range(len(size)): - sz_shape.append(size[i] // block_size[i]) - # scale and zero_point have reversed dimensions - sz_shape[0], sz_shape[1] = sz_shape[1], sz_shape[0] - - scale = self.scale.reshape(*sz_shape) - zero_point = self.zero_point.reshape(*sz_shape) - elif len(original_shape) == 2 and len(size) == 3: - # only support extending the dim 0 to 2, `t.unflatten(0, (num_experts, -1))` - assert ( - original_shape[-1] == size[-1] - ), f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" - if original_packing_dim == 0: - packing_dim = 1 - else: - # original_packing_dim is 1 - packing_dim = 2 - - block_size = self.block_size.copy() - block_size = [1, block_size[0], block_size[1]] - - qdata_shape = size.copy() - qdata_shape[packing_dim] //= 2 - qdata = self.qdata.reshape(*qdata_shape) - - sz_shape = [] - for i in range(len(size)): - sz_shape.append(size[i] // block_size[i]) - - # scale and zero_point have reversed dimensions - sz_shape[1], sz_shape[2] = sz_shape[2], sz_shape[1] - - scale = self.scale.reshape(*sz_shape) - zero_point = self.zero_point.reshape(*sz_shape) - elif len(original_shape) == len(size): - assert all( - x == y or y == -1 for x, y in zip(original_shape, size) - ), f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" - packing_dim = original_packing_dim - block_size = self.block_size - else: - assert ( - len(original_shape) == 2 and len(size) == 3 - ), f"Only support reshaping from 2D to 3D or from 3D to 2D or between sam ranges, requested: reshaping from {original_shape} to {size}" - - shape = list(qdata.shape) - for i in range(len(shape)): - if i == packing_dim: - shape[i] *= 2 - - new = Int4Tensor( - qdata, - scale, - zero_point, - block_size, - shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.squeeze.dim) -def _(func, types, args, kwargs): - self, dim = args - - # Squeeze qdata - qdata = self.qdata.squeeze(dim=dim) - - # For scale and zero_point, we need to squeeze based on the tensor layout - # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, N, K/group_size) for 3D - if self.qdata.ndim == 2: # 2D case - # qdata is (N, K/2), scale/zero_point is (K/group_size, N) - # When squeezing qdata dim, we need to squeeze scale/zero_point in reverse order - sz_dim = 1 - dim - else: # 3D case - # qdata is (B, N, K/2), scale/zero_point is (B, N, K/group_size) - sz_dim = dim - - scale = self.scale.squeeze(dim=sz_dim) - zero_point = self.zero_point.squeeze(dim=sz_dim) - - # Update block_size by removing the squeezed dimension - new_block_size = list(self.block_size) - if len(qdata.shape) < len(new_block_size): - new_block_size.pop(dim) - - # Update shape by removing the squeezed dimension - new_shape = list(self.shape) - if len(qdata.shape) < len(new_shape): - assert new_shape[dim] == 1 - new_shape.pop(dim) - - new = Int4Tensor( - qdata, - scale, - zero_point, - new_block_size, - new_shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -Int4Tensor.__module__ = "torchao.quantization" - -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4Tensor]) From ac3e430277db2de3aa98b64913b66ab95b697945 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 13 Aug 2025 11:24:45 -0700 Subject: [PATCH 03/14] packing --- .../int4/int4_marlin_sparse_tensor.py | 69 +++++++++++++++++-- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index 0f5c7bed39..af282b38c7 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -37,7 +37,7 @@ class Int4MarlinSparseTensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "shape", "meta", "num_bits"] - def __new__(cls, qdata, scale, shape): + def __new__(cls, qdata, scale, zero_point, meta, block_size, shape, num_bits): kwargs = {} kwargs["device"] = qdata.device kwargs["dtype"] = scale.dtype @@ -80,18 +80,71 @@ def from_plain( cls, qdata: torch.Tensor, scale: torch.Tensor, - zero: torch.Tensor, + zero_point: torch.Tensor, ): - from torchao.sparsity.marlin import const, pack_to_marlin_24 + from torchao.sparsity.marlin import ( + const, + pack_to_marlin_24 + ) - # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # Linear layers are (in_features, out_features) but the qdata that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w_24 = int_data.t() + q_w_24 = qdata.t() # addressing the case when scale has dimension 1, happens when # weight_shape[-1] == group_size == 128 if scale.ndim == 1: scale = scale.reshape(scale.shape[0], -1) + scale_t = scale.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w_24.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w_24.shape + if in_features % 128 != 0 or out_features != 256 == 0: + raise ValueError( + "`in_features` must be divisible by 64 and `out_features` by 256." + ) + + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 + # will require a bit more work to get our current quantization flow to work with it. + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main + num_bits = 4 if torch.max(q_w_24) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + group_size = in_features // scale_t.shape[0] + if group_size == 0: + group_size = in_features + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin 2:4 format + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) + + return cls( + qdata=marlin_24_q_w_comp, + scale=marlin_24_s, + zero_point=zero_point, + meta=meta, + block_size=group_size, + shape=q_w_24.shape, + num_bits=num_bits, + ) + @classmethod def from_hp( cls, @@ -144,6 +197,12 @@ def from_hp( scale = scale.to(w.dtype) zero_point = zero_point.to(w.dtype) + return cls.from_plain( + qdata=wq, + scale=scale, + zero_point=zero_point + ) + implements = Int4MarlinSparseTensor.implements From a8bfed3b9320bd8a5741d407e9bc485b84fd9b3a Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 13 Aug 2025 11:48:28 -0700 Subject: [PATCH 04/14] linear --- .../int4/int4_marlin_sparse_tensor.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index af282b38c7..02a7343d70 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -152,19 +152,6 @@ def from_hp( block_size: Tuple[int], # quantize functions needs it as tuple not list ): preprocessed_w = cls.pre_process(w) - # assert ( - # len(block_size) == w.ndim - # ), f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" - # if int4_row_quantize_zp is None: - # raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - # assert ( - # all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1 - # ), "Only groupwise quant is supported right now" - - # group_size = block_size[-1] - # original_shape = w.shape - assert ( block_size == 128 or block_size == w.shape[-1] ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {block_size}" @@ -206,9 +193,11 @@ def from_hp( implements = Int4MarlinSparseTensor.implements - @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): + from torchao.ops import marlin_24_gemm + from torchao.sparsity.marlin import marlin_24_workspace + input_tensor, weight_tensor, bias = ( args[0], args[1], @@ -220,17 +209,35 @@ def _(func, types, args, kwargs): weight_tensor.zero_point.is_contiguous() ), "Expected zero_point to be contiguous" - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]) - res = torch.ops.fbgemm.bf16i4bf16_rowwise( - input_tensor, - weight_tensor.qdata, - weight_tensor.scale, - weight_tensor.zero_point, + sparse_w_int4 = weight_tensor.qdata + scale = weight_tensor.scale + meta = weight_tensor.meta + original_shape = weight_tensor.shape + num_bits = weight_tensor.num_bits + + # Folds batch dimension into the first dimension + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + + size_m = input_2d.shape[0] + size_n = scale.shape[1] + size_k = input_2d.shape[1] + workspace_24 = marlin_24_workspace(original_shape[1]) + + out = marlin_24_gemm( + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) + + # Unfold the batch dimension + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) + if bias is not None: - res = res + bias - return res + out += bias.to(out.dtype) + return out From b51b091237700be6a293e8ddd17afbb31d7f6a0e Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 13 Aug 2025 14:31:15 -0700 Subject: [PATCH 05/14] add call to from_hp --- torchao/quantization/quant_api.py | 7 +++++++ torchao/quantization/quantize_/workflows/__init__.py | 5 +++++ .../workflows/int4/int4_marlin_sparse_tensor.py | 10 ++++++---- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5d191a7c0e..a0a4c8cba1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -74,6 +74,7 @@ Float8Tensor, Int4PreshuffledTensor, Int4Tensor, + Int4MarlinSparseTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -1068,6 +1069,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight + elif packing_format == PackingFormat.MARLIN_SPARSE: + new_weight = Int4MarlinSparseTensor.from_hp( + weight, + block_size, + ) + return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 98480c2db2..7e0d944b2c 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -9,9 +9,14 @@ Int4Tensor, ) +from .int4.int4_marlin_sparse_tensor import ( + Int4MarlinSparseTensor, +) + __all__ = [ "Int4Tensor", "Int4PreshuffledTensor", + "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index 02a7343d70..862ba8de3f 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import List, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -149,7 +149,7 @@ def from_plain( def from_hp( cls, w: torch.Tensor, - block_size: Tuple[int], # quantize functions needs it as tuple not list + block_size: List[int], ): preprocessed_w = cls.pre_process(w) assert ( @@ -160,10 +160,12 @@ def from_hp( quant_max = 15 target_dtype = torch.int4 + assert(len(block_size) == 1), f"Expected one block size, got {len(block_size)}" + scale, zero_point = choose_qparams_affine( input=preprocessed_w, mapping_type=MappingType.SYMMETRIC, - block_size=block_size, + block_size=(block_size[0],), target_dtype=torch.int4, # ??? i think its int4 because we wanna convert to int4 idk but in the old version i think its int32 quant_min=quant_min, quant_max=quant_max, @@ -173,7 +175,7 @@ def from_hp( wq = quantize_affine( input=preprocessed_w, - block_size=block_size, + block_size=(block_size[0],), scale=scale, zero_point=zero_point, output_dtype=target_dtype, From 641cc71dc7b53faecc806920c505aa41d06a4b25 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 14 Aug 2025 14:27:05 -0700 Subject: [PATCH 06/14] unit test --- .../int4/test_int4_marlin_sparse_tensor.py | 89 +++++++++++++++++++ .../int4/int4_marlin_sparse_tensor.py | 23 +++-- 2 files changed, 99 insertions(+), 13 deletions(-) create mode 100644 test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py new file mode 100644 index 0000000000..81b48dead9 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_90, +) + +BF16_ACT_CONFIG = Int4WeightOnlyConfig( + group_size=128, + packing_format="marlin_sparse", + VERSION=2, +) + +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestInt4MarlinSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @parametrize("config", [BF16_ACT_CONFIG]) + def test_linear(self, config): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + @parametrize("config", [BF16_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @parametrize("config", [BF16_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4MarlinSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index 862ba8de3f..e9d5d4ea06 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -34,23 +34,22 @@ class Int4MarlinSparseTensor(TorchAOBaseTensor): - tensor_data_names = ["qdata", "scale", "zero_point"] - tensor_attribute_names = ["block_size", "shape", "meta", "num_bits"] + tensor_data_names = ["qdata", "scale", "zero_point", "meta"] # meta is a tensor + tensor_attribute_names = ["block_size", "num_bits", "shape"] - def __new__(cls, qdata, scale, zero_point, meta, block_size, shape, num_bits): + def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): kwargs = {} kwargs["device"] = qdata.device kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, qdata, scale, zero_point, meta, block_size, shape, num_bits): + def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): # args need to match lines 37 and 38 self.qdata = qdata self.scale = scale self.zero_point = zero_point self.meta = meta self.block_size = block_size - self.shape = shape self.num_bits = num_bits def _quantization_type(self): @@ -153,20 +152,18 @@ def from_hp( ): preprocessed_w = cls.pre_process(w) assert ( - block_size == 128 or block_size == w.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {block_size}" + block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1] + ), f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" quant_min = 0 quant_max = 15 - target_dtype = torch.int4 - - assert(len(block_size) == 1), f"Expected one block size, got {len(block_size)}" + target_dtype = torch.int32 scale, zero_point = choose_qparams_affine( input=preprocessed_w, mapping_type=MappingType.SYMMETRIC, - block_size=(block_size[0],), - target_dtype=torch.int4, # ??? i think its int4 because we wanna convert to int4 idk but in the old version i think its int32 + block_size=block_size, + target_dtype=target_dtype, # ??? i think its int4 because we wanna convert to int4 idk but in the old version i think its int32 quant_min=quant_min, quant_max=quant_max, eps=1e-6, @@ -175,7 +172,7 @@ def from_hp( wq = quantize_affine( input=preprocessed_w, - block_size=(block_size[0],), + block_size=block_size, scale=scale, zero_point=zero_point, output_dtype=target_dtype, From ae14aa94706d0865e26d65e711cdb5817a95e0d0 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 10:41:51 -0700 Subject: [PATCH 07/14] fix test_linear --- .../int4/test_int4_marlin_sparse_tensor.py | 11 ++++- torchao/quantization/quant_api.py | 2 +- .../quantize_/workflows/__init__.py | 7 ++- .../int4/int4_marlin_sparse_tensor.py | 48 ++++++++----------- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 81b48dead9..47e9ba4a91 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -15,22 +15,26 @@ run_tests, ) +from torchao.dtypes import MarlinSparseLayout from torchao.quantization import ( Int4WeightOnlyConfig, + int4_weight_only, quantize_, ) from torchao.quantization.utils import compute_error +from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, - is_sm_at_least_90, ) BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, packing_format="marlin_sparse", + layout=MarlinSparseLayout(), VERSION=2, ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") class TestInt4MarlinSparseTensor(TestCase): @@ -39,15 +43,17 @@ def setUp(self): @parametrize("config", [BF16_ACT_CONFIG]) def test_linear(self, config): - dtype = torch.bfloat16 + dtype = torch.float16 device = "cuda" input = torch.randn(128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + apply_fake_sparsity(linear) original = linear(input) quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) + @unittest.skip("Fix later") @parametrize("config", [BF16_ACT_CONFIG]) def test_to_device(self, config): for device in self.GPU_DEVICES: @@ -63,6 +69,7 @@ def test_to_device(self, config): quantize_(linear, config) linear.to(device) + @unittest.skip("Fix later") @parametrize("config", [BF16_ACT_CONFIG]) def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0a4c8cba1..ed5abb7333 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -72,9 +72,9 @@ ) from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - Int4MarlinSparseTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 7e0d944b2c..8441382243 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -2,6 +2,9 @@ Float8Tensor, QuantizeTensorToFloat8Kwargs, ) +from .int4.int4_marlin_sparse_tensor import ( + Int4MarlinSparseTensor, +) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) @@ -9,10 +12,6 @@ Int4Tensor, ) -from .int4.int4_marlin_sparse_tensor import ( - Int4MarlinSparseTensor, -) - __all__ = [ "Int4Tensor", "Int4PreshuffledTensor", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index e9d5d4ea06..d910c24d3b 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -5,19 +5,16 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import List import torch -from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import ( - choose_qparams_affine, MappingType, + choose_qparams_affine, quantize_affine, ) - -from torchao.utils import fill_defaults, TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor - +from torchao.utils import TorchAOBaseTensor __all__ = [ "Int4MarlinSparseTensor", @@ -34,7 +31,7 @@ class Int4MarlinSparseTensor(TorchAOBaseTensor): - tensor_data_names = ["qdata", "scale", "zero_point", "meta"] # meta is a tensor + tensor_data_names = ["qdata", "scale", "zero_point", "meta"] # meta is a tensor tensor_attribute_names = ["block_size", "num_bits", "shape"] def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): @@ -44,7 +41,9 @@ def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): # args need to match lines 37 and 38 + def __init__( + self, qdata, scale, zero_point, meta, block_size, num_bits, shape + ): # args need to match lines 37 and 38 self.qdata = qdata self.scale = scale self.zero_point = zero_point @@ -81,10 +80,7 @@ def from_plain( scale: torch.Tensor, zero_point: torch.Tensor, ): - from torchao.sparsity.marlin import ( - const, - pack_to_marlin_24 - ) + from torchao.sparsity.marlin import const, pack_to_marlin_24 # Linear layers are (in_features, out_features) but the qdata that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. @@ -120,9 +116,9 @@ def from_plain( group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert ( - group_size <= in_features - ), "Group size must be less than or equal to in_features." + assert group_size <= in_features, ( + "Group size must be less than or equal to in_features." + ) if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -151,9 +147,9 @@ def from_hp( block_size: List[int], ): preprocessed_w = cls.pre_process(w) - assert ( - block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1] - ), f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" + assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( + f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" + ) quant_min = 0 quant_max = 15 @@ -163,11 +159,10 @@ def from_hp( input=preprocessed_w, mapping_type=MappingType.SYMMETRIC, block_size=block_size, - target_dtype=target_dtype, # ??? i think its int4 because we wanna convert to int4 idk but in the old version i think its int32 + target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=1e-6, - # leaving scale dtype and zero point dtype as default for now idk ) wq = quantize_affine( @@ -183,15 +178,12 @@ def from_hp( scale = scale.to(w.dtype) zero_point = zero_point.to(w.dtype) - return cls.from_plain( - qdata=wq, - scale=scale, - zero_point=zero_point - ) + return cls.from_plain(qdata=wq, scale=scale, zero_point=zero_point) implements = Int4MarlinSparseTensor.implements + @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): from torchao.ops import marlin_24_gemm @@ -204,9 +196,9 @@ def _(func, types, args, kwargs): ) assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" - assert ( - weight_tensor.zero_point.is_contiguous() - ), "Expected zero_point to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) sparse_w_int4 = weight_tensor.qdata scale = weight_tensor.scale From cbd1bae032b0da28c7a4b46df66991e4f64647c6 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 10:43:05 -0700 Subject: [PATCH 08/14] formatting --- .../quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 47e9ba4a91..1caac8f6a1 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -18,7 +18,6 @@ from torchao.dtypes import MarlinSparseLayout from torchao.quantization import ( Int4WeightOnlyConfig, - int4_weight_only, quantize_, ) from torchao.quantization.utils import compute_error From 9f2ae7ceb5449a5d71b63cc3b9071991f28eaf6b Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 10:49:55 -0700 Subject: [PATCH 09/14] remove comments --- .../quantize_/workflows/int4/int4_marlin_sparse_tensor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index d910c24d3b..6dde91992f 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -31,7 +31,7 @@ class Int4MarlinSparseTensor(TorchAOBaseTensor): - tensor_data_names = ["qdata", "scale", "zero_point", "meta"] # meta is a tensor + tensor_data_names = ["qdata", "scale", "zero_point", "meta"] tensor_attribute_names = ["block_size", "num_bits", "shape"] def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): @@ -41,9 +41,7 @@ def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__( - self, qdata, scale, zero_point, meta, block_size, num_bits, shape - ): # args need to match lines 37 and 38 + def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): self.qdata = qdata self.scale = scale self.zero_point = zero_point From 30b23f3de20050842433e61f439d5dd24ebc3aff Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 11:07:25 -0700 Subject: [PATCH 10/14] update VERSION to version --- .../quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 1caac8f6a1..714bdefbbf 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -30,7 +30,7 @@ group_size=128, packing_format="marlin_sparse", layout=MarlinSparseLayout(), - VERSION=2, + version=2, ) From dffd0e0fa6a3b9e1694d470f59d58591baf9b123 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 11:21:15 -0700 Subject: [PATCH 11/14] fix module path unit test --- .../int4/test_int4_marlin_sparse_tensor.py | 5 +---- torchao/quantization/__init__.py | 2 ++ .../workflows/int4/int4_marlin_sparse_tensor.py | 13 ++++++------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 714bdefbbf..52df5ba89a 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -15,7 +15,6 @@ run_tests, ) -from torchao.dtypes import MarlinSparseLayout from torchao.quantization import ( Int4WeightOnlyConfig, quantize_, @@ -29,7 +28,6 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, packing_format="marlin_sparse", - layout=MarlinSparseLayout(), version=2, ) @@ -68,11 +66,10 @@ def test_to_device(self, config): quantize_(linear, config) linear.to(device) - @unittest.skip("Fix later") @parametrize("config", [BF16_ACT_CONFIG]) def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, config) + quantize_(linear.cuda(), config) self.assertEqual( str(type(linear.weight)), "", diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 282c18bccb..8e98e55178 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -90,6 +90,7 @@ ) from .quantize_.workflows import ( Float8Tensor, + Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, ) @@ -159,6 +160,7 @@ # tensor subclasses "Int4Tensor", "Int4PreshuffledTensor", + "Int4MarlinSparseTensor", "Float8Tensor", # smooth quant - subject to change "get_scale", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index 6dde91992f..de51b7e49d 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -23,13 +23,6 @@ aten = torch.ops.aten -try: - from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 -except: - int4_row_quantize_zp = None - pack_int4 = None - - class Int4MarlinSparseTensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point", "meta"] tensor_attribute_names = ["block_size", "num_bits", "shape"] @@ -230,3 +223,9 @@ def _(func, types, args, kwargs): if bias is not None: out += bias.to(out.dtype) return out + + +Int4MarlinSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4MarlinSparseTensor]) From 45e8a9e14a84628c0f6b37834a10e94392a8f4e7 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 11:51:38 -0700 Subject: [PATCH 12/14] adding sizes to linear unit test --- .../int4/test_int4_marlin_sparse_tensor.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 52df5ba89a..0147fd5eb5 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -39,11 +39,22 @@ def setUp(self): self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @parametrize("config", [BF16_ACT_CONFIG]) - def test_linear(self, config): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + def test_linear(self, config, sizes): dtype = torch.float16 device = "cuda" - input = torch.randn(128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + apply_fake_sparsity(linear) original = linear(input) quantize_(linear, config) From ebbd3ab1ca8dcb75c9c6adc089784242342f08f9 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 12:09:28 -0700 Subject: [PATCH 13/14] move pre_process and from_plain to from_hp --- .../int4/int4_marlin_sparse_tensor.py | 105 ++++++++---------- 1 file changed, 45 insertions(+), 60 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py index de51b7e49d..d4a4f147da 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -45,37 +45,62 @@ def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" - @staticmethod - def pre_process(input: torch.Tensor) -> torch.Tensor: + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.marlin import ( + const, + inject_24, # avoid circular import + pack_to_marlin_24, + ) + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - 2º: tensor is injected with 2:4 sparsity - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + """ - Args: - input (torch.Tensor): the input tensor to preprocess + w_t = w.t() + w_24, _ = inject_24(w_t, *w_t.shape) + preprocessed_w = w_24.t() - Returns: - torch.Tensor: the preprocessed tensor - """ - from torchao.sparsity.marlin import inject_24 # avoid circular import + assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( + f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" + ) - input_t = input.t() - w_24, _ = inject_24(input_t, *input_t.shape) - return w_24.t() + quant_min = 0 + quant_max = 15 + target_dtype = torch.int32 - @classmethod - def from_plain( - cls, - qdata: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - ): - from torchao.sparsity.marlin import const, pack_to_marlin_24 + scale, zero_point = choose_qparams_affine( + input=preprocessed_w, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=1e-6, + ) + + wq = quantize_affine( + input=preprocessed_w, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) # Linear layers are (in_features, out_features) but the qdata that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w_24 = qdata.t() + q_w_24 = wq.t() # addressing the case when scale has dimension 1, happens when # weight_shape[-1] == group_size == 128 if scale.ndim == 1: @@ -131,46 +156,6 @@ def from_plain( num_bits=num_bits, ) - @classmethod - def from_hp( - cls, - w: torch.Tensor, - block_size: List[int], - ): - preprocessed_w = cls.pre_process(w) - assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( - f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" - ) - - quant_min = 0 - quant_max = 15 - target_dtype = torch.int32 - - scale, zero_point = choose_qparams_affine( - input=preprocessed_w, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=1e-6, - ) - - wq = quantize_affine( - input=preprocessed_w, - block_size=block_size, - scale=scale, - zero_point=zero_point, - output_dtype=target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) - - scale = scale.to(w.dtype) - zero_point = zero_point.to(w.dtype) - - return cls.from_plain(qdata=wq, scale=scale, zero_point=zero_point) - implements = Int4MarlinSparseTensor.implements From 7de0b1264c7486c842178f34f05e310b5bf0c85f Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 15 Aug 2025 12:14:21 -0700 Subject: [PATCH 14/14] compile test --- .../workflows/int4/test_int4_marlin_sparse_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index 0147fd5eb5..443a2c149b 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -61,6 +61,10 @@ def test_linear(self, config, sizes): quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + @unittest.skip("Fix later") @parametrize("config", [BF16_ACT_CONFIG]) def test_to_device(self, config):