2020    sparse_cutlass_supported )
2121from  vllm .platforms  import  current_platform 
2222
23+ # AITER only supports per-channel-per-channel INT8 gemm 
24+ # and per-tensor-per-tensor INT8 GEMM. 
25+ # It does not support mix precision MM and mix quantization scheme. 
26+ ROCM_AITER_SUPPORTED_INT8_MODEL  =  [
27+     "neuralmagic/Llama-3.2-1B-quantized.w8a8" ,
28+     "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" 
29+ ]
30+ 
31+ # TritonScaledMMLinearKernel only supports symmetric quantization. 
32+ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL  =  [
33+     "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" ,
34+     "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor" ,
35+     "neuralmagic/Llama-3.2-1B-quantized.w8a8" ,
36+     "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" ,
37+     "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" ,
38+ ]
39+ 
2340
2441@pytest .fixture (scope = "function" , autouse = True ) 
2542def  use_v0_only (monkeypatch ):
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
5774) 
5875def  test_compressed_tensors_w8a8_static_setup (vllm_runner , model_args ):
5976    model_path , strategy , quant_type , shape_0 , is_symmetric  =  model_args 
77+ 
78+     if  current_platform .is_rocm (
79+     ) and  model_path  not  in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
80+         pytest .skip (f"Skip model { model_path }  )
81+ 
6082    with  vllm_runner (model_path , enforce_eager = True ) as  llm :
6183
6284        def  check_model (model ):
@@ -123,14 +145,30 @@ def zp_valid(zp: Optional[torch.Tensor]):
123145) 
124146@pytest .mark .parametrize ("max_tokens" , [32 ]) 
125147@pytest .mark .parametrize ("num_logprobs" , [10 ]) 
148+ @pytest .mark .parametrize ( 
149+     "use_aiter" , [True , False ] if  current_platform .is_rocm () else  [False ]) 
126150def  test_compressed_tensors_w8a8_logprobs (
127151    hf_runner ,
128152    vllm_runner ,
129153    example_prompts ,
130154    model_path ,
131155    max_tokens ,
132156    num_logprobs ,
157+     use_aiter ,
158+     monkeypatch ,
133159):
160+ 
161+     if  current_platform .is_rocm (
162+     ) and  model_path  not  in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
163+         pytest .skip (f"Skip model { model_path }  )
164+ 
165+     if  use_aiter :
166+         if  model_path  not  in ROCM_AITER_SUPPORTED_INT8_MODEL :
167+             pytest .skip (
168+                 f"Skip model { model_path }  )
169+         # this will enable VLLM_ROCM_USE_AITER_LINEAR 
170+         monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
171+ 
134172    dtype  =  "bfloat16" 
135173
136174    # skip language translation prompt for the static per tensor asym model 
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
154192        name_1 = "vllm" ,
155193    )
156194
195+     if  current_platform .is_rocm ():
196+         torch .cuda .synchronize ()
197+ 
157198
158199def  test_compressed_tensors_no_enforce_eager (vllm_runner ):
159200    model_path  =  "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" 
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
177218        ), 
178219    ], 
179220) 
180- def  test_compressed_tensors_w8a8_dynamic_per_token (vllm_runner , model_args ):
221+ @pytest .mark .parametrize ( 
222+     "use_aiter" , [True , False ] if  current_platform .is_rocm () else  [False ]) 
223+ def  test_compressed_tensors_w8a8_dynamic_per_token (
224+     vllm_runner ,
225+     model_args ,
226+     use_aiter ,
227+     monkeypatch ,
228+ ):
181229    model_path , strategy  =  model_args 
230+ 
231+     if  current_platform .is_rocm (
232+     ) and  model_path  not  in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
233+         pytest .skip (f"Skip model { model_path }  )
234+ 
235+     if  use_aiter :
236+         if  model_path  not  in ROCM_AITER_SUPPORTED_INT8_MODEL :
237+             pytest .skip (
238+                 f"Skip model { model_path }  )
239+         # this will enable VLLM_ROCM_USE_AITER_LINEAR 
240+         monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
241+ 
182242    with  vllm_runner (model_path , dtype = torch .float16 ) as  llm :
183243
184244        def  check_model (model ):
@@ -207,6 +267,8 @@ def check_model(model):
207267        ("nm-testing/tinyllama-oneshot-w8a16-per-channel" , "channel" , None , 4 ), 
208268    ], 
209269) 
270+ @pytest .mark .skipif (not  current_platform .is_cuda (), 
271+                     reason = "The tests are skipped on non-CUDA platform." ) 
210272def  test_compressed_tensors_wNa16 (vllm_runner , wNa16_args ):
211273    model , strategy , group , pack_factor  =  wNa16_args 
212274    with  vllm_runner (model ) as  llm :
@@ -231,6 +293,8 @@ def check_model(model):
231293        assert  output 
232294
233295
296+ @pytest .mark .skipif (not  current_platform .is_cuda (), 
297+                     reason = "This test is skipped on non-CUDA platform." ) 
234298def  test_compressed_tensors_w4a16_marlin24 (vllm_runner ):
235299    model_path  =  "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" 
236300    with  vllm_runner (model_path ) as  llm :
@@ -271,7 +335,7 @@ def check_model(model):
271335
272336            if  isinstance (qkv_proj .scheme , CompressedTensorsW8A8Fp8 ):
273337                assert  len (qkv_proj .input_scale .shape ) ==  0 
274-                 assert  qkv_proj .weight .dtype  is  torch . float8_e4m3fn 
338+                 assert  qkv_proj .weight .dtype  is  current_platform . fp8_dtype () 
275339                assert  qkv_proj .weight_scale .dtype  is  torch .float32 
276340                assert  len (qkv_proj .weight_scale .shape ) ==  0 
277341
@@ -281,6 +345,8 @@ def check_model(model):
281345        assert  output 
282346
283347
348+ @pytest .mark .skipif (not  current_platform .is_cuda (), 
349+                     reason = "This test is skipped on non-CUDA platform." ) 
284350def  test_compressed_tensors_kv_cache (vllm_runner ):
285351    model_path  =  "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" 
286352    with  vllm_runner (model_path , kv_cache_dtype = "fp8" ) as  llm :
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
309375
310376
311377@pytest .mark .skipif ( 
312-     not  current_platform .has_device_capability (90 ), 
378+     not  current_platform .is_cuda () 
379+     or  not  current_platform .has_device_capability (90 ), 
313380    reason = "Sparse FP8 is not yet supported on this GPU type." , 
314381) 
315382@pytest .mark .parametrize ( 
@@ -356,7 +423,8 @@ def check_model(model):
356423
357424
358425@pytest .mark .skipif ( 
359-     not  current_platform .has_device_capability (90 ), 
426+     not  current_platform .is_cuda () 
427+     or  not  current_platform .has_device_capability (90 ), 
360428    reason = "Sparse FP8 is not yet supported on this GPU type." , 
361429) 
362430@pytest .mark .parametrize ( 
0 commit comments