Skip to content

Commit 90b6c38

Browse files
committed
Add hardware check to fp8 quant
1 parent b714026 commit 90b6c38

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchao/quantization/quant_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from .utils import _get_per_token_block_size
8383

8484
logger = logging.getLogger(__name__)
85+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
8586

8687
__all__ = [
8788
"swap_conv2d_1x1_to_linear",
@@ -939,6 +940,7 @@ def float8_dynamic_activation_float8_weight(
939940
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
940941
941942
"""
943+
assert is_cuda_8_9, "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above"
942944
if mm_config is None:
943945
mm_config = Float8MMConfig(use_fast_accum=True)
944946

@@ -993,6 +995,7 @@ def float8_static_activation_float8_weight(
993995
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
994996
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
995997
"""
998+
assert is_cuda_8_9, "Float8 static activation quantization is only supported on CUDA 8.9 and above"
996999
if mm_config is None:
9971000
mm_config = Float8MMConfig(use_fast_accum=True)
9981001

0 commit comments

Comments
 (0)