diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a476b377ba3..238b6ef98bf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -416,8 +416,8 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 20min - timeout_in_minutes: 30 +- label: PyTorch Fullgraph Test # 22min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -425,6 +425,7 @@ steps: - tests/compile commands: - pytest -v -s compile/test_full_graph.py + - pytest -v -s compile/test_fusions_e2e.py - label: Kernels Core Operation Test # 48min timeout_in_minutes: 75 @@ -807,8 +808,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 38 min - timeout_in_minutes: 60 +- label: Blackwell Test # 21 min + timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -821,8 +822,6 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/fusion.py - - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -839,15 +838,32 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - # Fusion - - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + +- label: Blackwell Fusion Tests # 30 min + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py - - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py - - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -1100,7 +1116,7 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 ##### H200 test ##### -- label: Distrubted Tests (H200) # optional +- label: Distributed Tests (H200) # optional gpu: h200 optional: true working_dir: "/vllm-workspace/" @@ -1108,6 +1124,8 @@ steps: commands: - pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7e8ec9937e5..732ed4760f3 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 64d14429f93..0f7f034ee18 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant( double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.scalar_type() == input.scalar_type()); + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 95aa92e25b3..92d6c2f402a 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant( if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); } + TORCH_CHECK(weight.dtype() == input.dtype()); TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ef1fdd4f9da..fa426190067 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -3,16 +3,22 @@ import weakref from collections.abc import Callable, Sequence +from contextlib import nullcontext from copy import deepcopy +import depyf from torch import fx from torch._ops import OpOverload +from torch.fx._utils import lazy_format_graph_code from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger("vllm.tests.compile.backend") class LazyInitPass(InductorPass): @@ -45,20 +51,32 @@ class TestBackend: def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]): self.custom_passes = list(passes) - compile_config = get_current_vllm_config().compilation_config - self.inductor_config = compile_config.inductor_compile_config + vllm_config = get_current_vllm_config() + compile_config = vllm_config.compilation_config + # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig + self.inductor_config = deepcopy(compile_config.inductor_compile_config) self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass + if debug_dump_path := vllm_config.compile_debug_dump_path(): + logger.debug("Dumping depyf output to %s", debug_dump_path) + self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix()) + else: + self.debug_ctx = nullcontext() + def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, example_inputs, config_patches=self.inductor_config) + with self.debug_ctx: + return compile_fx( + graph, example_inputs, config_patches=self.inductor_config + ) @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + lazy_format_graph_code("graph_pre_pass", graph.owning_module) VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: @@ -68,6 +86,7 @@ def post_pass(self, graph: fx.Graph): VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) + lazy_format_graph_code("graph_post_pass", graph.owning_module) # assign by reference, will reflect the final state of the graph self.final_graph = graph diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2d290771f9a..fb511dd8f7c 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging import tempfile +from pathlib import Path from typing import Any import pytest @@ -10,8 +10,6 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer @@ -22,23 +20,24 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", - { - "dtype": torch.float16, - }, + {"dtype": torch.float16}, ), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + {"dtype": torch.float16}, + ), + ] + ) + # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 TEST_MODELS.append( @@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): "compilation_mode", [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE], ) -@pytest.mark.parametrize("model_info", models_list(all=True)) +@pytest.mark.parametrize("model, model_kwargs", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], compilation_mode: int, ): - model, model_kwargs = model_info + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") with monkeypatch.context(): print(f"MODEL={model}") - run_model(compilation_mode, model, model_kwargs) + run_model(compilation_mode, model, **model_kwargs) # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( - "compilation_config, model_info", + "compilation_config, model, model_kwargs", [ # additional compile sizes, only some of the models ( CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) ] + [ # RMSNorm + quant fusion, only 8-bit quant models @@ -117,18 +123,19 @@ def test_full_graph( custom_ops=["+rms_norm"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ), - model, + *model_info, ) - for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) ] + [ # Test depyf integration works ( CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - debug_dump_path=tempfile.gettempdir(), + debug_dump_path=Path(tempfile.gettempdir()), ), - ("facebook/opt-125m", {}), + "facebook/opt-125m", + {}, ), ] + [ @@ -142,9 +149,9 @@ def test_full_graph( cudagraph_mode=CUDAGraphMode.PIECEWISE, compile_sizes=[1, 2], ), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") ], ) @@ -152,16 +159,24 @@ def test_full_graph( @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( "2.9.0.dev" ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - model, model_kwargs = model_info print(f"MODEL={model}") - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model, **model_kwargs) @pytest.mark.parametrize( @@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int): "calculate_kv_scales": True, "max_model_len": 512, } - run_model(compilation_mode, model, model_kwargs) + run_model(compilation_mode, model, **model_kwargs) -def test_inductor_graph_partition_attn_fusion(caplog_vllm): - if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - compilation_config = CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - use_inductor_graph_partition=True, - cudagraph_mode=CUDAGraphMode.PIECEWISE, - custom_ops=["+quant_fp8"], - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), +def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) ) - model_kwargs = { - "kv_cache_dtype": "fp8", - "max_model_len": 1024, - } - with ( - caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(_Backend.FLASHINFER), - ): - run_model(compilation_config, model, model_kwargs) - try: - assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( - caplog_vllm.text - ) - except AssertionError: - # Note: this message is only triggered when the compilation goes - # through the custom pass. Due to multiple layers of cache on - # PyTorch side, the compilation of a graph may be cached such - # that custom pass directly goes through cache. In this case, - # we go through this branch and assert that the pass is not - # triggered. - assert "Fused quantization" not in caplog_vllm.text - - -def run_model( - compile_config: int | CompilationConfig, - model: str, - model_kwargs: dict[str, Any], -): prompts = [ "Hello, my name is", "The president of the United States is", @@ -227,12 +208,17 @@ def run_model( "The future of AI is", ] sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + llm = LLM( model=model, - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - compilation_config=compile_config, + compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ae17bc67b1f..11ae96e930d 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,13 @@ from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -48,8 +54,7 @@ def forward(self, x): return y def example_inputs(self, num_tokens=32, hidden_size=128): - dtype = torch.float16 if TEST_FP8 else torch.float32 - return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) + return (torch.rand(num_tokens, hidden_size * 2),) def ops_in_model(self, do_fusion): if TEST_FP8 and do_fusion: @@ -67,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32): self.hidden_size = hidden_size self.intermediate_size = intermediate_size - dtype = torch.float16 if TEST_FP8 else torch.float32 - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size), dtype=dtype) + torch.empty((intermediate_size, hidden_size)) ) self.norm = RMSNorm(intermediate_size, 1e-05) - self.norm.weight = torch.nn.Parameter( - torch.ones(intermediate_size, dtype=dtype) - ) + self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size)) torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -112,9 +113,8 @@ def forward(self, hidden_states, residual): return norm_output, residual_output def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): - dtype = torch.float16 if TEST_FP8 else torch.float32 - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size)) + residual = torch.randn((batch_size * seq_len, hidden_size)) return (hidden_states, residual) def ops_in_model(self, do_fusion): @@ -145,10 +145,9 @@ def forward(self, positions, q, k): return q_rotated, k_rotated def example_inputs(self, num_tokens=32, head_dim=64): - dtype = torch.float16 positions = torch.arange(num_tokens, dtype=torch.long) - q = torch.randn(num_tokens, head_dim, dtype=dtype) - k = torch.randn(num_tokens, head_dim, dtype=dtype) + q = torch.randn(num_tokens, head_dim) + k = torch.randn(num_tokens, head_dim) return (positions, q, k) def ops_in_model(self, do_fusion): @@ -166,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): self.hidden_size = head_dim * num_heads self.qkv_proj = torch.nn.Linear( - self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + self.hidden_size, self.hidden_size * 3, bias=False ) self.rotary_emb = get_rope( @@ -190,10 +189,9 @@ def forward(self, positions, hidden_states): return qkv_updated def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): - dtype = torch.float16 hidden_size = head_dim * num_heads positions = torch.arange(num_tokens, dtype=torch.long) - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + hidden_states = torch.randn(num_tokens, hidden_size) return (positions, hidden_states) def ops_in_model(self, do_fusion): @@ -211,48 +209,58 @@ def ops_not_in_model(self): ] +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): +def test_fix_functionalization( + model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype +): torch.set_default_device("cuda") - - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + torch.set_default_dtype(dtype) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + custom_ops=["all"], + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + ), ) - noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - - passes = ( - [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] - if do_fusion - else [noop_pass, cleanup_pass] - ) - func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(*passes, func_pass) - backend_no_func = TestBackend(*passes) + with set_current_vllm_config(vllm_config): + assert RMSNorm.enabled() + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) + + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) + func_pass = FixFunctionalizationPass(vllm_config) - model = model_class() - torch.compile(model, backend=backend_func)(*model.example_inputs()) - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) - # check if the functionalization pass is applied - for op in model.ops_in_model(do_fusion): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): - if is_func(node, op): - found[op] = True - for op in model.ops_not_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model(do_fusion)) - assert all(not found.get(op) for op in model.ops_not_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(do_fusion): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 1a5eaf2639b..286f2276367 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,15 +5,18 @@ import torch import vllm.plugins -from vllm.compilation.fusion import ( - FUSED_OPS, - QUANT_OPS, - FusedRMSQuantKey, - RMSNormQuantFusionPass, -) +from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass +from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -32,6 +35,9 @@ FP8_DTYPE = current_platform.fp8_dtype() +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + class TestModel(torch.nn.Module): def __init__( @@ -45,18 +51,18 @@ def __init__( ): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch - self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: - self.scale = [None for _ in range(2)] + self.scale = [None for _ in range(3)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(2) + for _ in range(3) ] with override_cutlass_fp8_supported(not cuda_force_torch): @@ -65,8 +71,12 @@ def __init__( act_quant_group_shape=group_shape, ) + self.enable_rms_norm_custom_op = self.norm[0].enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + def forward(self, x): - resid = torch.sqrt(x) + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) y = self.norm[0](x) x2 = self.fp8_linear.apply( @@ -78,24 +88,44 @@ def forward(self, x): x3 = self.fp8_linear.apply( y2, self.w[1], self.wscale[1], input_scale=self.scale[1] ) + y3, resid = self.norm[2](x3, resid) # use resid here - return y3 - def ops_in_model_before(self): - return [QUANT_OPS[self.key]] + x4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [ - FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], ] + def ops_in_model_before(self): + return ( + [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8_custom_op + else [torch.ops.aten.reciprocal] + ) + + def ops_in_model_before_partial(self): + return ( + [RMS_OP, RMS_ADD_OP] + if self.enable_rms_norm_custom_op + else [torch.ops.aten.rsqrt] + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -105,19 +135,32 @@ def ops_in_model_after(self): not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) def test_fusion_rmsnorm_quant( - dtype, hidden_size, num_tokens, eps, static, cuda_force_torch + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, + cuda_force_torch, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + custom_ops = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - custom_ops=["+rms_norm", "+quant_fp8"], + custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), - ) + ), ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work @@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) - result = model(x) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: + if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) else: ATOL, RTOL = (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - assert fusion_pass.matched_count == 2 - - # In pre-nodes, fp8 quant should be there and fused kernels should not + assert fusion_pass.matched_count == 3 backend.check_before_ops(model.ops_in_model_before()) - - # In post-nodes, fused kernels should be there and fp8 quant should not + backend.check_before_ops( + model.ops_in_model_before_partial(), fully_replaced=False + ) backend.check_after_ops(model.ops_in_model_after()) + + # If RMSNorm custom op is disabled (native/torch impl used), + # there's a risk that the fused add doesn't get included in the + # replacement and only the rms part gets fused with quant. + # Hence, we check only 2 add nodes are left (final fused rmsnorm add). + if not enable_rms_norm_custom_op: + n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) + # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) + assert n_add_nodes(backend.graph_pre_pass) == 7 + assert n_add_nodes(backend.graph_post_pass) == 2 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index fbcd6c71fb7..7688ba3d1b6 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -6,6 +6,7 @@ import torch import vllm.envs as envs +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -17,6 +18,7 @@ ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -25,8 +27,8 @@ ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, GroupShape, - QuantFP8, ) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -40,33 +42,30 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm = self.norm(all_reduce) - return norm + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(x) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - def ops_in_model_before(self): - return [torch.ops.vllm.all_reduce.default] + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - def ops_in_model_after(self): - return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + y2, resid = self.norm[1](x2, resid) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) -class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): - super().__init__() - self.hidden_size = hidden_size - self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + y3, resid = self.norm[2](x3, resid) + + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm, _ = self.norm(all_reduce, residual) - return norm + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -75,24 +74,53 @@ def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] -class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) - - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant( - self.output, norm_output.contiguous(), self.scale + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] + + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] ) - return self.output, residual_output + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -100,7 +128,9 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default, + torch.ops._C.static_scaled_fp8_quant.default + if self.fp8_linear.quant_fp8.enabled() + else torch.ops.aten.reciprocal.default, ] @@ -109,25 +139,48 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) - - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(token_num, 128) - scale_n = hidden_size // 16 - rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) - - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - norm_output = norm_output.reshape(-1, norm_output.shape[-1]) - torch.ops._C.scaled_fp4_quant( - self.output, norm_output, self.output_scale, self.scale + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)] + + wq_gen, wscale_gen = zip( + *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale)) + ) + self.wq, self.wscale = list(wq_gen), list(wscale_gen) + print(f"{self.wq=}, {self.wscale=}") + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + yq, y_scale = scaled_fp4_quant(y, self.agscale[0]) + z2 = cutlass_scaled_fp4_mm( + yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1]) + z3 = cutlass_scaled_fp4_mm( + yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype ) - return self.output, residual_output, self.output_scale + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2]) + z4 = cutlass_scaled_fp4_mm( + yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -141,19 +194,19 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model", + "test_model, enable_quant_fp8_custom_op", [ - TestAllReduceRMSNormModel, - TestAllReduceFusedAddRMSNormModel, - TestAllReduceFusedAddRMSNormStaticQuantFP8Model, - # TODO: Enable with torch==2.8.0 - # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + (TestAllReduceRMSNormModel, False), + (TestAllReduceRMSNormStaticQuantFP8Model, True), + (TestAllReduceRMSNormStaticQuantFP8Model, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): num_processes = 2 if ( @@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace( def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, + ), nprocs=nprocs, ) @@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): current_platform.seed_everything(0) @@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model( init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) + custom_ops = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"] + mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops ) ) vllm_config.compilation_config.pass_config = PassConfig( enable_fi_allreduce_fusion=True, enable_noop=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank # Setup rank for debug path # this is a fake model name to construct the model config # in the vllm_config, it's not really used. @@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model( vllm_config.model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend( + noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass + ) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) - assert all_reduce_fusion_pass.matched_count == 1 - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + assert all_reduce_fusion_pass.matched_count == 4, ( + f"{all_reduce_fusion_pass.matched_count=}" + ) + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 4d6f4b471a3..fecb1e2e918 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -6,14 +6,15 @@ import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend +from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( @@ -28,21 +29,18 @@ ) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -# globals needed for string-import custom Dynamo backend field -backend: TestBackend | None = None -backend_unfused: TestBackend | None = None - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" @@ -104,6 +102,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend + # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention @@ -241,26 +240,40 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) +MODELS_FP8: list[tuple[str, type]] = [] +MODELS_FP4: list[tuple[str, type]] = [] +HEADS: list[tuple[int, int]] = [] +SPLIT_ATTENTION: list[bool] = [] +BACKENDS_FP8: list[_Backend] = [] +BACKENDS_FP4: list[_Backend] = [] + if current_platform.is_cuda(): - MODELS = [ + HEADS = [(64, 8), (40, 8)] + MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, - ), + ) + ] + MODELS_FP4 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", TestAttentionNvfp4QuantPatternModel, - ), + ) ] - HEADS = [(64, 8), (40, 8)] + BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] + BACKENDS_FP4 = [_Backend.FLASHINFER] + elif current_platform.is_rocm(): - MODELS = [ + HEADS = [(32, 8), (40, 8)] + MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] - HEADS = [(32, 8), (40, 8)] -else: - MODELS = [] - HEADS = [] + BACKENDS = [ + _Backend.ROCM_AITER_UNIFIED_ATTN, + _Backend.ROCM_ATTN, + _Backend.TRITON_ATTN, + ] @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @@ -269,46 +282,36 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize( - "backend", - [_Backend.FLASHINFER] - if current_platform.is_cuda() - else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], -) -# TODO(boyuan): test inductor graph partition on rocm -@pytest.mark.parametrize( - "use_inductor_graph_partition", - [False] if current_platform.is_rocm() else [False, True], + "backend, model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) + # quant_fp4 only has the custom impl + + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif( - current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), - reason="On CUDA only test on SM100(Blackwell)", -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" -) def test_attention_quant_pattern( num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, + custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, - use_inductor_graph_partition: bool, dist_init, - caplog_vllm, ): """Test AttentionStaticQuantPattern fusion pass""" + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") torch.manual_seed(42) @@ -322,8 +325,7 @@ def test_attention_quant_pattern( scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - custom_ops=["+quant_fp8"], - use_inductor_graph_partition=use_inductor_graph_partition, + custom_ops=custom_ops_list, ), cache_config=CacheConfig(cache_dtype="fp8"), ) @@ -358,8 +360,9 @@ def test_attention_quant_pattern( forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) - # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v) + # Run model directly without fusion + # Still compile so query QuantFP8 has closer numerics + result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( @@ -414,16 +417,25 @@ def test_attention_quant_pattern( ) # Check attn fusion support - quant_key = model_class.quant_key + quant_key: QuantKey = model_class.quant_key attn_fusion_supported = [ layer.impl.fused_output_quant_supported(quant_key) for key, layer in vllm_config.compilation_config.static_forward_context.items() ] - if any(attn_fusion_supported): - # Check quantization ops in the graph before and after fusion - # Note: fully_replaced=False because query quant ops remain in graph. - # Only output quant ops are fused into attention. - test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) + assert sum(attn_fusion_supported) == len(attn_fusion_supported), ( + "All layers should support attention fusion" + ) + + # Check quantization ops in the graph before and after fusion + quant_op = ( + torch.ops.aten.reciprocal + if "-quant_fp8" in custom_ops_list + else QUANT_OPS[quant_key] + ) + + # Note: for fp8, fully_replaced=False because query quant ops remain in graph. + # Only output quant ops are fused into attention. + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py new file mode 100644 index 00000000000..7399abaec54 --- /dev/null +++ b/tests/compile/test_fusions_e2e.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import itertools +import logging +from collections.abc import Iterable +from typing import Any, NamedTuple + +import pytest +import regex as re + +from tests.v1.attention.utils import _Backend +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig +from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer + +from ..utils import flat_product, multi_gpu_test + + +class ModelBackendTestCase(NamedTuple): + model_name: str + model_kwargs: dict[str, Any] + backend: _Backend + attention_fusions: int + allreduce_fusions: int | None = None + + +MODELS_FP8: list[ModelBackendTestCase] = [] +MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS: list[ModelBackendTestCase] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 = [ + ModelBackendTestCase( + # Use smaller model for L40s in CI + model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + allreduce_fusions=65, + ), + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] + + MODELS_FP4 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] + + # TP only + MODELS = [ + ModelBackendTestCase( + model_name="meta-llama/Llama-3.1-8B-Instruct", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=65, + ), + ] + +elif current_platform.is_rocm(): + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + attention_fusions=32, + ), + ] + +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +def test_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + level=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(matches) == 1, log_holder.text + assert int(matches[0]) == attention_fusions + + +# TODO(luka) test both in nightly +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.skipif( + not current_platform.is_cuda() + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", +) +def test_tp2_attn_quant_allreduce_rmsnorm( + model_name: str, + model_kwargs: dict, + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == attention_fusions + assert int(matches[1]) == attention_fusions + + matches = re.findall( + r"collective_fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == allreduce_fusions + assert int(matches[1]) == allreduce_fusions + + +def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(level=compile_config) + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + llm = LLM( + model=model, + compilation_config=compilation_config, + **model_kwargs, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index ac561d2e8f8..1c40c599f74 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -7,7 +7,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig # dummy custom pass that doesn't inherit @@ -42,7 +42,8 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ], ) def test_pass_manager_uuid(callable): - config = VllmConfig() + # Some passes need dtype to be set + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) pass_manager = PostGradPassManager() pass_manager.configure(config) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 9969a629c00..31b6ddf3c69 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,8 @@ ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -42,9 +44,7 @@ class TestModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -95,13 +95,11 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = vllm_config + self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) @@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( + compilation_config = CompilationConfig( pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, ) ) # NoOp needed for fusion - vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" - vllm_config.model_config = ModelConfig( + model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) - noop_pass = NoOpEliminationPass(vllm_config) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - assert ( - sequence_parallelism_pass.compilation_config.splitting_ops - == vllm_config.compilation_config.splitting_ops + vllm_config = VllmConfig( + model_config=model_config, + device_config=device_config, + compilation_config=compilation_config, ) - assert ( - sequence_parallelism_pass.compilation_config.use_inductor_graph_partition - == vllm_config.compilation_config.use_inductor_graph_partition - ) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] - if enable_fusion: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - passes_for_backend.append(fusion_pass) + with set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + assert ( + sequence_parallelism_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + sequence_parallelism_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + passes_for_backend: list[VllmInductorPass] = [ + noop_pass, + sequence_parallelism_pass, + ] - passes_for_backend.append(cleanup_pass) + if enable_fusion: + fusion_pass = RMSNormQuantFusionPass(vllm_config) + passes_for_backend.append(fusion_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + passes_for_backend.append(cleanup_pass) - model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + model = test_model_cls(hidden_size, hidden_size * 2) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) - # In pre-nodes, all reduce should be there, - # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + assert sequence_parallelism_pass.matched_count == 1 - # In post-nodes, reduce scatter and all gather should be there, - # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + backend_no_func.check_before_ops(model.ops_in_model_before()) - # check if the functionalization pass is applied - for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + backend_no_func.check_after_ops(model.ops_in_model_after()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model()) diff --git a/tests/conftest.py b/tests/conftest.py index 8acc1f28559..8537e8fd032 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ruff: noqa +import contextlib +import pathlib +from copy import deepcopy from tblib import pickling_support +# ruff: noqa + # Install support for pickling exceptions so that we can nicely propagate # failures from tests running in a subprocess. # This should be run before any custom exception subclasses are defined. @@ -40,7 +43,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams, envs from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +@pytest.fixture() +def caplog_mp_fork(): + """ + This fixture enables capturing logs from a forked MP subprocess. + It should be used in conjunction with caplog_vllm. + + By default, subprocess logs do not go through the parent process. + We instead create a queue listener in the parent process which + forwards logs to the logger's other handlers, and add a QueueHandler + to the root logger. Forked subprocesses will inherit the root logger + and pass their messages to the queue, which the listener will forward + to the root logger, which can be captured by caplog. + + Note that this workaround only works for fork; with spawn, the subprocess + reinitializes logging and does not automatically inherit the queue. + We'd have to manually pass the queue to the subprocess at the spawn point. + See caplog_mp_spawn below. + """ + + @contextlib.contextmanager + def ctx(): + import logging.handlers + import multiprocessing as mp + + logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() + logger = logging.getLogger() + handlers = logger.handlers + + # The listener works on a background thread, not inherited by the child. + queue_listener = logging.handlers.QueueListener(logger_queue, *handlers) + queue_listener.start() + + # Add queue handler after creating the listener to avoid cycle + logger.addHandler(logging.handlers.QueueHandler(logger_queue)) + yield + queue_listener.stop() + + return ctx + + +class LogHolder: + def __init__(self): + self.text = None + + +@pytest.fixture() +def caplog_mp_spawn(tmp_path, monkeypatch): + """ + This fixture enables capturing logs from a forked MP subprocess. + It does not require caplog_vllm (but it only contains logs from the child). + + By default, subprocess logs do not go through the parent process. + We instead add a FileHandler to the config so the spawned child process + writes its logs to a temp file. + In the parent, we read the file and return the contents. + + Note: this method could be extended to fork by either reconfiguring logging + in the parent or using a SocketHandler: + https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501 + """ + + @contextlib.contextmanager + def ctx(level: int | str): + from vllm.logger import DEFAULT_LOGGING_CONFIG + + config_path = tmp_path / "vllm_logging_config.json" + log_path = tmp_path / "vllm.log" + log_holder = LogHolder() + + config = deepcopy(DEFAULT_LOGGING_CONFIG) + if envs.VLLM_LOGGING_CONFIG_PATH: + path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH) + assert path.exists() + config = json.loads(path.read_text()) + + config["loggers"]["vllm"]["handlers"] += ["vllm_file"] + config["handlers"]["vllm_file"] = { + "class": "logging.FileHandler", + "formatter": "vllm", + "level": level, + "filename": log_path.as_posix(), + } + + config_path.write_text(json.dumps(config)) + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix()) + monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1") + yield log_holder + + log_holder.text = log_path.read_text() + + return ctx + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 9d11a7ef641..34ce9158552 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant( .clamp(fp8_traits_min, fp8_traits_max) .to(FP8_DTYPE) ) - return ref_out, ref_scale.view((1,)) + return ref_out, ref_scale.view((1, 1)) def native_w8a8_block_matmul( diff --git a/tests/test_logger.py b/tests/test_logger.py index ec368d4897b..01672358902 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[1] == "test-streaming-full-text" assert call_args[2] == " (streaming complete)" assert call_args[5] == "streaming_complete" + + +# Add vllm prefix to make sure logs go through the vllm logger +test_logger = init_logger("vllm.test_logger") + + +def mp_function(**kwargs): + # This function runs in a subprocess + + test_logger.warning("This is a subprocess: %s", kwargs.get("a")) + test_logger.error("This is a subprocess error.") + test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b")) + + +def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): + with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork(): + import multiprocessing + + ctx = multiprocessing.get_context("fork") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in caplog_vllm.text + assert "BBBBB" in caplog_vllm.text + + +def test_caplog_mp_spawn(caplog_mp_spawn): + with caplog_mp_spawn(logging.DEBUG) as log_holder: + import multiprocessing + + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in log_holder.text + assert "BBBBB" in log_holder.text diff --git a/tests/utils.py b/tests/utils.py index 5bfdf703390..9aed55b7258 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,7 @@ import copy import functools import importlib +import itertools import json import os import random @@ -15,7 +16,7 @@ import tempfile import time import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path @@ -1261,3 +1262,23 @@ def check_answers( frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") assert frac_ok >= accept_rate + + +def flat_product(*iterables: Iterable[Any]): + """ + Flatten lists of tuples of the cartesian product. + Useful when we want to avoid nested tuples to allow + test params to be unpacked directly from the decorator. + + Example: + flat_product([(1, 2), (3, 4)], ["a", "b"]) -> + [ + (1, 2, "a"), + (1, 2, "b"), + (3, 4, "a"), + (3, 4, "b"), + ] + """ + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else (e,) for e in element) + yield tuple(itertools.chain(*normalized)) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 8b9411975e1..f2a323ad0c4 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -40,7 +40,7 @@ unique_filepath, ) -from ..utils import create_new_process_for_each_test +from ..utils import create_new_process_for_each_test, flat_product def test_get_open_port(monkeypatch: pytest.MonkeyPatch): @@ -771,3 +771,25 @@ def test_unique_filepath(): paths.add(path) assert len(paths) == 10 assert len(list(Path(temp_dir).glob("*.txt"))) == 10 + + +def test_flat_product(): + # Check regular itertools.product behavior + result1 = list(flat_product([1, 2, 3], ["a", "b"])) + assert result1 == [ + (1, "a"), + (1, "b"), + (2, "a"), + (2, "b"), + (3, "a"), + (3, "b"), + ] + + # check that the tuples get flattened + result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)])) + assert result2 == [ + (1, 2, "a", 5, 6), + (1, 2, "b", 5, 6), + (3, 4, "a", 5, 6), + (3, 4, "b", 5, 6), + ] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dbbfc01e3bb..d6866763040 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1507,7 +1507,7 @@ def scaled_fp8_quant( output, input, scale, scale_ub ) else: - scale = torch.empty(1, device=input.device, dtype=torch.float32) + scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7c85c89bcd7..c1ed058ded7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -17,10 +17,14 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -41,11 +45,8 @@ logger = init_logger(__name__) -ALLREDUCE_OP = torch.ops.vllm.all_reduce.default -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default -STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default +if hasattr(torch.ops._C, "scaled_fp4_quant"): + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -669,33 +670,24 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) + input, weight = self.rmsnorm_matcher.inputs() - return [input, rms_result, weight] + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + def pattern(input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_OP, - result=rms_result, - input=allreduce_output, - weight=weight, - epsilon=self.epsilon, - ) - # rms_result, allreduce_output - return rms[1], allreduce_output + rms = self.rmsnorm_matcher(allreduce_output, weight) - def replacement( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + return rms, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -733,29 +725,19 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - return [ - residual, - input, - weight, - ] + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - # input, residual - return rms[1], rms[2] + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + return rms, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -779,6 +761,18 @@ def replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) + # Same pattern, but only return the output and not residual + # (helpful for end of graph where residual is not used again) + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + + pm.register_replacement( + first_return_only(pattern), + first_return_only(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ @@ -799,60 +793,37 @@ def __init__( self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.dtype - ) - quant_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.quant_dtype - ) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + input, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight, scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce - def replacement( - input: torch.Tensor, - result_rms: torch.Tensor, - quant_result: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=result_rms, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -892,64 +863,42 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) + def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + input, residual, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty( - [4, 4], device=self.device, dtype=self.quant_dtype - ) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [ - quant_result, - residual, - input, - weight, - scale, - ] + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight, scale] def pattern( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) + rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) + quant, _ = self.quant_matcher(rms, scale) - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant, res def replacement( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=None, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -986,14 +935,11 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - - rmsnorm_result = torch.empty( - [1, 16, 16], device=self.device, dtype=self.dtype - ) quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) input_global_scale = torch.empty( [1, 1], device=self.device, dtype=torch.float32 @@ -1001,36 +947,21 @@ def get_inputs(): weight = torch.empty([16], device=self.device, dtype=self.dtype) output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [ - input, - rmsnorm_result, - quant_result, - weight, - input_global_scale, - output_scale, - ] + return [input, quant_result, weight, input_global_scale, output_scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - + rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) @@ -1040,13 +971,13 @@ def pattern( def replacement( input: torch.Tensor, - result_rms: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1090,6 +1021,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -1121,28 +1053,17 @@ def pattern( input_global_scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) - - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=fused_add_rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) # quant_out, allreduce_output, output_scale - return ( - quant_out_tuple[1], - fused_add_rmsnorm_out_tuple[2], - quant_out_tuple[2], - ) + return quant_out_tuple[1], residual, quant_out_tuple[2] def replacement( quant_result: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index df54e94a03d..8f0ad2d69fb 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -24,6 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -92,13 +93,19 @@ class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" - self.QUANT_OP = QUANT_OPS[key.quant] + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon) + if not key.fused_add + else MatcherFusedAddRMSNorm(epsilon) + ) + self.quant_matcher = MatcherQuantFP8(key.quant) + class RMSNormStaticQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): @@ -112,34 +119,18 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) + return self.quant_matcher(result_rms, scale)[0] - # result - return at2[1] + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( self.FUSED_OP, result=result, @@ -153,12 +144,11 @@ def replacement( return at[1] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + # input, weight + *self.rmsnorm_matcher.inputs(), + self.quant_matcher.inputs()[1], # scale ] + pattern(*inputs) pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) @@ -175,33 +165,27 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor, ): - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale - ) + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, _ = self.quant_matcher(result_rms, scale) - # result, residual - return at1[1], at[2] + return result, residual def replacement( - result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor, ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -216,11 +200,9 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + # input, weight, residual + *self.rmsnorm_matcher.inputs(), + self.quant_matcher.inputs()[1], # scale ] pm.register_replacement( @@ -248,34 +230,18 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None - ) - + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) # result, scale - return at2[1], at2[2] + return self.quant_matcher(result_rms) - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -290,18 +256,10 @@ def replacement( # result, scale return at[1], at[2] - inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -323,34 +281,21 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None - ) + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - # result, residual, scale - return at1[1], at[2], at1[2] + return result, residual, scale def replacement( - result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -365,18 +310,10 @@ def replacement( # result, residual, scale return at[1], at[3], at[2] - inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -396,23 +333,25 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index ae36cef9265..aaf19e6d423 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable import torch import torch._inductor.pattern_matcher as pm +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -20,7 +22,9 @@ from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .fx_utils import is_func from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuantFP8 from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,9 +70,13 @@ def empty_quant(self, *args, **kwargs): return torch.empty(*args, **kwargs) @staticmethod - def wrap_trace_fn(process_fx, trace_fn): + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm return wrapped @@ -77,7 +85,20 @@ def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) - return gm + + @staticmethod + def remove_noop_permutes(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if not is_func(node, torch.ops.aten.permute.default): + continue + + dims = node.args[1] + if any(dim != i for i, dim in enumerate(dims)): + continue + + # this is now an identity op, remove + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) def register_if_supported(self, pm_pass: PatternMatcherPass): if self.layer.impl.fused_output_quant_supported(self.quant_key): @@ -108,6 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) + self.quant_matcher = MatcherQuantFP8(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( @@ -115,7 +137,6 @@ def pattern( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): at1 = auto_functionalized( @@ -131,17 +152,14 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) - at2 = auto_functionalized( - self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale - ) - return at2[1] + + return self.quant_matcher(attn_out_view, scale)[0] def replacement( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): # attn output in quant_dtype @@ -164,13 +182,10 @@ def replacement( return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v - self.empty( - 5, self.num_heads, self.head_size, dtype=self.dtype - ), # attn_output - self.empty_quant(5, self.num_heads * self.head_size), # quant_output + self.empty(5, self.num_heads, self.head_size), # q + self.empty(5, self.num_heads, self.head_size), # k + self.empty(5, self.num_heads, self.head_size), # v + self.empty(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1), # scale ] @@ -179,7 +194,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) @@ -279,7 +296,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 45fe88a5f4d..f2497950fc2 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -6,7 +6,7 @@ from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._ops import OpOverload +from torch._ops import OpOverload, OpOverloadPacket def is_func(node: fx.Node, target) -> bool: @@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: # An auto-functionalization-aware utility for finding nodes with a specific op -def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: +# Also handles op overload packets and finds all overloads +def find_op_nodes( + op: OpOverload | OpOverloadPacket, graph: fx.Graph +) -> Iterator[fx.Node]: + if isinstance(op, OpOverloadPacket): + for overload in op.overloads(): + overload_op = getattr(op, overload) + yield from find_op_nodes(overload_op, graph) + return + + assert isinstance(op, OpOverload) if not op._schema.is_mutable: yield from graph.find_nodes(op="call_function", target=op) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py new file mode 100644 index 00000000000..c4eb463de1d --- /dev/null +++ b/vllm/compilation/matcher_utils.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch +from torch._higher_order_ops import auto_functionalized +from torch._ops import OpOverload + +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.platforms import current_platform + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + +QUANT_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 +} + +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + + +class MatcherCustomOp(ABC): + def __init__(self, enabled: bool): + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None + + self.enabled = enabled + self.forward = self.forward_custom if enabled else self.forward_native + + @abstractmethod + def forward_custom(self, *args, **kws): + pass + + @abstractmethod + def forward_native(self, *args, **kws): + pass + + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + + def empty_f32(self, *args, **kws): + return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) + + def inputs(self) -> list[torch.Tensor]: + """Utility for inputs to the pattern""" + raise NotImplementedError + + +class MatcherRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + return [input, weight] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + + return result + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight + ) + + +class MatcherFusedAddRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + residual = self.empty(5, 16) + return [input, weight, residual] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + + return result, residual + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight, residual + ) + + +class MatcherQuantFP8(MatcherCustomOp): + def __init__(self, quant_key: QuantKey, enabled: bool | None = None): + if enabled is None: + enabled = QuantFP8.enabled() + + super().__init__(enabled) + self.quant_key = quant_key + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + self.QUANT_OP = QUANT_OPS[quant_key] + + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) + assert quant_key.scale2 is None + self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) + + def forward_custom( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) + + if self.quant_key.scale.static: + assert scale is not None + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) + return result, scale + else: + assert scale is None + scale = self.make_scale(input) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) + return result, scale + + def forward_native( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.quant_fp8(input, scale) + + def make_scale(self, input: torch.Tensor): + normalized_group_shape = _normalize_quant_group_shape( + input, self.quant_key.scale.group_shape + ) + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) + + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) + if self.quant_key.scale.static: + return [input, self.empty_f32(1, 1)] + + return [input] diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 1e6d0e79228..d26fa40993d 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): import depyf path.mkdir(parents=True, exist_ok=True) + logger.debug("Dumping depyf output to %s", path) global context_manager context_manager = depyf.prepare_debug(path.as_posix()) context_manager.__enter__() diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 343297e9446..a06956c7d6e 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,7 +5,7 @@ from torch import fx as fx from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import set_env_var @@ -88,27 +88,30 @@ def __call__(self, graph: fx.Graph): def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - if self.pass_config.enable_noop: - self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: - self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: - self.passes += [AsyncTPPass(config)] + # Set the current vllm config to allow tracing CustomOp instances + with set_current_vllm_config(config, check_compile=False): + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fusion: - self.passes += [RMSNormQuantFusionPass(config)] - self.passes += [ActivationQuantFusionPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] - if self.pass_config.enable_attn_fusion: - self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_fusion: + self.passes += [RMSNormQuantFusionPass(config)] + self.passes += [ActivationQuantFusionPass(config)] - # needs a functional graph - self.post_cleanup = PostCleanupPass(config) - self.fix_functionalization = FixFunctionalizationPass(config) + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) + self.fix_functionalization = FixFunctionalizationPass(config) # [HACK: Bug with Inductor graph partition and torch.compile cache] # In PyTorch 2.9, torch.compile has a bug where the graph diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 7ef2dddcb40..8add14ebcc3 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -128,7 +128,8 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): f" please add to dump_patterns if there are any errors.\n\n" f"from torch._higher_order_ops.auto_functionalize import " f"auto_functionalized as auto_functionalized\n" - f"from torch._inductor.pattern_matcher import *", + f"from torch._inductor.pattern_matcher import *\n" + f"vllm = torch.ops.vllm", file=f, ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 85ead0d8105..75d907521e1 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -178,14 +178,11 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) + weight_dtype = dtype or torch.get_default_dtype() self.has_weight = has_weight - if dtype is not None: - self.weight = torch.ones(hidden_size, dtype=dtype) - else: - self.weight = torch.ones(hidden_size) + self.weight = torch.ones(hidden_size, dtype=weight_dtype) if self.has_weight: self.weight = nn.Parameter(self.weight) - weight_dtype = self.weight.data.dtype if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( @@ -195,47 +192,69 @@ def __init__( with_fused_add=True, dtype=weight_dtype ) - def forward_native( - self, + @staticmethod + def forward_static( x: torch.Tensor, + variance_epsilon: float, + hidden_size: int, + orig_dtype: torch.dtype, + weight: torch.Tensor | None = None, residual: torch.Tensor | None = None, + variance_size_override: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) + # residual promoted f16->f32 automatically, + # otherwise Inductor eliminates the casts to and from f16, + # increasing memory usage (and complicating pattern matching) + x = x + residual residual = x.to(orig_dtype) - hidden_size = x.shape[-1] - if hidden_size != self.hidden_size: + if x.shape[-1] != hidden_size: raise ValueError( - "Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}" + f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) - if self.variance_size_override is None: + if variance_size_override is None: x_var = x else: - if hidden_size < self.variance_size_override: + if hidden_size < variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}" + f"{variance_size_override}, but found: {hidden_size}" ) - x_var = x[:, :, : self.variance_size_override] + x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) - if self.has_weight: - x = x * self.weight + if weight is not None: + x = x * weight if residual is None: return x else: return x, residual + def forward_native( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + + return self.forward_static( + x, + self.variance_epsilon, + self.hidden_size, + x.dtype, + self.weight.data if self.has_weight else None, + residual, + self.variance_size_override, + ) + def forward_cuda( self, x: torch.Tensor,