2424 tensor_size_fp4x2_to_hp ,
2525 tensor_size_hp_to_fp4x2 ,
2626)
27- from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
27+ from torchao .prototype .mx_formats .utils import (
28+ from_blocked ,
29+ hp_data_dims_to_swizzled_scale_dims_nvfp4 ,
30+ to_blocked ,
31+ )
2832from torchao .quantization .quantize_ .common import (
2933 QuantizeTensorKwargs ,
3034)
@@ -170,6 +174,9 @@ def to_nvfp4(
170174 Returns:
171175 NVFP4Tensor: Quantized tensor in NVFP4 format
172176 """
177+ assert len (data_hp .shape ) == 2 , "unsupported"
178+ M , K = data_hp .shape [0 ], data_hp .shape [1 ]
179+
173180 if use_triton_kernel :
174181 assert is_swizzled_scales , "Triton kernel only supports swizzled scales"
175182 assert data_hp .shape [1 ] % 16 == 0 , (
@@ -181,12 +188,19 @@ def to_nvfp4(
181188 data_hp , block_size , per_tensor_scale
182189 )
183190 if is_swizzled_scales :
184- M , K = data_hp .shape [0 ], data_hp .shape [1 ]
185191 scale_shape = (M , K // block_size )
186192 blockwise_scales = to_blocked (
187193 blockwise_scales .view (scale_shape )
188194 ).flatten ()
189195
196+ if is_swizzled_scales :
197+ scale_M , scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4 (M , K )
198+ else :
199+ # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200+ # scale element
201+ scale_M , scale_K = M , K // block_size
202+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
203+
190204 return NVFP4Tensor (
191205 data_lp ,
192206 blockwise_scales ,
@@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor:
239253 is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
240254 if is_transposed :
241255 M , K = self .shape [1 ], self .shape [0 ]
256+ scale_e4m3 = self ._scale_e4m3 .t ()
242257 else :
243258 M , K = self .shape [0 ], self .shape [1 ]
259+ scale_e4m3 = self ._scale_e4m3
244260
245261 if self ._is_swizzled_scales :
246- scale_e4m3 = from_blocked (self ._scale_e4m3 , M , K // self ._block_size )
247- else :
248- scale_e4m3 = self ._scale_e4m3
262+ scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
249263
250264 return (
251265 scale_e4m3 .to (self ._orig_dtype )
@@ -369,6 +383,9 @@ def nvfp4_slice(func, types, args, kwargs):
369383
370384 M , K = x .shape [0 ], x .shape [1 ]
371385
386+ # the scale manipulations below assume a flattened scale
387+ # TODO(future or this PR): update this
388+
372389 if x ._is_swizzled_scales :
373390 scale_rows = M
374391 scale_cols = K // x ._block_size
@@ -407,7 +424,9 @@ def nvfp4_slice(func, types, args, kwargs):
407424 else None
408425 )
409426
410- sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
427+ sliced_scale = aten .slice .Tensor (
428+ x ._scale_e4m3 .flatten (), 0 , start_idx , end_idx , 1
429+ )
411430 sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
412431
413432 elif dim == 1 :
@@ -462,7 +481,7 @@ def nvfp4_slice(func, types, args, kwargs):
462481 row_start = row_block * elements_per_row_block
463482 col_start = row_start + start_col_block * elements_per_block
464483 col_end = row_start + end_col_block * elements_per_block
465- slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
484+ slices_to_extract .append (x ._scale_e4m3 . flatten () [col_start :col_end ])
466485
467486 # Concatenate all the slices
468487 sliced_scale = torch .cat (slices_to_extract , dim = 0 )
@@ -515,6 +534,19 @@ def nvfp4_slice(func, types, args, kwargs):
515534
516535 sliced_scale = sliced_scale .flatten ()
517536
537+ # reshape at the end
538+ sliced_M = sliced_data .shape [0 ]
539+ # multiply by 2 to convert from bytes to num_elements
540+ sliced_K = sliced_data .shape [1 ] * 2
541+ if x ._is_swizzled_scales :
542+ scale_M , scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4 (sliced_M , sliced_K )
543+ else :
544+ # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
545+ # scale element
546+ scale_M = sliced_M
547+ scale_K = sliced_K // x ._block_size
548+ sliced_scale = sliced_scale .view (scale_M , scale_K )
549+
518550 # Create result tensor
519551 result = NVFP4Tensor (
520552 sliced_data ,
@@ -537,7 +569,7 @@ def nvfp4_t(func, types, args, kwargs):
537569 old = args [0 ]
538570 new = NVFP4Tensor (
539571 old .qdata .t (),
540- old ._scale_e4m3 ,
572+ old ._scale_e4m3 . t () ,
541573 old ._block_size ,
542574 old ._orig_dtype ,
543575 old ._per_tensor_scale ,
@@ -576,7 +608,9 @@ def _addmm_nvfp4_dispatch(
576608 The only difference is whether bias is None or not.
577609 """
578610 assert a .qdata .is_contiguous ()
611+ assert a ._scale_e4m3 .is_contiguous ()
579612 assert b .qdata .t ().is_contiguous ()
613+ assert b ._scale_e4m3 .t ().is_contiguous ()
580614 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
581615 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
582616
@@ -591,9 +625,9 @@ def _addmm_nvfp4_dispatch(
591625 a_scale_blocked = to_blocked (a_scale )
592626
593627 if b ._is_swizzled_scales :
594- b_scale_blocked = b ._scale_e4m3 # Already swizzled
628+ b_scale_blocked = b ._scale_e4m3 . t () # Already swizzled
595629 else :
596- b_scale = b ._scale_e4m3 .view (N , K // b ._block_size )
630+ b_scale = b ._scale_e4m3 .t (). view (N , K // b ._block_size )
597631 b_scale_blocked = to_blocked (b_scale )
598632
599633 # Merge double quant scales into 1 scale for Scale_In^D
0 commit comments