Skip to content

Commit 56bf244

Browse files
committed
[wip] make scale shape 2d and match qdata shape in NVFP4Tensor
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: dc4970e ghstack-comment-id: 3357503258 Pull-Request: #3108
1 parent 9368b28 commit 56bf244

File tree

3 files changed

+128
-10
lines changed

3 files changed

+128
-10
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
per_tensor_amax_to_scale,
2222
unpack_uint4,
2323
)
24+
from torchao.prototype.mx_formats.utils import ceil_div
2425
from torchao.quantization.utils import compute_error
2526
from torchao.testing.utils import skip_if_rocm
2627
from torchao.utils import (
@@ -525,3 +526,68 @@ def test_nvfp4_to_copy():
525526
assert x.act_quant_kwargs == y.act_quant_kwargs
526527
assert x.dtype == torch.float32
527528
assert y.dtype == torch.bfloat16
529+
530+
531+
@pytest.mark.parametrize("transpose", [False, True])
532+
# @pytest.mark.parametrize("transpose", [True])
533+
# @pytest.mark.parametrize("transpose", [False])
534+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
535+
# @pytest.mark.parametrize("use_triton_kernel", [False])
536+
# @pytest.mark.parametrize("use_triton_kernel", [True])
537+
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
538+
# @pytest.mark.parametrize("is_swizzled_scales", [False])
539+
# @pytest.mark.parametrize("is_swizzled_scales", [True])
540+
@pytest.mark.parametrize(
541+
"mk",
542+
(
543+
(128, 64),
544+
(128 + 16, 64),
545+
(128, 64 + 16),
546+
(128 + 16, 64 + 16),
547+
),
548+
)
549+
# @pytest.mark.parametrize("mk", ((128 + 16, 64),))
550+
def test_scale_shape_matches_qdata(
551+
transpose, use_triton_kernel, is_swizzled_scales, mk
552+
):
553+
if use_triton_kernel and not is_swizzled_scales:
554+
pytest.skip("triton kernel requires swizzled scales")
555+
556+
M, K = mk
557+
558+
block_size = 16
559+
560+
# TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles,
561+
# to test the padding logic
562+
# context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
563+
x_hp = torch.randn(M, K, device="cuda")
564+
x = NVFP4Tensor.to_nvfp4(
565+
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
566+
)
567+
568+
m_dim, k_dim = 0, 1
569+
if transpose:
570+
x_hp = x_hp.t()
571+
x = x.t()
572+
m_dim, k_dim = 1, 0
573+
574+
orig_m = x_hp.shape[m_dim]
575+
expected_padded_m = orig_m
576+
if is_swizzled_scales:
577+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
578+
expected_padded_m = ceil_div(orig_m, 128) * 32
579+
actual_padded_m = x._scale_e4m3.shape[m_dim]
580+
assert expected_padded_m == actual_padded_m, (
581+
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}"
582+
)
583+
584+
orig_k = x_hp.shape[k_dim]
585+
expected_padded_k = orig_k // block_size
586+
if is_swizzled_scales:
587+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
588+
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
589+
actual_padded_k = x._scale_e4m3.shape[k_dim]
590+
591+
assert expected_padded_k == actual_padded_k, (
592+
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
593+
)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
tensor_size_fp4x2_to_hp,
2525
tensor_size_hp_to_fp4x2,
2626
)
27-
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
27+
from torchao.prototype.mx_formats.utils import (
28+
from_blocked,
29+
hp_data_dims_to_swizzled_scale_dims_nvfp4,
30+
to_blocked,
31+
)
2832
from torchao.quantization.quantize_.common import (
2933
QuantizeTensorKwargs,
3034
)
@@ -170,6 +174,9 @@ def to_nvfp4(
170174
Returns:
171175
NVFP4Tensor: Quantized tensor in NVFP4 format
172176
"""
177+
assert len(data_hp.shape) == 2, "unsupported"
178+
M, K = data_hp.shape[0], data_hp.shape[1]
179+
173180
if use_triton_kernel:
174181
assert is_swizzled_scales, "Triton kernel only supports swizzled scales"
175182
assert data_hp.shape[1] % 16 == 0, (
@@ -181,12 +188,19 @@ def to_nvfp4(
181188
data_hp, block_size, per_tensor_scale
182189
)
183190
if is_swizzled_scales:
184-
M, K = data_hp.shape[0], data_hp.shape[1]
185191
scale_shape = (M, K // block_size)
186192
blockwise_scales = to_blocked(
187193
blockwise_scales.view(scale_shape)
188194
).flatten()
189195

196+
if is_swizzled_scales:
197+
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(M, K)
198+
else:
199+
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200+
# scale element
201+
scale_M, scale_K = M, K // block_size
202+
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
203+
190204
return NVFP4Tensor(
191205
data_lp,
192206
blockwise_scales,
@@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor:
239253
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
240254
if is_transposed:
241255
M, K = self.shape[1], self.shape[0]
256+
scale_e4m3 = self._scale_e4m3.t()
242257
else:
243258
M, K = self.shape[0], self.shape[1]
259+
scale_e4m3 = self._scale_e4m3
244260

245261
if self._is_swizzled_scales:
246-
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
247-
else:
248-
scale_e4m3 = self._scale_e4m3
262+
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
249263

250264
return (
251265
scale_e4m3.to(self._orig_dtype)
@@ -369,6 +383,9 @@ def nvfp4_slice(func, types, args, kwargs):
369383

370384
M, K = x.shape[0], x.shape[1]
371385

386+
# the scale manipulations below assume a flattened scale
387+
# TODO(future or this PR): update this
388+
372389
if x._is_swizzled_scales:
373390
scale_rows = M
374391
scale_cols = K // x._block_size
@@ -407,7 +424,9 @@ def nvfp4_slice(func, types, args, kwargs):
407424
else None
408425
)
409426

410-
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
427+
sliced_scale = aten.slice.Tensor(
428+
x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1
429+
)
411430
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
412431

413432
elif dim == 1:
@@ -462,7 +481,7 @@ def nvfp4_slice(func, types, args, kwargs):
462481
row_start = row_block * elements_per_row_block
463482
col_start = row_start + start_col_block * elements_per_block
464483
col_end = row_start + end_col_block * elements_per_block
465-
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
484+
slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end])
466485

467486
# Concatenate all the slices
468487
sliced_scale = torch.cat(slices_to_extract, dim=0)
@@ -515,6 +534,19 @@ def nvfp4_slice(func, types, args, kwargs):
515534

516535
sliced_scale = sliced_scale.flatten()
517536

537+
# reshape at the end
538+
sliced_M = sliced_data.shape[0]
539+
# multiply by 2 to convert from bytes to num_elements
540+
sliced_K = sliced_data.shape[1] * 2
541+
if x._is_swizzled_scales:
542+
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K)
543+
else:
544+
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
545+
# scale element
546+
scale_M = sliced_M
547+
scale_K = sliced_K // x._block_size
548+
sliced_scale = sliced_scale.view(scale_M, scale_K)
549+
518550
# Create result tensor
519551
result = NVFP4Tensor(
520552
sliced_data,
@@ -537,7 +569,7 @@ def nvfp4_t(func, types, args, kwargs):
537569
old = args[0]
538570
new = NVFP4Tensor(
539571
old.qdata.t(),
540-
old._scale_e4m3,
572+
old._scale_e4m3.t(),
541573
old._block_size,
542574
old._orig_dtype,
543575
old._per_tensor_scale,
@@ -576,7 +608,9 @@ def _addmm_nvfp4_dispatch(
576608
The only difference is whether bias is None or not.
577609
"""
578610
assert a.qdata.is_contiguous()
611+
assert a._scale_e4m3.is_contiguous()
579612
assert b.qdata.t().is_contiguous()
613+
assert b._scale_e4m3.t().is_contiguous()
580614
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
581615
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
582616

@@ -591,9 +625,9 @@ def _addmm_nvfp4_dispatch(
591625
a_scale_blocked = to_blocked(a_scale)
592626

593627
if b._is_swizzled_scales:
594-
b_scale_blocked = b._scale_e4m3 # Already swizzled
628+
b_scale_blocked = b._scale_e4m3.t() # Already swizzled
595629
else:
596-
b_scale = b._scale_e4m3.view(N, K // b._block_size)
630+
b_scale = b._scale_e4m3.t().view(N, K // b._block_size)
597631
b_scale_blocked = to_blocked(b_scale)
598632

599633
# Merge double quant scales into 1 scale for Scale_In^D

torchao/prototype/mx_formats/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
import torch
810
from torch.distributed._tensor import DTensor
911

@@ -99,6 +101,22 @@ def from_blocked(
99101
return padded[:original_rows, :original_cols]
100102

101103

104+
def hp_data_dims_to_swizzled_scale_dims_nvfp4(
105+
hp_data_M,
106+
hp_data_K,
107+
) -> Tuple[int, int]:
108+
"""
109+
Given the `M` and `K` dimensions of a high precision contiguous tensor,
110+
returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to
111+
that tensor.
112+
"""
113+
# a 128x64 unpacked or 128x64 packed qdata tile corresponds
114+
# to a swizzled 32x16 scale tile
115+
scale_M = ceil_div(hp_data_M, 128) * 32
116+
scale_K = ceil_div(hp_data_K, 64) * 16
117+
return scale_M, scale_K
118+
119+
102120
def _to_blocked_single(scales: Tensor) -> Tensor:
103121
"""Assume that we have a 128x4 block of scales in K Major order
104122

0 commit comments

Comments
 (0)