Skip to content

Commit 4833bde

Browse files
committed
[wip] enable 3d weights for NVFP4Tensor
Summary: doesn't work yet, stay tuned this is needed for vLLM stitching 2d weights into a 3d weight for MoEs Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a69b96d ghstack-comment-id: 3357908175 Pull-Request: #3109
1 parent 8955739 commit 4833bde

File tree

7 files changed

+141
-40
lines changed

7 files changed

+141
-40
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,11 @@ def test_narrow_similar_to_vllm(self):
218218
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
219219
)
220220
self._test_narrow_similar_to_vllm(config)
221+
222+
def test_nvfp4_quantize_3d_param_similar_to_vllm(self):
223+
config = NVFP4InferenceConfig(
224+
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
225+
use_triton_kernel=False,
226+
use_dynamic_per_tensor_scale=False,
227+
)
228+
self._test_quantize_3d_param_similar_to_vllm(config)

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
(torch.float32, (64, 128), False),
4343
(torch.bfloat16, (128, 256), False),
4444
(torch.bfloat16, (64, 128), True),
45+
(torch.bfloat16, (1, 32, 64), False),
4546
],
4647
)
4748
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -83,14 +84,21 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
8384
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
8485
)
8586

86-
x_nvfp4_t = x_nvfp4.t()
87+
if len(x.shape) == 2:
88+
x_nvfp4_t = x_nvfp4.t()
89+
x_t = x.t()
90+
else:
91+
# TODO(before land): also test transpose dims (1, 2), (2, 1), (-1, -2)
92+
x_nvfp4_t = x_nvfp4.transpose(-2, -1)
93+
x_t = x.transpose(-2, -1)
94+
8795
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
88-
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
96+
assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0)
8997

90-
assert x.t().shape == x_reconstructed_t.shape, (
98+
assert x_t.shape == x_reconstructed_t.shape, (
9199
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
92100
)
93-
assert x.t().dtype == x_reconstructed_t.dtype, (
101+
assert x_t.dtype == x_reconstructed_t.dtype, (
94102
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
95103
)
96104

@@ -103,6 +111,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
103111
(16, 32),
104112
(64, 128),
105113
(384, 128),
114+
(1, 32, 64),
106115
],
107116
)
108117
@pytest.mark.skipif(
@@ -115,8 +124,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
115124
that the _is_swizzled_scales flag is set correctly.
116125
"""
117126

118-
M, K = shape
119-
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
127+
data = torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
120128

121129
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
122130
assert tensor._is_swizzled_scales == is_swizzled_scales
@@ -536,36 +544,43 @@ def test_nvfp4_to_copy():
536544
@pytest.mark.parametrize("use_triton_kernel", [False, True])
537545
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
538546
@pytest.mark.parametrize(
539-
"mk",
547+
"shape",
540548
(
541549
(128, 64),
542550
(128 + 16, 64),
543551
(128, 64 + 16),
544552
(128 + 16, 64 + 16),
553+
(1, 128, 64),
545554
),
546555
)
547556
def test_scale_shape_matches_qdata(
548-
transpose, use_triton_kernel, is_swizzled_scales, mk
557+
transpose, use_triton_kernel, is_swizzled_scales, shape
549558
):
550559
if use_triton_kernel and not is_sm_at_least_100():
551560
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
552561
if use_triton_kernel and not is_swizzled_scales:
553562
pytest.skip("triton kernel requires swizzled scales")
554563

555-
M, K = mk
556-
557564
block_size = 16
558565

559-
x_hp = torch.randn(M, K, device="cuda")
566+
x_hp = torch.randn(*shape, device="cuda")
560567
x = NVFP4Tensor.to_nvfp4(
561568
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
562569
)
563570

564-
m_dim, k_dim = 0, 1
565-
if transpose:
566-
x_hp = x_hp.t()
567-
x = x.t()
568-
m_dim, k_dim = 1, 0
571+
if len(shape) == 2:
572+
m_dim, k_dim = 0, 1
573+
if transpose:
574+
x_hp = x_hp.t()
575+
x = x.t()
576+
m_dim, k_dim = 1, 0
577+
else:
578+
assert len(shape) == 3, "unsupported"
579+
m_dim, k_dim = 1, 2
580+
if transpose:
581+
x_hp = x_hp.transpose(-2, -1)
582+
x = x.transpose(-2, -1)
583+
m_dim, k_dim = 2, 1
569584

570585
orig_m = x_hp.shape[m_dim]
571586
expected_padded_m = orig_m
@@ -587,3 +602,17 @@ def test_scale_shape_matches_qdata(
587602
assert expected_padded_k == actual_padded_k, (
588603
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
589604
)
605+
606+
607+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
608+
@pytest.mark.skipif(
609+
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
610+
)
611+
@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1)))
612+
@pytest.mark.parametrize("is_swizzled_scales", [True, False])
613+
def test_3d_transpose(dims, is_swizzled_scales):
614+
x_hp = torch.randn(2, 128, 256, device="cuda")
615+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales)
616+
x_hp_t = x_hp.transpose(dims[0], dims[1])
617+
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
618+
assert x_hp_t.shape == x_nvfp4_t.shape

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def _nvfp4_inference_linear_transform(
168168

169169
weight = module.weight
170170

171-
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
171+
if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
172172
raise RuntimeError(
173-
f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}"
173+
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
174174
)
175175

176176
if module.bias is not None and weight.dtype == torch.float32:

torchao/prototype/mx_formats/kernels.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,10 @@ def triton_quantize_nvfp4(
13911391
Since VLLM does not use dyanmo guards we need to make this a custom op
13921392
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
13931393
"""
1394+
# reshape to 2d
1395+
orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1]
1396+
x = x.reshape(-1, orig_N)
1397+
13941398
M, N = x.shape
13951399
# assert M % 128 == 0 and N % 64 == 0
13961400
assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization"
@@ -1431,6 +1435,10 @@ def triton_quantize_nvfp4(
14311435
MASK_SCALES=MASK_SCALES,
14321436
)
14331437

1438+
# reshape back to original shape
1439+
scales = scales.view(*orig_leading_dims, -1, padded_cols)
1440+
xq = xq.view(*orig_leading_dims, -1, N // 2)
1441+
14341442
return scales, xq.view(torch.uint8)
14351443

14361444
@triton_quantize_nvfp4.register_fake

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,12 @@ def tensor_size_hp_to_fp4x2(orig_size, is_contiguous):
426426
if is_contiguous:
427427
new_size = [*list(new_size[:-1]), new_size[-1] // 2]
428428
else:
429-
new_size = [new_size[0] // 2, *list(new_size[1:])]
429+
if len(orig_size) == 2:
430+
new_size = [new_size[0] // 2, *list(new_size[1:])]
431+
else:
432+
assert len(orig_size) == 3, "unsupported"
433+
# only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders
434+
new_size = [new_size[0], new_size[2] // 2, new_size[1]]
430435
return new_size
431436

432437

@@ -435,10 +440,16 @@ def tensor_size_fp4x2_to_hp(orig_size, is_contiguous):
435440
if is_contiguous:
436441
new_size = [*list(new_size[:-1]), new_size[-1] * 2]
437442
else:
438-
new_size = [new_size[0] * 2, *list(new_size[1:])]
443+
if len(orig_size) == 2:
444+
new_size = [new_size[0] * 2, *list(new_size[1:])]
445+
else:
446+
assert len(orig_size) == 3, "unsupported"
447+
# only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders
448+
new_size = [new_size[0], new_size[2] * 2, new_size[1]]
439449
return new_size
440450

441451

452+
# TODO(future PR): fix this function for rank 3 and add tests
442453
def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous):
443454
new_size = orig_size
444455
if is_contiguous:
@@ -448,6 +459,7 @@ def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous):
448459
return new_size
449460

450461

462+
# TODO(future PR): fix this function for rank 3 and add tests
451463
def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
452464
new_size = orig_size
453465
if is_contiguous:

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import sys
89
from dataclasses import dataclass
910
from enum import Enum
@@ -112,7 +113,7 @@ def __new__(
112113

113114
new_size = tensor_size_fp4x2_to_hp(
114115
new_size,
115-
qdata.stride(0) > qdata.stride(1),
116+
qdata.stride(-2) > qdata.stride(-1),
116117
)
117118

118119
self = torch.Tensor._make_wrapper_subclass(
@@ -174,21 +175,21 @@ def to_nvfp4(
174175
Returns:
175176
NVFP4Tensor: Quantized tensor in NVFP4 format
176177
"""
177-
assert len(data_hp.shape) == 2, "unsupported"
178-
M, K = data_hp.shape[0], data_hp.shape[1]
178+
assert len(data_hp.shape) in (2, 3), "unsupported"
179+
leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1]
179180

180181
if use_triton_kernel:
181182
assert is_swizzled_scales, "Triton kernel only supports swizzled scales"
182-
assert data_hp.shape[1] % 16 == 0, (
183-
f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}"
183+
assert K % 16 == 0, (
184+
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
184185
)
185186
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
186187
else:
187188
blockwise_scales, data_lp = nvfp4_quantize(
188189
data_hp, block_size, per_tensor_scale
189190
)
190191
if is_swizzled_scales:
191-
scale_shape = (M, K // block_size)
192+
scale_shape = (math.prod(leading_dims) * M, K // block_size)
192193
blockwise_scales = to_blocked(
193194
blockwise_scales.view(scale_shape)
194195
).flatten()
@@ -199,7 +200,7 @@ def to_nvfp4(
199200
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200201
# scale element
201202
scale_M, scale_K = M, K // block_size
202-
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
203+
blockwise_scales = blockwise_scales.view(*leading_dims, scale_M, scale_K)
203204

204205
return NVFP4Tensor(
205206
data_lp,
@@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
225226
Returns:
226227
torch.Tensor: Dequantized tensor in the target dtype
227228
"""
228-
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
229+
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
229230
if is_transposed:
230-
M, K = self.shape[1], self.shape[0]
231+
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
231232
else:
232-
M, K = self.shape[0], self.shape[1]
233-
data = self.qdata.t() if is_transposed else self.qdata
233+
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
234+
data = self.qdata.transpose(-2, -1) if is_transposed else self.qdata
234235
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
235236
data_f32 = f4_unpacked_to_f32(data_unpacked)
236237

237-
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
238-
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
238+
data_f32 = data_f32.view(
239+
*leading_dims, M, K // self._block_size, self._block_size
240+
)
241+
scale_e4m3_reshaped = self.get_hp_scales().view(
242+
*leading_dims, M, K // self._block_size, 1
243+
)
239244
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
240-
result = data_scaled.view(M, K).to(target_dtype)
245+
result = data_scaled.view(*leading_dims, M, K).to(target_dtype)
241246

242247
if is_transposed:
243-
result = result.t()
248+
result = result.transpose(-2, -1)
244249

245250
return result
246251

@@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor:
250255
Returns:
251256
torch.Tensor: Scales of the NVFP4Tensor
252257
"""
253-
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
258+
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
254259
if is_transposed:
255-
M, K = self.shape[1], self.shape[0]
256-
scale_e4m3 = self._scale_e4m3.t()
260+
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
261+
scale_e4m3 = self._scale_e4m3.transpose(-2, -1)
257262
else:
258-
M, K = self.shape[0], self.shape[1]
263+
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
259264
scale_e4m3 = self._scale_e4m3
260265

261266
if self._is_swizzled_scales:
262-
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
267+
scale_e4m3 = from_blocked(
268+
scale_e4m3, math.prod(leading_dims) * M, K // self._block_size
269+
)
263270

264271
return (
265272
scale_e4m3.to(self._orig_dtype)
@@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs):
380387
raise ValueError("Only support aten.slice with step=1")
381388

382389
assert x.qdata.is_contiguous(), "Only support contiguous data for now"
390+
assert len(x.shape) == 2, (
391+
f"only rank 2 is supported for slice, got rank {len(x.shape)}"
392+
)
383393

384394
M, K = x.shape[0], x.shape[1]
385395

@@ -583,6 +593,28 @@ def nvfp4_t(func, types, args, kwargs):
583593
return new
584594

585595

596+
@implements([aten.transpose.int])
597+
def nvfp4_transpose(func, types, args, kwargs):
598+
old, dim0, dim1 = args
599+
assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}"
600+
valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1))
601+
assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}"
602+
new_qdata = func(old.qdata, dim0, dim1, **kwargs)
603+
new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs)
604+
new = NVFP4Tensor(
605+
new_qdata,
606+
new_scale,
607+
old._block_size,
608+
old._orig_dtype,
609+
old._per_tensor_scale,
610+
old._act_per_tensor_scale,
611+
old._is_swizzled_scales,
612+
old.use_triton_kernel,
613+
old.act_quant_kwargs,
614+
)
615+
return new
616+
617+
586618
@implements([aten.view.default])
587619
def nvfp4_view_op(func, types, args, kwargs):
588620
data = args[0].qdata

torchao/testing/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,18 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
625625
f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}"
626626
)
627627

628+
def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
629+
# this happens when vLLM loads empty MoE weights and quantizes
630+
# them
631+
632+
dtype = torch.bfloat16
633+
with torch.device("meta"):
634+
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
635+
l.weight = torch.nn.Parameter(
636+
torch.randn(60, 2816, 2048, device="cuda", dtype=dtype)
637+
)
638+
quantize_(l, config)
639+
628640

629641
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
630642
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

0 commit comments

Comments
 (0)