Skip to content

Commit 2423c1d

Browse files
committed
MI300 check
Summary: Test Plan: Tested on AMD Instinct MI300X Reviewers: Subscribers: Tasks: Tags:
1 parent 88b6ba1 commit 2423c1d

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

torchao/quantization/quant_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
TORCH_VERSION_AT_LEAST_2_4,
5252
TORCH_VERSION_AT_LEAST_2_5,
5353
TORCH_VERSION_AT_LEAST_2_6,
54+
is_MI300,
5455
)
5556

5657
from .autoquant import AutoQuantizableLinearWeight, autoquant
@@ -941,7 +942,7 @@ def float8_dynamic_activation_float8_weight(
941942
942943
"""
943944
assert (
944-
is_cuda_8_9
945+
is_cuda_8_9 or is_MI300
945946
), "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above"
946947
if mm_config is None:
947948
mm_config = Float8MMConfig(use_fast_accum=True)
@@ -998,7 +999,7 @@ def float8_static_activation_float8_weight(
998999
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
9991000
"""
10001001
assert (
1001-
is_cuda_8_9
1002+
is_cuda_8_9 or is_MI300
10021003
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
10031004
if mm_config is None:
10041005
mm_config = Float8MMConfig(use_fast_accum=True)

torchao/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"TORCH_VERSION_AFTER_2_3",
3232
"TORCH_VERSION_AFTER_2_4",
3333
"TORCH_VERSION_AFTER_2_5",
34+
"is_MI300",
3435
]
3536

3637

@@ -586,6 +587,16 @@ def _torch_version_at_least(min_version):
586587
return is_fbcode() or version("torch") >= min_version
587588

588589

590+
def is_MI300():
591+
if torch.cuda.is_available() and torch.version.hip:
592+
mxArchName = ["gfx940", "gfx941", "gfx942"]
593+
archName = torch.cuda.get_device_properties().gcnArchName
594+
for arch in mxArchName:
595+
if arch in archName:
596+
return True
597+
return False
598+
599+
589600
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
590601
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
591602
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")

0 commit comments

Comments
 (0)