From dd9156d55bec7d75d679f8fc8161388269c8df56 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 29 Aug 2025 10:10:04 -0700 Subject: [PATCH] [compile] Fix de-functionalization pass for rotary_embedding Signed-off-by: angelayi --- .buildkite/test-pipeline.yaml | 1 + tests/compile/test_functionalization.py | 298 +++++++++++++++++----- vllm/compilation/fix_functionalization.py | 54 ++-- 3 files changed, 266 insertions(+), 87 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c131192c56fc..9c200a577167 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -397,6 +397,7 @@ steps: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion_attn.py + - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 2ee9aa7476be..0c8d610bc9c5 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -5,54 +5,237 @@ import torch import vllm.envs as envs -from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) + GroupShape) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform from .backend import TestBackend -OPS_IN_MODEL = [ - torch.ops._C.rotary_embedding.default, - torch.ops._C.fused_add_rms_norm.default, -] +TEST_FP8 = current_platform.supports_fp8() +FP8_DTYPE = current_platform.fp8_dtype() + + +class TestSiluMul(torch.nn.Module): + + def __init__(self, hidden_size: int = 128): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.wscale = torch.rand(1, dtype=torch.float32) + self.scale = torch.rand(1, dtype=torch.float32) + + if TEST_FP8: + self.w = torch.rand(hidden_size, + hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + + def forward(self, x): + y = self.silu_and_mul(x) + if TEST_FP8: + x2 = self.fp8_linear.apply(y, + self.w, + self.wscale, + input_scale=self.wscale) + return x2 + else: + return y + + def example_inputs(self, num_tokens=32, hidden_size=128): + dtype = torch.float16 if TEST_FP8 else torch.float32 + return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), ) + + def ops_in_model(self, do_fusion): + if TEST_FP8 and do_fusion: + return [torch.ops._C.silu_and_mul_quant.default] + else: + return [torch.ops._C.silu_and_mul.default] + + def ops_not_in_model(self): + return [] + + +class TestFusedAddRMSNorm(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + dtype = torch.float16 if TEST_FP8 else torch.float32 + + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size), dtype=dtype)) + self.norm = RMSNorm(intermediate_size, 1e-05) + self.norm.weight = torch.nn.Parameter( + torch.ones(intermediate_size, dtype=dtype)) + + torch.nn.init.normal_(self.gate_proj, std=0.02) + + if TEST_FP8: + self.fp8_linear = Fp8LinearOp(act_quant_static=True) + + self.scale = torch.rand(1, dtype=torch.float32) + self.w = torch.rand(hidden_size, + intermediate_size).to(dtype=FP8_DTYPE).t() + self.wscale = torch.rand(1, dtype=torch.float32) + + def forward(self, hidden_states, residual): + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + # matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # layer normalization + norm_output, residual_output = self.norm(mm, residual) + + if TEST_FP8: + # scaled_mm with static input quantization + fp8_linear_result = self.fp8_linear.apply( + norm_output, + self.w, + self.wscale, + input_scale=self.scale.to(norm_output.device), + ) + + return fp8_linear_result, residual_output + + else: + return norm_output, residual_output + + def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): + dtype = torch.float16 if TEST_FP8 else torch.float32 + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype) + return (hidden_states, residual) -RMS_OP = torch.ops._C.rms_norm.default + def ops_in_model(self, do_fusion): + if TEST_FP8 and do_fusion: + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] + else: + return [torch.ops._C.fused_add_rms_norm.default] -RMS_QUANT_OPS = { - "static_fp8": [ - torch.ops._C.rms_norm_static_fp8_quant.default, - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ], -} + def ops_not_in_model(self): + return [] -SILU_MUL_OP = torch.ops._C.silu_and_mul.default -SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", +class TestRotaryEmbedding(torch.nn.Module): + + def __init__(self, + head_dim=64, + rotary_dim=None, + max_position=2048, + base=10000): + super().__init__() + self.head_dim = head_dim + self.rotary_dim = rotary_dim or head_dim + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=base, + ) + + def forward(self, positions, q, k): + q_rotated, k_rotated = self.rotary_emb(positions, q, k) + return q_rotated, k_rotated + + def example_inputs(self, num_tokens=32, head_dim=64): + dtype = torch.float16 + positions = torch.arange(num_tokens, dtype=torch.long) + q = torch.randn(num_tokens, head_dim, dtype=dtype) + k = torch.randn(num_tokens, head_dim, dtype=dtype) + return (positions, q, k) + + def ops_in_model(self, do_fusion): + return [torch.ops._C.rotary_embedding.default] + + def ops_not_in_model(self): + return [] + + +class TestRotaryEmbeddingSliceScatter(torch.nn.Module): + + def __init__(self, + head_dim=64, + num_heads=4, + max_position=2048, + base=10000): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.hidden_size = head_dim * num_heads + + self.qkv_proj = torch.nn.Linear(self.hidden_size, + self.hidden_size * 3, + bias=False, + dtype=torch.float16) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=base, + ) + + def forward(self, positions, hidden_states): + # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding + # -> slice_scatter -> split_with_sizes + + qkv = self.qkv_proj(hidden_states) + split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size] + q, k, v = torch.split(qkv, split_sizes, dim=-1) + + q_rotated, k_rotated = self.rotary_emb(positions, q, k) + + qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1) + return qkv_updated + + def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): + dtype = torch.float16 + hidden_size = head_dim * num_heads + positions = torch.arange(num_tokens, dtype=torch.long) + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (positions, hidden_states) + + def ops_in_model(self, do_fusion): + return [torch.ops._C.rotary_embedding.default] + + def ops_not_in_model(self): + return [torch.ops.aten.slice_scatter.default] + + +MODELS = [ + TestSiluMul, + TestFusedAddRMSNorm, + TestRotaryEmbedding, + TestRotaryEmbeddingSliceScatter, ] -@pytest.mark.parametrize( - "model, quant_key", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", - kFp8DynamicTokenSym)]) +@pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model: str, quant_key: QuantKey, - do_fusion: bool): +def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): torch.set_default_device("cuda") vllm_config = VllmConfig() @@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, cleanup_pass = PostCleanupPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass - ] if do_fusion else [noop_pass, cleanup_pass] + passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion else [noop_pass, cleanup_pass]) func_pass = FixFunctionalizationPass(vllm_config) + backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) - # instantiate a full engine and manually compile the model 2x - # (with and without FixFunctionalizationPass) - llm = LLM(model=model, enforce_eager=True) - model_runner = llm.llm_engine.model_executor.driver_worker.model_runner - orig_model = model_runner.model - # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) - # Can only do that by using the decorator but then we'd have to instantiate - # 2 LLM instances. - - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_func) - gen_func = llm.generate(prompts, sampling_params) - - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_no_func) - - gen_no_func = llm.generate(prompts, sampling_params) - - for output_func, output_no_func in zip(gen_func, gen_no_func): - assert output_func.outputs[0].text == output_no_func.outputs[0].text - - # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, - # and replaced by fused quantized ops in RMS_QUANT_OPS. - rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] - ] if do_fusion else [RMS_OP] - silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ - quant_key == kFp8StaticTensorSym else [ - SILU_MUL_OP - ] - - ops = OPS_IN_MODEL + rms_ops + silu_mul_ops - - for op in ops: + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + + # check if the functionalization pass is applied + for op in model.ops_in_model(do_fusion): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) + is None) # noqa: E501 # make sure the ops were all de-functionalized found = dict() for node in backend_func.graph_post_pass.nodes: - for op in ops: + for op in model.ops_in_model(do_fusion): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): if is_func(node, op): found[op] = True - assert all(found[op] for op in ops) + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 54403c1f7ca3..ce6db9c1ebca 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -46,23 +46,43 @@ def __call__(self, graph: torch.fx.Graph): if at_target == torch.ops._C.rotary_embedding.default: query = kwargs['query'] - mm_node = query.args[0].args[0] - - # rotary_embedding is a special case: the two mutating inputs - # are query and key, which are slices of mm_node. - # While functionalized, results at[1] and at[2] are scattered - # back into mm_node. After de-functionalization, we can just - # use mm_node directly. - for idx, user in self.getitem_users(node).items(): - for user_of_getitem in user.users: - if is_func(user_of_getitem, - torch.ops.aten.slice_scatter.default): - user_of_getitem.replace_all_uses_with(mm_node) - self._remove(user_of_getitem) - self._remove(user) - - self.insert_defunctionalized(graph, node) - self._remove(node) + key = kwargs['key'] + getitem_nodes = self.getitem_users(node) + + if (is_func(query, operator.getitem) + and is_func(key, operator.getitem) + and query.args[0] == key.args[0] + and is_func(query.args[0], + torch.ops.aten.split_with_sizes.default) + and all( + is_func(user, torch.ops.aten.slice_scatter.default) + for getitem_node in getitem_nodes.values() + for user in getitem_node.users)): + # Pattern where query and key are slices of an mm_node. + # While functionalized, results at [1] and [2] are scattered + # back into mm_node. So after de-functionalization, we can + # just use mm_node directly. + + mm_node = query.args[0].args[0] + for user in getitem_nodes.values(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + else: + # Directly replace the auto_functionalize(rotary_embedding) + # with the inplace rotary_embedding. In theory, we shouldn't + # do this blindly, but in practice in vLLM it's ok. The best + # solution is to use auto_functionalization_v2 and then use + # inductor's builtin defunctionalization (reinplacing) pass. + mutated_args = {1: 'query', 2: 'key'} + self.defunctionalize(graph, node, mutated_args) # rms_norm replacements avoid the most copies for LLaMa. elif at_target == torch.ops._C.fused_add_rms_norm.default: