24
24
int4_weight_only ,
25
25
int8_weight_only ,
26
26
int8_dynamic_activation_int8_weight ,
27
+ int8_dynamic_activation_int4_weight ,
27
28
quantize_ ,
28
29
_replace_with_custom_fn_if_matches_filter ,
29
30
)
@@ -137,6 +138,12 @@ def _int4wo_api(mod):
137
138
else :
138
139
change_linear_weights_to_int4_woqtensors (mod )
139
140
141
+ def _int8da_int4w_api (mod ):
142
+ quantize_ (mod , int8_dynamic_activation_int4_weight (), set_inductor_config = False )
143
+ if not TORCH_VERSION_AT_LEAST_2_5 :
144
+ unwrap_tensor_subclass (mod )
145
+
146
+
140
147
# TODO: use this to reduce the number of tests
141
148
TENSOR_SUBCLASS_APIS = [
142
149
_int8wo_api ,
@@ -781,7 +788,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
781
788
self ._test_lin_weight_subclass_impl (
782
789
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight .from_float , device , 25 , test_dtype = dtype
783
790
)
784
-
791
+
785
792
@parameterized .expand (COMMON_DEVICE_DTYPE )
786
793
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant+aqt needs newer pytorch" )
787
794
@unittest .skipIf (not is_H100 , "Need H100 to run" )
@@ -973,11 +980,11 @@ def test_weight_only_groupwise_embedding_quant(self):
973
980
group_size = 64
974
981
m = nn .Embedding (4096 , 128 )
975
982
input = torch .randint (0 , 4096 , (1 , 6 ))
976
-
983
+
977
984
quantize_ (m , int8_weight_only (group_size = group_size ), filter_fn = lambda x , * args : isinstance (x , nn .Embedding ))
978
985
y_q = m (input )
979
986
y_ref = m .weight .dequantize ()[input ]
980
-
987
+
981
988
sqnr = compute_error (y_ref , y_q )
982
989
983
990
self .assertGreater (sqnr , 45.0 )
@@ -1486,22 +1493,22 @@ def forward(self, x):
1486
1493
1487
1494
1488
1495
1496
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "requires 2.5+." )
1497
+ @unittest .skipIf (not torch .cuda .is_available (), "requires cuda" )
1498
+ @unittest .skip ("AOTI tests are failing right now" )
1489
1499
class TestAOTI (unittest .TestCase ):
1490
1500
@parameterized .expand (
1491
1501
list (itertools .product (TENSOR_SUBCLASS_APIS , COMMON_DEVICES , COMMON_DTYPES )),
1492
1502
)
1493
- @run_supported_device_dtype
1494
1503
def test_aoti (self , api , test_device , test_dtype ):
1495
- if not TORCH_VERSION_AT_LEAST_2_4 :
1496
- self .skipTest ("aoti compatibility requires 2.4+." )
1497
-
1498
- print (f"TestAOTI: { api } , { test_device } , { test_dtype } " )
1499
- logger .info (f"TestAOTI: { api } , { test_device } , { test_dtype } " )
1500
1504
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda" :
1501
1505
self .skipTest (f"{ api } in { test_device } is not support for aoti compilation yet" )
1502
1506
1503
- if test_dtype != torch .bfloat16 :
1504
- self .skipTest (f"{ api } in { test_dtype } is not support for aoti compilation yet" )
1507
+ if test_device == "cuda" and torch .cuda .is_available () and test_dtype == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
1508
+ self .skipTest ("Need CUDA and SM80+ available." )
1509
+
1510
+
1511
+ logger .info (f"TestAOTI: { api } , { test_device } , { test_dtype } " )
1505
1512
1506
1513
m , k , n = 32 , 64 , 32
1507
1514
@@ -1525,29 +1532,30 @@ def forward(self, x):
1525
1532
ref_f = model (x )
1526
1533
1527
1534
api (model )
1535
+ unwrap_tensor_subclass (model )
1528
1536
1529
1537
# running model
1530
1538
model (x )
1531
1539
1532
1540
# make sure it compiles
1541
+ torch ._inductor .config .mixed_mm_choice = "triton"
1542
+
1533
1543
example_inputs = (x ,)
1534
- torch ._export . aot_compile ( model , example_inputs )
1544
+ torch ._inductor . aoti_compile_and_package ( torch . export . export ( model , example_inputs ) , example_inputs )
1535
1545
1536
1546
1547
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "requires 2.5+." )
1548
+ @unittest .skipIf (not torch .cuda .is_available (), "requires cuda" )
1537
1549
class TestExport (unittest .TestCase ):
1538
1550
@parameterized .expand (
1539
- list (itertools .product (TENSOR_SUBCLASS_APIS , COMMON_DEVICES , COMMON_DTYPES )),
1551
+ list (itertools .product (TENSOR_SUBCLASS_APIS + [ _int8da_int4w_api ] , COMMON_DEVICES , COMMON_DTYPES )),
1540
1552
)
1541
- @run_supported_device_dtype
1542
1553
def test_export (self , api , test_device , test_dtype ):
1543
- if not TORCH_VERSION_AT_LEAST_2_4 :
1544
- self .skipTest ("aoti compatibility requires 2.4+ ." )
1554
+ if test_device == "cuda" and torch . cuda . is_available () and test_dtype == torch . bfloat16 and torch . cuda . get_device_capability () < ( 8 , 0 ) :
1555
+ self .skipTest ("Need CUDA and SM80+ available ." )
1545
1556
1546
1557
logger .info (f"TestExport: { api } , { test_device } , { test_dtype } " )
1547
1558
1548
- if test_dtype != torch .bfloat16 :
1549
- self .skipTest (f"{ api } in { test_dtype } is not support for aoti compilation yet" )
1550
-
1551
1559
m , k , n = 32 , 64 , 32
1552
1560
1553
1561
class test_model (nn .Module ):
@@ -1570,6 +1578,7 @@ def forward(self, x):
1570
1578
ref_f = model (x )
1571
1579
1572
1580
api (model )
1581
+ unwrap_tensor_subclass (model )
1573
1582
1574
1583
# running model
1575
1584
ref = model (x )
@@ -1585,10 +1594,11 @@ def forward(self, x):
1585
1594
model = torch ._export .capture_pre_autograd_graph (model , example_inputs )
1586
1595
after_export = model (x )
1587
1596
self .assertTrue (torch .equal (after_export , ref ))
1588
- if api is _int8da_int8w_api :
1597
+ if api is _int8da_int4w_api :
1589
1598
targets = [n .target for n in model .graph .nodes ]
1590
1599
self .assertTrue (torch .ops .quant .choose_qparams_affine .default in targets )
1591
1600
self .assertTrue (torch .ops .quant .quantize_affine .default in targets )
1601
+ self .assertFalse (torch .ops .aten .narrow .default in targets )
1592
1602
1593
1603
1594
1604
0 commit comments