1+ import itertools
2+
3+ import torchao
4+
15import torch
26from torch .testing ._internal .common_utils import (
37 TestCase ,
610 run_tests ,
711)
812from torch .testing ._internal .optests import opcheck
9- from torchao .utils import is_fbcode
13+ from torchao .utils import is_fbcode , TORCH_VERSION_AFTER_2_5
1014from torchao .prototype .quant_llm import from_scaled_tc_fpx
1115import pytest
1216
1822except RuntimeError :
1923 pytest .skip ("torchao.ops not available" )
2024
25+ from torchao .quantization .utils import (
26+ get_groupwise_affine_qparams ,
27+ groupwise_affine_dequantize_tensor_from_qparams ,
28+ groupwise_affine_quantize_tensor_from_qparams ,
29+ pack_tinygemm_scales_and_zeros ,
30+ unpack_tinygemm_scales_and_zeros ,
31+ )
32+
2133
2234class TestOps (TestCase ):
2335 def _create_fpx_inputs (self , ebits : int , mbits : int , BS : int , OC : int , IC : int , device ):
@@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
6173 relative_error = error / gt
6274 assert relative_error < 1e-3
6375
64-
6576instantiate_parametrized_tests (TestOps )
6677
6778
79+ ## Tests for `tensor_core_layout`
80+ kTileSizeN = 8
81+ kTileSizeK = 16
82+
83+ SHAPES = [
84+ (4096 , 4096 ),
85+ # Llama 2 GEMM shapes
86+ (4096 , 11008 ),
87+ (11008 , 4096 ),
88+ # Llama 3 GEMM shapes
89+ (4096 , 14336 ),
90+ (14336 , 4096 ),
91+ ]
92+ INNERKTILES = [2 , 4 , 8 ]
93+ QGROUP_SIZES = [32 , 64 , 128 , 256 ]
94+ TEST_CONFIGS_UNPACK = list (itertools .product (SHAPES , INNERKTILES ))
95+ TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
96+
97+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
98+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
99+ def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
100+ N , K = shape
101+ assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
102+
103+ t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
104+ packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
105+ unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed_w , inner_k_tiles )
106+ assert torch .equal (t , unpacked )
107+
108+ # TODO: Fix "test_aot_dispatch_dynamic" test failure
109+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
110+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
111+ def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
112+ test_utils = [
113+ "test_schema" ,
114+ "test_autograd_registration" ,
115+ "test_faketensor" ,
116+ ]
117+
118+ # TODO: Figure out why test fails unless torch >= 2.5
119+ if TORCH_VERSION_AFTER_2_5 :
120+ test_utils .append ("test_aot_dispatch_dynamic" )
121+
122+ t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
123+ packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
124+
125+ opcheck (
126+ torch .ops .torchao .unpack_tensor_core_tiled_layout ,
127+ (packed_w , inner_k_tiles ),
128+ test_utils = test_utils ,
129+ )
130+
131+ def dequant_ref (q , scales , zeros , group_size , nbits = 4 , dtype = torch .bfloat16 ):
132+ n , k = q .shape
133+ assert q .dtype == torch .int
134+
135+ n_groups = k // group_size
136+ assert scales .shape [0 ] == n and scales .shape [1 ] == n_groups
137+ assert scales .shape == zeros .shape
138+
139+ midpoint = 2 ** (nbits - 1 )
140+
141+ #Convert fron u4 -> s4 and upcast to bfloat16
142+ q = q .sub (midpoint ).to (dtype )
143+
144+ # Dequantize
145+ q = q .reshape (- 1 , group_size )
146+ dq = q * scales .reshape (- 1 , 1 ) + zeros .reshape (- 1 , 1 )
147+
148+ return dq .reshape (n , k )
149+
150+
151+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
152+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
153+ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant (shape , inner_k_tiles , group_size ):
154+ n , k = shape
155+ dtype = torch .bfloat16
156+
157+ device = "cuda"
158+
159+ t = torch .randn (n , k , dtype = dtype , device = device )
160+ scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
161+
162+ # Quantize
163+ q = groupwise_affine_quantize_tensor_from_qparams (
164+ t , scales , zeros , n_bit = 4 , groupsize = group_size
165+ )
166+
167+ # Pack to tensor core layout
168+ packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
169+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
170+ q_groups = k // group_size
171+ assert scales_and_zeros .shape == torch .Size ([q_groups , n , 2 ])
172+
173+ # Dequantize 'ao' ref
174+ dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
175+ q , scales , zeros , n_bit = 4 , groupsize = group_size
176+ )
177+
178+ # Dequantize by passing in an identity matrix as the activation
179+ a_eye = torch .eye (k , device = device , dtype = dtype )
180+ dq_id = torch .ops .aten ._weight_int4pack_mm (
181+ a_eye ,
182+ packed ,
183+ group_size ,
184+ scales_and_zeros ,
185+ ).t ()
186+
187+ # Actual operation to test
188+ dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
189+
190+ # Compare results
191+ diff_ao_id = (dq_id - dq_ao ).abs ().max ()
192+ diff_op_id = (dq_op - dq_id ).abs ().max ()
193+ diff_op_ao = (dq_op - dq_ao ).abs ().max ()
194+
195+ # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196+ # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197+ # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198+ # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199+
200+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
201+ assert diff_op_id == 0
202+
203+ # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204+ assert diff_op_ao == diff_ao_id
205+
206+ assert diff_op_ao < 1e-1
207+
208+ # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
209+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
210+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
211+ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant (shape , inner_k_tiles , group_size ):
212+ n , k = shape
213+ dtype = torch .bfloat16
214+ device = "cuda"
215+
216+ # Quantize and pack
217+ t = torch .randn (n , k , dtype = dtype , device = device )
218+ scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
219+ q = groupwise_affine_quantize_tensor_from_qparams (
220+ t , scales , zeros , n_bit = 4 , groupsize = group_size
221+ )
222+
223+ packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
224+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
225+
226+ # Unpack and dequantize
227+ unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
228+ dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
229+ unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
230+ )
231+
232+ # Dequantize by passing in an identity matrix as the activation
233+ a_eye = torch .eye (k , device = device , dtype = dtype )
234+ dq_id = torch .ops .aten ._weight_int4pack_mm (
235+ a_eye ,
236+ packed ,
237+ group_size ,
238+ scales_and_zeros ,
239+ ).t ()
240+
241+ # Actual operation to test
242+ dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
243+
244+ # Compare results
245+ diff_ao_id = (dq_id - dq_ao ).abs ().max ()
246+ diff_op_id = (dq_op - dq_id ).abs ().max ()
247+ diff_op_ao = (dq_op - dq_ao ).abs ().max ()
248+
249+ # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250+ # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251+ # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252+ # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253+
254+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
255+ assert diff_op_id == 0
256+
257+ # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258+ assert diff_op_ao == diff_ao_id
259+
260+ assert diff_op_ao < 1e-1
261+
262+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
263+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
264+ def test_dequantize_tensor_core_tiled_layout_op (shape , inner_k_tiles , group_size ):
265+ n , k = shape
266+ device = "cuda"
267+
268+ q = torch .randint (0 , 16 , shape , dtype = torch .int , device = device )
269+ packed_w = torch ._convert_weight_to_int4pack (q , inner_k_tiles )
270+ q_groups = k // group_size
271+ scales = torch .randn (n , q_groups , dtype = torch .bfloat16 , device = device )
272+ zeros = torch .randn_like (scales )
273+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
274+
275+ test_utils = [
276+ "test_schema" ,
277+ "test_autograd_registration" ,
278+ "test_faketensor" ,
279+ ]
280+ # TODO: Figure out why test fails unless torch >= 2.5
281+ if TORCH_VERSION_AFTER_2_5 :
282+ test_utils .append ("test_aot_dispatch_dynamic" )
283+ opcheck (
284+ torch .ops .torchao .dequantize_tensor_core_tiled_layout ,
285+ (packed_w , scales_and_zeros , group_size , inner_k_tiles ),
286+ test_utils = test_utils ,
287+ )
288+
68289if __name__ == "__main__" :
69- run_tests ()
290+ run_tests ()
0 commit comments