diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 62cd1b88ad..7ab37a0dba 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -21,6 +21,7 @@ per_tensor_amax_to_scale, unpack_uint4, ) +from torchao.prototype.mx_formats.utils import ceil_div from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -525,3 +526,64 @@ def test_nvfp4_to_copy(): assert x.act_quant_kwargs == y.act_quant_kwargs assert x.dtype == torch.float32 assert y.dtype == torch.bfloat16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize("use_triton_kernel", [False, True]) +@pytest.mark.parametrize("is_swizzled_scales", [False, True]) +@pytest.mark.parametrize( + "mk", + ( + (128, 64), + (128 + 16, 64), + (128, 64 + 16), + (128 + 16, 64 + 16), + ), +) +def test_scale_shape_matches_qdata( + transpose, use_triton_kernel, is_swizzled_scales, mk +): + if use_triton_kernel and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel") + if use_triton_kernel and not is_swizzled_scales: + pytest.skip("triton kernel requires swizzled scales") + + M, K = mk + + block_size = 16 + + x_hp = torch.randn(M, K, device="cuda") + x = NVFP4Tensor.to_nvfp4( + x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel + ) + + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_m = ceil_div(orig_m, 128) * 32 + actual_padded_m = x._scale_e4m3.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_k = ceil_div(orig_k // block_size, 4) * 16 + actual_padded_k = x._scale_e4m3.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}" + ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index c22f7793bb..d1775d0812 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -24,7 +24,11 @@ tensor_size_fp4x2_to_hp, tensor_size_hp_to_fp4x2, ) -from torchao.prototype.mx_formats.utils import from_blocked, to_blocked +from torchao.prototype.mx_formats.utils import ( + from_blocked, + hp_data_dims_to_swizzled_scale_dims_nvfp4, + to_blocked, +) from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) @@ -170,6 +174,9 @@ def to_nvfp4( Returns: NVFP4Tensor: Quantized tensor in NVFP4 format """ + assert len(data_hp.shape) == 2, "unsupported" + M, K = data_hp.shape[0], data_hp.shape[1] + if use_triton_kernel: assert is_swizzled_scales, "Triton kernel only supports swizzled scales" assert data_hp.shape[1] % 16 == 0, ( @@ -181,12 +188,19 @@ def to_nvfp4( data_hp, block_size, per_tensor_scale ) if is_swizzled_scales: - M, K = data_hp.shape[0], data_hp.shape[1] scale_shape = (M, K // block_size) blockwise_scales = to_blocked( blockwise_scales.view(scale_shape) ).flatten() + if is_swizzled_scales: + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(M, K) + else: + # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 + # scale element + scale_M, scale_K = M, K // block_size + blockwise_scales = blockwise_scales.view(scale_M, scale_K) + return NVFP4Tensor( data_lp, blockwise_scales, @@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor: is_transposed = self.qdata.stride(0) < self.qdata.stride(1) if is_transposed: M, K = self.shape[1], self.shape[0] + scale_e4m3 = self._scale_e4m3.t() else: M, K = self.shape[0], self.shape[1] + scale_e4m3 = self._scale_e4m3 if self._is_swizzled_scales: - scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size) - else: - scale_e4m3 = self._scale_e4m3 + scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size) return ( scale_e4m3.to(self._orig_dtype) @@ -369,6 +383,11 @@ def nvfp4_slice(func, types, args, kwargs): M, K = x.shape[0], x.shape[1] + # The scale manipulations below assume a flattened scale. For now, we + # flatten the scale, go through the calculations below, and then reshape + # it back to the format which matches the shape of `qdata`. + # TODO(future PR): update this + if x._is_swizzled_scales: scale_rows = M scale_cols = K // x._block_size @@ -407,7 +426,9 @@ def nvfp4_slice(func, types, args, kwargs): else None ) - sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1) + sliced_scale = aten.slice.Tensor( + x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1 + ) sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) elif dim == 1: @@ -462,7 +483,7 @@ def nvfp4_slice(func, types, args, kwargs): row_start = row_block * elements_per_row_block col_start = row_start + start_col_block * elements_per_block col_end = row_start + end_col_block * elements_per_block - slices_to_extract.append(x._scale_e4m3[col_start:col_end]) + slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end]) # Concatenate all the slices sliced_scale = torch.cat(slices_to_extract, dim=0) @@ -515,6 +536,19 @@ def nvfp4_slice(func, types, args, kwargs): sliced_scale = sliced_scale.flatten() + # reshape at the end + sliced_M = sliced_data.shape[0] + # multiply by 2 to convert from bytes to num_elements + sliced_K = sliced_data.shape[1] * 2 + if x._is_swizzled_scales: + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K) + else: + # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 + # scale element + scale_M = sliced_M + scale_K = sliced_K // x._block_size + sliced_scale = sliced_scale.view(scale_M, scale_K) + # Create result tensor result = NVFP4Tensor( sliced_data, @@ -537,7 +571,7 @@ def nvfp4_t(func, types, args, kwargs): old = args[0] new = NVFP4Tensor( old.qdata.t(), - old._scale_e4m3, + old._scale_e4m3.t(), old._block_size, old._orig_dtype, old._per_tensor_scale, @@ -576,7 +610,9 @@ def _addmm_nvfp4_dispatch( The only difference is whether bias is None or not. """ assert a.qdata.is_contiguous() + assert a._scale_e4m3.is_contiguous() assert b.qdata.t().is_contiguous() + assert b._scale_e4m3.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" @@ -591,9 +627,9 @@ def _addmm_nvfp4_dispatch( a_scale_blocked = to_blocked(a_scale) if b._is_swizzled_scales: - b_scale_blocked = b._scale_e4m3 # Already swizzled + b_scale_blocked = b._scale_e4m3.t() # Already swizzled else: - b_scale = b._scale_e4m3.view(N, K // b._block_size) + b_scale = b._scale_e4m3.t().view(N, K // b._block_size) b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 247b17d838..28a8526709 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import torch from torch.distributed._tensor import DTensor @@ -99,6 +101,22 @@ def from_blocked( return padded[:original_rows, :original_cols] +def hp_data_dims_to_swizzled_scale_dims_nvfp4( + hp_data_M, + hp_data_K, +) -> Tuple[int, int]: + """ + Given the `M` and `K` dimensions of a high precision contiguous tensor, + returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to + that tensor. + """ + # a 128x64 unpacked or 128x32 packed qdata tile corresponds + # to a swizzled 32x16 scale tile + scale_M = ceil_div(hp_data_M, 128) * 32 + scale_K = ceil_div(hp_data_K, 64) * 16 + return scale_M, scale_K + + def _to_blocked_single(scales: Tensor) -> Tensor: """Assume that we have a 128x4 block of scales in K Major order