Skip to content

Commit 1260180

Browse files
authored
Revert "[Performance] Move apply_w8a8_block_fp8_linear to an op class… (#25607)
Signed-off-by: Tyler Michael Smith <[email protected]>
1 parent af4ee63 commit 1260180

File tree

14 files changed

+205
-346
lines changed

14 files changed

+205
-346
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm import _custom_ops as ops
1919
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
20-
w8a8_triton_block_scaled_mm,
20+
w8a8_block_fp8_matmul,
2121
)
2222
from vllm.utils import FlexibleArgumentParser, cdiv
2323

@@ -158,7 +158,7 @@ def bench_fp8(
158158
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
159159
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
160160
),
161-
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
161+
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
162162
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
163163
),
164164
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1111
per_token_group_quant_fp8,
12-
w8a8_triton_block_scaled_mm,
12+
w8a8_block_fp8_matmul,
1313
)
1414
from vllm.triton_utils import triton
1515
from vllm.utils.deep_gemm import (
@@ -63,7 +63,7 @@ def deepgemm_gemm():
6363

6464
# === vLLM Triton Implementation ===
6565
def vllm_triton_gemm():
66-
return w8a8_triton_block_scaled_mm(A_vllm,
66+
return w8a8_block_fp8_matmul(A_vllm,
6767
B_vllm,
6868
A_scale_vllm,
6969
B_scale_vllm,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
native_w8a8_block_matmul)
1212
from vllm.config import VllmConfig
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
14+
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
1515
from vllm.platforms import current_platform
1616
from vllm.utils import has_deep_gemm
1717
from vllm.utils.deep_gemm import (fp8_gemm_nt,
@@ -91,8 +91,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
9191

9292
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
9393
out_dtype)
94-
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
95-
out_dtype)
94+
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
9695

9796
rel_diff = (torch.mean(
9897
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

tests/kernels/quantization/test_fp8_quant_group.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
(8, 513, 64), # Non-divisible (native only)
2121
])
2222
@pytest.mark.parametrize("seed", [42])
23-
@pytest.mark.parametrize("use_ue8m0", [True, False])
2423
@torch.inference_mode()
2524
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
26-
group_size: int, seed: int,
27-
use_ue8m0: bool) -> None:
25+
group_size: int, seed: int) -> None:
2826
"""Test QuantFP8 group quantization with various configurations.
2927
3028
Tests both CUDA and native implementations, column-major scales,
@@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
4038
group_shape = GroupShape(1, group_size)
4139
quant_op = QuantFP8(static=False,
4240
group_shape=group_shape,
43-
column_major_scales=False,
44-
use_ue8m0=use_ue8m0)
41+
column_major_scales=False)
4542

4643
# 1. Test native implementation (always available)
4744
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
5148
# 2. Test column-major scales configuration
5249
quant_op_col = QuantFP8(static=False,
5350
group_shape=group_shape,
54-
column_major_scales=True,
55-
use_ue8m0=use_ue8m0)
51+
column_major_scales=True)
5652
_, scales_col = quant_op_col.forward_native(x.clone())
57-
assert scales_col.shape == (batch_size, expected_num_groups)
58-
assert scales_col.stride(0) == 1
59-
assert scales_col.stride(1) == batch_size
60-
61-
# Test column-major scales consistency
62-
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
53+
assert scales_col.shape == (expected_num_groups, batch_size)
6354

6455
# 3. Test CUDA implementation (only for divisible dimensions)
6556
if is_divisible:
@@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
7768

7869

7970
@pytest.mark.parametrize("seed", [42])
80-
@pytest.mark.parametrize("use_ue8m0", [True, False])
8171
@torch.inference_mode()
82-
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
72+
def test_quantfp8_group_multidimensional(seed: int) -> None:
8373
current_platform.seed_everything(seed)
8474

8575
group_size = 64
@@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
9282
group_shape = GroupShape(1, group_size)
9383
quant_op = QuantFP8(static=False,
9484
group_shape=group_shape,
95-
column_major_scales=False,
96-
use_ue8m0=use_ue8m0)
85+
column_major_scales=False)
9786

9887
x_quant, scales = quant_op.forward_native(x_3d.clone())
9988
assert x_quant.shape == x_3d.shape
@@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
10291
# Test column_major_scales with multi-dim
10392
quant_op_col = QuantFP8(static=False,
10493
group_shape=group_shape,
105-
column_major_scales=True,
106-
use_ue8m0=use_ue8m0)
94+
column_major_scales=True)
10795
_, scales_col = quant_op_col.forward_native(x_3d.clone())
10896
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
10997

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from vllm.model_executor.layers.layernorm import (RMSNorm,
1818
dispatch_rocm_rmsnorm_func,
1919
fused_add_rms_norm, rms_norm)
20+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21+
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
2022
from vllm.platforms import current_platform
2123

2224
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str):
109111
RMSNorm(1024).enabled()
110112

111113

114+
@pytest.mark.skipif(
115+
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
116+
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
117+
@pytest.mark.parametrize("use_cutlass", [True, False])
118+
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
119+
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
120+
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
121+
use_rocm_aiter_gemm_w8a8_blockscale: str,
122+
monkeypatch):
123+
124+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
125+
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
126+
use_rocm_aiter_gemm_w8a8_blockscale)
127+
128+
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
129+
int(use_rocm_aiter_gemm_w8a8_blockscale)))
130+
block_scale_func = dispatch_w8a8_blockscale_func(
131+
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
132+
if use_cutlass:
133+
assert block_scale_func == cutlass_scaled_mm
134+
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
135+
use_rocm_aiter_gemm_w8a8_blockscale):
136+
assert block_scale_func == (
137+
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
138+
else:
139+
assert block_scale_func == w8a8_block_fp8_matmul
140+
141+
112142
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
113143
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
114144
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

tests/quantization/test_compressed_tensors.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
1919
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
2020
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
21-
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
22-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23-
W8A8BlockFp8LinearOp)
2421
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2522
cutlass_fp4_supported)
2623
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
745742
perplexity = llm.generate_prompt_perplexity([prompt])[0]
746743
print(perplexity)
747744
assert perplexity <= exp_perplexity
748-
749-
750-
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
751-
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
752-
with vllm_runner(model_path) as llm:
753-
754-
fp8_dtype = current_platform.fp8_dtype()
755-
756-
def check_model(model):
757-
layer = model.model.layers[0]
758-
759-
qkv_proj = layer.self_attn.qkv_proj
760-
assert isinstance(qkv_proj.quant_method,
761-
CompressedTensorsLinearMethod)
762-
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
763-
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
764-
W8A8BlockFp8LinearOp)
765-
766-
assert qkv_proj.weight.dtype is fp8_dtype
767-
assert qkv_proj.weight_scale.dtype is torch.float32
768-
assert len(qkv_proj.weight.shape) == 2
769-
assert len(qkv_proj.weight_scale.shape) == 2
770-
771-
input_quant_op = \
772-
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
773-
assert isinstance(input_quant_op, QuantFP8)
774-
assert input_quant_op._forward_method == input_quant_op.forward_cuda
775-
776-
llm.apply_model(check_model)
777-
778-
output = llm.generate_greedy("Hello my name is", max_tokens=20)
779-
assert output

vllm/config/__init__.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -545,23 +545,6 @@ def __post_init__(self):
545545
# local attention.
546546
self.scheduler_config.disable_hybrid_kv_cache_manager = True
547547

548-
def has_blocked_weights():
549-
if self.quant_config is not None:
550-
if hasattr(self.quant_config, "weight_block_size"):
551-
return self.quant_config.weight_block_size is not None
552-
elif hasattr(self.quant_config, "has_blocked_weights"):
553-
return self.quant_config.has_blocked_weights()
554-
return False
555-
556-
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
557-
# On H100 the CUDA kernel is faster than
558-
# native implementation
559-
# https://github.com/vllm-project/vllm/issues/25094
560-
if has_blocked_weights():
561-
custom_ops = self.compilation_config.custom_ops
562-
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
563-
custom_ops.append("+quant_fp8")
564-
565548
def update_sizes_for_sequence_parallelism(self,
566549
possible_sizes: list) -> list:
567550
# remove the sizes that not multiple of tp_size when

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,6 @@ def get_cache_scale(self, name: str) -> Optional[str]:
644644
# If no matches, return None
645645
return None
646646

647-
def has_blocked_weights(self) -> bool:
648-
for scheme in self.target_scheme_map.values():
649-
weight_quant = scheme.get("weights")
650-
if (weight_quant is not None
651-
and weight_quant.strategy == QuantizationStrategy.BLOCK):
652-
return True
653-
return False
654-
655647
@staticmethod
656648
def supports_cutlass_24(
657649
weight_quant: Optional[QuantizationArgs],

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1212
CompressedTensorsScheme)
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
14+
apply_fp8_block_linear, check_aiter_fp8_linear_support,
1515
create_fp8_input_scale, create_fp8_scale_parameter,
1616
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
1717
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@@ -41,30 +41,16 @@ def __init__(self, weight_quant: QuantizationArgs,
4141
self.strategy = weight_quant.strategy
4242
self.out_dtype = torch.get_default_dtype()
4343
self.is_static_input_scheme = is_static_input_scheme
44+
self.act_q_group_shape = GroupShape.PER_TENSOR \
45+
if is_static_input_scheme else GroupShape.PER_TOKEN
46+
self.fp8_linear = Fp8LinearOp(
47+
act_quant_static=self.is_static_input_scheme,
48+
act_quant_group_shape=self.act_q_group_shape)
4449

4550
self.weight_block_size = self.weight_quant.block_structure
46-
if self.weight_block_size is not None:
47-
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
48-
else:
49-
self.act_q_group_shape = GroupShape.PER_TENSOR \
50-
if is_static_input_scheme else GroupShape.PER_TOKEN
51-
5251
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
5352
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
5453

55-
if self.weight_block_size is not None:
56-
assert not self.is_static_input_scheme
57-
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
58-
weight_group_shape=GroupShape(*self.weight_block_size),
59-
act_quant_group_shape=self.act_q_group_shape,
60-
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
61-
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
62-
)
63-
else:
64-
self.fp8_linear = Fp8LinearOp(
65-
act_quant_static=self.is_static_input_scheme,
66-
act_quant_group_shape=self.act_q_group_shape)
67-
6854
@classmethod
6955
def get_min_capability(cls) -> int:
7056
# lovelace and up
@@ -155,14 +141,13 @@ def apply_weights(self,
155141
x: torch.Tensor,
156142
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
157143

158-
if self.weight_block_size is not None:
159-
return self.w8a8_block_fp8_linear.apply(
144+
if layer.weight_block_size is not None:
145+
return apply_fp8_block_linear(
146+
layer,
160147
input=x,
161-
weight=layer.weight,
162-
weight_scale=layer.weight_scale,
163-
input_scale=layer.input_scale,
164148
bias=bias,
165-
)
149+
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
150+
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
166151

167152
return self.fp8_linear.apply(input=x,
168153
weight=layer.weight,

vllm/model_executor/layers/quantization/deepgemm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs(
4242
return M, N, K, C
4343

4444

45-
def w8a8_deepgemm_block_scaled_mm(
45+
def w8a8_block_fp8_matmul_deepgemm(
4646
A: torch.Tensor,
4747
B: torch.Tensor,
4848
As: torch.Tensor,
@@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm(
5858
return C
5959

6060

61-
def w8a8_deepgemm_block_scaled_mm_fake(
61+
def w8a8_block_fp8_matmul_deepgemm_fake(
6262
A: torch.Tensor,
6363
B: torch.Tensor,
6464
As: torch.Tensor,
@@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake(
7272

7373

7474
direct_register_custom_op(
75-
op_name="w8a8_deepgemm_block_scaled_mm",
76-
op_func=w8a8_deepgemm_block_scaled_mm,
77-
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
75+
op_name="w8a8_block_fp8_matmul_deepgemm",
76+
op_func=w8a8_block_fp8_matmul_deepgemm,
77+
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
7878
)

0 commit comments

Comments
 (0)