Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
per_tensor_amax_to_scale,
unpack_uint4,
)
from torchao.prototype.mx_formats.utils import ceil_div
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
Expand Down Expand Up @@ -525,3 +526,64 @@ def test_nvfp4_to_copy():
assert x.act_quant_kwargs == y.act_quant_kwargs
assert x.dtype == torch.float32
assert y.dtype == torch.bfloat16


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize("use_triton_kernel", [False, True])
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
@pytest.mark.parametrize(
"mk",
(
(128, 64),
(128 + 16, 64),
(128, 64 + 16),
(128 + 16, 64 + 16),
),
)
def test_scale_shape_matches_qdata(
transpose, use_triton_kernel, is_swizzled_scales, mk
):
if use_triton_kernel and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
if use_triton_kernel and not is_swizzled_scales:
pytest.skip("triton kernel requires swizzled scales")

M, K = mk

block_size = 16

x_hp = torch.randn(M, K, device="cuda")
x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
)

m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_m = ceil_div(orig_m, 128) * 32
actual_padded_m = x._scale_e4m3.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
actual_padded_k = x._scale_e4m3.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
)
56 changes: 46 additions & 10 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
tensor_size_fp4x2_to_hp,
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
from torchao.prototype.mx_formats.utils import (
from_blocked,
hp_data_dims_to_swizzled_scale_dims_nvfp4,
to_blocked,
)
from torchao.quantization.quantize_.common import (
QuantizeTensorKwargs,
)
Expand Down Expand Up @@ -170,6 +174,9 @@ def to_nvfp4(
Returns:
NVFP4Tensor: Quantized tensor in NVFP4 format
"""
assert len(data_hp.shape) == 2, "unsupported"
M, K = data_hp.shape[0], data_hp.shape[1]

if use_triton_kernel:
assert is_swizzled_scales, "Triton kernel only supports swizzled scales"
assert data_hp.shape[1] % 16 == 0, (
Expand All @@ -181,12 +188,19 @@ def to_nvfp4(
data_hp, block_size, per_tensor_scale
)
if is_swizzled_scales:
M, K = data_hp.shape[0], data_hp.shape[1]
scale_shape = (M, K // block_size)
blockwise_scales = to_blocked(
blockwise_scales.view(scale_shape)
).flatten()

if is_swizzled_scales:
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(M, K)
else:
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
# scale element
scale_M, scale_K = M, K // block_size
blockwise_scales = blockwise_scales.view(scale_M, scale_K)

return NVFP4Tensor(
data_lp,
blockwise_scales,
Expand Down Expand Up @@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor:
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
if is_transposed:
M, K = self.shape[1], self.shape[0]
scale_e4m3 = self._scale_e4m3.t()
else:
M, K = self.shape[0], self.shape[1]
scale_e4m3 = self._scale_e4m3

if self._is_swizzled_scales:
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
else:
scale_e4m3 = self._scale_e4m3
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)

return (
scale_e4m3.to(self._orig_dtype)
Expand Down Expand Up @@ -369,6 +383,11 @@ def nvfp4_slice(func, types, args, kwargs):

M, K = x.shape[0], x.shape[1]

# The scale manipulations below assume a flattened scale. For now, we
# flatten the scale, go through the calculations below, and then reshape
# it back to the format which matches the shape of `qdata`.
# TODO(future PR): update this

if x._is_swizzled_scales:
scale_rows = M
scale_cols = K // x._block_size
Expand Down Expand Up @@ -407,7 +426,9 @@ def nvfp4_slice(func, types, args, kwargs):
else None
)

sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
sliced_scale = aten.slice.Tensor(
x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1
)
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)

elif dim == 1:
Expand Down Expand Up @@ -462,7 +483,7 @@ def nvfp4_slice(func, types, args, kwargs):
row_start = row_block * elements_per_row_block
col_start = row_start + start_col_block * elements_per_block
col_end = row_start + end_col_block * elements_per_block
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end])

# Concatenate all the slices
sliced_scale = torch.cat(slices_to_extract, dim=0)
Expand Down Expand Up @@ -515,6 +536,19 @@ def nvfp4_slice(func, types, args, kwargs):

sliced_scale = sliced_scale.flatten()

# reshape at the end
sliced_M = sliced_data.shape[0]
# multiply by 2 to convert from bytes to num_elements
sliced_K = sliced_data.shape[1] * 2
if x._is_swizzled_scales:
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K)
else:
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
# scale element
scale_M = sliced_M
scale_K = sliced_K // x._block_size
sliced_scale = sliced_scale.view(scale_M, scale_K)

# Create result tensor
result = NVFP4Tensor(
sliced_data,
Expand All @@ -537,7 +571,7 @@ def nvfp4_t(func, types, args, kwargs):
old = args[0]
new = NVFP4Tensor(
old.qdata.t(),
old._scale_e4m3,
old._scale_e4m3.t(),
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
Expand Down Expand Up @@ -576,7 +610,9 @@ def _addmm_nvfp4_dispatch(
The only difference is whether bias is None or not.
"""
assert a.qdata.is_contiguous()
assert a._scale_e4m3.is_contiguous()
assert b.qdata.t().is_contiguous()
assert b._scale_e4m3.t().is_contiguous()
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"

Expand All @@ -591,9 +627,9 @@ def _addmm_nvfp4_dispatch(
a_scale_blocked = to_blocked(a_scale)

if b._is_swizzled_scales:
b_scale_blocked = b._scale_e4m3 # Already swizzled
b_scale_blocked = b._scale_e4m3.t() # Already swizzled
else:
b_scale = b._scale_e4m3.view(N, K // b._block_size)
b_scale = b._scale_e4m3.t().view(N, K // b._block_size)
b_scale_blocked = to_blocked(b_scale)

# Merge double quant scales into 1 scale for Scale_In^D
Expand Down
18 changes: 18 additions & 0 deletions torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torch.distributed._tensor import DTensor

Expand Down Expand Up @@ -99,6 +101,22 @@ def from_blocked(
return padded[:original_rows, :original_cols]


def hp_data_dims_to_swizzled_scale_dims_nvfp4(
hp_data_M,
hp_data_K,
) -> Tuple[int, int]:
"""
Given the `M` and `K` dimensions of a high precision contiguous tensor,
returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to
that tensor.
"""
# a 128x64 unpacked or 128x32 packed qdata tile corresponds
# to a swizzled 32x16 scale tile
scale_M = ceil_div(hp_data_M, 128) * 32
scale_K = ceil_div(hp_data_K, 64) * 16
return scale_M, scale_K


def _to_blocked_single(scales: Tensor) -> Tensor:
"""Assume that we have a 128x4 block of scales in K Major order

Expand Down
Loading