Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 114 additions & 10 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
per_tensor_amax_to_scale,
unpack_uint4,
)
from torchao.prototype.mx_formats.utils import ceil_div
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
Expand Down Expand Up @@ -127,18 +128,18 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
"slice_dim,slice_spec",
[
# Row slicing - must align with 128-row boundaries
pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
# pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
# pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
# Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
pytest.param(1, slice(0, 64), id="slice_cols[0:64]"),
pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"),
# pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
# pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"),
# Test tensor parallelism patterns (half splits)
pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"),
pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"),
# pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"),
# pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"),
# Test quarter splits
pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"),
pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
# pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"),
# pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -157,21 +158,54 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
M, K = 256, 4096
else:
# For column slicing, need multiples of 64 columns for alignment
M, K = 128, 4096
# M, K = 128, 4096
M, K = 128, 64 * 2

data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
# tensor.to_dtype(torch.bfloat16)
assert tensor._is_swizzled_scales == True

print(
"before",
tensor.shape,
tensor.qdata.shape,
tensor._scale_e4m3.shape,
)
# print(tensor._scale_e4m3[0:128, 0:4])
if slice_dim == 0:
sliced_tensor = tensor[slice_spec, :]
else:
sliced_tensor = tensor[:, slice_spec]
print(
"after",
sliced_tensor.shape,
sliced_tensor.qdata.shape,
sliced_tensor._scale_e4m3.shape,
)
# print(sliced_tensor._scale_e4m3[0:128, 0:4])
# print(sliced_tensor.qdata.float() - tensor.qdata[0:128, 0:32].float())
# print(sliced_tensor._scale_e4m3.float() - tensor._scale_e4m3[0:128, 0:4].float())

# Verify sliced tensor maintains swizzled state
assert sliced_tensor._is_swizzled_scales == True

# this matches sliced_reconstructed, but not original_reconstructed[:, slice_spec]
if False:
sliced_manually = NVFP4Tensor(
tensor.qdata[:, 0:32],
tensor._scale_e4m3[:, 0:4].contiguous(),
tensor._block_size,
tensor._orig_dtype,
tensor._per_tensor_scale,
tensor._act_per_tensor_scale,
tensor._is_swizzled_scales,
tensor.use_triton_kernel,
tensor.act_quant_kwargs,
)
import pdb; pdb.set_trace()

# Verify sliced tensor can be dequantized
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)

Expand All @@ -181,6 +215,11 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
expected = original_reconstructed[slice_spec, :]
else:
expected = original_reconstructed[:, slice_spec]
print('e', expected)
print('s', sliced_reconstructed)
print('e - s', expected - sliced_reconstructed)
print(1, expected.abs().sum())
print(2, sliced_reconstructed.abs().sum())

torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6)

Expand Down Expand Up @@ -421,7 +460,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("use_triton_kernel", [True, False])
# @pytest.mark.parametrize("use_triton_kernel", [True, False])
@pytest.mark.parametrize("use_triton_kernel", [False])
@pytest.mark.parametrize(
"shapes",
[
Expand Down Expand Up @@ -525,3 +565,67 @@ def test_nvfp4_to_copy():
assert x.act_quant_kwargs == y.act_quant_kwargs
assert x.dtype == torch.float32
assert y.dtype == torch.bfloat16


@pytest.mark.parametrize("transpose", [False, True])
# @pytest.mark.parametrize("transpose", [True])
# @pytest.mark.parametrize("transpose", [False])
@pytest.mark.parametrize("use_triton_kernel", [False, True])
# @pytest.mark.parametrize("use_triton_kernel", [False])
# @pytest.mark.parametrize("use_triton_kernel", [True])
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
# @pytest.mark.parametrize("is_swizzled_scales", [True])
@pytest.mark.parametrize(
"mk",
(
(128, 64),
(128 + 16, 64),
(128, 64 + 16),
(128 + 16, 64 + 16),
),
)
# @pytest.mark.parametrize("mk", ((128 + 16, 64),))
def test_scale_shape_matches_qdata(
transpose, use_triton_kernel, is_swizzled_scales, mk
):
if use_triton_kernel and not is_swizzled_scales:
pytest.skip("triton kernel requires swizzled scales")

M, K = mk

block_size = 16

# TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles,
# to test the padding logic
# context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
x_hp = torch.randn(M, K, device="cuda")
x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
)

m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_m = ceil_div(orig_m, 128) * 32
actual_padded_m = x._scale_e4m3.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {x.shape} and {x._scale_e4m3.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
actual_padded_k = x._scale_e4m3.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {x.shape} and {x._scale_e4m3.shape}"
)
114 changes: 90 additions & 24 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,31 @@ def to_nvfp4(
f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}"
)
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)

# TODO(before land): share code for scale shape manipulation in the two
# if branches
scale_M = ceil_div(data_hp.shape[0], 128) * 32
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
else:
blockwise_scales, data_lp = nvfp4_quantize(
data_hp, block_size, per_tensor_scale
)
if is_swizzled_scales:
M, K = data_hp.shape[0], data_hp.shape[1]
scale_shape = (M, K // block_size)
blockwise_scales = to_blocked(
blockwise_scales.view(scale_shape)
).flatten()
# print(1, blockwise_scales.shape)
blockwise_scales = blockwise_scales.view(scale_shape)
# print(2, blockwise_scales.shape, blockwise_scales)
blockwise_scales = to_blocked(blockwise_scales)
print(3, blockwise_scales.shape, blockwise_scales)

# match shape of data_hp
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
scale_M = ceil_div(data_hp.shape[0], 128) * 32
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
print(4, blockwise_scales.shape, blockwise_scales)

return NVFP4Tensor(
data_lp,
Expand All @@ -212,6 +227,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
torch.Tensor: Dequantized tensor in the target dtype
"""
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
print('is_transposed', is_transposed)
if is_transposed:
M, K = self.shape[1], self.shape[0]
else:
Expand All @@ -220,8 +236,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
data_f32 = f4_unpacked_to_f32(data_unpacked)

# next: debug scale shape here
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
scales = self.get_hp_scales()
scales_tmp = scales.reshape(32, -1)
print('scales', scales_tmp.shape, scales_tmp[0:8])
scale_e4m3_reshaped = scales.view(M, K // self._block_size, 1)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
result = data_scaled.view(M, K).to(target_dtype)

Expand All @@ -237,15 +257,17 @@ def get_hp_scales(self) -> torch.Tensor:
torch.Tensor: Scales of the NVFP4Tensor
"""
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
print("is_transposed", is_transposed)
if is_transposed:
M, K = self.shape[1], self.shape[0]
scale_e4m3 = self._scale_e4m3.t()
else:
M, K = self.shape[0], self.shape[1]
scale_e4m3 = self._scale_e4m3

if self._is_swizzled_scales:
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
else:
scale_e4m3 = self._scale_e4m3
# import pdb; pdb.set_trace()
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)

return (
scale_e4m3.to(self._orig_dtype)
Expand Down Expand Up @@ -366,6 +388,7 @@ def nvfp4_slice(func, types, args, kwargs):
raise ValueError("Only support aten.slice with step=1")

assert x.qdata.is_contiguous(), "Only support contiguous data for now"
assert x._scale_e4m3.is_contiguous(), "Only support contiguous scale for now"

M, K = x.shape[0], x.shape[1]

Expand All @@ -376,6 +399,22 @@ def nvfp4_slice(func, types, args, kwargs):
n_col_blocks = ceil_div(scale_cols, 4)
elements_per_block = 32 * 16 # 512 elements

#
# See https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
# for qdata vs scale layout. Here is a summary specific to nvfp4:
#
# 1. qdata tile shape is (128, 32) packed fp4, which is (128, 64) unpacked fp4
# 2. scale tile shape is (32, 16)
# 3. correspondence of qdata vs scale tiles is as follows, in a 2 by 2 tile example
#
# | tile_idx | qdata_rows | qdata_cols | scale_rows | scale_cols |
# ----------------------------------------------------------------
# | 0 | 0:127 | 0:31 | 0:31 | 0:15 |
# | 1 | 128:255 | 0:31 | 32:63 | 0:15 |
# | 2 | 0:127 | 32:63 | 0:31 | 16:31 |
# | 3 | 128:255 | 32:63 | 32:63 | 16:31 |
#

if dim == 0:
# Row slicing
# Handle sys.maxsize (default slice end)
Expand Down Expand Up @@ -407,7 +446,9 @@ def nvfp4_slice(func, types, args, kwargs):
else None
)

sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
# TODO(before land): this is wrong, it works but need to express in terms of
# properly laid out scale as in the comment block above
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start, end, 1)
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)

elif dim == 1:
Expand Down Expand Up @@ -452,20 +493,43 @@ def nvfp4_slice(func, types, args, kwargs):
# Full width - no slicing needed
sliced_scale = x._scale_e4m3
else:
# Extract specific column blocks from each row block
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
elements_per_row_block = n_col_blocks * elements_per_block

# Build list of slices to extract
slices_to_extract = []
for row_block in range(n_row_blocks):
row_start = row_block * elements_per_row_block
col_start = row_start + start_col_block * elements_per_block
col_end = row_start + end_col_block * elements_per_block
slices_to_extract.append(x._scale_e4m3[col_start:col_end])

# Concatenate all the slices
sliced_scale = torch.cat(slices_to_extract, dim=0)

scale = x._scale_e4m3

# reshape to expected shape
# TODO(before land): do this when swizzling so we don't have to do it here

# TODO(before land): comment the mul by 2 here
scale_rows = n_row_blocks * 16 * 2
scale_cols = n_col_blocks * 16
scale = scale.view(scale_rows, scale_cols)

# convert from hp_tensor row to scale row
start_scale_col = 0 if start is None else (start // 128 * 16)
end_scale_col = scale_cols if end is None or end >= K else (end // 16 * 4)
# import pdb; pdb.set_trace()

if False:
# Extract specific column blocks from each row block
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
elements_per_row_block = n_col_blocks * elements_per_block

# Build list of slices to extract
slices_to_extract = []
for row_block in range(n_row_blocks):
row_start = row_block * elements_per_row_block
col_start = row_start + start_col_block * elements_per_block
col_end = row_start + end_col_block * elements_per_block
slices_to_extract.append(x._scale_e4m3[col_start:col_end])

# Concatenate all the slices
sliced_scale = torch.cat(slices_to_extract, dim=0)
# import pdb; pdb.set_trace()
sliced_scale = aten.slice.Tensor(
# x._scale_e4m3, dim, start_scale_col, end_scale_col, step
scale, dim, start_scale_col, end_scale_col, step
).contiguous()
# import pdb; pdb.set_trace()

# Slice the data tensor
packed_start = None if start is None else start // 2
Expand Down Expand Up @@ -537,7 +601,7 @@ def nvfp4_t(func, types, args, kwargs):
old = args[0]
new = NVFP4Tensor(
old.qdata.t(),
old._scale_e4m3,
old._scale_e4m3.t(),
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
Expand Down Expand Up @@ -577,6 +641,8 @@ def _addmm_nvfp4_dispatch(
"""
assert a.qdata.is_contiguous()
assert b.qdata.t().is_contiguous()
assert a._scale_e4m3.is_contiguous()
assert b._scale_e4m3.t().is_contiguous()
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"

Expand Down Expand Up @@ -615,7 +681,7 @@ def _addmm_nvfp4_dispatch(
a.qdata.view(torch.float4_e2m1fn_x2),
b.qdata.view(torch.float4_e2m1fn_x2),
a_scale_blocked.view(torch.float8_e4m3fn),
b_scale_blocked.view(torch.float8_e4m3fn),
b_scale_blocked.t().view(torch.float8_e4m3fn),
bias=None if should_add_bias_separately else bias,
out_dtype=a._orig_dtype,
# scale_result=scale_result, # Not supported yet
Expand Down
Loading
Loading