Skip to content

Commit 1e2a716

Browse files
committed
Test cases
1 parent 0ba6a2c commit 1e2a716

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

test/integration/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
AQInt8WeightOnlyQuantizedLinearWeight2,
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
75-
75+
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
)
7777
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7878
import os
@@ -98,6 +98,7 @@
9898
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
9999

100100
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
101+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
101102

102103
def _int8wo_api(mod):
103104
if TORCH_VERSION_AT_LEAST_2_4:
@@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
744745
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
745746
)
746747

748+
@parameterized.expand(COMMON_DEVICE_DTYPE)
749+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
750+
@unittest.skipIf(not is_H100, "Need H100 to run")
751+
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
752+
self._test_lin_weight_subclass_impl(
753+
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
754+
)
755+
747756
@parameterized.expand(COMMON_DEVICE_DTYPE)
748757
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
749758
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/quantization/autoquant.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def from_float(cls, weight):
501501
# AQInt8WeightOnlyQuantizedLinearWeight3,
502502
# TODO this gets picked in places where it makes perf worse, why?
503503
AQInt8DynamicallyQuantizedLinearWeight,
504-
AQFloat8WeightOnlyQuantizedLinearWeight,
505504
]
506505

507506
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
@@ -510,6 +509,11 @@ def from_float(cls, weight):
510509
AQInt4G64WeightOnlyQuantizedLinearWeight
511510
]
512511

512+
OTHER_AUTOQUANT_CLASS_LIST = [
513+
AQFloat8WeightOnlyQuantizedLinearWeight,
514+
]
515+
516+
513517
def _change_linears_to_autoquantizable(model, **kwargs):
514518
"""
515519
Converts all linear weight tensors to the

0 commit comments

Comments
 (0)