diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index c689befdf2da..b56edfc90612 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -6,7 +6,9 @@ import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, PassConfig, VllmConfig) @@ -14,12 +16,15 @@ from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform from vllm.utils import update_environment_variables from ..utils import multi_gpu_test from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() prompts = [ "Hello, my name is", "The president of the United States is", @@ -30,13 +35,16 @@ class TestModel(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): + def __init__(self, + hidden_size=16, + intermediate_size=32, + vllm_config: VllmConfig = None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size))) - self.norm = RMSNorm(hidden_size, 1e-05) + self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -79,32 +87,138 @@ def ops_in_model(self): return [torch.ops._C.fused_add_rms_norm.default] +class TestQuantModel(torch.nn.Module): + + def __init__(self, + hidden_size=16, + intermediate_size=32, + vllm_config: VllmConfig = None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.vllm_config = vllm_config + self.gate_proj = torch.nn.Parameter(torch.empty( + (intermediate_size, hidden_size)), + requires_grad=False) + self.norm = RMSNorm(intermediate_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, + use_per_token_if_dynamic=False) + + self.scale = torch.rand(1, dtype=torch.float32) + # Create a weight that is compatible with torch._scaled_mm, + # which expects a column-major layout. + self.w = torch.rand(hidden_size, + intermediate_size).to(dtype=FP8_DTYPE).t() + self.wscale = torch.rand(1, dtype=torch.float32) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + # for static input quantization + # self.fp8_linear is initialized with use_per_token_if_dynamic=False + fp8_linear_result = self.fp8_linear.apply(norm_output, + self.w, + self.wscale, + input_scale=self.scale.to( + norm_output.device)) + + return fp8_linear_result, residual_output + + def ops_in_model_before(self): + ops_to_remove = [torch.ops.vllm.all_reduce.default + ] # Always removed by SP + # The following are only removed if fusion happens + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + ops_to_remove.extend([ + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.static_scaled_fp8_quant.default, + ]) + return ops_to_remove + + def ops_in_model_after(self): + ops_to_add = [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + # The following is only added if fusion happens + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + ops_to_add.append( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) + return ops_to_add + + def ops_in_model(self): + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + # If fusion happens, the fused op is the one + # we check for (de)functionalization + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + ] # noqa: E501 + else: + # If no fusion, the original ops are checked + return [ + torch.ops._C.fused_add_rms_norm.default, + # TODO functionalization pass does not handle this yet + # torch.ops._C.static_scaled_fp8_quant.default, + ] + + @multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("enable_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -def test_sequence_parallelism_pass(batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype, + enable_fusion: bool): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda torch.multiprocessing.spawn(fn, - args=(num_processes, batch_size, seq_len, - hidden_size, dtype), + args=(num_processes, test_model_cls, + batch_size, seq_len, hidden_size, + dtype, enable_fusion), nprocs=nprocs) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) -def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, - batch_size: int, seq_len: int, - hidden_size: int, - dtype: torch.dtype): +def sequence_parallelism_pass_on_test_model( + local_rank: int, world_size: int, + test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype, enable_fusion: bool): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_sequence_parallelism=True)) + enable_sequence_parallelism=True, + enable_fusion=enable_fusion, + enable_noop=True)) # NoOp needed for fusion vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model, + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, task="auto", - tokenizer=model, + tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, seed=42) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - backend_no_func = TestBackend(sequence_parallelism_pass) + noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(sequence_parallelism_pass, func_pass) - model = TestModel(hidden_size, hidden_size * 2) + passes_for_backend = [noop_pass, sequence_parallelism_pass] + + if enable_fusion: + fusion_pass = FusionPass.instance(vllm_config) + passes_for_backend.append(fusion_pass) + + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) + + model = test_model_cls(hidden_size, + hidden_size * 2, + vllm_config=vllm_config) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 91a594eac5c4..b2f6a8ab9dd3 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -28,7 +28,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int - sp_enabled: bool + enable_fusion: bool eager_mode: bool chunked_prefill: bool @@ -67,49 +67,18 @@ def detailed( task: TaskOption = "auto", load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True) - ], + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -126,19 +95,44 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ + parallel_setups=parallel_setups, + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fp8_quant( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + parallel_setups = [] + for fusion_val in [False, True]: + parallel_setups.append( ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ], + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False)) + return SPTestSettings( + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -171,7 +165,7 @@ def _compare_sp( ( tp_size, pp_size, - sp_enabled, + enable_fusion, eager_mode, chunked_prefill, ) = parallel_setup @@ -240,9 +234,9 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallelism': sp_enabled, + 'enable_sequence_parallelism': True, + 'enable_fusion': enable_fusion, 'enable_noop': True, - 'enable_fusion': True, }, } @@ -291,12 +285,14 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ] diff --git a/tests/models/registry.py b/tests/models/registry.py index 49510af880cf..4a587e39ad4c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -193,7 +193,8 @@ def check_available_online( extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 - "hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501 + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 + "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 9d908fcae3df..951a2861e3a4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -345,8 +345,8 @@ def process(self): # 0 is always None fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} self.insert_fused_node(fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - **kwargs) + **kwargs, + epsilon=rms_node.kwargs["epsilon"]) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 28a59905ecf8..3ce00e3610c5 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -51,15 +51,15 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_noop: self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_fusion: - self.passes += [FusionPass.instance(config)] - self.passes += [ActivationQuantFusionPass(config)] - if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] + self.passes += [ActivationQuantFusionPass(config)] + if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index d41093903480..6107046e40dc 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -12,91 +12,142 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) -class AllReduceRMSNormPattern: +class _RMSNormAndQuantOpHelper: + """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): self.epsilon = epsilon self.dtype = dtype self.device = device - - -class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + self.quant_op = quant_op + + def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, + weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, + quant_result_buffer, input_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer, + input_tensor, + weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple + + def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, + input_tensor, residual_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( + input_tensor, residual_tensor, weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + + +class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): + """Helper for sequence parallelism patterns.""" + + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): + super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + +class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): - arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) - mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], - device=self.device, - dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + return [input, permute, arg3_1] def register(self, pm_pass: PatternMatcherPass): def pattern( - arg2_1: torch.Tensor, - mul_6: torch.Tensor, - unsqueeze: torch.Tensor, - full_default: torch.Tensor, + input: torch.Tensor, permute: torch.Tensor, arg3_1: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) - where = torch.ops.aten.where.self(unsqueeze, full_default, - embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=permute, - input=all_reduce, - weight=arg3_1, - epsilon=self.epsilon, - ) + all_reduce = self._all_reduce(input) + rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) return rmsnorm[1], all_reduce def replacement( - arg2_1: torch.Tensor, - mul_6: torch.Tensor, - unsqueeze: torch.Tensor, - full_default: torch.Tensor, + input: torch.Tensor, permute: torch.Tensor, arg3_1: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) - where = torch.ops.aten.where.self(unsqueeze, full_default, - embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(input) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=arg3_1, - epsilon=self.epsilon, - ) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, + arg3_1) - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter @@ -104,7 +155,7 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -127,16 +178,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1], rmsnorm[2] def replacement( @@ -144,32 +188,17 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -192,16 +221,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1] def replacement( @@ -209,26 +231,185 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + normalized = self._all_gather(rmsnorm[1]) + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +FP8_DTYPE = current_platform.fp8_dtype() + + +class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, rmsnorm_result, quant_result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + all_reduce = self._all_reduce(input) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, all_reduce, weight, scale) + return static_fp8[1], all_reduce + + def replacement( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + reduce_scatter = self._reduce_scatter(input) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=rmsnorm_result.dtype) + quant_result = torch.empty_like( + rmsnorm_result, # Output of RMSNorm + dtype=quant_result.dtype) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(static_fp8[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1], rmsnorm_residual_out + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + all_gather = self._all_gather(static_fp8[1]) + return all_gather, rmsnorm_residual_out + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1] + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + normalized = self._all_gather(static_fp8[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -236,21 +417,54 @@ def replacement( class SequenceParallelismPass(VllmInductorPass): + """ + This pass enables sequence parallelism for models. + It identifies patterns where an AllReduce operation is followed by + an RMSNorm (or RMSNorm and then Quantization) operation. + These patterns are replaced with a ReduceScatter operation, followed by + a local RMSNorm/Quantization, and then an AllGather operation. + + The general transformation is: + Input -> AllReduce -> RMSNorm -> Output + becomes + Input -> ReduceScatter -> RMSNorm -> AllGather -> Output + + While this pass itself does not directly yield performance improvements, + it lays the groundwork for subsequent fusion passes, such as + GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can + significantly reduce communication overhead and improve overall model + performance. + """ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: - EmbeddingAllReduceRMSNormPattern( - epsilon, self.model_dtype, self.device).register(self.patterns) + # RMSNorm + Static FP8 quantization patterns + fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default + FirstAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + + # Normal RMSNorm patterns + FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() diff --git a/vllm/config.py b/vllm/config.py index ce7e2a2929cf..9698103d38ac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3800,11 +3800,11 @@ class PassConfig: its own stages (before, after, maybe in-between).""" dump_graph_dir: Path = Path(".") """Directory to dump the graphs.""" - enable_fusion: bool = True + enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" enable_attn_fusion: bool = False """Whether to enable the custom attention+quant fusion pass.""" - enable_noop: bool = True + enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom no-op elimination pass.""" enable_sequence_parallelism: bool = False """Whether to enable sequence parallelism.""" @@ -4449,8 +4449,6 @@ def __post_init__(self): # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph # is set to True, full CUDA graphs will be used. self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.pass_config.enable_fusion = False - self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1()