diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index c3227e211c..c61f973d03 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -218,3 +218,16 @@ def test_narrow_similar_to_vllm(self): gemm_kernel_choice=MXGemmKernelChoice.EMULATED, ) self._test_narrow_similar_to_vllm(config) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) + def test_nvfp4_quantize_3d_param_similar_to_vllm(self): + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.WEIGHT_ONLY, + use_triton_kernel=False, + use_dynamic_per_tensor_scale=False, + ) + self._test_quantize_3d_param_similar_to_vllm(config) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 7ab37a0dba..773777d400 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -42,6 +42,7 @@ (torch.float32, (64, 128), False), (torch.bfloat16, (128, 256), False), (torch.bfloat16, (64, 128), True), + (torch.bfloat16, (1, 32, 64), False), ], ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -83,14 +84,20 @@ def assert_sqnr_gt_threshold(orig, new, threshold): f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}" ) - x_nvfp4_t = x_nvfp4.t() + if len(x.shape) == 2: + x_nvfp4_t = x_nvfp4.t() + x_t = x.t() + else: + x_nvfp4_t = x_nvfp4.transpose(-2, -1) + x_t = x.transpose(-2, -1) + x_reconstructed_t = x_nvfp4_t.to_dtype(dtype) - assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0) + assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0) - assert x.t().shape == x_reconstructed_t.shape, ( + assert x_t.shape == x_reconstructed_t.shape, ( f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}" ) - assert x.t().dtype == x_reconstructed_t.dtype, ( + assert x_t.dtype == x_reconstructed_t.dtype, ( f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}" ) @@ -103,6 +110,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): (16, 32), (64, 128), (384, 128), + (1, 32, 64), ], ) @pytest.mark.skipif( @@ -115,8 +123,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): that the _is_swizzled_scales flag is set correctly. """ - M, K = shape - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(*shape, device="cuda", dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales) assert tensor._is_swizzled_scales == is_swizzled_scales @@ -536,36 +543,43 @@ def test_nvfp4_to_copy(): @pytest.mark.parametrize("use_triton_kernel", [False, True]) @pytest.mark.parametrize("is_swizzled_scales", [False, True]) @pytest.mark.parametrize( - "mk", + "shape", ( (128, 64), (128 + 16, 64), (128, 64 + 16), (128 + 16, 64 + 16), + (1, 128, 64), ), ) def test_scale_shape_matches_qdata( - transpose, use_triton_kernel, is_swizzled_scales, mk + transpose, use_triton_kernel, is_swizzled_scales, shape ): if use_triton_kernel and not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel") if use_triton_kernel and not is_swizzled_scales: pytest.skip("triton kernel requires swizzled scales") - M, K = mk - block_size = 16 - x_hp = torch.randn(M, K, device="cuda") + x_hp = torch.randn(*shape, device="cuda") x = NVFP4Tensor.to_nvfp4( x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel ) - m_dim, k_dim = 0, 1 - if transpose: - x_hp = x_hp.t() - x = x.t() - m_dim, k_dim = 1, 0 + if len(shape) == 2: + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + else: + assert len(shape) == 3, "unsupported" + m_dim, k_dim = 1, 2 + if transpose: + x_hp = x_hp.transpose(-2, -1) + x = x.transpose(-2, -1) + m_dim, k_dim = 2, 1 orig_m = x_hp.shape[m_dim] expected_padded_m = orig_m @@ -587,3 +601,17 @@ def test_scale_shape_matches_qdata( assert expected_padded_k == actual_padded_k, ( f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1))) +@pytest.mark.parametrize("is_swizzled_scales", [True, False]) +def test_3d_transpose(dims, is_swizzled_scales): + x_hp = torch.randn(2, 128, 256, device="cuda") + x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales) + x_hp_t = x_hp.transpose(dims[0], dims[1]) + x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1]) + assert x_hp_t.shape == x_nvfp4_t.shape diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index d5bf290589..d0f1b04119 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -168,9 +168,9 @@ def _nvfp4_inference_linear_transform( weight = module.weight - if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: + if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: raise RuntimeError( - f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}" + f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}" ) if module.bias is not None and weight.dtype == torch.float32: diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5811dd9d21..4a8c899d1c 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1391,6 +1391,10 @@ def triton_quantize_nvfp4( Since VLLM does not use dyanmo guards we need to make this a custom op to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` """ + # reshape to 2d + orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1] + x = x.reshape(-1, orig_N) + M, N = x.shape # assert M % 128 == 0 and N % 64 == 0 assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" @@ -1431,6 +1435,10 @@ def triton_quantize_nvfp4( MASK_SCALES=MASK_SCALES, ) + # reshape back to original shape + scales = scales.view(*orig_leading_dims, -1, padded_cols) + xq = xq.view(*orig_leading_dims, -1, N // 2) + return scales, xq.view(torch.uint8) @triton_quantize_nvfp4.register_fake diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 432ef393d2..477834ca19 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -426,7 +426,12 @@ def tensor_size_hp_to_fp4x2(orig_size, is_contiguous): if is_contiguous: new_size = [*list(new_size[:-1]), new_size[-1] // 2] else: - new_size = [new_size[0] // 2, *list(new_size[1:])] + if len(orig_size) == 2: + new_size = [new_size[0] // 2, *list(new_size[1:])] + else: + assert len(orig_size) == 3, "unsupported" + # only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders + new_size = [new_size[0], new_size[2] // 2, new_size[1]] return new_size @@ -435,10 +440,16 @@ def tensor_size_fp4x2_to_hp(orig_size, is_contiguous): if is_contiguous: new_size = [*list(new_size[:-1]), new_size[-1] * 2] else: - new_size = [new_size[0] * 2, *list(new_size[1:])] + if len(orig_size) == 2: + new_size = [new_size[0] * 2, *list(new_size[1:])] + else: + assert len(orig_size) == 3, "unsupported" + # only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders + new_size = [new_size[0], new_size[2] * 2, new_size[1]] return new_size +# TODO(future PR): fix this function for rank 3 and add tests def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous): new_size = orig_size if is_contiguous: @@ -448,6 +459,7 @@ def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous): return new_size +# TODO(future PR): fix this function for rank 3 and add tests def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): new_size = orig_size if is_contiguous: diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index d1775d0812..043e1160e0 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import math import sys from dataclasses import dataclass from enum import Enum @@ -112,7 +113,7 @@ def __new__( new_size = tensor_size_fp4x2_to_hp( new_size, - qdata.stride(0) > qdata.stride(1), + qdata.stride(-2) > qdata.stride(-1), ) self = torch.Tensor._make_wrapper_subclass( @@ -174,13 +175,13 @@ def to_nvfp4( Returns: NVFP4Tensor: Quantized tensor in NVFP4 format """ - assert len(data_hp.shape) == 2, "unsupported" - M, K = data_hp.shape[0], data_hp.shape[1] + assert len(data_hp.shape) in (2, 3), "unsupported" + leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1] if use_triton_kernel: assert is_swizzled_scales, "Triton kernel only supports swizzled scales" - assert data_hp.shape[1] % 16 == 0, ( - f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}" + assert K % 16 == 0, ( + f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" ) blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) else: @@ -188,7 +189,7 @@ def to_nvfp4( data_hp, block_size, per_tensor_scale ) if is_swizzled_scales: - scale_shape = (M, K // block_size) + scale_shape = (math.prod(leading_dims) * M, K // block_size) blockwise_scales = to_blocked( blockwise_scales.view(scale_shape) ).flatten() @@ -199,7 +200,7 @@ def to_nvfp4( # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 # scale element scale_M, scale_K = M, K // block_size - blockwise_scales = blockwise_scales.view(scale_M, scale_K) + blockwise_scales = blockwise_scales.view(*leading_dims, scale_M, scale_K) return NVFP4Tensor( data_lp, @@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: Returns: torch.Tensor: Dequantized tensor in the target dtype """ - is_transposed = self.qdata.stride(0) < self.qdata.stride(1) + is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1) if is_transposed: - M, K = self.shape[1], self.shape[0] + leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2] else: - M, K = self.shape[0], self.shape[1] - data = self.qdata.t() if is_transposed else self.qdata + leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] + data = self.qdata.transpose(-2, -1) if is_transposed else self.qdata data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) data_f32 = f4_unpacked_to_f32(data_unpacked) - data_f32 = data_f32.view(M, K // self._block_size, self._block_size) - scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1) + data_f32 = data_f32.view( + *leading_dims, M, K // self._block_size, self._block_size + ) + scale_e4m3_reshaped = self.get_hp_scales().view( + *leading_dims, M, K // self._block_size, 1 + ) data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) - result = data_scaled.view(M, K).to(target_dtype) + result = data_scaled.view(*leading_dims, M, K).to(target_dtype) if is_transposed: - result = result.t() + result = result.transpose(-2, -1) return result @@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor: Returns: torch.Tensor: Scales of the NVFP4Tensor """ - is_transposed = self.qdata.stride(0) < self.qdata.stride(1) + is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1) if is_transposed: - M, K = self.shape[1], self.shape[0] - scale_e4m3 = self._scale_e4m3.t() + leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2] + scale_e4m3 = self._scale_e4m3.transpose(-2, -1) else: - M, K = self.shape[0], self.shape[1] + leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] scale_e4m3 = self._scale_e4m3 if self._is_swizzled_scales: - scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size) + scale_e4m3 = from_blocked( + scale_e4m3, math.prod(leading_dims) * M, K // self._block_size + ) return ( scale_e4m3.to(self._orig_dtype) @@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs): raise ValueError("Only support aten.slice with step=1") assert x.qdata.is_contiguous(), "Only support contiguous data for now" + assert len(x.shape) == 2, ( + f"only rank 2 is supported for slice, got rank {len(x.shape)}" + ) M, K = x.shape[0], x.shape[1] @@ -583,6 +593,28 @@ def nvfp4_t(func, types, args, kwargs): return new +@implements([aten.transpose.int]) +def nvfp4_transpose(func, types, args, kwargs): + old, dim0, dim1 = args + assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}" + valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1)) + assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}" + new_qdata = func(old.qdata, dim0, dim1, **kwargs) + new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs) + new = NVFP4Tensor( + new_qdata, + new_scale, + old._block_size, + old._orig_dtype, + old._per_tensor_scale, + old._act_per_tensor_scale, + old._is_swizzled_scales, + old.use_triton_kernel, + old.act_quant_kwargs, + ) + return new + + @implements([aten.view.default]) def nvfp4_view_op(func, types, args, kwargs): data = args[0].qdata diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 5fec85fee6..7f694b56d3 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -625,6 +625,18 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig): f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}" ) + def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig): + # this happens when vLLM loads empty MoE weights and quantizes + # them + + dtype = torch.bfloat16 + with torch.device("meta"): + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + l.weight = torch.nn.Parameter( + torch.randn(60, 2816, 2048, device="cuda", dtype=dtype) + ) + quantize_(l, config) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)