4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import math
7
8
import sys
8
9
from dataclasses import dataclass
9
10
from enum import Enum
@@ -112,7 +113,7 @@ def __new__(
112
113
113
114
new_size = tensor_size_fp4x2_to_hp (
114
115
new_size ,
115
- qdata .stride (0 ) > qdata .stride (1 ),
116
+ qdata .stride (- 2 ) > qdata .stride (- 1 ),
116
117
)
117
118
118
119
self = torch .Tensor ._make_wrapper_subclass (
@@ -174,21 +175,21 @@ def to_nvfp4(
174
175
Returns:
175
176
NVFP4Tensor: Quantized tensor in NVFP4 format
176
177
"""
177
- assert len (data_hp .shape ) == 2 , "unsupported"
178
- M , K = data_hp .shape [0 ], data_hp .shape [1 ]
178
+ assert len (data_hp .shape ) in ( 2 , 3 ) , "unsupported"
179
+ leading_dims , M , K = data_hp .shape [: - 2 ], data_hp .shape [- 2 ], data_hp . shape [ - 1 ]
179
180
180
181
if use_triton_kernel :
181
182
assert is_swizzled_scales , "Triton kernel only supports swizzled scales"
182
- assert data_hp . shape [ 1 ] % 16 == 0 , (
183
- f"Triton kernel requires K (dim 1) to be divisible by 16, got { data_hp . shape [ 1 ] } "
183
+ assert K % 16 == 0 , (
184
+ f"Triton kernel requires K (dim - 1) to be divisible by 16, got { K } "
184
185
)
185
186
blockwise_scales , data_lp = triton_quantize_nvfp4 (data_hp , per_tensor_scale )
186
187
else :
187
188
blockwise_scales , data_lp = nvfp4_quantize (
188
189
data_hp , block_size , per_tensor_scale
189
190
)
190
191
if is_swizzled_scales :
191
- scale_shape = (M , K // block_size )
192
+ scale_shape = (math . prod ( leading_dims ) * M , K // block_size )
192
193
blockwise_scales = to_blocked (
193
194
blockwise_scales .view (scale_shape )
194
195
).flatten ()
@@ -199,7 +200,7 @@ def to_nvfp4(
199
200
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200
201
# scale element
201
202
scale_M , scale_K = M , K // block_size
202
- blockwise_scales = blockwise_scales .view (scale_M , scale_K )
203
+ blockwise_scales = blockwise_scales .view (* leading_dims , scale_M , scale_K )
203
204
204
205
return NVFP4Tensor (
205
206
data_lp ,
@@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
225
226
Returns:
226
227
torch.Tensor: Dequantized tensor in the target dtype
227
228
"""
228
- is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
229
+ is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
229
230
if is_transposed :
230
- M , K = self .shape [1 ], self .shape [0 ]
231
+ leading_dims , M , K = self .shape [: - 2 ], self . shape [ - 1 ], self .shape [- 2 ]
231
232
else :
232
- M , K = self .shape [0 ], self .shape [1 ]
233
- data = self .qdata .t ( ) if is_transposed else self .qdata
233
+ leading_dims , M , K = self .shape [: - 2 ], self .shape [- 2 ], self . shape [ - 1 ]
234
+ data = self .qdata .transpose ( - 2 , - 1 ) if is_transposed else self .qdata
234
235
data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
235
236
data_f32 = f4_unpacked_to_f32 (data_unpacked )
236
237
237
- data_f32 = data_f32 .view (M , K // self ._block_size , self ._block_size )
238
- scale_e4m3_reshaped = self .get_hp_scales ().view (M , K // self ._block_size , 1 )
238
+ data_f32 = data_f32 .view (
239
+ * leading_dims , M , K // self ._block_size , self ._block_size
240
+ )
241
+ scale_e4m3_reshaped = self .get_hp_scales ().view (
242
+ * leading_dims , M , K // self ._block_size , 1
243
+ )
239
244
data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
240
- result = data_scaled .view (M , K ).to (target_dtype )
245
+ result = data_scaled .view (* leading_dims , M , K ).to (target_dtype )
241
246
242
247
if is_transposed :
243
- result = result .t ( )
248
+ result = result .transpose ( - 2 , - 1 )
244
249
245
250
return result
246
251
@@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor:
250
255
Returns:
251
256
torch.Tensor: Scales of the NVFP4Tensor
252
257
"""
253
- is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
258
+ is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
254
259
if is_transposed :
255
- M , K = self .shape [1 ], self .shape [0 ]
256
- scale_e4m3 = self ._scale_e4m3 .t ( )
260
+ leading_dims , M , K = self .shape [: - 2 ], self . shape [ - 1 ], self .shape [- 2 ]
261
+ scale_e4m3 = self ._scale_e4m3 .transpose ( - 2 , - 1 )
257
262
else :
258
- M , K = self .shape [0 ], self .shape [1 ]
263
+ leading_dims , M , K = self .shape [: - 2 ], self .shape [- 2 ], self . shape [ - 1 ]
259
264
scale_e4m3 = self ._scale_e4m3
260
265
261
266
if self ._is_swizzled_scales :
262
- scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
267
+ scale_e4m3 = from_blocked (
268
+ scale_e4m3 , math .prod (leading_dims ) * M , K // self ._block_size
269
+ )
263
270
264
271
return (
265
272
scale_e4m3 .to (self ._orig_dtype )
@@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs):
380
387
raise ValueError ("Only support aten.slice with step=1" )
381
388
382
389
assert x .qdata .is_contiguous (), "Only support contiguous data for now"
390
+ assert len (x .shape ) == 2 , (
391
+ f"only rank 2 is supported for slice, got rank { len (x .shape )} "
392
+ )
383
393
384
394
M , K = x .shape [0 ], x .shape [1 ]
385
395
@@ -583,6 +593,28 @@ def nvfp4_t(func, types, args, kwargs):
583
593
return new
584
594
585
595
596
+ @implements ([aten .transpose .int ])
597
+ def nvfp4_transpose (func , types , args , kwargs ):
598
+ old , dim0 , dim1 = args
599
+ assert len (old .shape ) == 3 , f"unsupported rank { len (old .shape )} "
600
+ valid_3d_dims = ((1 , 2 ), (2 , 1 ), (- 1 , - 2 ), (- 2 , - 1 ))
601
+ assert (dim0 , dim1 ) in valid_3d_dims , f"transpose unsupported for { dim0 = } { dim1 = } "
602
+ new_qdata = func (old .qdata , dim0 , dim1 , ** kwargs )
603
+ new_scale = func (old ._scale_e4m3 , dim0 , dim1 , ** kwargs )
604
+ new = NVFP4Tensor (
605
+ new_qdata ,
606
+ new_scale ,
607
+ old ._block_size ,
608
+ old ._orig_dtype ,
609
+ old ._per_tensor_scale ,
610
+ old ._act_per_tensor_scale ,
611
+ old ._is_swizzled_scales ,
612
+ old .use_triton_kernel ,
613
+ old .act_quant_kwargs ,
614
+ )
615
+ return new
616
+
617
+
586
618
@implements ([aten .view .default ])
587
619
def nvfp4_view_op (func , types , args , kwargs ):
588
620
data = args [0 ].qdata
0 commit comments