From 21d7d67c4065de986d6cec9d97ef19f1333c5722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 6 Sep 2025 14:35:13 -0700 Subject: [PATCH 01/81] Functionalized patterns in prep for utility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 62 +++++++++++++++----------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index df54e94a03db..71a3153bf0bc 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -112,13 +112,13 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = torch.empty_like(input) + # TODO: why does empty_like produce a permute but + # empty via shape doesn't? + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at1 = auto_functionalized( RMS_OP, result=result_rms, @@ -133,13 +133,8 @@ def pattern( # result return at2[1] - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -153,8 +148,6 @@ def replacement( return at[1] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight empty_fp32(1, 1), # scale @@ -175,12 +168,14 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( RMS_ADD_OP, input=input, @@ -196,12 +191,12 @@ def pattern( return at1[1], at[2] def replacement( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -216,7 +211,6 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight @@ -248,13 +242,11 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = torch.empty_like(input) + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at1 = auto_functionalized( RMS_OP, result=result_rms, @@ -269,13 +261,8 @@ def pattern( # result, scale return at2[1], at2[2] - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -291,8 +278,6 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight empty_fp32(1, 1), # scale @@ -324,12 +309,14 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( RMS_ADD_OP, input=input, @@ -345,12 +332,12 @@ def pattern( return at1[1], at[2], at1[2] def replacement( - result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -366,7 +353,6 @@ def replacement( return at[1], at[3], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight From f3b4cf190736f949eab00d1ee3a1846a409770c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 9 Sep 2025 09:48:53 -0700 Subject: [PATCH 02/81] TEMP Mostly working MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 37 +++++++- vllm/_custom_ops.py | 2 +- vllm/compilation/fusion.py | 99 ++++---------------- vllm/compilation/matcher_utils.py | 116 ++++++++++++++++++++++++ vllm/model_executor/layers/layernorm.py | 54 +++++++---- 5 files changed, 204 insertions(+), 104 deletions(-) create mode 100644 vllm/compilation/matcher_utils.py diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7c2233643229..fb17dfd0dd46 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -8,6 +8,7 @@ from vllm.compilation.fusion import ( FUSED_OPS, QUANT_OPS, + RMS_OP, FusedRMSQuantKey, RMSNormQuantFusionPass, ) @@ -65,6 +66,9 @@ def __init__( act_quant_group_shape=group_shape, ) + self.enable_rms_norm = self.norm[0].enabled() + self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() + def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) @@ -82,7 +86,18 @@ def forward(self, x): return y3 def ops_in_model_before(self): - return [QUANT_OPS[self.key]] + ops = [] + if self.enable_rms_norm: + ops += [RMS_OP] + else: + ops += [torch.ops.aten.rsqrt.default] + + if self.enable_quant_fp8: + ops += [QUANT_OPS[self.key]] + else: + ops += [torch.ops.aten.reciprocal.default] + + return ops def ops_in_model_after(self): return [ @@ -91,11 +106,13 @@ def ops_in_model_after(self): ] -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("enable_rms_norm", [True]) # , False]) +@pytest.mark.parametrize("enable_quant_fp8", [True]) # , False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -105,17 +122,29 @@ def ops_in_model_after(self): not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) def test_fusion_rmsnorm_quant( - dtype, hidden_size, num_tokens, eps, static, cuda_force_torch + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm, + enable_quant_fp8, + cuda_force_torch, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + custom_ops = [] + if enable_rms_norm: + custom_ops.append("+rms_norm") + if enable_quant_fp8: + custom_ops.append("+quant_fp8") vllm_config = VllmConfig( compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], + custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), ) ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eac0a5009e81..646d8de39a45 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1507,7 +1507,7 @@ def scaled_fp8_quant( output, input, scale, scale_ub ) else: - scale = torch.empty(1, device=input.device, dtype=torch.float32) + scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 71a3153bf0bc..4afb8ba537e7 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -24,6 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -99,6 +100,9 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(key.quant) + class RMSNormStaticQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): @@ -113,25 +117,8 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - result_rms = torch.empty_like(input) - # TODO: why does empty_like produce a permute but - # empty via shape doesn't? - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) - - # result - return at2[1] + result_rms = self.rmsnorm_matcher(input, weight) + return self.quant_matcher(result_rms, scale) def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) @@ -173,22 +160,10 @@ def pattern( weight: torch.Tensor, scale: torch.Tensor, ): - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale - ) + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result = self.quant_matcher(result_rms, scale) - # result, residual - return at1[1], at[2] + return result, residual def replacement( input: torch.Tensor, @@ -242,27 +217,14 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - result_rms = torch.empty_like(input) - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None - ) - + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) # result, scale - return at2[1], at2[2] + return self.quant_matcher(result_rms) - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -280,7 +242,6 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): inputs = [ empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -308,36 +269,17 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_dtype - ) - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None - ) + def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - # result, residual, scale - return at1[1], at[2], at1[2] + return result, residual, scale def replacement( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor ): result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -356,7 +298,6 @@ def replacement( empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale ] pm.register_replacement( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py new file mode 100644 index 000000000000..1200e236bae4 --- /dev/null +++ b/vllm/compilation/matcher_utils.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +from torch._higher_order_ops import auto_functionalized +from torch._ops import OpOverload + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + +QUANT_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 +} + +# TODO +# if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): +# QUANT_OPS[ +# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + + +class MatcherRMSNorm: + def __init__(self, epsilon: float): + self.epsilon = epsilon + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if residual is None: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + + return result + else: + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + + return result, residual + + def __call__( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward(input, weight, residual) + + +class MatcherQuant: + def __init__(self, quant_key: QuantKey): + self.quant_key = quant_key + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + self.QUANT_OP = QUANT_OPS[quant_key] + + def forward( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # TODO: why does empty_like produce a permute but + # empty via shape doesn't? + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) + + if self.quant_key.scale.static: + assert scale is not None + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) + return result + else: + assert scale is None + scale = self.make_scale(input) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) + return result, scale + + def make_scale(self, input: torch.Tensor): + normalized_group_shape = _normalize_quant_group_shape( + input, self.quant_key.scale.group_shape + ) + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) + + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def __call__( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward(input, scale) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 6a49ae42ca89..3c58832cad4c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,13 +170,10 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) - self.has_weight = has_weight - if dtype is not None: - self.weight = torch.ones(hidden_size, dtype=dtype) - else: - self.weight = torch.ones(hidden_size) - if self.has_weight: - self.weight = nn.Parameter(self.weight) + self.weight = None + if has_weight: + dtype = dtype or torch.get_default_dtype() + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) weight_dtype = self.weight.data.dtype if current_platform.is_rocm(): @@ -187,9 +184,13 @@ def __init__( with_fused_add=True, dtype=weight_dtype ) - def forward_native( - self, + @staticmethod + def forward_static( x: torch.Tensor, + variance_epsilon: float, + hidden_size: int, + variance_size_override: Optional[int], + weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" @@ -199,35 +200,48 @@ def forward_native( x = x + residual.to(torch.float32) residual = x.to(orig_dtype) - hidden_size = x.shape[-1] - if hidden_size != self.hidden_size: + if x.shape[-1] != hidden_size: raise ValueError( - "Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}" + f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) - if self.variance_size_override is None: + if variance_size_override is None: x_var = x else: - if hidden_size < self.variance_size_override: + if hidden_size < variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}" + f"{variance_size_override}, but found: {hidden_size}" ) - x_var = x[:, :, : self.variance_size_override] + x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) - if self.has_weight: - x = x * self.weight + if weight is not None: + x = x * weight if residual is None: return x else: return x, residual + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static( + x, + self.variance_epsilon, + self.hidden_size, + self.variance_size_override, + self.weight.data, + residual, + ) + def forward_cuda( self, x: torch.Tensor, From cdad3c05ea12ddba69805277085e2cf658a065f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 12 Sep 2025 12:11:48 -0700 Subject: [PATCH 03/81] TEMP: fixed rmsnorm issue (TODO assert dtypes in fused norm_quant kernels) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_kernels.cu | 1 + tests/compile/backend.py | 19 +- tests/compile/test_fusion.py | 100 ++++----- vllm/compilation/fusion.py | 277 +++++++++++------------- vllm/compilation/matcher_utils.py | 105 ++++++--- vllm/model_executor/layers/layernorm.py | 4 +- 6 files changed, 261 insertions(+), 245 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 6c3685f6f7cd..b738cdbbdc53 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -380,6 +380,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 36bc832a1329..113906af0203 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -4,10 +4,13 @@ import weakref from collections.abc import Sequence from copy import deepcopy +from pathlib import Path from typing import Callable, Union +import depyf from torch import fx from torch._ops import OpOverload +from torch.fx._utils import lazy_format_graph_code from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass @@ -46,11 +49,20 @@ class TestBackend: def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) - compile_config = get_current_vllm_config().compilation_config + vllm_config = get_current_vllm_config() + compile_config = vllm_config.compilation_config self.inductor_config = compile_config.inductor_compile_config self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass + if compile_config.debug_dump_path: + self.debug_dump_path = (Path(compile_config.debug_dump_path) / + f"rank_{vllm_config.parallel_config.rank}") + self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) + self.ctx.__enter__() + else: + self.ctx = None + def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx @@ -60,6 +72,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs): @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + lazy_format_graph_code("graph_pre_pass", graph.owning_module) VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: @@ -69,9 +82,13 @@ def post_pass(self, graph: fx.Graph): VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) + lazy_format_graph_code("graph_post_pass", graph.owning_module) # assign by reference, will reflect the final state of the graph self.final_graph = graph + if self.ctx is not None: + self.ctx.__exit__(None, None, None) + def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index fb17dfd0dd46..ac5d9b9c93bf 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,27 +5,17 @@ import torch import vllm.plugins -from vllm.compilation.fusion import ( - FUSED_OPS, - QUANT_OPS, - RMS_OP, - FusedRMSQuantKey, - RMSNormQuantFusionPass, -) +from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, RMS_OP, + FusedRMSQuantKey, RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - QuantKey, - ScaleDesc, -) + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - cutlass_fp8_supported, - maybe_create_device_identity, -) + Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -35,15 +25,9 @@ class TestModel(torch.nn.Module): - def __init__( - self, - hidden_size: int, - eps: float, - static: bool, - cuda_force_torch: bool, - *args, - **kwargs, - ): + + def __init__(self, hidden_size: int, eps: float, static: bool, + cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] @@ -70,18 +54,21 @@ def __init__( self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() def forward(self, x): - resid = torch.sqrt(x) + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + x2 = self.fp8_linear.apply(y, + self.w[0], + self.wscale[0], + input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + x3 = self.fp8_linear.apply(y2, + self.w[1], + self.wscale[1], + input_scale=self.scale[1]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -102,35 +89,26 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)], + FUSED_OPS[FusedRMSQuantKey(self.key, True)] ] -@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm", [True]) # , False]) -@pytest.mark.parametrize("enable_quant_fp8", [True]) # , False]) +@pytest.mark.parametrize("enable_rms_norm", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" -) -def test_fusion_rmsnorm_quant( - dtype, - hidden_size, - num_tokens, - eps, - static, - enable_rms_norm, - enable_quant_fp8, - cuda_force_torch, -): +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test on CUDA and ROCm") +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, + enable_rms_norm, enable_quant_fp8, + cuda_force_torch): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -141,13 +119,13 @@ def test_fusion_rmsnorm_quant( custom_ops.append("+rms_norm") if enable_quant_fp8: custom_ops.append("+quant_fp8") - vllm_config = VllmConfig( - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - ) - ) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + debug_dump_path=f"/home/luka/git/vllm/._workspace/" + f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", + level=CompilationLevel.PIECEWISE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) @@ -179,7 +157,7 @@ def test_fusion_rmsnorm_quant( assert fusion_pass.matched_count == 2 # In pre-nodes, fp8 quant should be there and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + # backend.check_before_ops(model.ops_in_model_before()) # In post-nodes, fused kernels should be there and fp8 quant should not - backend.check_after_ops(model.ops_in_model_after()) + # backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 4afb8ba537e7..8e3a1de99898 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,15 +12,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - QuantKey, - ScaleDesc, - kFp8DynamicTensorSym, - kFp8DynamicTokenSym, - kFp8StaticTensorSym, - kNvfp4Quant, - kStaticTensorScale, -) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode @@ -48,9 +41,12 @@ def empty_i32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: + torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default @@ -62,42 +58,38 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ - quant: QuantKey fused_add: bool def __str__(self): - return ( - f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)" - ) + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey( - kFp8StaticTensorSym, False - ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8StaticTensorSym, True - ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, False - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, True - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } class RMSNormQuantPattern: + def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] self.rmsnorm_matcher = MatcherRMSNorm(epsilon) @@ -105,82 +97,80 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): class RMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - fused_key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey( - dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric - ), - ) + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) return self.quant_matcher(result_rms, scale) - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result return at[1] inputs = [ empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + empty_bf16(4,), # weight + empty_fp32(1, 1) # scale ] + pattern(*inputs) - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - key = FusedRMSQuantKey( - fused_add=True, - quant=QuantKey( - dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric - ), - ) + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + + def pattern(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher( + input, weight, residual) result = self.quant_matcher(result_rms, scale) return result, residual - def replacement( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result, residual return at[1], at[2] @@ -188,8 +178,8 @@ def replacement( inputs = [ empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + empty_bf16(4, ), # weight + empty_fp32(1, 1) # scale ] pm.register_replacement( @@ -202,21 +192,21 @@ def replacement( class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + scale=scale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) # result, scale @@ -225,23 +215,21 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) # result, scale return at[1], at[2] inputs = [ empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight + empty_bf16(4), # weight ] pm.register_replacement( @@ -254,42 +242,41 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=True, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + scale=scale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + + def pattern(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher( + input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor - ): + def replacement(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized( - self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual, - ) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) # result, residual, scale return at[1], at[3], at[2] @@ -297,7 +284,7 @@ def replacement( inputs = [ empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight + empty_bf16(4), # weight ] pm.register_replacement( @@ -320,25 +307,24 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass" - ) + pass_name="rmsnorm_quant_fusion_pass") for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + self.patterns) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + self.patterns) self.dump_patterns(config, self.patterns) @@ -348,11 +334,8 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source( - self, - RMSNormQuantPattern, - RMSNormStaticQuantPattern, - RMSNormDynamicQuantPattern, - FusedAddRMSNormStaticQuantPattern, - FusedAddRMSNormDynamicQuantPattern, - ) + return self.hash_source(self, RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 1200e236bae4..1b88d2916b0d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,21 +6,21 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - _normalize_quant_group_shape, - kFp8DynamicTensorSym, - kFp8DynamicTokenSym, - kFp8StaticTensorSym, -) + QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, + kFp8DynamicTokenSym, kFp8StaticTensorSym) RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: + torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } # TODO @@ -29,11 +29,18 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: - def __init__(self, epsilon: float): +class MatcherRMSNorm: # TODO separate residual and not residual + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.epsilon = epsilon - def forward( + if enabled is None: + # TODO either pass config to enabled or set it globally (global during pass init seems reasonable) + enabled = RMSNorm.enabled() + + self.forward = self.forward_custom if enabled else self.forward_native + + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, @@ -51,16 +58,36 @@ def forward( return result else: - _, result, residual = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + _, result, residual = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) return result, residual + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = input.dtype + x = input.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x + + variance = x.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.epsilon) + x = x.to(orig_dtype) + if weight is not None: + x = x * weight + + return x if residual is None else (x, residual) + + def __call__( self, input: torch.Tensor, @@ -71,46 +98,56 @@ def __call__( class MatcherQuant: + def __init__(self, quant_key: QuantKey): self.quant_key = quant_key - assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + assert quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] def forward( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? - result = torch.empty( - input.shape, device=input.device, dtype=self.quant_key.dtype - ) + result = torch.empty(input.shape, + device=input.device, + dtype=self.quant_key.dtype) if self.quant_key.scale.static: assert scale is not None - _, result = auto_functionalized( - self.QUANT_OP, result=result, input=input, scale=scale - ) + _, result = auto_functionalized(self.QUANT_OP, + result=result, + input=input, + scale=scale) return result else: assert scale is None scale = self.make_scale(input) - _, result, scale = auto_functionalized( - self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None - ) + _, result, scale = auto_functionalized(self.QUANT_OP, + result=result, + input=input, + scale=scale, + scale_ub=None) return result, scale def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( - input, self.quant_key.scale.group_shape - ) + input, self.quant_key.scale.group_shape) scale_shape = ( input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) - return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + return torch.empty(scale_shape, + device=input.device, + dtype=torch.float32) def __call__( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: return self.forward(input, scale) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 3c58832cad4c..976b2e852265 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -189,9 +189,9 @@ def forward_static( x: torch.Tensor, variance_epsilon: float, hidden_size: int, - variance_size_override: Optional[int], weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, + variance_size_override: Optional[int] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype @@ -237,9 +237,9 @@ def forward_native( x, self.variance_epsilon, self.hidden_size, - self.variance_size_override, self.weight.data, residual, + self.variance_size_override, ) def forward_cuda( From 8e4a56f57581e98bfa0e146197cfff860a2f95f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 16 Sep 2025 10:47:13 -0700 Subject: [PATCH 04/81] rms works fully now, had to remove more conversions (and add them in replacements). TODO pass to remove unnecessary conversions? MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- ...fused_layernorm_dynamic_per_token_quant.cu | 4 ++ tests/compile/test_fusion.py | 21 ++++---- vllm/compilation/fusion.py | 53 +++++++++++++------ vllm/compilation/matcher_utils.py | 15 +++--- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 95aa92e25b30..92d6c2f402a2 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant( if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); } + TORCH_CHECK(weight.dtype() == input.dtype()); TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index ac5d9b9c93bf..aea9038a64e3 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,8 +9,8 @@ FusedRMSQuantKey, RMSNormQuantFusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, + PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc) @@ -119,13 +119,16 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, custom_ops.append("+rms_norm") if enable_quant_fp8: custom_ops.append("+quant_fp8") - vllm_config = VllmConfig(compilation_config=CompilationConfig( - debug_dump_path=f"/home/luka/git/vllm/._workspace/" - f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", - level=CompilationLevel.PIECEWISE, - custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - )) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + debug_dump_path=f"/home/luka/git/vllm/._workspace/" + f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", + level=CompilationLevel.PIECEWISE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), + ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8e3a1de99898..0efdd7d2d0e4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, @@ -117,6 +117,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized(self.FUSED_OP, result=result, @@ -130,7 +134,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ empty_bf16(5, 4), # input - empty_bf16(4,), # weight + empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] pattern(*inputs) @@ -163,6 +167,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, def replacement(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + residual = residual.to(dtype=torch.float16) + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized(self.FUSED_OP, result=result, @@ -176,9 +185,11 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ + # TODO: maybe 32bit for torch impl? + # TODO dtype doesn't seem to matter? empty_bf16(5, 4), # input empty_bf16(5, 4), # residual - empty_bf16(4, ), # weight + empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] @@ -213,6 +224,10 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): return self.quant_matcher(result_rms) def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) at = auto_functionalized(self.FUSED_OP, @@ -267,6 +282,11 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, def replacement(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=torch.float16) # TODO model dtype + residual = residual.to(dtype=torch.float16) + result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) at = auto_functionalized(self.FUSED_OP, @@ -309,22 +329,23 @@ def __init__(self, config: VllmConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rmsnorm_quant_fusion_pass") - for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + with set_current_vllm_config(config, check_compile=False): + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant - FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + # Fuse fused_add_rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE).register(self.patterns) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 1b88d2916b0d..ebb5e26b324c 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,6 +6,7 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, @@ -29,16 +30,18 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: # TODO separate residual and not residual +class MatcherRMSNorm: # TODO separate residual and not residual def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.epsilon = epsilon if enabled is None: - # TODO either pass config to enabled or set it globally (global during pass init seems reasonable) + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) enabled = RMSNorm.enabled() self.forward = self.forward_custom if enabled else self.forward_native + self.model_dtype = get_current_vllm_config().model_config.dtype def forward_custom( self, @@ -72,22 +75,20 @@ def forward_native( weight: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = input.dtype - x = input.to(torch.float32) + x = input # .to(torch.float32) if residual is not None: x = x + residual.to(torch.float32) - residual = x + residual = x # conversion to 16-bit is eliminated in full graph variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(orig_dtype) + x = x.to(self.model_dtype) if weight is not None: x = x * weight return x if residual is None else (x, residual) - def __call__( self, input: torch.Tensor, From e151e6d16e1ef6c2c0cddf6ee9fc074f88dc3bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 16 Sep 2025 11:08:39 -0700 Subject: [PATCH 05/81] quant works except (torch,torch) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 4 ++-- vllm/compilation/matcher_utils.py | 37 +++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 0efdd7d2d0e4..fffe2a6432ec 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -113,7 +113,7 @@ def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) - return self.quant_matcher(result_rms, scale) + return self.quant_matcher(result_rms, scale)[0] def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -161,7 +161,7 @@ def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) - result = self.quant_matcher(result_rms, scale) + result, _ = self.quant_matcher(result_rms, scale) return result, residual diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index ebb5e26b324c..51fff7fe0c9e 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -8,6 +8,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym) @@ -100,17 +101,29 @@ def __call__( class MatcherQuant: - def __init__(self, quant_key: QuantKey): + def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): + self.quant_key = quant_key assert quant_key in QUANT_OPS, \ f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] - def forward( + assert quant_key.scale2 is None + self.quant_fp8 = QuantFP8(quant_key.scale.static, + quant_key.scale.group_shape) + + if enabled is None: + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) + enabled = self.quant_fp8.enabled() + + self.forward = self.forward_custom if enabled else self.forward_native + + def forward_custom( self, input: torch.Tensor, scale: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? result = torch.empty(input.shape, @@ -123,7 +136,7 @@ def forward( result=result, input=input, scale=scale) - return result + return result, scale else: assert scale is None scale = self.make_scale(input) @@ -134,6 +147,13 @@ def forward( scale_ub=None) return result, scale + def forward_native( + self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.quant_fp8(input, scale) + def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape) @@ -146,9 +166,8 @@ def make_scale(self, input: torch.Tensor): device=input.device, dtype=torch.float32) - def __call__( - self, - input: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + def __call__(self, + input: torch.Tensor, + scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: return self.forward(input, scale) From 14fdc8b9d51ac418c8f09f367c1a81228ee163ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 18 Sep 2025 12:32:27 -0700 Subject: [PATCH 06/81] quant with fix for pure torch, broke others MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 6 ++---- vllm/compilation/fusion.py | 8 ++++---- vllm/compilation/matcher_utils.py | 10 +++++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aea9038a64e3..4a9a497989e8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -147,10 +147,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: + # Higher tol for dynamic bfloat16 + if dtype == torch.float16 or static: ATOL, RTOL = (2e-3, 2e-3) else: ATOL, RTOL = (1e-2, 1e-2) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index fffe2a6432ec..92caf47945ef 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -26,7 +26,7 @@ def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda") def empty_fp32(*args, **kwargs): @@ -133,7 +133,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, return at[1] inputs = [ - empty_bf16(5, 4), # input + empty_fp32(5, 4), # input # TODO: rms_input empty_bf16(4, ), # weight empty_fp32(1, 1) # scale ] @@ -185,8 +185,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ - # TODO: maybe 32bit for torch impl? - # TODO dtype doesn't seem to matter? + # TODO: maybe 32bit for torch impl? yes to resolve bug + # TODO dtype doesn't seem to matter? it does matter for what cvts get traced empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(4, ), # weight diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 51fff7fe0c9e..9cde9230211f 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -43,6 +43,10 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): self.forward = self.forward_custom if enabled else self.forward_native self.model_dtype = get_current_vllm_config().model_config.dtype + print(self.model_dtype) + + def inputs(self): + return def forward_custom( self, @@ -76,10 +80,10 @@ def forward_native( weight: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - x = input # .to(torch.float32) + x = input.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) - residual = x # conversion to 16-bit is eliminated in full graph + x = x + residual + residual = x.to(self.model_dtype) variance = x.pow(2).mean(dim=-1, keepdim=True) From 05a65f39a5043315afcad16f9060c21079908ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 18 Sep 2025 13:21:46 -0700 Subject: [PATCH 07/81] ALL WORKS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 47 ++++------- vllm/compilation/matcher_utils.py | 125 ++++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 62 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 92caf47945ef..4e1b569f77e0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -92,7 +92,8 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) if not key.fused_add \ + else MatcherFusedAddRMSNorm(epsilon) self.quant_matcher = MatcherQuant(key.quant) @@ -133,8 +134,8 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, return at[1] inputs = [ - empty_fp32(5, 4), # input # TODO: rms_input - empty_bf16(4, ), # weight + # input, weight + *self.rmsnorm_matcher.inputs(), empty_fp32(1, 1) # scale ] pattern(*inputs) @@ -157,16 +158,16 @@ def __init__(self, def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) result, _ = self.quant_matcher(result_rms, scale) return result, residual - def replacement(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -185,11 +186,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, return at[1], at[2] inputs = [ - # TODO: maybe 32bit for torch impl? yes to resolve bug - # TODO dtype doesn't seem to matter? it does matter for what cvts get traced - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(4, ), # weight + # input, weight, residual + *self.rmsnorm_matcher.inputs(), empty_fp32(1, 1) # scale ] @@ -242,15 +240,10 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): # result, scale return at[1], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(4), # weight - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -272,16 +265,16 @@ def __init__(self, def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor): result_rms, residual = self.rmsnorm_matcher( input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -301,16 +294,10 @@ def replacement(input: torch.Tensor, residual: torch.Tensor, # result, residual, scale return at[1], at[3], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(4), # weight - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 9cde9230211f..a72e7396f526 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from abc import ABC, abstractmethod +from typing import Optional import torch from torch._higher_order_ops import auto_functionalized @@ -31,55 +32,71 @@ # kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 -class MatcherRMSNorm: # TODO separate residual and not residual +class MatcherCustomOp(ABC): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): - self.epsilon = epsilon + def __init__(self, enabled: bool): + self.model_dtype = get_current_vllm_config().model_config.dtype + + self.enabled = enabled + self.forward = self.forward_custom if enabled else self.forward_native + + @abstractmethod + def forward_custom(self, *args, **kws): + pass + + @abstractmethod + def forward_native(self, *args, **kws): + pass + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws) + + def empty_f32(self, *args, **kws): + return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + + +class MatcherRMSNorm(MatcherCustomOp): + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally # (global during pass init seems reasonable) enabled = RMSNorm.enabled() - self.forward = self.forward_custom if enabled else self.forward_native - self.model_dtype = get_current_vllm_config().model_config.dtype - print(self.model_dtype) + super().__init__(enabled) + self.epsilon = epsilon def inputs(self): - return + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16, ) + return [input, weight] def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if residual is None: - result = torch.empty_like(input) - _, result = auto_functionalized( - RMS_OP, - result=result, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - - return result - else: - _, result, residual = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) + ) -> torch.Tensor: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) - return result, residual + return result def forward_native( self, input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor: x = input.to(torch.float32) if residual is not None: x = x + residual @@ -94,13 +111,57 @@ def forward_native( return x if residual is None else (x, residual) - def __call__( + +class MatcherFusedAddRMSNorm(MatcherCustomOp): + + def __init__(self, epsilon: float, enabled: Optional[bool] = None): + if enabled is None: + # TODO either pass config to enabled or set it globally + # (global during pass init seems reasonable) + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16, ) + residual = self.empty(5, 16) + return [input, weight, residual] + + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return self.forward(input, weight, residual) + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, result, residual = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + + return result, residual + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = input.to(torch.float32) + if residual is not None: + x = x + residual + residual = x.to(self.model_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.epsilon) + x = x.to(self.model_dtype) + if weight is not None: + x = x * weight + + return x if residual is None else (x, residual) class MatcherQuant: From e6b394e28a10cd19e2db7af52686591bc30599a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 19 Sep 2025 19:00:27 -0700 Subject: [PATCH 08/81] Add TODO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a9a497989e8..edda51e2844a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -64,7 +64,8 @@ def forward(self, x): input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - + # TODO another fp8 linear + rmsnorm to make sure fusion + # works for residual output as well x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], From d96913a7987a6747eef7cfddc05c45c1220312b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:06:25 -0400 Subject: [PATCH 09/81] Cleanup test_fusion.py, added extra layer of rms/quant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 52 +++++++++++------------------------- 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index edda51e2844a..3b494fce3bae 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,8 +5,7 @@ import torch import vllm.plugins -from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, RMS_OP, - FusedRMSQuantKey, RMSNormQuantFusionPass) +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, @@ -30,18 +29,18 @@ def __init__(self, hidden_size: int, eps: float, static: bool, cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch - self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: - self.scale = [None for _ in range(2)] + self.scale = [None for _ in range(3)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(2) + for _ in range(3) ] with override_cutlass_fp8_supported(not cuda_force_torch): @@ -64,34 +63,21 @@ def forward(self, x): input_scale=self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - # TODO another fp8 linear + rmsnorm to make sure fusion - # works for residual output as well + x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], input_scale=self.scale[1]) - y3, resid = self.norm[2](x3, resid) # use resid here - return y3 - def ops_in_model_before(self): - ops = [] - if self.enable_rms_norm: - ops += [RMS_OP] - else: - ops += [torch.ops.aten.rsqrt.default] - - if self.enable_quant_fp8: - ops += [QUANT_OPS[self.key]] - else: - ops += [torch.ops.aten.reciprocal.default] + y3, resid = self.norm[2](x3, resid) # use resid here - return ops + x4 = self.fp8_linear.apply(y3, + self.w[2], + self.wscale[2], + input_scale=self.scale[2]) - def ops_in_model_after(self): - return [ - FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)] - ] + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 @pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) @@ -123,8 +109,6 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( - debug_dump_path=f"/home/luka/git/vllm/._workspace/" - f"debug_dump_{enable_rms_norm}_{enable_quant_fp8}", level=CompilationLevel.PIECEWISE, custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), @@ -156,10 +140,4 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - assert fusion_pass.matched_count == 2 - - # In pre-nodes, fp8 quant should be there and fused kernels should not - # backend.check_before_ops(model.ops_in_model_before()) - - # In post-nodes, fused kernels should be there and fp8 quant should not - # backend.check_after_ops(model.ops_in_model_after()) + assert fusion_pass.matched_count == 3 From b1727475027f129549da400e41e46bb4e4e045a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 15:02:33 -0700 Subject: [PATCH 10/81] Functionalize attn+quant patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 6 +- tests/compile/test_fusion.py | 77 ++++++--- vllm/compilation/fusion.py | 275 ++++++++++++++++-------------- vllm/compilation/fusion_attn.py | 54 ++++-- vllm/compilation/matcher_utils.py | 83 +++++---- 5 files changed, 281 insertions(+), 214 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 113906af0203..fb92fd7b42a5 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -56,8 +56,10 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["post_grad_custom_post_pass"] = self.post_pass if compile_config.debug_dump_path: - self.debug_dump_path = (Path(compile_config.debug_dump_path) / - f"rank_{vllm_config.parallel_config.rank}") + self.debug_dump_path = ( + Path(compile_config.debug_dump_path) + / f"rank_{vllm_config.parallel_config.rank}" + ) self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) self.ctx.__enter__() else: diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 3b494fce3bae..13cffbe087c6 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -8,13 +8,24 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig, - PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc) + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) + Fp8LinearOp, + cutlass_fp8_supported, + maybe_create_device_identity, +) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -24,9 +35,15 @@ class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, eps: float, static: bool, - cuda_force_torch: bool, *args, **kwargs): + def __init__( + self, + hidden_size: int, + eps: float, + static: bool, + cuda_force_torch: bool, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] @@ -57,30 +74,27 @@ def forward(self, x): x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply(y, - self.w[0], - self.wscale[0], - input_scale=self.scale[0]) + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply(y2, - self.w[1], - self.wscale[1], - input_scale=self.scale[1]) + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply(y3, - self.w[2], - self.wscale[2], - input_scale=self.scale[2]) + x4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) y4, resid = self.norm[3](x4, resid) # use resid here return y4 -@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @@ -89,13 +103,22 @@ def forward(self, x): @pytest.mark.parametrize("enable_quant_fp8", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize("cuda_force_torch", - [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test on CUDA and ROCm") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - enable_rms_norm, enable_quant_fp8, - cuda_force_torch): +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" +) +def test_fusion_rmsnorm_quant( + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm, + enable_quant_fp8, + cuda_force_torch, +): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 4e1b569f77e0..742e5355d1cf 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,8 +12,15 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, - kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode @@ -41,12 +48,9 @@ def empty_i32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default @@ -58,77 +62,82 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ + quant: QuantKey fused_add: bool def __str__(self): - return (f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)") + return ( + f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)" + ) FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey(kFp8StaticTensorSym, False): - torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8StaticTensorSym, True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, False + ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, True + ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, False + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, True + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, \ - f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, \ - f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) if not key.fused_add \ + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon) + if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) + ) self.quant_matcher = MatcherQuant(key.quant) class RMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + fused_key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) return self.quant_matcher(result_rms, scale)[0] - def replacement(input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result return at[1] @@ -136,51 +145,56 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ # input, weight *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pattern(*inputs) - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, - pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor, scale: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher( - input, weight, residual) + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor, + ): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, _ = self.quant_matcher(result_rms, scale) return result, residual - def replacement(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor, scale: torch.Tensor): + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor, + ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype residual = residual.to(dtype=torch.float16) result = torch.empty_like(input, dtype=self.quant_dtype) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result, residual return at[1], at[2] @@ -188,7 +202,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, inputs = [ # input, weight, residual *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -201,21 +215,21 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor): result_rms = self.rmsnorm_matcher(input, weight) # result, scale @@ -228,14 +242,16 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + ) # result, scale return at[1], at[2] @@ -250,31 +266,30 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor): - result_rms, residual = self.rmsnorm_matcher( - input, weight, residual) + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale - def replacement(input: torch.Tensor, weight: torch.Tensor, - residual: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=torch.float16) # TODO model dtype @@ -282,14 +297,16 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + ) # result, residual, scale return at[1], at[3], at[2] @@ -314,25 +331,26 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass") + pass_name="rmsnorm_quant_fusion_pass" + ) with set_current_vllm_config(config, check_compile=False): for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + self.patterns + ) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern( - epsilon, FP8_DTYPE).register(self.patterns) + FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) self.dump_patterns(config, self.patterns) @@ -342,8 +360,11 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source(self, RMSNormQuantPattern, - RMSNormStaticQuantPattern, - RMSNormDynamicQuantPattern, - FusedAddRMSNormStaticQuantPattern, - FusedAddRMSNormDynamicQuantPattern) + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern, + ) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index ae36cef92653..6933442552aa 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from typing import Callable import torch import torch._inductor.pattern_matcher as pm +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -20,7 +22,9 @@ from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .fx_utils import is_func from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuant from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,9 +70,13 @@ def empty_quant(self, *args, **kwargs): return torch.empty(*args, **kwargs) @staticmethod - def wrap_trace_fn(process_fx, trace_fn): + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm return wrapped @@ -77,7 +85,20 @@ def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) - return gm + + @staticmethod + def remove_noop_permutes(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if not is_func(node, torch.ops.aten.permute.default): + continue + + dims = node.args[1] + if any(dim != i for i, dim in enumerate(dims)): + continue + + # this is now an identity op, remove + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) def register_if_supported(self, pm_pass: PatternMatcherPass): if self.layer.impl.fused_output_quant_supported(self.quant_key): @@ -108,6 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) + self.quant_matcher = MatcherQuant(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( @@ -115,7 +137,6 @@ def pattern( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): at1 = auto_functionalized( @@ -131,6 +152,11 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) + output_quant = torch.empty( + attn_out_view.size(), + device=attn_out_view.device, + dtype=self.quant_dtype, + ) at2 = auto_functionalized( self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale ) @@ -141,7 +167,6 @@ def replacement( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): # attn output in quant_dtype @@ -164,13 +189,10 @@ def replacement( return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v - self.empty( - 5, self.num_heads, self.head_size, dtype=self.dtype - ), # attn_output - self.empty_quant(5, self.num_heads * self.head_size), # quant_output + self.empty(5, self.num_heads, self.head_size), # q + self.empty(5, self.num_heads, self.head_size), # k + self.empty(5, self.num_heads, self.head_size), # v + self.empty(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1), # scale ] @@ -179,7 +201,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) @@ -279,7 +303,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index a72e7396f526..d3603372d69f 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -11,19 +11,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, _normalize_quant_group_shape, kFp8DynamicTensorSym, - kFp8DynamicTokenSym, kFp8StaticTensorSym) + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } # TODO @@ -33,7 +34,6 @@ class MatcherCustomOp(ABC): - def __init__(self, enabled: bool): self.model_dtype = get_current_vllm_config().model_config.dtype @@ -59,7 +59,6 @@ def empty_f32(self, *args, **kws): class MatcherRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally @@ -71,7 +70,9 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty(16, ) + weight = self.empty( + 16, + ) return [input, weight] def forward_custom( @@ -113,7 +114,6 @@ def forward_native( class MatcherFusedAddRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: # TODO either pass config to enabled or set it globally @@ -125,7 +125,9 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty(16, ) + weight = self.empty( + 16, + ) residual = self.empty(5, 16) return [input, weight, residual] @@ -135,11 +137,13 @@ def forward_custom( weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - _, result, residual = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) return result, residual @@ -165,17 +169,13 @@ def forward_native( class MatcherQuant: - def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): - self.quant_key = quant_key - assert quant_key in QUANT_OPS, \ - f"unsupported quantization scheme {quant_key}" + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] assert quant_key.scale2 is None - self.quant_fp8 = QuantFP8(quant_key.scale.static, - quant_key.scale.group_shape) + self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) if enabled is None: # TODO either pass config to enabled or set it globally @@ -191,25 +191,22 @@ def forward_custom( ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: why does empty_like produce a permute but # empty via shape doesn't? - result = torch.empty(input.shape, - device=input.device, - dtype=self.quant_key.dtype) + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) if self.quant_key.scale.static: assert scale is not None - _, result = auto_functionalized(self.QUANT_OP, - result=result, - input=input, - scale=scale) + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) return result, scale else: assert scale is None scale = self.make_scale(input) - _, result, scale = auto_functionalized(self.QUANT_OP, - result=result, - input=input, - scale=scale, - scale_ub=None) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) return result, scale def forward_native( @@ -221,18 +218,16 @@ def forward_native( def make_scale(self, input: torch.Tensor): normalized_group_shape = _normalize_quant_group_shape( - input, self.quant_key.scale.group_shape) + input, self.quant_key.scale.group_shape + ) scale_shape = ( input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) - return torch.empty(scale_shape, - device=input.device, - dtype=torch.float32) + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) - def __call__(self, - input: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: + def __call__( + self, input: torch.Tensor, scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: return self.forward(input, scale) From 1ae80c6fff346994a199a358a6a89821e3890ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:02:21 -0700 Subject: [PATCH 11/81] Move global vllm_config to pass manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/fusion.py | 29 ++++++++++++------------- vllm/compilation/pass_manager.py | 37 +++++++++++++++++--------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 742e5355d1cf..883743b635a8 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -334,23 +334,22 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) - with set_current_vllm_config(config, check_compile=False): - for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant - FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + # Fuse fused_add_rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns - ) + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..3d7c6287fe07 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,7 +5,7 @@ from torch import fx as fx from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import set_env_var @@ -86,27 +86,30 @@ def __call__(self, graph: fx.Graph): def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - if self.pass_config.enable_noop: - self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: - self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: - self.passes += [AsyncTPPass(config)] + # Set the current vllm config to allow tracing CustomOp instances + with set_current_vllm_config(config, check_compile=False): + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fusion: - self.passes += [RMSNormQuantFusionPass(config)] - self.passes += [ActivationQuantFusionPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] - if self.pass_config.enable_attn_fusion: - self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_fusion: + self.passes += [RMSNormQuantFusionPass(config)] + self.passes += [ActivationQuantFusionPass(config)] - # needs a functional graph - self.post_cleanup = PostCleanupPass(config) - self.fix_functionalization = FixFunctionalizationPass(config) + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) From 77835fd36531b2b88f591182dee4e61e9cd9639e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:12:11 -0700 Subject: [PATCH 12/81] Attention fusion works with custom ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 80 ++++++++++++++++++------------- vllm/compilation/fusion_attn.py | 11 +---- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 0f2e3bffbd31..5b6b7dcfe8f1 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -242,26 +241,49 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) +MODELS_FP8 = [] +MODELS_FP4 = [] +HEADS = [] +SPLIT_ATTENTION = [] +BACKENDS: list[_Backend] = [] + if current_platform.is_cuda(): - MODELS = [ + MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, - ), - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel, - ), + ) ] HEADS = [(64, 8), (40, 8)] + SPLIT_ATTENTION = [False] + BACKENDS = [] # TODO [_Backend.TRITON_ATTN] + + if current_platform.is_device_capability((10, 0)): + BACKENDS += [_Backend.FLASHINFER] + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ) + ] + elif current_platform.is_rocm(): - MODELS = [ + MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] HEADS = [(32, 8), (40, 8)] + SPLIT_ATTENTION = [False, True] + BACKENDS = [ + _Backend.TRITON_ATTN, + _Backend.ROCM_AITER_UNIFIED_ATTN, + _Backend.ROCM_ATTN, + ] + +# TODO(boyuan/luka): test inductor graph partition on rocm +if is_torch_equal_or_newer("2.9.0.dev") and current_platform.is_cuda(): + USE_INDUCTOR_GRAPH_PARTITION = [False, True] else: - MODELS = [] - HEADS = [] + USE_INDUCTOR_GRAPH_PARTITION = [False] @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @@ -270,35 +292,26 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize( - "backend", - [_Backend.FLASHINFER] - if current_platform.is_cuda() - else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], -) -# TODO(boyuan): test inductor graph partition on rocm -@pytest.mark.parametrize( - "use_inductor_graph_partition", - [False] if current_platform.is_rocm() else [False, True], + "model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls + [(*model, c) for model in MODELS_FP8 for c in ["+quant_fp8", "-quant_fp8"]] + # quant_fp4 only has the custom impl + + [(*model, c) for model in MODELS_FP4 for c in [""]], ) +@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif( - current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), - reason="On CUDA only test on SM100(Blackwell)", -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" -) def test_attention_quant_pattern( num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, + custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, @@ -308,8 +321,7 @@ def test_attention_quant_pattern( ): """Test AttentionStaticQuantPattern fusion pass""" - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") torch.manual_seed(42) @@ -323,7 +335,7 @@ def test_attention_quant_pattern( scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, - custom_ops=["+quant_fp8"], + custom_ops=custom_ops_list, use_inductor_graph_partition=use_inductor_graph_partition, ), cache_config=CacheConfig(cache_dtype="fp8"), @@ -420,12 +432,12 @@ def test_attention_quant_pattern( layer.impl.fused_output_quant_supported(quant_key) for key, layer in vllm_config.compilation_config.static_forward_context.items() ] - if any(attn_fusion_supported): - # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + assert sum(attn_fusion_supported) == len(attn_fusion_supported), ( + "All layers should support attention fusion" + ) # access the underlying `AttnFusionPass` on the `LazyInitPass` - assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) + assert attn_pass.pass_.matched_count == 1 # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 6933442552aa..761acb35834b 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -152,15 +152,8 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) - output_quant = torch.empty( - attn_out_view.size(), - device=attn_out_view.device, - dtype=self.quant_dtype, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale - ) - return at2[1] + + return self.quant_matcher(attn_out_view, scale)[0] def replacement( q: torch.Tensor, From 1277999c297cf1fcc784ed3a1f698284e9f63cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 16:12:23 -0700 Subject: [PATCH 13/81] Remove V0 attn fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 5b6b7dcfe8f1..c91e162c8e74 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from typing import Optional import pytest import torch._dynamo @@ -39,10 +38,6 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -# globals needed for string-import custom Dynamo backend field -backend: Optional[TestBackend] = None -backend_unfused: Optional[TestBackend] = None - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" From d843a67c428ae6e1c4397d24377a206b86cbb6cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 17:02:14 -0700 Subject: [PATCH 14/81] Add triton attn test to attn+quant fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index c91e162c8e74..4d6cdabf6a90 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import itertools import pytest import torch._dynamo @@ -99,6 +100,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend + # TODO use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention @@ -240,7 +242,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): MODELS_FP4 = [] HEADS = [] SPLIT_ATTENTION = [] -BACKENDS: list[_Backend] = [] +BACKENDS_FP8: list[_Backend] = [] +BACKENDS_FP4: list[_Backend] = [] if current_platform.is_cuda(): MODELS_FP8 = [ @@ -251,10 +254,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ] HEADS = [(64, 8), (40, 8)] SPLIT_ATTENTION = [False] - BACKENDS = [] # TODO [_Backend.TRITON_ATTN] + BACKENDS_FP8 = [_Backend.TRITON_ATTN] if current_platform.is_device_capability((10, 0)): - BACKENDS += [_Backend.FLASHINFER] + BACKENDS_FP8 += [_Backend.FLASHINFER] + BACKENDS_FP4 += [_Backend.FLASHINFER] MODELS_FP4 += [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", @@ -288,13 +292,12 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "model_name, model_class, custom_ops", + "backend, model, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls - [(*model, c) for model in MODELS_FP8 for c in ["+quant_fp8", "-quant_fp8"]] + list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) # quant_fp4 only has the custom impl - + [(*model, c) for model in MODELS_FP4 for c in [""]], + + list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" @@ -307,8 +310,7 @@ def test_attention_quant_pattern( batch_size: int, dtype: torch.dtype, custom_ops: str, - model_name: str, - model_class: type[AttentionQuantPatternModel], + model: tuple[str, type[AttentionQuantPatternModel]], backend: _Backend, use_inductor_graph_partition: bool, dist_init, @@ -317,6 +319,7 @@ def test_attention_quant_pattern( """Test AttentionStaticQuantPattern fusion pass""" custom_ops_list = custom_ops.split(",") if custom_ops else [] + model_name, model_class = model device = torch.device("cuda:0") torch.manual_seed(42) From cdd1529b0899cb18495ce46b8b27a0e5a0db5719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 25 Sep 2025 17:18:43 -0700 Subject: [PATCH 15/81] Flat product for better test names/visibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 4d6cdabf6a90..7d672bc343b4 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy import itertools +from collections.abc import Iterable +from typing import Any import pytest import torch._dynamo @@ -285,6 +287,13 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): USE_INDUCTOR_GRAPH_PARTITION = [False] +def flat_product(*iterables: Iterable[Any]): + """Flatten lists of tuples into cartesian product.""" + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else [e] for e in element) + yield list(itertools.chain(*normalized)) + + @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( @@ -292,11 +301,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "backend, model, custom_ops", - # Test attention+quant_fp8 fusion with custom and torch impls - list(itertools.product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) + "backend, model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) # quant_fp4 only has the custom impl - + list(itertools.product(BACKENDS_FP4, MODELS_FP4, [""])), + + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) @pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( @@ -310,7 +319,8 @@ def test_attention_quant_pattern( batch_size: int, dtype: torch.dtype, custom_ops: str, - model: tuple[str, type[AttentionQuantPatternModel]], + model_name: str, + model_class: type[AttentionQuantPatternModel], backend: _Backend, use_inductor_graph_partition: bool, dist_init, @@ -319,7 +329,6 @@ def test_attention_quant_pattern( """Test AttentionStaticQuantPattern fusion pass""" custom_ops_list = custom_ops.split(",") if custom_ops else [] - model_name, model_class = model device = torch.device("cuda:0") torch.manual_seed(42) From 141a37eb431da104f4173ea7d1c0c3895354020d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 07:41:41 -0700 Subject: [PATCH 16/81] Fix rmsnorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/model_executor/layers/layernorm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 976b2e852265..7e15efab379b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,11 +170,9 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) - self.weight = None - if has_weight: - dtype = dtype or torch.get_default_dtype() - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype)) - weight_dtype = self.weight.data.dtype + weight_dtype = dtype or torch.get_default_dtype() + self.has_weight = has_weight + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( @@ -233,11 +231,12 @@ def forward_native( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" + return self.forward_static( x, self.variance_epsilon, self.hidden_size, - self.weight.data, + self.weight.data if self.has_weight else None, residual, self.variance_size_override, ) From c6d6c3ba7f35105ed5809ac553012db7c1677746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 13:20:52 -0700 Subject: [PATCH 17/81] Refactor E2E attn fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 89 ++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 8ccae4cfb9df..dffa221a9f7f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -24,23 +24,30 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { "dtype": torch.float16, }, ), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: + if not current_platform.has_device_capability((10, 0)): + # int8 removed on Blackwell + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + { + "dtype": torch.float16, + }, + ), + ] + ) + # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 TEST_MODELS.append( @@ -85,15 +92,14 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): "optimization_level", [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], ) -@pytest.mark.parametrize("model_info", models_list(all=True)) +@pytest.mark.parametrize("model, model_kwargs", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], optimization_level: int, ): - model, model_kwargs = model_info - with monkeypatch.context(): print(f"MODEL={model}") @@ -180,40 +186,55 @@ def test_fp8_kv_scale_compile(optimization_level: int): run_model(optimization_level, model, model_kwargs) -def test_inductor_graph_partition_attn_fusion(caplog_vllm): - if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") +INDUCTOR_GRAPH_PARTITION = ( + [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] +) + +@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_default_fusion( + custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +): model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} + backend = _Backend.FLASHINFER + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common level=CompilationLevel.PIECEWISE, - use_inductor_graph_partition=True, - cudagraph_mode=CUDAGraphMode.PIECEWISE, - custom_ops=["+quant_fp8"], pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, ) - model_kwargs = { - "kv_cache_dtype": "fp8", - "max_model_len": 1024, - } + with ( caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(_Backend.FLASHINFER), + global_force_attn_backend_context_manager(backend), ): run_model(compilation_config, model, model_kwargs) - try: - assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( - caplog_vllm.text - ) - except AssertionError: - # Note: this message is only triggered when the compilation goes - # through the custom pass. Due to multiple layers of cache on - # PyTorch side, the compilation of a graph may be cached such - # that custom pass directly goes through cache. In this case, - # we go through this branch and assert that the pass is not - # triggered. - assert "Fused quantization" not in caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text def run_model( From 490ac8610d9e8876dd79619b9b8f72332e5694d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 13:24:01 -0700 Subject: [PATCH 18/81] Add TP=2 test (untested) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 57 +++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index dffa221a9f7f..99b072cfd30f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -18,7 +18,7 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from ..utils import create_new_process_for_each_test +from ..utils import create_new_process_for_each_test, multi_gpu_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -237,6 +237,61 @@ def test_default_fusion( assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_default_fusion_tp2( + custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +): + model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} + backend = _Backend.FLASHINFER + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + model_kwargs["tensor_parallel_size"] = 2 + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with ( + caplog_vllm.at_level(logging.DEBUG), + global_force_attn_backend_context_manager(backend), + ): + run_model(compilation_config, model, model_kwargs) + + assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + + # TODO fill in correct number + assert "Replaced 5 patterns" in caplog_vllm.text, caplog_vllm.text + + def run_model( compile_config: Union[int, CompilationConfig], model: str, From d0b1b563b4118afe73b3dfb7431359ef82de1830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 26 Sep 2025 15:39:08 -0700 Subject: [PATCH 19/81] improve tests by adding more cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 151 ++++++++++++++++++++++++------ tests/compile/test_fusion_attn.py | 11 +-- tests/utils.py | 9 ++ 3 files changed, 131 insertions(+), 40 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 99b072cfd30f..2f18488424e7 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,9 +3,11 @@ from __future__ import annotations +import itertools import logging import tempfile -from typing import Any, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import pytest import torch @@ -18,7 +20,7 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from ..utils import create_new_process_for_each_test, multi_gpu_test +from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -103,7 +105,7 @@ def test_full_graph( with monkeypatch.context(): print(f"MODEL={model}") - run_model(optimization_level, model, model_kwargs) + run_model(optimization_level, model, **model_kwargs) # TODO(luka) add other supported compilation config scenarios here @@ -168,7 +170,49 @@ def test_custom_compile_config( model, model_kwargs = model_info print(f"MODEL={model}") - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model, **model_kwargs) + + +MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"max_model_len": 1024}, + _Backend.TRITON_ATTN, + ) + ] + + if current_platform.is_device_capability((10, 0)): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS += [ + ( + "meta-llama/Llama-3.1-8B-Instruct", + {"max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + +elif current_platform.is_rocm(): + MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] @pytest.mark.parametrize( @@ -183,23 +227,34 @@ def test_fp8_kv_scale_compile(optimization_level: int): "calculate_kv_scales": True, "max_model_len": 512, } - run_model(optimization_level, model, model_kwargs) + run_model(optimization_level, model, **model_kwargs) INDUCTOR_GRAPH_PARTITION = ( [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] ) +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + -@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_default_fusion( - custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +def test_e2e_fusion_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_vllm, + monkeypatch, ): - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} - backend = _Backend.FLASHINFER - custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -232,21 +287,47 @@ def test_default_fusion( caplog_vllm.at_level(logging.DEBUG), global_force_attn_backend_context_manager(backend), ): - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model_name, **model_kwargs) assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text +# TODO(luka) test both in nightly +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("custom_ops", ["+quant_fp8", "-quant_fp8"]) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_default_fusion_tp2( - custom_ops: str, inductor_graph_partition: bool, caplog_vllm, monkeypatch +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability((10, 0)), + reason="allreduce+rmsnorm fusion only supported on blackwell", +) +def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( + model_name, + model_kwargs, + backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_vllm, + monkeypatch, ): - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - model_kwargs = {"kv_cache_dtype": "fp8", "max_model_len": 1024} - backend = _Backend.FLASHINFER - custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -262,7 +343,6 @@ def test_default_fusion_tp2( monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - model_kwargs["tensor_parallel_size"] = 2 compilation_config = CompilationConfig( # Testing properties use_inductor_graph_partition=inductor_graph_partition, @@ -284,19 +364,25 @@ def test_default_fusion_tp2( caplog_vllm.at_level(logging.DEBUG), global_force_attn_backend_context_manager(backend), ): - run_model(compilation_config, model, model_kwargs) + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text # TODO fill in correct number - assert "Replaced 5 patterns" in caplog_vllm.text, caplog_vllm.text + assert "Replaced 96 patterns" in caplog_vllm.text, caplog_vllm.text def run_model( - compile_config: Union[int, CompilationConfig], - model: str, - model_kwargs: dict[str, Any], + compile_config: Union[int, CompilationConfig], model: str, **model_kwargs ): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) + ) + prompts = [ "Hello, my name is", "The president of the United States is", @@ -304,12 +390,17 @@ def run_model( "The future of AI is", ] sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + llm = LLM( model=model, - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - compilation_config=compile_config, + compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 7d672bc343b4..b52b573ec7e5 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -import itertools -from collections.abc import Iterable -from typing import Any import pytest import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend +from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata @@ -287,13 +285,6 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): USE_INDUCTOR_GRAPH_PARTITION = [False] -def flat_product(*iterables: Iterable[Any]): - """Flatten lists of tuples into cartesian product.""" - for element in itertools.product(*iterables): - normalized = (e if isinstance(e, tuple) else [e] for e in element) - yield list(itertools.chain(*normalized)) - - @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize( diff --git a/tests/utils.py b/tests/utils.py index b853542c241f..16ef6458cf50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,7 @@ import copy import functools import importlib +import itertools import json import os import random @@ -15,6 +16,7 @@ import tempfile import time import warnings +from collections.abc import Iterable from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path @@ -1260,3 +1262,10 @@ def check_answers( frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") assert frac_ok >= accept_rate + + +def flat_product(*iterables: Iterable[Any]): + """Flatten lists of tuples into cartesian product.""" + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else [e] for e in element) + yield list(itertools.chain(*normalized)) From 47b4688d1cdad8701e2aad381fd2568fc1bce78e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 27 Sep 2025 07:38:52 -0700 Subject: [PATCH 20/81] TEMP working on caplog MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 8 ++++++-- tests/conftest.py | 22 ++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2f18488424e7..b282e234572f 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -253,6 +253,7 @@ def test_e2e_fusion_attn_quant( custom_ops: str, inductor_graph_partition: bool, caplog_vllm, + caplog_mp_workaround, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -268,7 +269,7 @@ def test_e2e_fusion_attn_quant( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = CompilationConfig( # Testing properties @@ -285,6 +286,7 @@ def test_e2e_fusion_attn_quant( with ( caplog_vllm.at_level(logging.DEBUG), + caplog_mp_workaround(), global_force_attn_backend_context_manager(backend), ): run_model(compilation_config, model_name, **model_kwargs) @@ -319,6 +321,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: or not current_platform.has_device_capability((10, 0)), reason="allreduce+rmsnorm fusion only supported on blackwell", ) +@pytest.mark.skip(reason="Still no solution for capturing logs from subprocess") def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, model_kwargs, @@ -341,7 +344,8 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # TODO + # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = CompilationConfig( # Testing properties diff --git a/tests/conftest.py b/tests/conftest.py index 4713e1238596..b2fa96f48e8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ruff: noqa +import contextlib from tblib import pickling_support +# ruff: noqa + # Install support for pickling exceptions so that we can nicely propagate # failures from tests running in a subprocess. # This should be run before any custom exception subclasses are defined. @@ -1067,6 +1068,23 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +@pytest.fixture() +def caplog_mp_workaround(): + @contextlib.contextmanager + def ctx(): + import logging.handlers + import multiprocessing as mp + + logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() + logger = logging.getLogger() + logger.addHandler(logging.handlers.QueueHandler(logger_queue)) + yield + while not logger_queue.empty(): + logger.handle(logger_queue.get()) + + return ctx + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context From ae7f56f042876122127b04e199162b985334b747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 30 Sep 2025 12:50:28 -0700 Subject: [PATCH 21/81] Temp MP workaround P2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 89 +++++++++++++++++++++++++++++++++++++++++--- tests/test_logger.py | 17 +++++++++ 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b2fa96f48e8c..bbfe3eeac8f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import pathlib +from copy import deepcopy from tblib import pickling_support @@ -41,7 +43,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams, envs from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -1068,8 +1070,25 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog -@pytest.fixture() -def caplog_mp_workaround(): +@pytest.fixture(scope="session") +def caplog_mp_fork(): + """ + This fixture enables capturing logs from a forked MP subprocess. + It should be used in conjunction with caplog_vllm. + + By default, subprocess logs do not go through the parent process. + We instead create a queue listener in the parent process which + forwards logs to the logger's other handlers, and add a QueueHandler + to the root logger. Forked subprocesses will inherit the root logger + and pass their messages to the queue, which the listener will forward + to the root logger, which can be captured by caplog. + + Note that this workaround only works for fork; with spawn, the subprocess + reinitializes logging and does not automatically inherit the queue. + We'd have to manually pass the queue to the subprocess at the spawn point. + See caplog_mp_spawn below. + """ + @contextlib.contextmanager def ctx(): import logging.handlers @@ -1077,10 +1096,70 @@ def ctx(): logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() logger = logging.getLogger() + handlers = logger.handlers + + # The listener works on a background thread, not inherited by the child. + queue_listener = logging.handlers.QueueListener(logger_queue, *handlers) + queue_listener.start() + + # Add queue handler after creating the listener to avoid cycle logger.addHandler(logging.handlers.QueueHandler(logger_queue)) yield - while not logger_queue.empty(): - logger.handle(logger_queue.get()) + queue_listener.stop() + + return ctx + + +class LogHolder: + def __init__(self): + self.text = None + + +@pytest.fixture(scope="session") +def caplog_mp_spawn(tmp_path, monkeypatch): + """ + This fixture enables capturing logs from a forked MP subprocess. + It does not require caplog_vllm (but it only contains log + + By default, subprocess logs do not go through the parent process. + We instead add a FileHandler to the config so the spawned child process + writes its logs to a temp file and then return the contents. + + Note: this method could be extended to fork by either reconfiguring logging + in the parent or using a SocketHandler: + https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501 + """ + + @contextlib.contextmanager + def ctx(level: int | str): + from vllm.logger import DEFAULT_LOGGING_CONFIG + + config_path = tmp_path / "vllm_logging_config.json" + log_path = tmp_path / "vllm.log" + log_holder = LogHolder() + + config = deepcopy(DEFAULT_LOGGING_CONFIG) + if envs.VLLM_LOGGING_CONFIG_PATH: + path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH) + assert path.exists() + config = json.loads(path.read_text()) + + config["loggers"]["vllm"]["handlers"] += ["vllm_file"] + config["handlers"]["vllm_file"] = { + "class": "logging.FileHandler", + "formatter": "vllm", + "level": level, + "filename": log_path.as_posix(), + } + + config_path.write_text(json.dumps(config)) + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix()) + monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1") + yield log_holder + + log_holder.text = log_path.read_text() return ctx diff --git a/tests/test_logger.py b/tests/test_logger.py index ec368d4897b5..af006f1456b8 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -501,3 +501,20 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[1] == "test-streaming-full-text" assert call_args[2] == " (streaming complete)" assert call_args[5] == "streaming_complete" + + +test_logger = init_logger("vllm.test_logger") +# https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network + + +def mp_function(**kwargs): + # This function runs in a subprocess + + test_logger.warning("This is a subprocess: %s", kwargs.get("a")) + test_logger.error("This is a subprocess error.") + test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b")) + + +def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): + pass + # TODO From eb899a4d34d2b7ca6d77b36decbd46b3eb873e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 30 Sep 2025 12:55:33 -0700 Subject: [PATCH 22/81] Temp MP workaround P3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 3 ++- tests/test_logger.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bbfe3eeac8f7..df34924a6f70 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1123,7 +1123,8 @@ def caplog_mp_spawn(tmp_path, monkeypatch): By default, subprocess logs do not go through the parent process. We instead add a FileHandler to the config so the spawned child process - writes its logs to a temp file and then return the contents. + writes its logs to a temp file. + In the parent, we read the file and return the contents. Note: this method could be extended to fork by either reconfiguring logging in the parent or using a SocketHandler: diff --git a/tests/test_logger.py b/tests/test_logger.py index af006f1456b8..22e084991343 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -504,7 +504,6 @@ def test_streaming_complete_logs_full_text_content(): test_logger = init_logger("vllm.test_logger") -# https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network def mp_function(**kwargs): From a2aa9787df6ed5be6fd9f6010e0e330702052036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 1 Oct 2025 11:21:02 -0700 Subject: [PATCH 23/81] Test for caplog utils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 4 ++-- tests/test_logger.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index df34924a6f70..7a907b2ac79f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1070,7 +1070,7 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog -@pytest.fixture(scope="session") +@pytest.fixture() def caplog_mp_fork(): """ This fixture enables capturing logs from a forked MP subprocess. @@ -1115,7 +1115,7 @@ def __init__(self): self.text = None -@pytest.fixture(scope="session") +@pytest.fixture() def caplog_mp_spawn(tmp_path, monkeypatch): """ This fixture enables capturing logs from a forked MP subprocess. diff --git a/tests/test_logger.py b/tests/test_logger.py index 22e084991343..f1c31c245475 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -515,5 +515,34 @@ def mp_function(**kwargs): def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): - pass - # TODO + with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork(): + import multiprocessing + + ctx = multiprocessing.get_context("fork") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in caplog_vllm.text + assert "BBBBB" in caplog_vllm.text + + +def test_caplog_mp_spawn(caplog_mp_spawn): + with caplog_mp_spawn(logging.DEBUG) as log_holder: + import multiprocessing + + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in log_holder.text + assert "BBBBB" in log_holder.text From 21a9f9f42b21f46190838a0336438fe5c091e728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 1 Oct 2025 19:02:24 -0700 Subject: [PATCH 24/81] Fixed tests, passing with 2.8, 2.9 tbd MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 48 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b282e234572f..b6f7aba6821c 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -10,12 +10,12 @@ from typing import Any, Optional, Union import pytest +import regex as re import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer @@ -235,7 +235,8 @@ def test_fp8_kv_scale_compile(optimization_level: int): ) # TODO(luka) test both in nightly -CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] +# TODO(luka) change to - +CUSTOM_OPS_FP8 = ["+quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( @@ -252,8 +253,7 @@ def test_e2e_fusion_attn_quant( backend: _Backend, custom_ops: str, inductor_graph_partition: bool, - caplog_vllm, - caplog_mp_workaround, + caplog_mp_spawn, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -269,7 +269,11 @@ def test_e2e_fusion_attn_quant( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) compilation_config = CompilationConfig( # Testing properties @@ -284,18 +288,15 @@ def test_e2e_fusion_attn_quant( inductor_compile_config={"force_disable_caches": True}, ) - with ( - caplog_vllm.at_level(logging.DEBUG), - caplog_mp_workaround(), - global_force_attn_backend_context_manager(backend), - ): + with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text # TODO(luka) test both in nightly -CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] +# TODO(luka) change to - +CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @@ -321,14 +322,13 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: or not current_platform.has_device_capability((10, 0)), reason="allreduce+rmsnorm fusion only supported on blackwell", ) -@pytest.mark.skip(reason="Still no solution for capturing logs from subprocess") def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, model_kwargs, backend, custom_ops: str, inductor_graph_partition: bool, - caplog_vllm, + caplog_mp_spawn, monkeypatch, ): custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -344,8 +344,11 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( # Otherwise, we can't verify fusion happened through the logs. # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - # TODO - # monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) compilation_config = CompilationConfig( # Testing properties @@ -364,18 +367,17 @@ def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( inductor_compile_config={"force_disable_caches": True}, ) - with ( - caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(backend), - ): + with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - assert "Fused quant onto 48 attention nodes" in caplog_vllm.text, caplog_vllm.text + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - # TODO fill in correct number - assert "Replaced 96 patterns" in caplog_vllm.text, caplog_vllm.text + matches = re.findall( + r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + ) + assert len(matches) == 2, log_holder.text def run_model( From 66a35a90b724f53037395bc96fcef87ea8b3b172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 19:26:42 -0400 Subject: [PATCH 25/81] Update tests/compile/backend.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index fb92fd7b42a5..ac62040287d2 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -4,7 +4,6 @@ import weakref from collections.abc import Sequence from copy import deepcopy -from pathlib import Path from typing import Callable, Union import depyf @@ -55,12 +54,8 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass - if compile_config.debug_dump_path: - self.debug_dump_path = ( - Path(compile_config.debug_dump_path) - / f"rank_{vllm_config.parallel_config.rank}" - ) - self.ctx = depyf.prepare_debug(str(self.debug_dump_path)) + if debug_dump_path := vllm_config.compile_debug_dump_path(): + self.ctx = depyf.prepare_debug(debug_dump_path.as_posix()) self.ctx.__enter__() else: self.ctx = None From 7eb1364457d3ebc65d0a156a4a48ccef122bba5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 19:26:48 -0400 Subject: [PATCH 26/81] Update csrc/layernorm_kernels.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index b738cdbbdc53..b037531cceb5 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -380,6 +380,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); From 5fef1804edebe5d0bb441f0e7620f691164707d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 16:35:31 -0700 Subject: [PATCH 27/81] clean up fullgraph tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 43 ++++++++++++++------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b6f7aba6821c..d5d22844a223 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -173,6 +173,21 @@ def test_custom_compile_config( run_model(compilation_config, model, **model_kwargs) +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], +) +def test_fp8_kv_scale_compile(optimization_level: int): + model = "Qwen/Qwen2-0.5B" + model_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": "fp8_e4m3", + "calculate_kv_scales": True, + "max_model_len": 512, + } + run_model(optimization_level, model, **model_kwargs) + + MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only @@ -214,29 +229,12 @@ def test_custom_compile_config( elif current_platform.is_rocm(): MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] - -@pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], -) -def test_fp8_kv_scale_compile(optimization_level: int): - model = "Qwen/Qwen2-0.5B" - model_kwargs = { - "quantization": "fp8", - "kv_cache_dtype": "fp8_e4m3", - "calculate_kv_scales": True, - "max_model_len": 512, - } - run_model(optimization_level, model, **model_kwargs) - - INDUCTOR_GRAPH_PARTITION = ( - [False, True] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] + [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] ) # TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_FP8 = ["+quant_fp8"] # , "+quant_fp8"] +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( @@ -308,11 +306,8 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @pytest.mark.parametrize( "model_name, model_kwargs, backend, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models - list( - flat_product( - MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) - ) - ) + list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) + # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) From db479ae069e833e0c48186bbbb40ef7173c4485d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 16:51:30 -0700 Subject: [PATCH 28/81] TEMP allreduce fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 37 +++---- vllm/compilation/collective_fusion.py | 129 +++++++++--------------- 2 files changed, 66 insertions(+), 100 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7e5c460db174..7d63a380d72c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -17,6 +17,7 @@ ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -233,24 +234,26 @@ def all_reduce_fusion_pass_on_test_model( vllm_config.model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend( + all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass + ) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) + residual = torch.randn((token_num, hidden_size), requires_grad=False) - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states, residual) - assert all_reduce_fusion_pass.matched_count == 1 - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + assert all_reduce_fusion_pass.matched_count == 1 + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 988a1069cd9e..b41655ffd130 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,10 +18,14 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -646,6 +650,19 @@ def get_trtllm_fused_allreduce_kwargs(self): } +class BaseAllReduceRMSNormPattern(BasePattern): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + class AllReduceRMSNormPattern(BasePattern): """ This pattern replaces the allreduce + rms norm (without residual) @@ -663,33 +680,24 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) - return [input, rms_result, weight] + return [input, weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + def pattern(input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_OP, - result=rms_result, - input=allreduce_output, - weight=weight, - epsilon=self.epsilon, - ) - # rms_result, allreduce_output - return rms[1], allreduce_output + rms = self.rmsnorm_matcher(allreduce_output, weight) - def replacement( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + return rms, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -727,6 +735,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -741,15 +750,8 @@ def get_inputs(self): def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - # input, residual - return rms[1], rms[2] + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + return allreduce_output, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -793,60 +795,36 @@ def __init__( self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.dtype - ) - quant_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.quant_dtype - ) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [input, weight, scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce - def replacement( - input: torch.Tensor, - result_rms: torch.Tensor, - quant_result: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=result_rms, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -886,19 +864,18 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty( - [4, 4], device=self.device, dtype=self.quant_dtype - ) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ - quant_result, residual, input, weight, @@ -906,44 +883,30 @@ def get_inputs(): ] def pattern( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) + rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) + quant, _ = self.quant_matcher(rms, scale) - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant, allreduce_output def replacement( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=None, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, From 54189a9f880335a73c4d5ec8601027c7aeb43bc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 21:24:51 -0400 Subject: [PATCH 29/81] allreduce fusion working (custom ops on) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 8 +++----- vllm/compilation/collective_fusion.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7d63a380d72c..88305c0ed85c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -84,16 +84,13 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.norm = RMSNorm(hidden_size, eps) self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant( - self.output, norm_output.contiguous(), self.scale - ) - return self.output, residual_output + quant_out, _ = self.quant_fp8(norm_output, self.scale) + return quant_out, residual_output def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -227,6 +224,7 @@ def all_reduce_fusion_pass_on_test_model( enable_fi_allreduce_fusion=True, enable_noop=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank # Setup rank for debug path # this is a fake model name to construct the model config # in the vllm_config, it's not really used. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b41655ffd130..7d212ef17fb4 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -751,7 +751,7 @@ def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) - return allreduce_output, residual + return rms, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -892,7 +892,7 @@ def pattern( rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) quant, _ = self.quant_matcher(rms, scale) - return quant, allreduce_output + return quant, res def replacement( residual: torch.Tensor, From b7f52bf2fe31b1215b0dc3c81f801944671989f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:12:04 -0400 Subject: [PATCH 30/81] allreduce fusion working with/without custom ops (except fp4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 46 +++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 88305c0ed85c..12fa56826840 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -66,8 +66,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) - norm, _ = self.norm(all_reduce, residual) - return norm + norm, res = self.norm(all_reduce, residual) + + return norm, res def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -98,7 +99,9 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default, + torch.ops._C.static_scaled_fp8_quant.default + if self.quant_fp8.enabled() + else torch.ops.aten.reciprocal.default, ] @@ -139,19 +142,21 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model", + "test_model, enable_quant_fp8", [ - TestAllReduceRMSNormModel, - TestAllReduceFusedAddRMSNormModel, - TestAllReduceFusedAddRMSNormStaticQuantFP8Model, + (TestAllReduceRMSNormModel, False), + (TestAllReduceFusedAddRMSNormModel, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), + (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), # TODO: Enable with torch==2.8.0 - # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + # (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("enable_rms_norm", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -165,6 +170,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm, + enable_quant_fp8, ): num_processes = 2 if ( @@ -179,7 +186,16 @@ def test_all_reduce_fusion_pass_replace( def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + enable_rms_norm, + enable_quant_fp8, + ), nprocs=nprocs, ) @@ -194,6 +210,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm, + enable_quant_fp8, ): current_platform.seed_everything(0) @@ -215,9 +233,15 @@ def all_reduce_fusion_pass_on_test_model( init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) + custom_ops = [] + if enable_rms_norm: + custom_ops.append("+rms_norm") + if enable_quant_fp8: + custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + level=CompilationLevel.PIECEWISE, custom_ops=custom_ops ) ) vllm_config.compilation_config.pass_config = PassConfig( @@ -239,7 +263,7 @@ def all_reduce_fusion_pass_on_test_model( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend( - all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass + noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass ) token_num = batch_size * seq_len From d09a278fa869c8e625403cca3e578f6806ca5623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:16:24 -0400 Subject: [PATCH 31/81] allreduce fusion working with/without custom ops (with fp4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 3 +-- vllm/compilation/collective_fusion.py | 30 ++++++------------------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 12fa56826840..657ebc4a28a6 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -148,8 +148,7 @@ def ops_in_model_before(self): (TestAllReduceFusedAddRMSNormModel, False), (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), - # TODO: Enable with torch==2.8.0 - # (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7d212ef17fb4..d5a3fcde03b6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -943,6 +943,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -976,18 +977,11 @@ def pattern( output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - + rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) @@ -1047,6 +1041,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -1078,28 +1073,17 @@ def pattern( input_global_scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) - - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=fused_add_rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) # quant_out, allreduce_output, output_scale - return ( - quant_out_tuple[1], - fused_add_rmsnorm_out_tuple[2], - quant_out_tuple[2], - ) + return quant_out_tuple[1], residual, quant_out_tuple[2] def replacement( quant_result: torch.Tensor, From c8675ffdbcf5167da76fae8ffa6d3ffdd0e30146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 2 Oct 2025 22:18:24 -0400 Subject: [PATCH 32/81] log depyf folder, fix context for TestBackend, fix pattern dump MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 18 +++++++++++------- vllm/compilation/monitor.py | 1 + vllm/compilation/vllm_inductor_pass.py | 3 ++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ac62040287d2..a16ab9f15c9f 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -3,6 +3,7 @@ import weakref from collections.abc import Sequence +from contextlib import nullcontext from copy import deepcopy from typing import Callable, Union @@ -16,6 +17,9 @@ from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger("vllm.tests.compile.backend") class LazyInitPass(InductorPass): @@ -55,16 +59,19 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.inductor_config["post_grad_custom_post_pass"] = self.post_pass if debug_dump_path := vllm_config.compile_debug_dump_path(): - self.ctx = depyf.prepare_debug(debug_dump_path.as_posix()) - self.ctx.__enter__() + logger.debug("Dumping depyf output to %s", debug_dump_path) + self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix()) else: - self.ctx = None + self.debug_ctx = nullcontext() def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, example_inputs, config_patches=self.inductor_config) + with self.debug_ctx: + return compile_fx( + graph, example_inputs, config_patches=self.inductor_config + ) @with_pattern_match_debug def post_pass(self, graph: fx.Graph): @@ -83,9 +90,6 @@ def post_pass(self, graph: fx.Graph): # assign by reference, will reflect the final state of the graph self.final_graph = graph - if self.ctx is not None: - self.ctx.__exit__(None, None, None) - def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index d3c437795fab..f9a189b7c77d 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): import depyf path.mkdir(parents=True, exist_ok=True) + logger.debug("Dumping depyf output to %s", path) global context_manager context_manager = depyf.prepare_debug(path.as_posix()) context_manager.__enter__() diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 5aa08220bc2d..b7b3c98eb4ed 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -115,7 +115,8 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): f" please add to dump_patterns if there are any errors.\n\n" f"from torch._higher_order_ops.auto_functionalize import " f"auto_functionalized as auto_functionalized\n" - f"from torch._inductor.pattern_matcher import *", + f"from torch._inductor.pattern_matcher import *\n" + f"vllm = torch.ops.vllm", file=f, ) From d3f95feda3df0127fb40a33828bd880cee2c9c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 11:38:39 -0400 Subject: [PATCH 33/81] fullgraph allreduce test update requirements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index d5d22844a223..beaeed30f004 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -19,6 +19,7 @@ from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test @@ -314,8 +315,9 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif( not current_platform.is_cuda() - or not current_platform.has_device_capability((10, 0)), - reason="allreduce+rmsnorm fusion only supported on blackwell", + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", ) def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( model_name, From 4dbfcf7017116e622f7365beb9fc6562076fd53d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 11:49:24 -0400 Subject: [PATCH 34/81] Move e2e tests to new file, add to test pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 13 +- tests/compile/test_full_graph.py | 198 +----------------------- tests/compile/test_fusions_e2e.py | 246 ++++++++++++++++++++++++++++++ 3 files changed, 255 insertions(+), 202 deletions(-) create mode 100644 tests/compile/test_fusions_e2e.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ebe0602a1b5d..f734526db130 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -796,8 +796,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 38 min - timeout_in_minutes: 60 +- label: Blackwell Test # 48 min + timeout_in_minutes: 70 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -810,8 +810,7 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/fusion.py - - vllm/compilation/fusion_attn.py + - vllm/compilation/ commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -828,6 +827,8 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py # Fusion @@ -835,8 +836,7 @@ steps: - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py - - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py - - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -1109,6 +1109,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index beaeed30f004..402e6499b9d6 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,25 +3,19 @@ from __future__ import annotations -import itertools -import logging import tempfile -from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Union import pytest -import regex as re import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.attention.backends.registry import _Backend from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer -from vllm.utils.flashinfer import has_flashinfer -from ..utils import create_new_process_for_each_test, flat_product, multi_gpu_test +from ..utils import create_new_process_for_each_test def models_list(*, all: bool = True, keywords: list[str] | None = None): @@ -189,194 +183,6 @@ def test_fp8_kv_scale_compile(optimization_level: int): run_model(optimization_level, model, **model_kwargs) -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only - -if current_platform.is_cuda(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"max_model_len": 1024}, - _Backend.TRITON_ATTN, - ) - ] - - if current_platform.is_device_capability((10, 0)): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS += [ - ( - "meta-llama/Llama-3.1-8B-Instruct", - {"max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - -elif current_platform.is_rocm(): - MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] - -INDUCTOR_GRAPH_PARTITION = ( - [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] -) - -# TODO(luka) test both in nightly -CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] - - -@pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", - # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 - list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) - # quant_fp4 only has the custom impl - + list(flat_product(MODELS_FP4, [""])), -) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -def test_e2e_fusion_attn_quant( - model_name: str, - model_kwargs: dict[str, Any], - backend: _Backend, - custom_ops: str, - inductor_graph_partition: bool, - caplog_mp_spawn, - monkeypatch, -): - custom_ops_list = custom_ops.split(",") if custom_ops else [] - - if inductor_graph_partition: - mode = CUDAGraphMode.FULL_AND_PIECEWISE - splitting_ops: Optional[list[str]] = None - else: - mode = CUDAGraphMode.FULL_DECODE_ONLY - splitting_ops = [] - - # Disable, compile cache to make sure custom passes run. - # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - - # To capture subprocess logs, we need to know whether spawn or fork is used. - # Force spawn as it is more general. - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) - - compilation_config = CompilationConfig( - # Testing properties - custom_ops=custom_ops_list, - use_inductor_graph_partition=inductor_graph_partition, - cudagraph_mode=mode, - splitting_ops=splitting_ops, - # Common - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), - # Inductor caches custom passes by default as well via uuid - inductor_compile_config={"force_disable_caches": True}, - ) - - with caplog_mp_spawn(logging.DEBUG) as log_holder: - run_model(compilation_config, model_name, **model_kwargs) - - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - - -# TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] - - -def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: - for op_list in itertools.product(*custom_ops_lists): - yield ",".join(op_list) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", - # Toggle RMSNorm and QuantFP8 for FP8 models - list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) - # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO - # Toggle RMSNorm for FP4 models and unquant models - + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), -) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) -@pytest.mark.skipif( - not current_platform.is_cuda() - or not has_flashinfer() - or not current_platform.has_device_capability(90), - reason="allreduce+rmsnorm fusion requires flashinfer", -) -def test_e2e_fusion_tp2_attn_quant_allreduce_rmsnorm( - model_name, - model_kwargs, - backend, - custom_ops: str, - inductor_graph_partition: bool, - caplog_mp_spawn, - monkeypatch, -): - custom_ops_list = custom_ops.split(",") if custom_ops else [] - - if inductor_graph_partition: - mode = CUDAGraphMode.FULL_AND_PIECEWISE - splitting_ops: Optional[list[str]] = None - else: - mode = CUDAGraphMode.FULL_DECODE_ONLY - splitting_ops = [] - - # Disable, compile cache to make sure custom passes run. - # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - - # To capture subprocess logs, we need to know whether spawn or fork is used. - # Force spawn as it is more general. - monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) - - compilation_config = CompilationConfig( - # Testing properties - use_inductor_graph_partition=inductor_graph_partition, - cudagraph_mode=mode, - custom_ops=custom_ops_list, - splitting_ops=splitting_ops, - # Common - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_fi_allreduce_fusion=True, - ), - # Inductor caches custom passes by default as well via uuid - inductor_compile_config={"force_disable_caches": True}, - ) - - with caplog_mp_spawn(logging.DEBUG) as log_holder: - run_model( - compilation_config, model_name, tensor_parallel_size=2, **model_kwargs - ) - - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text - - matches = re.findall( - r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text - ) - assert len(matches) == 2, log_holder.text - - def run_model( compile_config: Union[int, CompilationConfig], model: str, **model_kwargs ): diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py new file mode 100644 index 000000000000..b0700e4e86a4 --- /dev/null +++ b/tests/compile/test_fusions_e2e.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import itertools +import logging +from collections.abc import Iterable +from typing import Any, Optional, Union + +import pytest +import regex as re + +from tests.v1.attention.utils import _Backend +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig +from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer + +from ..utils import flat_product, multi_gpu_test + +MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] +MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"max_model_len": 1024}, + _Backend.TRITON_ATTN, + ) + ] + + if current_platform.is_device_capability((10, 0)): + MODELS_FP8 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS_FP4 += [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + {"kv_cache_dtype": "fp8", "max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + + MODELS += [ + ( + "meta-llama/Llama-3.1-8B-Instruct", + {"max_model_len": 1024}, + _Backend.FLASHINFER, + ) + ] + +elif current_platform.is_rocm(): + MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] + +INDUCTOR_GRAPH_PARTITION = ( + [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] +) + +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + + +# TODO(luka) test both in nightly +# TODO(luka) change to - +CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) + # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif( + not current_platform.is_cuda() + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", +) +def test_tp2_attn_quant_allreduce_rmsnorm( + model_name, + model_kwargs, + backend, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: Optional[list[str]] = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + # Log capture also doesn't work with multiprocessing yet. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + + assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + + matches = re.findall( + r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + ) + assert len(matches) == 2, log_holder.text + + +def run_model( + compile_config: Union[int, CompilationConfig], model: str, **model_kwargs +): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + llm = LLM( + model=model, + compilation_config=compilation_config, + **model_kwargs, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 31d0127c71e73b6ee257cf564833681da71cff1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 13:01:13 -0400 Subject: [PATCH 35/81] Add e2e fusions to fullgraph test (should work with Triton backend), disable without flashinfer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 5 +++-- tests/compile/test_fusions_e2e.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f734526db130..85616de5b197 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -416,8 +416,8 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 20min - timeout_in_minutes: 30 +- label: PyTorch Fullgraph Test # 22min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -425,6 +425,7 @@ steps: - tests/compile commands: - pytest -v -s compile/test_full_graph.py + - pytest -v -s compile/test_fusions_e2e.py - label: Kernels Core Operation Test # 48min timeout_in_minutes: 75 diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index b0700e4e86a4..cbeaa8bcb3f3 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -33,7 +33,7 @@ ) ] - if current_platform.is_device_capability((10, 0)): + if current_platform.is_device_capability((10, 0)) and has_flashinfer(): MODELS_FP8 += [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", @@ -97,7 +97,6 @@ def test_attn_quant( # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") # To capture subprocess logs, we need to know whether spawn or fork is used. @@ -170,7 +169,6 @@ def test_tp2_attn_quant_allreduce_rmsnorm( # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. - # Log capture also doesn't work with multiprocessing yet. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") # To capture subprocess logs, we need to know whether spawn or fork is used. From c653d24a39d32f4e227fd845335b539366ea35e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 3 Oct 2025 23:39:23 -0400 Subject: [PATCH 36/81] Fix spelling, precommit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/conftest.py | 2 +- vllm/compilation/matcher_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7a907b2ac79f..76bbccc29534 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1119,7 +1119,7 @@ def __init__(self): def caplog_mp_spawn(tmp_path, monkeypatch): """ This fixture enables capturing logs from a forked MP subprocess. - It does not require caplog_vllm (but it only contains log + It does not require caplog_vllm (but it only contains logs from the child). By default, subprocess logs do not go through the parent process. We instead add a FileHandler to the config so the spawned child process diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index d3603372d69f..55fbeadc22fe 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -57,6 +57,10 @@ def empty(self, *args, **kws): def empty_f32(self, *args, **kws): return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + def inputs(self) -> list[torch.Tensor]: + """Utility for inputs to the pattern""" + raise NotImplementedError + class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): From 1756f6755970f14c3cf643cb522c69208f80c3b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 4 Oct 2025 00:06:13 -0400 Subject: [PATCH 37/81] add back fp4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/matcher_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 55fbeadc22fe..fe558b7acac2 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -16,7 +16,9 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, + kNvfp4Quant, ) +from vllm.platforms import current_platform RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default @@ -27,10 +29,8 @@ kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } -# TODO -# if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): -# QUANT_OPS[ -# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class MatcherCustomOp(ABC): From 5619bc38bc781cd70f8f3c12124fee3042ff7437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 9 Oct 2025 21:42:34 -0400 Subject: [PATCH 38/81] clean up e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 147 ++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 48 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index cbeaa8bcb3f3..6e4893cd0f66 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -10,6 +10,7 @@ import pytest import regex as re +from black.cache import NamedTuple from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams @@ -20,72 +21,111 @@ from ..utils import flat_product, multi_gpu_test -MODELS_FP8: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS_FP4: list[tuple[str, dict[str, Any], _Backend]] = [] -MODELS: list[tuple[str, dict[str, Any], _Backend]] = [] # tp-only + +class ModelBackendTestCase(NamedTuple): + model_name: str + model_kwargs: dict[str, Any] + backend: _Backend + attention_fusions: int + allreduce_fusions: Optional[int] = None + + +MODELS_FP8: list[ModelBackendTestCase] = [] +MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS: list[ModelBackendTestCase] = [] # tp-only if current_platform.is_cuda(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"max_model_len": 1024}, - _Backend.TRITON_ATTN, - ) + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=48, + allreduce_fusions=96, + ), + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), ] - if current_platform.is_device_capability((10, 0)) and has_flashinfer(): - MODELS_FP8 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - {"kv_cache_dtype": "fp8", "max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] - - MODELS += [ - ( - "meta-llama/Llama-3.1-8B-Instruct", - {"max_model_len": 1024}, - _Backend.FLASHINFER, - ) - ] + MODELS_FP4 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] -elif current_platform.is_rocm(): - MODELS_FP8 += [("amd/Llama-3.1-8B-Instruct-FP8-KV", {}, _Backend.TRITON_ATTN)] + # TP only + MODELS = [ + ModelBackendTestCase( + model_name="meta-llama/Llama-3.1-8B-Instruct", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=64, + ), + ] -INDUCTOR_GRAPH_PARTITION = ( - [True, False] if (is_torch_equal_or_newer("2.9.0.dev")) else [False] -) +elif current_platform.is_rocm(): + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN + attention_fusions=32, + ), + ] # TODO(luka) test both in nightly CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] @pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) # quant_fp4 only has the custom impl + list(flat_product(MODELS_FP4, [""])), ) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, monkeypatch, ): + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -120,7 +160,9 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( + log_holder.text + ) # TODO(luka) test both in nightly @@ -135,14 +177,15 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model_name, model_kwargs, backend, custom_ops", + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) -@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda() or not has_flashinfer() @@ -150,14 +193,19 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: reason="allreduce+rmsnorm fusion requires flashinfer", ) def test_tp2_attn_quant_allreduce_rmsnorm( - model_name, - model_kwargs, - backend, + model_name: str, + model_kwargs: dict, + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, monkeypatch, ): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + custom_ops_list = custom_ops.split(",") if custom_ops else [] if inductor_graph_partition: @@ -198,10 +246,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - assert "Fused quant onto 48 attention nodes" in log_holder.text, log_holder.text + assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( + log_holder.text + ) matches = re.findall( - r"\[collective_fusion.py:\d+] Replaced 96 patterns", log_holder.text + rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns", + log_holder.text, ) assert len(matches) == 2, log_holder.text From 32989d804e03ec0a3e9c97ee84b684e4683d0a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 13:49:09 -0400 Subject: [PATCH 39/81] add pattern for final allreduce in model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d5a3fcde03b6..d6ad0cd38f94 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -775,6 +775,18 @@ def replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) + # Same pattern, but only return the output and not residual + # (helpful for end of graph where residual is not used again) + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + + pm.register_replacement( + first_return_only(pattern), + first_return_only(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ From 46ee6267b7b8f78fe038bbcda5650a26a2031133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 13:51:13 -0400 Subject: [PATCH 40/81] add more comprehensive testing for quantfp8 (-rmsnorm+-quant still failing) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 89 +++++++++++++++++-------- 1 file changed, 62 insertions(+), 27 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 657ebc4a28a6..fa0293497aba 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,8 +26,8 @@ ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, GroupShape, - QuantFP8, ) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -43,9 +43,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.eps = eps self.norm = RMSNorm(hidden_size, eps) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, x): + z = torch.relu(x) + all_reduce = tensor_model_parallel_all_reduce(z) norm = self.norm(all_reduce) return norm @@ -63,9 +63,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.eps = eps self.norm = RMSNorm(hidden_size, eps) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, hidden_states): + z = residual = torch.relu(hidden_states) + all_reduce = tensor_model_parallel_all_reduce(z) norm, res = self.norm(all_reduce, residual) return norm, res @@ -77,21 +77,53 @@ def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] -class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) - self.scale = torch.rand(1, dtype=torch.float32) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - quant_out, _ = self.quant_fp8(norm_output, self.scale) - return quant_out, residual_output + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -100,7 +132,7 @@ def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.quant_fp8.enabled() + if self.fp8_linear.quant_fp8.enabled() else torch.ops.aten.reciprocal.default, ] @@ -120,11 +152,10 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): rounded_n = round_up(scale_n, 4) self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) + def forward(self, hidden_states): + z = residual = torch.relu(hidden_states) + all_reduce = tensor_model_parallel_all_reduce(z) norm_output, residual_output = self.norm(all_reduce, residual) - norm_output = norm_output.reshape(-1, norm_output.shape[-1]) torch.ops._C.scaled_fp4_quant( self.output, norm_output, self.output_scale, self.scale ) @@ -146,8 +177,8 @@ def ops_in_model_before(self): [ (TestAllReduceRMSNormModel, False), (TestAllReduceFusedAddRMSNormModel, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, True), - (TestAllReduceFusedAddRMSNormStaticQuantFP8Model, False), + (TestAllReduceRMSNormStaticQuantFP8Model, True), + (TestAllReduceRMSNormStaticQuantFP8Model, False), (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @@ -269,12 +300,16 @@ def all_reduce_fusion_pass_on_test_model( model = test_model_cls(hidden_size, token_num) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model(hidden_states) - assert all_reduce_fusion_pass.matched_count == 1 + # TODO cleanup + expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1 + + assert all_reduce_fusion_pass.matched_count == expected, ( + f"{all_reduce_fusion_pass.matched_count=}, {expected=}" + ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass From a1c7fdb32ad96648ebfdf94c79069d448ffbb2a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 16:13:42 -0400 Subject: [PATCH 41/81] add more comprehensive testing for allreduce-rmsnorm, fix fp4 (-rmsnorm still failing) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_all_reduce.py | 97 +++++++++++++++---------- vllm/compilation/collective_fusion.py | 16 +--- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index fa0293497aba..0c9d584ddf46 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -6,6 +6,7 @@ import torch import vllm.envs as envs +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -41,34 +42,30 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] def forward(self, x): + # avoid having graph input be an arg to a pattern directly z = torch.relu(x) - all_reduce = tensor_model_parallel_all_reduce(z) - norm = self.norm(all_reduce) - return norm + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - def ops_in_model_before(self): - return [torch.ops.vllm.all_reduce.default] + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - def ops_in_model_after(self): - return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + y2, resid = self.norm[1](x2, resid) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) -class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): - super().__init__() - self.hidden_size = hidden_size - self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + y3, resid = self.norm[2](x3, resid) - def forward(self, hidden_states): - z = residual = torch.relu(hidden_states) - all_reduce = tensor_model_parallel_all_reduce(z) - norm, res = self.norm(all_reduce, residual) + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - return norm, res + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -142,24 +139,48 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)] - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(token_num, 128) - scale_n = hidden_size // 16 - rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) + wq_gen, wscale_gen = zip( + *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale)) + ) + self.wq, self.wscale = list(wq_gen), list(wscale_gen) + print(f"{self.wq=}, {self.wscale=}") def forward(self, hidden_states): - z = residual = torch.relu(hidden_states) - all_reduce = tensor_model_parallel_all_reduce(z) - norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.scaled_fp4_quant( - self.output, norm_output, self.output_scale, self.scale + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + yq, y_scale = scaled_fp4_quant(y, self.agscale[0]) + z2 = cutlass_scaled_fp4_mm( + yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1]) + z3 = cutlass_scaled_fp4_mm( + yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype + ) + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2]) + z4 = cutlass_scaled_fp4_mm( + yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype ) - return self.output, residual_output, self.output_scale + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -176,7 +197,6 @@ def ops_in_model_before(self): "test_model, enable_quant_fp8", [ (TestAllReduceRMSNormModel, False), - (TestAllReduceFusedAddRMSNormModel, False), (TestAllReduceRMSNormStaticQuantFP8Model, True), (TestAllReduceRMSNormStaticQuantFP8Model, False), (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), @@ -184,7 +204,7 @@ def ops_in_model_before(self): ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("enable_rms_norm", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @@ -304,11 +324,8 @@ def all_reduce_fusion_pass_on_test_model( compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) - # TODO cleanup - expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1 - - assert all_reduce_fusion_pass.matched_count == expected, ( - f"{all_reduce_fusion_pass.matched_count=}, {expected=}" + assert all_reduce_fusion_pass.matched_count == 4, ( + f"{all_reduce_fusion_pass.matched_count=}" ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d6ad0cd38f94..cc4f2152e1c5 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -960,10 +960,6 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - - rmsnorm_result = torch.empty( - [1, 16, 16], device=self.device, dtype=self.dtype - ) quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) input_global_scale = torch.empty( [1, 1], device=self.device, dtype=torch.float32 @@ -971,18 +967,10 @@ def get_inputs(): weight = torch.empty([16], device=self.device, dtype=self.dtype) output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [ - input, - rmsnorm_result, - quant_result, - weight, - input_global_scale, - output_scale, - ] + return [input, quant_result, weight, input_global_scale, output_scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, @@ -1003,13 +991,13 @@ def pattern( def replacement( input: torch.Tensor, - result_rms: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, From c3264d849f1ca0d8736e39c0c25f6420930105a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 18:36:15 -0400 Subject: [PATCH 42/81] Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 40 +++++++++++++++++++++++-- vllm/compilation/fusion.py | 15 +++++----- vllm/compilation/fx_utils.py | 16 ++++++++-- vllm/model_executor/layers/layernorm.py | 5 +++- 4 files changed, 62 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 13cffbe087c6..4ab450827609 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,7 +5,9 @@ import torch import vllm.plugins -from vllm.compilation.fusion import RMSNormQuantFusionPass +from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass +from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( @@ -33,6 +35,9 @@ FP8_DTYPE = current_platform.fp8_dtype() +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + class TestModel(torch.nn.Module): def __init__( @@ -50,7 +55,7 @@ def __init__( self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: @@ -93,6 +98,22 @@ def forward(self, x): y4, resid = self.norm[3](x4, resid) # use resid here return y4 + def ops_in_model_after(self): + return [ + FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + ] + + def ops_in_model_before(self): + return ( + [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8 + else [torch.ops.aten.reciprocal] + ) + + def ops_in_model_before_partial(self): + return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] + @pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @@ -164,3 +185,18 @@ def test_fusion_rmsnorm_quant( torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) assert fusion_pass.matched_count == 3 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_before_ops( + model.ops_in_model_before_partial(), fully_replaced=False + ) + backend.check_after_ops(model.ops_in_model_after()) + + # If RMSNorm custom op is disabled (native/torch impl used), + # there's a risk that the fused add doesn't get included in the + # replacement and only the rms part gets fused with quant. + # Hence, we check only 2 add nodes are left (final fused rmsnorm add). + if not enable_rms_norm: + n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) + # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) + assert n_add_nodes(backend.graph_pre_pass) == 7 + assert n_add_nodes(backend.graph_post_pass) == 2 diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 883743b635a8..9ace7a8cf050 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -94,9 +94,6 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" - self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -334,23 +331,25 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 114b53c74c48..3209c49eba26 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -3,11 +3,11 @@ import operator from collections.abc import Iterable, Iterator -from typing import Optional +from typing import Optional, Union from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._ops import OpOverload +from torch._ops import OpOverload, OpOverloadPacket def is_func(node: fx.Node, target) -> bool: @@ -67,7 +67,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: # An auto-functionalization-aware utility for finding nodes with a specific op -def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: +# Also handles op overload packets and finds all overloads +def find_op_nodes( + op: Union[OpOverload, OpOverloadPacket], graph: fx.Graph +) -> Iterator[fx.Node]: + if isinstance(op, OpOverloadPacket): + for overload in op.overloads(): + overload_op = getattr(op, overload) + yield from find_op_nodes(overload_op, graph) + return + + assert isinstance(op, OpOverload) if not op._schema.is_mutable: yield from graph.find_nodes(op="call_function", target=op) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7e15efab379b..b70ea33f2cd2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -195,7 +195,10 @@ def forward_static( orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) + # residual promoted f16->f32 automatically, + # otherwise Inductor eliminates the casts to and from f16, + # increasing memory usage (and complicating pattern matching) + x = x + residual residual = x.to(orig_dtype) if x.shape[-1] != hidden_size: From 095277ca89b85a0ae7952218c97f0ababa34f30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 10 Oct 2025 19:03:18 -0400 Subject: [PATCH 43/81] Simplify matcher utils by using RMSNorm.forward_static MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/matcher_utils.py | 38 ++++--------------------- vllm/model_executor/layers/layernorm.py | 3 +- 2 files changed, 8 insertions(+), 33 deletions(-) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index fe558b7acac2..cc5e7ba8310d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -65,8 +65,6 @@ def inputs(self) -> list[torch.Tensor]: class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) enabled = RMSNorm.enabled() super().__init__(enabled) @@ -83,7 +81,6 @@ def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, ) -> torch.Tensor: result = torch.empty_like(input) _, result = auto_functionalized( @@ -100,28 +97,15 @@ def forward_native( self, input: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, ) -> torch.Tensor: - x = input.to(torch.float32) - if residual is not None: - x = x + residual - residual = x.to(self.model_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - - x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(self.model_dtype) - if weight is not None: - x = x * weight - - return x if residual is None else (x, residual) + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight + ) class MatcherFusedAddRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: Optional[bool] = None): if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) enabled = RMSNorm.enabled() super().__init__(enabled) @@ -157,19 +141,9 @@ def forward_native( weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - x = input.to(torch.float32) - if residual is not None: - x = x + residual - residual = x.to(self.model_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - - x = x * torch.rsqrt(variance + self.epsilon) - x = x.to(self.model_dtype) - if weight is not None: - x = x * weight - - return x if residual is None else (x, residual) + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight, residual + ) class MatcherQuant: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b70ea33f2cd2..5b9d24c19a3c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -187,12 +187,12 @@ def forward_static( x: torch.Tensor, variance_epsilon: float, hidden_size: int, + orig_dtype: torch.dtype, weight: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, variance_size_override: Optional[int] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: # residual promoted f16->f32 automatically, @@ -239,6 +239,7 @@ def forward_native( x, self.variance_epsilon, self.hidden_size, + x.dtype, self.weight.data if self.has_weight else None, residual, self.variance_size_override, From 52f78ce6760f9e8754cf6cde299d0448b93411a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 11 Oct 2025 08:38:42 -0400 Subject: [PATCH 44/81] Add allreduce test to 2-gpu test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 85616de5b197..f02fa0c27373 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -812,6 +812,10 @@ steps: - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -833,7 +837,6 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py # Fusion - - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py @@ -1090,7 +1093,7 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 ##### H200 test ##### -- label: Distrubted Tests (H200) # optional +- label: Distributed Tests (H200) # optional gpu: h200 optional: true working_dir: "/vllm-workspace/" @@ -1110,6 +1113,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm ##### RL Integration Tests ##### From 1b1a63eb2e3086a94fd3a350531f8707dcb7be3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 11 Oct 2025 14:33:46 -0400 Subject: [PATCH 45/81] Fix e2e allreduce fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 6e4893cd0f66..f80bdb06bc68 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -69,7 +69,7 @@ class ModelBackendTestCase(NamedTuple): model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, attention_fusions=0, - allreduce_fusions=64, + allreduce_fusions=65, ), ] @@ -166,8 +166,7 @@ def test_attn_quant( # TODO(luka) test both in nightly -# TODO(luka) change to - -CUSTOM_OPS_RMS_NORM = ["+rms_norm"] # , "+rms_norm"] +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @@ -180,8 +179,11 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: "model_name, model_kwargs, backend, " "attention_fusions, allreduce_fusions, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models - list(flat_product(MODELS_FP8, ["+quant_fp8,+rms_norm"])) - # custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM))) # TODO + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) # TODO # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) @@ -245,17 +247,26 @@ def test_tp2_attn_quant_allreduce_rmsnorm( run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - - assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( - log_holder.text + matches = re.findall( + r"\[compilation/fusion_attn.py:\d+] " + r"Fused quant onto (\d+) attention nodes", + log_holder.text, ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == attention_fusions + assert int(matches[1]) == attention_fusions matches = re.findall( - rf"\[collective_fusion.py:\d+] Replaced {allreduce_fusions} patterns", + r"\[compilation/collective_fusion.py:\d+] " + r"Replaced (\d+) patterns", log_holder.text, ) assert len(matches) == 2, log_holder.text + assert int(matches[0]) == allreduce_fusions + assert int(matches[1]) == allreduce_fusions + def run_model( compile_config: Union[int, CompilationConfig], model: str, **model_kwargs From 0d6e550bfe3032e374257b1fb495a6e7758b2fab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 10:57:07 -0400 Subject: [PATCH 46/81] fix func test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/backend.py | 3 +- tests/compile/test_functionalization.py | 80 ++++++++++++++----------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index a16ab9f15c9f..5d0e30ea5f39 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -54,7 +54,8 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) vllm_config = get_current_vllm_config() compile_config = vllm_config.compilation_config - self.inductor_config = compile_config.inductor_compile_config + # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig + self.inductor_config = deepcopy(compile_config.inductor_compile_config) self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ae17bc67b1fb..dd424d7f6ad0 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,13 @@ from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -217,42 +223,48 @@ def ops_not_in_model(self): def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): torch.set_default_device("cuda") - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + custom_ops=["all"], + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + ), ) - noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - - passes = ( - [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] - if do_fusion - else [noop_pass, cleanup_pass] - ) - func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(*passes, func_pass) - backend_no_func = TestBackend(*passes) + with set_current_vllm_config(vllm_config): + assert RMSNorm.enabled() + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) + + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) + func_pass = FixFunctionalizationPass(vllm_config) - model = model_class() - torch.compile(model, backend=backend_func)(*model.example_inputs()) - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) - # check if the functionalization pass is applied - for op in model.ops_in_model(do_fusion): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): - if is_func(node, op): - found[op] = True - for op in model.ops_not_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model(do_fusion)) - assert all(not found.get(op) for op in model.ops_not_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(do_fusion): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) From 26892dfa100a962b4502bf2e89f7610cc81b912d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 11:03:35 -0400 Subject: [PATCH 47/81] fix pass manager test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_pass_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index ac561d2e8f84..1c40c599f748 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -7,7 +7,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig # dummy custom pass that doesn't inherit @@ -42,7 +42,8 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ], ) def test_pass_manager_uuid(callable): - config = VllmConfig() + # Some passes need dtype to be set + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) pass_manager = PostGradPassManager() pass_manager.configure(config) From 3547b877ad82ec5f1de52ca4a27aa186a119a50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 12 Oct 2025 11:11:14 -0400 Subject: [PATCH 48/81] fix sequence parallelism test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_sequence_parallelism.py | 92 ++++++++++++---------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index afb31cb95be0..bca3932ffaf0 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,7 @@ ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -42,9 +43,7 @@ class TestModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -266,68 +265,77 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( + compilation_config = CompilationConfig( pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, ) ) # NoOp needed for fusion - vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig( + model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) - noop_pass = NoOpEliminationPass(vllm_config) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) + vllm_config = VllmConfig( + model_config=model_config, + device_config=device_config, + compilation_config=compilation_config, + ) - passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] + with set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - if enable_fusion: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - passes_for_backend.append(fusion_pass) + passes_for_backend: list[VllmInductorPass] = [ + noop_pass, + sequence_parallelism_pass, + ] - passes_for_backend.append(cleanup_pass) + if enable_fusion: + fusion_pass = RMSNormQuantFusionPass(vllm_config) + passes_for_backend.append(fusion_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + passes_for_backend.append(cleanup_pass) - model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + model = test_model_cls(hidden_size, hidden_size * 2) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) - # In pre-nodes, all reduce should be there, - # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + assert sequence_parallelism_pass.matched_count == 1 - # In post-nodes, reduce scatter and all gather should be there, - # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + backend_no_func.check_before_ops(model.ops_in_model_before()) - # check if the functionalization pass is applied - for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + backend_no_func.check_after_ops(model.ops_in_model_after()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model()) From af1ffa77d5606a30693a8c98d2333a50443ac5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 01:54:18 -0400 Subject: [PATCH 49/81] PR review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 26 +++++--------------------- vllm/compilation/fusion.py | 4 ++-- vllm/compilation/fusion_attn.py | 4 ++-- vllm/compilation/matcher_utils.py | 21 ++++++++------------- 4 files changed, 17 insertions(+), 38 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cc4f2152e1c5..d0e99497a372 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -25,7 +25,7 @@ from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -46,11 +46,8 @@ logger = init_logger(__name__) -ALLREDUCE_OP = torch.ops.vllm.all_reduce.default -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default -STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default +if hasattr(torch.ops._C, "scaled_fp4_quant"): + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -650,19 +647,6 @@ def get_trtllm_fused_allreduce_kwargs(self): } -class BaseAllReduceRMSNormPattern(BasePattern): - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - allreduce_params: FlashInferFusedAllReduceParams, - ): - super().__init__(dtype, device) - self.epsilon = epsilon - self.allreduce_params = allreduce_params - - class AllReduceRMSNormPattern(BasePattern): """ This pattern replaces the allreduce + rms norm (without residual) @@ -808,7 +792,7 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn self.rmsnorm_matcher = MatcherRMSNorm(epsilon) - self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -877,7 +861,7 @@ def __init__( self.quant_dtype = torch.float8_e4m3fn self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - self.quant_matcher = MatcherQuant(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 9ace7a8cf050..d6057e869ae0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -24,7 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -102,7 +102,7 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) ) - self.quant_matcher = MatcherQuant(key.quant) + self.quant_matcher = MatcherQuantFP8(key.quant) class RMSNormStaticQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 761acb35834b..2f3b0963d365 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -24,7 +24,7 @@ from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fx_utils import is_func from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherQuant +from .matcher_utils import MatcherQuantFP8 from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -129,7 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) - self.quant_matcher = MatcherQuant(quant_key) + self.quant_matcher = MatcherQuantFP8(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index cc5e7ba8310d..4b1c714fe4a4 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -146,22 +146,22 @@ def forward_native( ) -class MatcherQuant: +class MatcherQuantFP8(MatcherCustomOp): def __init__(self, quant_key: QuantKey, enabled: Optional[bool] = None): + if enabled is None: + enabled = QuantFP8.enabled() + + super().__init__(enabled) self.quant_key = quant_key assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) assert quant_key.scale2 is None self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) - if enabled is None: - # TODO either pass config to enabled or set it globally - # (global during pass init seems reasonable) - enabled = self.quant_fp8.enabled() - - self.forward = self.forward_custom if enabled else self.forward_native - def forward_custom( self, input: torch.Tensor, @@ -204,8 +204,3 @@ def make_scale(self, input: torch.Tensor): ) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) - - def __call__( - self, input: torch.Tensor, scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.forward(input, scale) From b5f89e5d0291d3c9e6a2152f75b186eaa10862dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:29:06 -0400 Subject: [PATCH 50/81] Cleanup test_full_graph.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_full_graph.py | 66 +++++++++++++++++++------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 9a955f4c9d81..fb511dd8f7ca 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile +from pathlib import Path from typing import Any import pytest @@ -21,27 +22,21 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): ("facebook/opt-125m", {}), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", - { - "dtype": torch.float16, - }, + {"dtype": torch.float16}, ), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: - if not current_platform.has_device_capability((10, 0)): - # int8 removed on Blackwell - TEST_MODELS.extend( - [ - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), - ] - ) + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + {"dtype": torch.float16}, + ), + ] + ) # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 @@ -95,6 +90,14 @@ def test_full_graph( model_kwargs: dict[str, Any], compilation_mode: int, ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + with monkeypatch.context(): print(f"MODEL={model}") @@ -103,14 +106,14 @@ def test_full_graph( # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( - "compilation_config, model_info", + "compilation_config, model, model_kwargs", [ # additional compile sizes, only some of the models ( CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) ] + [ # RMSNorm + quant fusion, only 8-bit quant models @@ -120,18 +123,19 @@ def test_full_graph( custom_ops=["+rms_norm"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ), - model, + *model_info, ) - for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) ] + [ # Test depyf integration works ( CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - debug_dump_path=tempfile.gettempdir(), + debug_dump_path=Path(tempfile.gettempdir()), ), - ("facebook/opt-125m", {}), + "facebook/opt-125m", + {}, ), ] + [ @@ -145,9 +149,9 @@ def test_full_graph( cudagraph_mode=CUDAGraphMode.PIECEWISE, compile_sizes=[1, 2], ), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") ], ) @@ -155,14 +159,22 @@ def test_full_graph( @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( "2.9.0.dev" ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - model, model_kwargs = model_info print(f"MODEL={model}") run_model(compilation_config, model, **model_kwargs) From f6429e416de6d5a0623a019f6afda7a9a5b2317a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:40:43 -0400 Subject: [PATCH 51/81] Cleanup test_fusion_attn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 54 +++++++++++++++---------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 375796952339..b6d8fc9e28dc 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -238,52 +239,41 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) -MODELS_FP8 = [] -MODELS_FP4 = [] -HEADS = [] -SPLIT_ATTENTION = [] +MODELS_FP8: list[tuple[str, type]] = [] +MODELS_FP4: list[tuple[str, type]] = [] +HEADS: list[tuple[int, int]] = [] +SPLIT_ATTENTION: list[bool] = [] BACKENDS_FP8: list[_Backend] = [] BACKENDS_FP4: list[_Backend] = [] if current_platform.is_cuda(): + HEADS = [(64, 8), (40, 8)] MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, ) ] - HEADS = [(64, 8), (40, 8)] - SPLIT_ATTENTION = [False] - BACKENDS_FP8 = [_Backend.TRITON_ATTN] - - if current_platform.is_device_capability((10, 0)): - BACKENDS_FP8 += [_Backend.FLASHINFER] - BACKENDS_FP4 += [_Backend.FLASHINFER] - MODELS_FP4 += [ - ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel, - ) - ] + MODELS_FP4 = [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ) + ] + BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] + BACKENDS_FP4 = [_Backend.FLASHINFER] elif current_platform.is_rocm(): + HEADS = [(32, 8), (40, 8)] MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] - HEADS = [(32, 8), (40, 8)] - SPLIT_ATTENTION = [False, True] BACKENDS = [ - _Backend.TRITON_ATTN, _Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, + _Backend.TRITON_ATTN, ] -# TODO(boyuan/luka): test inductor graph partition on rocm -if is_torch_equal_or_newer("2.9.0.dev") and current_platform.is_cuda(): - USE_INDUCTOR_GRAPH_PARTITION = [False, True] -else: - USE_INDUCTOR_GRAPH_PARTITION = [False] - @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) @@ -298,7 +288,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): # quant_fp4 only has the custom impl + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("use_inductor_graph_partition", USE_INDUCTOR_GRAPH_PARTITION) +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @@ -318,6 +308,14 @@ def test_attention_quant_pattern( caplog_vllm, ): """Test AttentionStaticQuantPattern fusion pass""" + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + + # TODO(boyuan/luka): test inductor graph partition on rocm + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -435,7 +433,7 @@ def test_attention_quant_pattern( ) # access the underlying `AttnFusionPass` on the `LazyInitPass` - assert attn_pass.pass_.matched_count == 1 + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) From 8a363d397227d55865e2e66159910aa51a3cd47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 02:43:03 -0400 Subject: [PATCH 52/81] Slight improvement for E2E fusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index b650b48c7d37..f55f3e1d2947 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -160,9 +160,13 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - assert f"Fused quant onto {attention_fusions} attention nodes" in log_holder.text, ( - log_holder.text + matches = re.findall( + r"\[compilation/fusion_attn.py:\d+] " + r"Fused quant onto (\d+) attention nodes", + log_holder.text, ) + assert len(matches) == 1, log_holder.text + assert int(matches[0]) == attention_fusions # TODO(luka) test both in nightly From 12a7c6d5d2f38b874e6933c6a049c873d5c4b441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 03:00:52 -0400 Subject: [PATCH 53/81] Tests & docs for flat_product MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/utils.py | 10 +++++++--- tests/utils_/test_utils.py | 24 +++++++++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 54c51ed284fa..3042bacd4bb6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1265,7 +1265,11 @@ def check_answers( def flat_product(*iterables: Iterable[Any]): - """Flatten lists of tuples into cartesian product.""" + """ + Flatten lists of tuples of the cartesian product. + Useful when we want to avoid nested tuples to allow + test params to be unpacked directly from the decorator. + """ for element in itertools.product(*iterables): - normalized = (e if isinstance(e, tuple) else [e] for e in element) - yield list(itertools.chain(*normalized)) + normalized = (e if isinstance(e, tuple) else (e,) for e in element) + yield tuple(itertools.chain(*normalized)) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index af5fc758f2c2..a14431681150 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -47,7 +47,7 @@ unique_filepath, ) -from ..utils import create_new_process_for_each_test, error_on_warning +from ..utils import create_new_process_for_each_test, error_on_warning, flat_product @pytest.mark.asyncio @@ -993,3 +993,25 @@ def test_unique_filepath(): paths.add(path) assert len(paths) == 10 assert len(list(Path(temp_dir).glob("*.txt"))) == 10 + + +def test_flat_product(): + # Check regular itertools.product behavior + result1 = list(flat_product([1, 2, 3], ["a", "b"])) + assert result1 == [ + (1, "a"), + (1, "b"), + (2, "a"), + (2, "b"), + (3, "a"), + (3, "b"), + ] + + # check that the tuples get flattened + result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)])) + assert result2 == [ + (1, 2, "a", 5, 6), + (1, 2, "b", 5, 6), + (3, 4, "a", 5, 6), + (3, 4, "b", 5, 6), + ] From 8ffb4744f86e003e08f2191292a7f2bfe731d13e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 03:25:26 -0400 Subject: [PATCH 54/81] Remove/fix TODOs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 2 +- tests/compile/test_fusions_e2e.py | 4 ++-- vllm/compilation/fusion.py | 16 +++++++++------- vllm/compilation/matcher_utils.py | 10 ++++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index b6d8fc9e28dc..32b207ed0109 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -101,7 +101,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend - # TODO use get_kv_cache_stride_order + # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f55f3e1d2947..533e0c5867d3 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -90,7 +90,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_FA, # TODO ROCM_AITER_UNIFIED_ATTN + backend=_Backend.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -187,7 +187,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: flat_product( MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) ) - ) # TODO + ) # Toggle RMSNorm for FP4 models and unquant models + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), ) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d6057e869ae0..d724eca03e82 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -93,6 +93,8 @@ class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -124,7 +126,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype + input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -179,8 +181,8 @@ def replacement( ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype - residual = residual.to(dtype=torch.float16) + input = input.to(dtype=self.model_dtype) + residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -235,7 +237,7 @@ def pattern(input: torch.Tensor, weight: torch.Tensor): def replacement(input: torch.Tensor, weight: torch.Tensor): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype + input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) @@ -289,8 +291,8 @@ def replacement( ): # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. - input = input.to(dtype=torch.float16) # TODO model dtype - residual = residual.to(dtype=torch.float16) + input = input.to(dtype=self.model_dtype) + residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 2fba5bd0cdbe..9b3854d9fb52 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -34,7 +34,9 @@ class MatcherCustomOp(ABC): def __init__(self, enabled: bool): - self.model_dtype = get_current_vllm_config().model_config.dtype + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None self.enabled = enabled self.forward = self.forward_custom if enabled else self.forward_native @@ -51,10 +53,10 @@ def __call__(self, *args, **kws): return self.forward(*args, **kws) def empty(self, *args, **kws): - return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws) + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) def empty_f32(self, *args, **kws): - return torch.empty(*args, dtype=torch.float32, device="cuda", **kws) + return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) def inputs(self) -> list[torch.Tensor]: """Utility for inputs to the pattern""" @@ -166,7 +168,7 @@ def forward_custom( input: torch.Tensor, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - # TODO: why does empty_like produce a permute but + # TODO(luka): why does empty_like produce a permute but # empty via shape doesn't? result = torch.empty( input.shape, device=input.device, dtype=self.quant_key.dtype From 2a6299c81b0e8de81161ff9efe4af719eed1b381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 04:12:01 -0400 Subject: [PATCH 55/81] Fix e2e test patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 533e0c5867d3..a8ece68d4f0e 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -161,8 +161,7 @@ def test_attn_quant( run_model(compilation_config, model_name, **model_kwargs) matches = re.findall( - r"\[compilation/fusion_attn.py:\d+] " - r"Fused quant onto (\d+) attention nodes", + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) assert len(matches) == 1, log_holder.text @@ -252,8 +251,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) matches = re.findall( - r"\[compilation/fusion_attn.py:\d+] " - r"Fused quant onto (\d+) attention nodes", + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) assert len(matches) == 2, log_holder.text @@ -262,8 +260,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( assert int(matches[1]) == attention_fusions matches = re.findall( - r"\[compilation/collective_fusion.py:\d+] " - r"Replaced (\d+) patterns", + r"collective_fusion.py:\d+] Replaced (\d+) patterns", log_holder.text, ) assert len(matches) == 2, log_holder.text From 465ce583f239e67e8518032d312debeab230cea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 09:59:54 -0400 Subject: [PATCH 56/81] Update tests/compile/test_fusion.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 9f7e025a232e..aa37db8022d5 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -115,7 +115,7 @@ def ops_in_model_before_partial(self): return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] -@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) From bcd95b5f67a0a51580be925073a3c61a5fcb1655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 11:54:47 -0400 Subject: [PATCH 57/81] Fix func test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- csrc/layernorm_quant_kernels.cu | 2 ++ tests/compile/test_functionalization.py | 34 +++++++++++-------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fc462194fcd..f82ae50ae6dd 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -216,6 +216,8 @@ void fused_add_rms_norm_static_fp8_quant( double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.scalar_type() == input.scalar_type()); + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index dd424d7f6ad0..11ae96e930da 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -54,8 +54,7 @@ def forward(self, x): return y def example_inputs(self, num_tokens=32, hidden_size=128): - dtype = torch.float16 if TEST_FP8 else torch.float32 - return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) + return (torch.rand(num_tokens, hidden_size * 2),) def ops_in_model(self, do_fusion): if TEST_FP8 and do_fusion: @@ -73,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32): self.hidden_size = hidden_size self.intermediate_size = intermediate_size - dtype = torch.float16 if TEST_FP8 else torch.float32 - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size), dtype=dtype) + torch.empty((intermediate_size, hidden_size)) ) self.norm = RMSNorm(intermediate_size, 1e-05) - self.norm.weight = torch.nn.Parameter( - torch.ones(intermediate_size, dtype=dtype) - ) + self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size)) torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -118,9 +113,8 @@ def forward(self, hidden_states, residual): return norm_output, residual_output def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): - dtype = torch.float16 if TEST_FP8 else torch.float32 - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size)) + residual = torch.randn((batch_size * seq_len, hidden_size)) return (hidden_states, residual) def ops_in_model(self, do_fusion): @@ -151,10 +145,9 @@ def forward(self, positions, q, k): return q_rotated, k_rotated def example_inputs(self, num_tokens=32, head_dim=64): - dtype = torch.float16 positions = torch.arange(num_tokens, dtype=torch.long) - q = torch.randn(num_tokens, head_dim, dtype=dtype) - k = torch.randn(num_tokens, head_dim, dtype=dtype) + q = torch.randn(num_tokens, head_dim) + k = torch.randn(num_tokens, head_dim) return (positions, q, k) def ops_in_model(self, do_fusion): @@ -172,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): self.hidden_size = head_dim * num_heads self.qkv_proj = torch.nn.Linear( - self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + self.hidden_size, self.hidden_size * 3, bias=False ) self.rotary_emb = get_rope( @@ -196,10 +189,9 @@ def forward(self, positions, hidden_states): return qkv_updated def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): - dtype = torch.float16 hidden_size = head_dim * num_heads positions = torch.arange(num_tokens, dtype=torch.long) - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + hidden_states = torch.randn(num_tokens, hidden_size) return (positions, hidden_states) def ops_in_model(self, do_fusion): @@ -217,14 +209,18 @@ def ops_not_in_model(self): ] +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): +def test_fix_functionalization( + model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype +): torch.set_default_device("cuda") + torch.set_default_dtype(dtype) vllm_config = VllmConfig( - model_config=ModelConfig(dtype=torch.bfloat16), + model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( custom_ops=["all"], pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), From db2b1c76be4bc5ddb5d0a7b3f37b70f88ee2100f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 11:59:35 -0400 Subject: [PATCH 58/81] Smaller model for e2e fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index a8ece68d4f0e..5d5750ca3715 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -37,11 +37,12 @@ class ModelBackendTestCase(NamedTuple): if current_platform.is_cuda(): MODELS_FP8 = [ ModelBackendTestCase( - model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + # Use smaller model for L40s in CI + model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=48, - allreduce_fusions=96, + attention_fusions=32, + allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", From a3ebf0a2e47adf60eab806d604c27de4795847c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:09:48 -0400 Subject: [PATCH 59/81] fix fp8 quant tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/kernels/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 9d11a7ef6413..34ce91585520 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant( .clamp(fp8_traits_min, fp8_traits_max) .to(FP8_DTYPE) ) - return ref_out, ref_scale.view((1,)) + return ref_out, ref_scale.view((1, 1)) def native_w8a8_block_matmul( From 3943257943e9a5aa8161190f581ae0592778db0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:11:29 -0400 Subject: [PATCH 60/81] Restore original torch.Parameter behavior in RMSNorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/model_executor/layers/layernorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 46a5dec14327..1e5703f4368c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,7 +170,9 @@ def __init__( ) weight_dtype = dtype or torch.get_default_dtype() self.has_weight = has_weight - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) + self.weight = torch.ones(hidden_size, dtype=weight_dtype) + if self.has_weight: + self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( From 532cbcf134e688bb960357706ed508afc15a17de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 12:56:07 -0400 Subject: [PATCH 61/81] Add comment to test_logger MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/test_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_logger.py b/tests/test_logger.py index f1c31c245475..01672358902f 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -503,6 +503,7 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[5] == "streaming_complete" +# Add vllm prefix to make sure logs go through the vllm logger test_logger = init_logger("vllm.test_logger") From 7e6f5b3f85763bcc1774647251ffbb521cb350f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:06:19 -0400 Subject: [PATCH 62/81] add flat_product example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 3042bacd4bb6..9aed55b7258b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1269,6 +1269,15 @@ def flat_product(*iterables: Iterable[Any]): Flatten lists of tuples of the cartesian product. Useful when we want to avoid nested tuples to allow test params to be unpacked directly from the decorator. + + Example: + flat_product([(1, 2), (3, 4)], ["a", "b"]) -> + [ + (1, 2, "a"), + (1, 2, "b"), + (3, 4, "a"), + (3, 4, "b"), + ] """ for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) From 24f1298435681914f58e9f25d211a226583b9a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:08:13 -0400 Subject: [PATCH 63/81] PR comments: cleanup fusion passes, & matching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/compilation/collective_fusion.py | 42 ++++++++++----------------- vllm/compilation/fusion.py | 2 -- vllm/compilation/matcher_utils.py | 4 +-- 3 files changed, 17 insertions(+), 31 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 599a30f72c8f..c1ed058ded70 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -673,10 +673,10 @@ def __init__( self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) + input, weight = self.rmsnorm_matcher.inputs() - return [input, weight] + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor): @@ -728,14 +728,10 @@ def __init__( self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - return [ - residual, - input, - weight, - ] + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): @@ -802,10 +798,11 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, weight, scale] + input, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight, scale] def pattern( input: torch.Tensor, @@ -871,18 +868,11 @@ def __init__( def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + input, residual, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [ - residual, - input, - weight, - scale, - ] + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight, scale] def pattern( residual: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d724eca03e82..606874cc1034 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -182,7 +182,6 @@ def replacement( # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) - residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( @@ -292,7 +291,6 @@ def replacement( # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) - residual = residual.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale(input) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 9b3854d9fb52..16d1d86d2b3e 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -73,9 +73,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty( - 16, - ) + weight = self.empty(16) return [input, weight] def forward_custom( From de7405b851d909dd8bb0241c81ca9c59bf5001bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:08:57 -0400 Subject: [PATCH 64/81] PR comments: add _custom_op suffix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 26 ++++++++++++++----------- tests/compile/test_fusion_all_reduce.py | 20 +++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa37db8022d5..8c388f13002f 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -71,8 +71,8 @@ def __init__( act_quant_group_shape=group_shape, ) - self.enable_rms_norm = self.norm[0].enabled() - self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled() + self.enable_rms_norm_custom_op = self.norm[0].enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -107,12 +107,16 @@ def ops_in_model_after(self): def ops_in_model_before(self): return ( [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8 + if self.enable_quant_fp8_custom_op else [torch.ops.aten.reciprocal] ) def ops_in_model_before_partial(self): - return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt] + return ( + [RMS_OP, RMS_ADD_OP] + if self.enable_rms_norm_custom_op + else [torch.ops.aten.rsqrt] + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -120,8 +124,8 @@ def ops_in_model_before_partial(self): @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -136,8 +140,8 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, static, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, cuda_force_torch, ): torch.set_default_device("cuda") @@ -146,9 +150,9 @@ def test_fusion_rmsnorm_quant( maybe_create_device_identity() # needed for certain non-cutlass fp8 paths custom_ops = [] - if enable_rms_norm: + if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") - if enable_quant_fp8: + if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), @@ -195,7 +199,7 @@ def test_fusion_rmsnorm_quant( # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if not enable_rms_norm: + if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4e6ed4446e4c..7688ba3d1b6c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -194,7 +194,7 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8", + "test_model, enable_quant_fp8_custom_op", [ (TestAllReduceRMSNormModel, False), (TestAllReduceRMSNormStaticQuantFP8Model, True), @@ -206,7 +206,7 @@ def ops_in_model_before(self): @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("enable_rms_norm", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -220,8 +220,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): num_processes = 2 if ( @@ -243,8 +243,8 @@ def run_torch_spawn(fn, nprocs): seq_len, hidden_size, dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ), nprocs=nprocs, ) @@ -260,8 +260,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_rms_norm, - enable_quant_fp8, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): current_platform.seed_everything(0) @@ -284,9 +284,9 @@ def all_reduce_fusion_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) custom_ops = [] - if enable_rms_norm: + if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") - if enable_quant_fp8: + if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") vllm_config = VllmConfig( From 6253d5bd143a1975213462e7d6c4f8d3a2e1fef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 13:18:03 -0400 Subject: [PATCH 65/81] Add e2e to L40 distributed, move tests to start of B200 distributed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9d98c5adf6ae..29cce6b398e0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -951,6 +951,7 @@ steps: - vllm/v1/worker/ - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py + - tests/compile/test_fusions_e2e.py - tests/distributed/ - tests/entrypoints/llm/test_collective_rpc.py - tests/v1/distributed @@ -964,6 +965,7 @@ steps: - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py + - pytest -v -s ./compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown @@ -1122,10 +1124,10 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/distributed/test_context_parallel.py - - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min From 876ef22e1e2921ed84b615a55edb442f627b42b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 18:43:48 -0400 Subject: [PATCH 66/81] Fix tests, PR feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 17 +++++++++++------ tests/compile/test_sequence_parallelism.py | 7 +++---- vllm/compilation/fusion.py | 6 +++--- vllm/compilation/matcher_utils.py | 11 ++++++++--- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 8c388f13002f..aa0728e39c94 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -169,24 +169,29 @@ def test_fusion_rmsnorm_quant( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) - result = model(x) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - # Higher tol for dynamic bfloat16 - if dtype == torch.float16 or static: + if enable_rms_norm_custom_op and static: + ATOL, RTOL = (1e-5, 1e-5) # up to 1e-8 close + elif dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) + elif static: + ATOL, RTOL = (5e-3, 5e-3) else: ATOL, RTOL = (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) assert fusion_pass.matched_count == 3 backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index ba2178964ff3..24bc88d44f38 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,7 @@ ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce @@ -94,13 +95,11 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = vllm_config + self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 606874cc1034..98703ed5f007 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -33,7 +33,7 @@ def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda") + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") def empty_fp32(*args, **kwargs): @@ -144,7 +144,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): inputs = [ # input, weight *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1), # scale + self.quant_matcher.inputs()[1], # scale ] pattern(*inputs) @@ -200,7 +200,7 @@ def replacement( inputs = [ # input, weight, residual *self.rmsnorm_matcher.inputs(), - empty_fp32(1, 1), # scale + self.quant_matcher.inputs()[1], # scale ] pm.register_replacement( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 16d1d86d2b3e..8be4de96ebbf 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -112,9 +112,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None): def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty( - 16, - ) + weight = self.empty(16) residual = self.empty(5, 16) return [input, weight, residual] @@ -203,3 +201,10 @@ def make_scale(self, input: torch.Tensor): ) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) + if self.quant_key.scale.static: + return [input, self.empty_f32(1, 1)] + + return [input] From e99a7598260d1a3f33cb85be27eff177a4b28dee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 19:20:47 -0400 Subject: [PATCH 67/81] Break up B200 tests, move allreduce to H200 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 29cce6b398e0..df5b474bb729 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -808,7 +808,7 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 48 min +- label: Blackwell Test # TODO min timeout_in_minutes: 70 working_dir: "/vllm-workspace/" gpu: b200 @@ -822,11 +822,6 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/ - # can affect pattern matching - - vllm/model_executor/layers/layernorm.py - - vllm/model_executor/layers/activation.py - - vllm/model_executor/layers/quantization/input_quant_fp8.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -847,10 +842,27 @@ steps: - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - # Fusion - - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + +- label: Blackwell Fusion Tests # TODO min + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval @@ -951,7 +963,6 @@ steps: - vllm/v1/worker/ - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py - - tests/compile/test_fusions_e2e.py - tests/distributed/ - tests/entrypoints/llm/test_collective_rpc.py - tests/v1/distributed @@ -965,7 +976,6 @@ steps: - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - - pytest -v -s ./compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown @@ -1114,6 +1124,8 @@ steps: commands: - pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 @@ -1124,8 +1136,6 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py From ae581e176787d4fab88438330a8b93add1f5ce48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 20:30:02 -0400 Subject: [PATCH 68/81] Fix attention fusion test numerics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 691d9256d7be..2498c2d58a31 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -368,8 +368,9 @@ def test_attention_quant_pattern( forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) - # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v) + # Run model directly without fusion + # Still compile so query QuantFP8 has closer numerics + result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( From c03b29bfb520735b4460dd5c4bf1b8ee3d5743cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 15 Oct 2025 20:31:11 -0400 Subject: [PATCH 69/81] Remove inductor graph partition from unit test (included in e2e tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 2498c2d58a31..fecb1e2e918f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -35,7 +35,6 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec @@ -290,7 +289,6 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): # quant_fp4 only has the custom impl + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) -@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @@ -305,7 +303,6 @@ def test_attention_quant_pattern( model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, - use_inductor_graph_partition: bool, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" @@ -314,10 +311,6 @@ def test_attention_quant_pattern( ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") - # TODO(boyuan/luka): test inductor graph partition on rocm - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("Inductor graph partition requires torch>=2.9") - custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") @@ -333,7 +326,6 @@ def test_attention_quant_pattern( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops_list, - use_inductor_graph_partition=use_inductor_graph_partition, ), cache_config=CacheConfig(cache_dtype="fp8"), ) From d2e0489da1200b387c09c7867b465e5a18c2275e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:31:15 -0400 Subject: [PATCH 70/81] Relax tolerance for L40 fusion test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa0728e39c94..4e42094f73e6 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -182,9 +182,7 @@ def test_fusion_rmsnorm_quant( model_unfused = torch.compile(model, backend=backend2) result_unfused = model_unfused(x) - if enable_rms_norm_custom_op and static: - ATOL, RTOL = (1e-5, 1e-5) # up to 1e-8 close - elif dtype == torch.float16: + if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) elif static: ATOL, RTOL = (5e-3, 5e-3) From d4fe977cdfe5419afd297c90fae45171ac004fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:54:25 -0400 Subject: [PATCH 71/81] Fix NamedTuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusions_e2e.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 5d5750ca3715..7399abaec542 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -6,11 +6,10 @@ import itertools import logging from collections.abc import Iterable -from typing import Any +from typing import Any, NamedTuple import pytest import regex as re -from black.cache import NamedTuple from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams From 6319e39757784acb19d84cec9a89791dc8939c4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 00:58:59 -0400 Subject: [PATCH 72/81] Update test durations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 236d6d4c8be5..238b6ef98bf2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -808,8 +808,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # TODO min - timeout_in_minutes: 70 +- label: Blackwell Test # 21 min + timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -844,8 +844,8 @@ steps: - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py -- label: Blackwell Fusion Tests # TODO min - timeout_in_minutes: 70 +- label: Blackwell Fusion Tests # 30 min + timeout_in_minutes: 40 working_dir: "/vllm-workspace/" gpu: b200 source_file_dependencies: From e34d36d2e13b25d066bd14a111d9cb3db998d34f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 16 Oct 2025 09:33:16 -0400 Subject: [PATCH 73/81] More tweaking of precision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/test_fusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4e42094f73e6..286f2276367a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -184,8 +184,6 @@ def test_fusion_rmsnorm_quant( if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) - elif static: - ATOL, RTOL = (5e-3, 5e-3) else: ATOL, RTOL = (1e-2, 1e-2) From f72ee4385c014ec68b96c0b72a130f2b6bd94ccd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 04:29:31 -0700 Subject: [PATCH 74/81] Split original pr Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 1270 +++++++++++++++++ tests/compile/test_fusion_all_reduce.py | 2 +- vllm/compilation/collective_fusion.py | 98 +- vllm/config/compilation.py | 64 +- 4 files changed, 1381 insertions(+), 53 deletions(-) create mode 100644 benchmarks/kernels/benchmark_fused_collective.py diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..ea78875c62cf --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time +from typing import Optional + +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def standard_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Then RMS norm + if residual is not None: + # Fused add + RMS norm + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + else: + # Just RMS norm + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + + +def standard_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization.""" + if quant_out is None: + quant_out = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then fused RMS norm + FP8 quantization + if residual is not None: + FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, residual, rms_gamma, scale_factor, rms_eps + ) + return quant_out, residual + else: + RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization.""" + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + if residual is not None: + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + quant_input = allreduce_out + residual_out = residual + else: + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + quant_input = norm_out + residual_out = allreduce_out + + # Finally FP4 quantization + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +def standard_allreduce_rmsnorm_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Apply native RMSNorm + if residual is not None: + result = rmsnorm_layer.forward_native(allreduce_out, residual) + return result # Returns (norm_out, residual_out) + else: + result = rmsnorm_layer.forward_native(allreduce_out) + return result # Returns norm_out + + +def standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + residual_out = allreduce_out + + # Apply native FP8 quantization + quant_out, _ = quant_fp8_layer.forward_native(norm_out, scale=scale_factor) + + if residual is not None: + return quant_out, residual_out + else: + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + quant_input = norm_out + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + quant_input = norm_out + residual_out = allreduce_out + + # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +# Compiled versions of native functions +@torch.compile +def standard_allreduce_rmsnorm_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm.""" + return standard_allreduce_rmsnorm_native( + input_tensor, residual, rmsnorm_layer, norm_out + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp8_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" + return standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor, + residual, + rmsnorm_layer, + quant_fp8_layer, + scale_factor, + norm_out, + quant_out, + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp4_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" + return standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor, + residual, + rmsnorm_layer, + input_global_scale, + quant_out, + output_scale, + norm_out, + ) + + +def create_test_tensors( + seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + seq_len: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: Optional[FlashInferFusedAllReduceParams], + quant_mode: str = "all", + disable_oneshot: bool = False, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps, dtype=dtype) + rmsnorm_layer.weight.data = rms_gamma + quant_fp8_layer = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + + if quant_mode in ["all", "none"]: + # Standard AllReduce + RMSNorm + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + results["standard_allreduce_rmsnorm"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results["standard_allreduce_rmsnorm"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") + + if quant_mode in ["all", "fp8_only"]: + # Standard AllReduce + RMSNorm + FP8 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + quant_fp8_layer=quant_fp8_layer, + scale_factor=scale_fp8, + norm_out=norm_out, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( + "inf" + ) + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + "inf" + ) + + if quant_mode in ["all", "fp4_only"]: + # Standard AllReduce + RMSNorm + FP4 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_mode={quant_mode}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results + +**World Size:** {world_size} +**Hidden Dimension:** {args.hidden_dim} +**Warmup Iterations:** {args.warmup} +**Benchmark Trials:** {args.trials} +**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} + +--- + +""" + + for result in all_results: + seq_len = result["seq_len"] + dtype = result["dtype"] + use_residual = result["use_residual"] + results_dict = result["results"] + + residual_str = "with residual" if use_residual else "no residual" + + markdown += f""" +## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} + +| Operation | Time (ms) | Speedup | +|-----------|-----------|---------| +""" + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + # Format operation name for better readability + formatted_op_name = result["operation"].replace("_", " ").title() + markdown += f"| {formatted_op_name} | {result['time_str']} |" + markdown += f"{result['speedup_str']} |\n" + + markdown += "\n" + + return markdown + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "w") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Sequence lengths to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + # Quantization mode options (mutually exclusive with --no-quant) + quant_group = parser.add_mutually_exclusive_group() + quant_group.add_argument( + "--no-quant", action="store_true", help="Skip all quantization tests" + ) + quant_group.add_argument( + "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" + ) + quant_group.add_argument( + "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" + ) + quant_group.add_argument( + "--quant-all", + action="store_true", + help="Run all quantization tests (default)", + ) + + parser.add_argument( + "--disable-oneshot", + action="store_true", + help="Disable oneshot mode for FlashInfer operations", + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Determine quantization mode + if args.no_quant: + quant_mode = "none" + elif args.quant_fp8: + quant_mode = "fp8_only" + elif args.quant_fp4: + quant_mode = "fp4_only" + else: # args.quant_all or default + quant_mode = "all" + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization mode: %s", quant_mode) + if flashinfer_comm is not None: + oneshot_status = "enabled" if not args.disable_oneshot else "disabled" + logger.info( + "FlashInfer available - will benchmark fused operations (oneshot: %s)", + oneshot_status, + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + if not args.no_residual: + residual_options.append(False) + + configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for seq_len, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", + seq_len, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + seq_len, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_mode=quant_mode, + disable_oneshot=args.disable_oneshot, + ) + + # Store results for markdown export + if rank == 0: + all_results.append( + { + "seq_len": seq_len, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_mode": quant_mode, + "results": results, + } + ) + + print_results( + results, + seq_len, + args.hidden_dim, + dtype, + use_residual, + quant_mode, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7688ba3d1b6c..4798dbf1df1e 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -329,4 +329,4 @@ def all_reduce_fusion_pass_on_test_model( ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + del all_reduce_fusion_pass \ No newline at end of file diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c1ed058ded70..c99c63aedc2a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,8 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -454,31 +453,21 @@ def __call__(self, graph: fx.Graph): _FI_WORKSPACE_TENSOR = None MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB + # Max size of the input tensor per world size per device capability + # to use flashinfer one shot fused allreduce + _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { + "9.0": { + 2: 32 * MiB, # 32MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 32 * MiB, # 32MB + 4: 4 * MiB, # 4MB + 8: 1 * MiB, # 1MB + }, } - try: - _FI_MAX_SIZES.update( - { - int(k): int(float(v) * MiB) - for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - } - ) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) - ) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -500,15 +489,22 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) + max_tensor_size = max_token_num * hidden_size * element_size + + if current_tensor_size <= max_tensor_size: + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size + + assert (_FI_WORKSPACE_TENSOR is not None + ), "Flashinfer must be enabled when using flashinfer" if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -532,7 +528,7 @@ def call_trtllm_fused_allreduce_norm( hidden_dim=allreduce_in.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, - use_oneshot=True, + use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, pattern_code=pattern_code, @@ -545,7 +541,8 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if (scale_factor is not None and scale_out is None and + fuse_rms_quant): # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -637,10 +634,9 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num self.fuse_rms_quant = fuse_rms_quant - + def get_trtllm_fused_allreduce_kwargs(self): return { "world_rank": self.rank, @@ -1096,7 +1092,6 @@ def replacement( pattern, replacement, get_inputs(), pm.fwd_only, pm_pass ) - class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1119,23 +1114,27 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.\ + pass_config.flashinfer_max_size(self.tp_size) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) - // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, - ) + element_size = 4 if use_fp32_lamport else 2 + max_token_num = (max_size // (self.hidden_dim * element_size)) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + max_token_num = min(max_token_num, + config.scheduler_config.max_num_batched_tokens) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1148,11 +1147,10 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, + max_token_num=max_token_num, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, - ) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) self.register_patterns() self.dump_patterns(config, self.patterns) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a34fb0bf920c..84bc5e19c74c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -109,11 +109,66 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: dict[int, + float] = field(default_factory=dict) + """The thresholds of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + dictionary mapping each world size to the threshold in MB + { : } + Unspecified world sizes will fallback to + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> Optional[int]: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Falls back to + conservative defaults if the world size is not specified in config. + """ + + # import here to avoid circular dependencies + from vllm.platforms import current_platform + MiB = 1024 * 1024 + + # Max size of the input tensor per world size per device capability + # to use flashinfer fused allreduce + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + } + + device_capability = current_platform.get_device_capability( + ).as_version_str() + max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) + max_sizes.update({ + k: int(v * MiB) + for k, v in self.fi_allreduce_fusion_max_size_mb.items() + }) + if world_size not in max_sizes: + # FlashInfer doesn't support other world sizes + return None + return max_sizes[world_size] + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -134,6 +189,11 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work" + ) @config From c4c0215874a0dec0981625c336e4449ee0e88e72 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 05:58:33 -0700 Subject: [PATCH 75/81] Update bench Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index ea78875c62cf..7f012af36a94 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -187,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm( allreduce_out=None, quant_out=None, scale_out=None, - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=None, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), @@ -962,10 +962,15 @@ def get_fastest_baseline(op_name, results_dict): return prepared_results -def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): +def print_results( + results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode, input_size_mb +): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") - print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"Results: seq_len={seq_len}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) print( f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " f"quant_mode={quant_mode}" @@ -1009,11 +1014,12 @@ def format_results_markdown( dtype = result["dtype"] use_residual = result["use_residual"] results_dict = result["results"] - + input_size_mb = result["input_size_mb"] residual_str = "with residual" if use_residual else "no residual" markdown += f""" ## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} +**Input Size:** {input_size_mb:.2f} MB | Operation | Time (ms) | Speedup | |-----------|-----------|---------| @@ -1234,6 +1240,10 @@ def main(): # Store results for markdown export if rank == 0: + # Calculate input size in MB + input_size_mb = ( + seq_len * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) all_results.append( { "seq_len": seq_len, @@ -1241,6 +1251,7 @@ def main(): "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, "quant_mode": quant_mode, + "input_size_mb": input_size_mb, "results": results, } ) @@ -1252,6 +1263,7 @@ def main(): dtype, use_residual, quant_mode, + input_size_mb, ) # Save results to markdown file From 309d79e8a41e1c1360adf8409ca6f38aa226a00c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 04:41:09 -0700 Subject: [PATCH 76/81] Update threshold configuration Signed-off-by: ilmarkov --- vllm/config/compilation.py | 59 ++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 84bc5e19c74c..ff3a092fe538 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -134,40 +134,24 @@ class PassConfig: def flashinfer_max_size(self, world_size: int) -> Optional[int]: """ Returns the max communication size in bytes for flashinfer - allreduce fusion for the given world size. Falls back to - conservative defaults if the world size is not specified in config. + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. """ # import here to avoid circular dependencies from vllm.platforms import current_platform MiB = 1024 * 1024 - # Max size of the input tensor per world size per device capability - # to use flashinfer fused allreduce - _FI_ALLREDUCE_MAX_INPUT_SIZES = { - "9.0": { - 2: 64 * MiB, # 64MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB - }, - "10.0": { - 2: 64 * MiB, # 64MB - 4: 32 * MiB, # 32MB - 8: 1 * MiB, # 1MB - }, - } - device_capability = current_platform.get_device_capability( ).as_version_str() - max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) - max_sizes.update({ + fi_allreduce_fusion_max_size_mb = \ + self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes = { k: int(v * MiB) - for k, v in self.fi_allreduce_fusion_max_size_mb.items() - }) - if world_size not in max_sizes: - # FlashInfer doesn't support other world sizes - return None - return max_sizes[world_size] + for k, v in fi_allreduce_fusion_max_size_mb.items() + } + # return None if world size is not supported by flashinfer + return max_sizes.get(world_size) def uuid(self): """ @@ -195,6 +179,31 @@ def __post_init__(self) -> None: "Allreduce + rms norm + quant (fp8) fusion might not work" ) + # import here to avoid circular dependencies + from vllm.platforms import current_platform + + # Default tuned max size of the input tensor + # per world size per device capability + # to use flashinfer fused allreduce + fi_allreduce_fusion_max_size_mb = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + } + device_capability = current_platform.get_device_capability( + ).as_version_str() + + max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes.update(self.fi_allreduce_fusion_max_size_mb) + self.fi_allreduce_fusion_max_size_mb[device_capability] = max_sizes + @config @dataclass From afcfd73f5c1b5cf1bfc17d0537ae526ad102eea2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 05:01:47 -0700 Subject: [PATCH 77/81] Move all_reduce from custom op in fused_moe Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 82 +++++++++---------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index de4ed58e0cf4..4bd7ab12f9c0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2105,33 +2105,59 @@ def forward_native( mode="constant", value=0.0, ) + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 and not self.quant_method.using_modular_kernel + ) - if self.shared_experts is None: + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + + if self.shared_experts is not None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) else: - fused_output = torch.ops.vllm.moe_forward( + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( hidden_states, router_logits, self.layer_name ) - return fused_output[..., :og_hidden_states] + return ( + reduce_output(shared_output[..., :og_hidden_states], do_combine=False), + reduce_output(fused_output[..., :og_hidden_states]), + ) else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], - ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return ( + reduce_output(fused_output[..., :og_hidden_states]) + + zero_expert_result + ) + else: + return reduce_output(fused_output[..., :og_hidden_states]) def forward_cuda( self, @@ -2360,35 +2386,7 @@ def forward_impl( shared_output, final_hidden_states, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states - - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - - return states - - if self.shared_experts is not None: - return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), - ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result - else: - return reduce_output(final_hidden_states) + return final_hidden_states @classmethod def make_expert_params_mapping( From 0248dcdf9e6002b925cf399cb8d39c2e8f5d2214 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 12:51:23 +0000 Subject: [PATCH 78/81] Linter fixes Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 44 +++++++++++++++------------ vllm/config/compilation.py | 20 ++++++------ 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c99c63aedc2a..01a0ebc993ae 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -492,19 +492,22 @@ def call_trtllm_fused_allreduce_norm( max_tensor_size = max_token_num * hidden_size * element_size if current_tensor_size <= max_tensor_size: - device_capability = current_platform.get_device_capability( - ).as_version_str() + device_capability = ( + current_platform.get_device_capability().as_version_str() + ) # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ - get(device_capability, {}). \ - get(world_size, None) + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get( + device_capability, {} + ).get(world_size, None) # Use one shot if no max size is specified - use_oneshot = max_one_shot_size is None or \ - current_tensor_size <= max_one_shot_size + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size + ) - assert (_FI_WORKSPACE_TENSOR is not None - ), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -541,8 +544,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None and - fuse_rms_quant): + if scale_factor is not None and scale_out is None and fuse_rms_quant: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -636,7 +638,7 @@ def __init__( self.fp32_acc = True self.max_token_num = max_token_num self.fuse_rms_quant = fuse_rms_quant - + def get_trtllm_fused_allreduce_kwargs(self): return { "world_rank": self.rank, @@ -1092,6 +1094,7 @@ def replacement( pattern, replacement, get_inputs(), pm.fwd_only, pm_pass ) + class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) @@ -1114,8 +1117,9 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - max_size = config.compilation_config.\ - pass_config.flashinfer_max_size(self.tp_size) + max_size = config.compilation_config.pass_config.flashinfer_max_size( + self.tp_size + ) if max_size is None: # Flashinfer doesn't support current world size logger.warning( @@ -1124,11 +1128,12 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = (max_size // (self.hidden_dim * element_size)) + max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min(max_token_num, - config.scheduler_config.max_num_batched_tokens) + max_token_num = min( + max_token_num, config.scheduler_config.max_num_batched_tokens + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( @@ -1150,7 +1155,8 @@ def __init__(self, config: VllmConfig): max_token_num=max_token_num, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + ) self.register_patterns() self.dump_patterns(config, self.patterns) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ff3a092fe538..ee0c40f4ef42 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -109,8 +109,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_size_mb: dict[int, - float] = field(default_factory=dict) + fi_allreduce_fusion_max_size_mb: dict[int, float] = field(default_factory=dict) """The thresholds of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a dictionary mapping each world size to the threshold in MB @@ -131,7 +130,7 @@ class PassConfig: # TODO(luka) better pass enabling system. - def flashinfer_max_size(self, world_size: int) -> Optional[int]: + def flashinfer_max_size(self, world_size: int) -> int | None: """ Returns the max communication size in bytes for flashinfer allreduce fusion for the given world size. Returns None if world size @@ -140,15 +139,15 @@ def flashinfer_max_size(self, world_size: int) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + MiB = 1024 * 1024 - device_capability = current_platform.get_device_capability( - ).as_version_str() - fi_allreduce_fusion_max_size_mb = \ - self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + device_capability = current_platform.get_device_capability().as_version_str() + fi_allreduce_fusion_max_size_mb = self.fi_allreduce_fusion_max_size_mb.get( + device_capability, {} + ) max_sizes = { - k: int(v * MiB) - for k, v in fi_allreduce_fusion_max_size_mb.items() + k: int(v * MiB) for k, v in fi_allreduce_fusion_max_size_mb.items() } # return None if world size is not supported by flashinfer return max_sizes.get(world_size) @@ -197,8 +196,7 @@ def __post_init__(self) -> None: 8: 1, # 1MB }, } - device_capability = current_platform.get_device_capability( - ).as_version_str() + device_capability = current_platform.get_device_capability().as_version_str() max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) max_sizes.update(self.fi_allreduce_fusion_max_size_mb) From 18e477160a207d73d3761c23d35ac78f19372d02 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 13:26:23 +0000 Subject: [PATCH 79/81] Upd Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 59 +++++++++---------- tests/compile/test_fusion_all_reduce.py | 2 +- vllm/config/compilation.py | 25 +++----- 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 7f012af36a94..0d1ec49e3f41 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -17,7 +17,6 @@ import itertools import os import time -from typing import Optional import torch # type: ignore import torch.distributed as dist # type: ignore @@ -156,12 +155,12 @@ def get_trtllm_fused_allreduce_kwargs(self): def flashinfer_fused_allreduce_rmsnorm( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, allreduce_params: "FlashInferFusedAllReduceParams", use_oneshot: bool, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -196,14 +195,14 @@ def flashinfer_fused_allreduce_rmsnorm( def flashinfer_fused_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, use_oneshot: bool = True, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -238,7 +237,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( def flashinfer_fused_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, @@ -246,7 +245,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( quant_out: torch.Tensor, use_oneshot: bool, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: @@ -281,10 +280,10 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( def standard_allreduce_rmsnorm( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm operations.""" # All-reduce first @@ -302,12 +301,12 @@ def standard_allreduce_rmsnorm( def standard_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP8 quantization.""" if quant_out is None: @@ -331,13 +330,13 @@ def standard_allreduce_rmsnorm_fp8_quant( def standard_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP4 quantization.""" @@ -366,9 +365,9 @@ def standard_allreduce_rmsnorm_fp4_quant( def standard_allreduce_rmsnorm_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" # All-reduce first @@ -384,12 +383,12 @@ def standard_allreduce_rmsnorm_native( def standard_allreduce_rmsnorm_fp8_quant_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, quant_fp8_layer: QuantFP8, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" # All-reduce first @@ -413,12 +412,12 @@ def standard_allreduce_rmsnorm_fp8_quant_native( def standard_allreduce_rmsnorm_fp4_quant_native( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" # All-reduce first @@ -446,9 +445,9 @@ def standard_allreduce_rmsnorm_fp4_quant_native( @torch.compile def standard_allreduce_rmsnorm_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm.""" return standard_allreduce_rmsnorm_native( @@ -459,12 +458,12 @@ def standard_allreduce_rmsnorm_native_compiled( @torch.compile def standard_allreduce_rmsnorm_fp8_quant_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, quant_fp8_layer: QuantFP8, scale_factor: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" return standard_allreduce_rmsnorm_fp8_quant_native( @@ -481,12 +480,12 @@ def standard_allreduce_rmsnorm_fp8_quant_native_compiled( @torch.compile def standard_allreduce_rmsnorm_fp4_quant_native_compiled( input_tensor: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, - norm_out: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, ): """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" return standard_allreduce_rmsnorm_fp4_quant_native( @@ -578,7 +577,7 @@ def run_benchmarks( hidden_dim: int, dtype: torch.dtype, use_residual: bool, - allreduce_params: Optional[FlashInferFusedAllReduceParams], + allreduce_params: FlashInferFusedAllReduceParams | None, quant_mode: str = "all", disable_oneshot: bool = False, ): diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4798dbf1df1e..7688ba3d1b6c 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -329,4 +329,4 @@ def all_reduce_fusion_pass_on_test_model( ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass \ No newline at end of file + del all_reduce_fusion_pass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ee0c40f4ef42..2ca1959d10b0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -117,14 +117,14 @@ class PassConfig: Unspecified world sizes will fallback to _FI_ALLREDUCE_MAX_INPUT_SIZES = { "9.0": { - 2: 64 * MiB, # 64MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB }, "10.0": { - 2: 64 * MiB, # 64MB - 4: 32 * MiB, # 32MB - 8: 1 * MiB, # 1MB + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB }, }, where key is the device capability""" @@ -137,18 +137,11 @@ def flashinfer_max_size(self, world_size: int) -> int | None: is not supported by configs as it's not supported by flashinfer. """ - # import here to avoid circular dependencies - from vllm.platforms import current_platform - MiB = 1024 * 1024 - - device_capability = current_platform.get_device_capability().as_version_str() - fi_allreduce_fusion_max_size_mb = self.fi_allreduce_fusion_max_size_mb.get( - device_capability, {} - ) max_sizes = { - k: int(v * MiB) for k, v in fi_allreduce_fusion_max_size_mb.items() + k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() } + # return None if world size is not supported by flashinfer return max_sizes.get(world_size) @@ -200,7 +193,7 @@ def __post_init__(self) -> None: max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) max_sizes.update(self.fi_allreduce_fusion_max_size_mb) - self.fi_allreduce_fusion_max_size_mb[device_capability] = max_sizes + self.fi_allreduce_fusion_max_size_mb = max_sizes @config From 9516d2bd3b8910d439985bcc9e1eb7377ce5348c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 21 Oct 2025 13:40:33 +0000 Subject: [PATCH 80/81] Upd after review Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 76 ++++++++++--------- vllm/config/compilation.py | 36 ++++----- vllm/model_executor/layers/fused_moe/layer.py | 39 +++++----- 3 files changed, 72 insertions(+), 79 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 4afa99a38760..056d3f482751 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -449,24 +449,40 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) +# Max size of the input tensor per world size per device capability +# to use flashinfer fused allreduce +FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, +} + +# Max size of the input tensor per world size per device capability +# to use flashinfer one shot fused allreduce +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB = { + "9.0": { + 2: 32, # 32MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 32, # 32MB + 4: 4, # 4MB + 8: 1, # 1MB + }, +} + + if flashinfer_comm is not None: _FI_WORKSPACE_TENSOR = None - MiB = 1024 * 1024 - # Max size of the input tensor per world size per device capability - # to use flashinfer one shot fused allreduce - _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { - "9.0": { - 2: 32 * MiB, # 32MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB - }, - "10.0": { - 2: 32 * MiB, # 32MB - 4: 4 * MiB, # 4MB - 8: 1 * MiB, # 1MB - }, - } def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -480,7 +496,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -497,12 +512,13 @@ def call_trtllm_fused_allreduce_norm( ) # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get( + max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( device_capability, {} ).get(world_size, None) - # Use one shot if no max size is specified + # Use one shot if no max size for one shot is specified use_oneshot = ( - max_one_shot_size is None or current_tensor_size <= max_one_shot_size + max_one_shot_size_mb is None + or current_tensor_size <= max_one_shot_size_mb * MiB ) assert _FI_WORKSPACE_TENSOR is not None, ( @@ -544,7 +560,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if scale_factor is not None and scale_out is None: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -567,15 +583,10 @@ def call_trtllm_fused_allreduce_norm( norm_out = allreduce_out else: torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor - ) + if scale_factor is not None and scale_out is not None: + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) if scale_factor is None or norm_out is not None: # we need to return allreduce output # in cases of non quant fused AR + RMS norm @@ -594,7 +605,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -628,7 +638,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -637,7 +646,6 @@ def __init__( self.launch_with_pdl = True self.fp32_acc = True self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -647,7 +655,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1153,9 +1160,6 @@ def __init__(self, config: VllmConfig): world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, max_token_num=max_token_num, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, ) self.register_patterns() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ee3b5cd94870..1ed5fcc8b9a8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -116,7 +116,7 @@ class PassConfig: dictionary mapping each world size to the threshold in MB { : } Unspecified world sizes will fallback to - _FI_ALLREDUCE_MAX_INPUT_SIZES = { + FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { "9.0": { 2: 64, # 64MB 4: 2, # 2MB @@ -146,6 +146,15 @@ def flashinfer_max_size(self, world_size: int) -> int | None: # return None if world size is not supported by flashinfer return max_sizes.get(world_size) + @staticmethod + def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: + from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB + from vllm.platforms import current_platform + + return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( + current_platform.get_device_capability().as_version_str(), {} + ) + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -172,29 +181,10 @@ def __post_init__(self) -> None: "Allreduce + rms norm + quant (fp8) fusion might not work" ) - # import here to avoid circular dependencies - from vllm.platforms import current_platform - - # Default tuned max size of the input tensor - # per world size per device capability - # to use flashinfer fused allreduce - fi_allreduce_fusion_max_size_mb = { - "9.0": { - 2: 64, # 64MB - 4: 2, # 2MB - 8: 1, # 1MB - }, - "10.0": { - 2: 64, # 64MB - 4: 32, # 32MB - 8: 1, # 1MB - }, + self.fi_allreduce_fusion_max_size_mb = { + **PassConfig.default_fi_allreduce_fusion_max_size_mb(), + **self.fi_allreduce_fusion_max_size_mb, } - device_capability = current_platform.get_device_capability().as_version_str() - - max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) - max_sizes.update(self.fi_allreduce_fusion_max_size_mb) - self.fi_allreduce_fusion_max_size_mb = max_sizes @config diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4a3ebbd74540..68982e37d825 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2148,22 +2148,7 @@ def reduce_output( states = self.maybe_all_reduce_tensor_model_parallel(states) return states - if self.shared_experts is not None: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits - ) - else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name - ) - return ( - reduce_output(shared_output[..., :og_hidden_states], do_combine=False), - reduce_output(fused_output[..., :og_hidden_states]), - ) - else: + if self.shared_experts is None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. @@ -2176,12 +2161,26 @@ def reduce_output( if self.zero_expert_num is not None and self.zero_expert_num > 0: assert isinstance(fused_output, tuple) fused_output, zero_expert_result = fused_output - return ( - reduce_output(fused_output[..., :og_hidden_states]) - + zero_expert_result + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits ) else: - return reduce_output(fused_output[..., :og_hidden_states]) + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], + reduce_output(fused_output)[..., :og_hidden_states], + ) def forward_cuda( self, From b789044ffe53f1e789cc8bae0f87109389f805e7 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 27 Oct 2025 11:09:43 +0000 Subject: [PATCH 81/81] Update fused_moe Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 112 +++++++++--------- 1 file changed, 59 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 68982e37d825..2aa42bd61a90 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2133,54 +2133,69 @@ def forward_native( self.dp_size > 1 and not self.quant_method.using_modular_kernel ) - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and not self.use_dp_chunking - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states + ctx = get_forward_context() + sp_ctx = ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) - if self.shared_experts is None: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - fused_output = self.forward_impl(hidden_states, router_logits) - assert not isinstance(fused_output, tuple) - else: - fused_output = torch.ops.vllm.moe_forward( - hidden_states, router_logits, self.layer_name - ) - if self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(fused_output, tuple) - fused_output, zero_expert_result = fused_output - return (reduce_output(fused_output) + zero_expert_result)[ - ..., :og_hidden_states - ] - else: - return reduce_output(fused_output)[..., :og_hidden_states] - else: - if current_platform.is_tpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - shared_output, fused_output = self.forward_impl( - hidden_states, router_logits + with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel ) + + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name + ) + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + reduce_output(shared_output, do_combine=False)[ + ..., :og_hidden_states + ], + reduce_output(fused_output)[..., :og_hidden_states], ) - return ( - reduce_output(shared_output, do_combine=False)[..., :og_hidden_states], - reduce_output(fused_output)[..., :og_hidden_states], - ) def forward_cuda( self, @@ -2349,10 +2364,6 @@ def forward_impl( if self.use_dp_chunking: return self.forward_impl_chunked(hidden_states, router_logits) - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 and not self.quant_method.using_modular_kernel - ) - # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here if ( @@ -2371,11 +2382,6 @@ def forward_impl( ) with sp_ctx: - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel - ) - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self,