20
20
(8 , 513 , 64 ), # Non-divisible (native only)
21
21
])
22
22
@pytest .mark .parametrize ("seed" , [42 ])
23
- @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
24
23
@torch .inference_mode ()
25
24
def test_quantfp8_group_functionality (batch_size : int , hidden_dim : int ,
26
- group_size : int , seed : int ,
27
- use_ue8m0 : bool ) -> None :
25
+ group_size : int , seed : int ) -> None :
28
26
"""Test QuantFP8 group quantization with various configurations.
29
27
30
28
Tests both CUDA and native implementations, column-major scales,
@@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
40
38
group_shape = GroupShape (1 , group_size )
41
39
quant_op = QuantFP8 (static = False ,
42
40
group_shape = group_shape ,
43
- column_major_scales = False ,
44
- use_ue8m0 = use_ue8m0 )
41
+ column_major_scales = False )
45
42
46
43
# 1. Test native implementation (always available)
47
44
x_quant_native , scales_native = quant_op .forward_native (x .clone ())
@@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
51
48
# 2. Test column-major scales configuration
52
49
quant_op_col = QuantFP8 (static = False ,
53
50
group_shape = group_shape ,
54
- column_major_scales = True ,
55
- use_ue8m0 = use_ue8m0 )
51
+ column_major_scales = True )
56
52
_ , scales_col = quant_op_col .forward_native (x .clone ())
57
- assert scales_col .shape == (batch_size , expected_num_groups )
58
- assert scales_col .stride (0 ) == 1
59
- assert scales_col .stride (1 ) == batch_size
60
-
61
- # Test column-major scales consistency
62
- assert torch .allclose (scales_col , scales_native , rtol = 1e-9 , atol = 1e-8 )
53
+ assert scales_col .shape == (expected_num_groups , batch_size )
63
54
64
55
# 3. Test CUDA implementation (only for divisible dimensions)
65
56
if is_divisible :
@@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
77
68
78
69
79
70
@pytest .mark .parametrize ("seed" , [42 ])
80
- @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
81
71
@torch .inference_mode ()
82
- def test_quantfp8_group_multidimensional (seed : int , use_ue8m0 : bool ) -> None :
72
+ def test_quantfp8_group_multidimensional (seed : int ) -> None :
83
73
current_platform .seed_everything (seed )
84
74
85
75
group_size = 64
@@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
92
82
group_shape = GroupShape (1 , group_size )
93
83
quant_op = QuantFP8 (static = False ,
94
84
group_shape = group_shape ,
95
- column_major_scales = False ,
96
- use_ue8m0 = use_ue8m0 )
85
+ column_major_scales = False )
97
86
98
87
x_quant , scales = quant_op .forward_native (x_3d .clone ())
99
88
assert x_quant .shape == x_3d .shape
@@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
102
91
# Test column_major_scales with multi-dim
103
92
quant_op_col = QuantFP8 (static = False ,
104
93
group_shape = group_shape ,
105
- column_major_scales = True ,
106
- use_ue8m0 = use_ue8m0 )
94
+ column_major_scales = True )
107
95
_ , scales_col = quant_op_col .forward_native (x_3d .clone ())
108
96
assert scales_col .shape == (batch1 , hidden_dim // group_size , batch2 )
109
97
0 commit comments