diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..997726747d73 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -21,6 +21,9 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass +if current_platform.is_rocm(): + from .rocm_fusion import ROCmFusionPass + from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass @@ -100,6 +103,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_fusion: self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] + if current_platform.is_rocm(): + self.passes += [ROCmFusionPass(config)] if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py new file mode 100644 index 000000000000..197e33146467 --- /dev/null +++ b/vllm/compilation/rocm_fusion.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, NamedTuple + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) +from vllm.platforms import current_platform + +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") + + +def empty_fp4(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda") + + +class SiluMulMXFP4GemmPattern: + def __init__(self): + pass + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default, + result=result, + x=at1[1], + weight=weight, + weight_scale=scale, + x_scales=None) + return at2[1] + + + def replacement(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default, + result=result, + x=input, + weight=weight, + weight_scale=scale) + return at[1] + + inputs = [ + empty_bf16(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp4(5, 4), # weight + empty_fp4(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + ) + + +ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default + + +class AddRMSNormMXFP4GemmPattern: + def __init__(self, epsilon: float): + self.epsilon = epsilon + self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default + self.QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(ADD_RMS_OP, + input=input, + residual=residual, + weight=weight_rms, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_F4GEMM_OP, + result=result, + x=at1[1], + weight=weight_gemm, + weight_scale=scale, + x_scales=None) + return at2[1], at1[2] + + def replacement( + result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + residual_out=residual, + weight_rms=weight_rms, + weight_gemm=weight_gemm, + scale=scale, + epsilon=self.epsilon) + return at[1], at[2] + + inputs = [ + empty_bf16(4, 4), # result + empty_bf16(4, 4), # input + empty_bf16(4, 4), # residual + empty_bf16(1, 4), # weight_rms + empty_fp4(4, 4), # weight_gemm + empty_fp4(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass) + + +class ROCmFusionPass(VllmPatternMatcherPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_fusion_pass") + + SiluMulMXFP4GemmPattern().register(self.patterns) + + for epsilon in [1e-5, 1e-6]: + AddRMSNormMXFP4GemmPattern(epsilon).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + return self.hash_source(self, SiluMulMXFP4GemmPattern) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index f8628a82277b..a6731d21982b 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -24,24 +24,28 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: try: + import triton from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant + from aiter.ops.triton.activation import act_mul_and_mxfp4_quant + from aiter.ops.triton.fused_mxfp4_quant import _fused_rms_mxfp4_quant_kernel from vllm.utils import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_with_dynamic_quant( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + x_scales: Optional[torch.Tensor] = None, rocm_use_aiter_fp4_asm_gemm: bool = False, out_dtype: Optional[torch.dtype] = torch.bfloat16, - x_scales: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - M = x.shape[0] + ) -> None: if rocm_use_aiter_fp4_asm_gemm: + M = x.shape[0] if x_scales is None: # use hip quant kernel for performance x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) @@ -62,22 +66,17 @@ def gemm_with_dynamic_quant( weight_scale.view(x_s.dtype), y, bpreshuffle=True) - return y[:M] + result.copy_(y[:M]) else: if x_scales is None: x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x x_s = x_scales - y = torch.empty(x_q.shape[0], - weight.shape[0], - device=x_q.device, - dtype=out_dtype) - - gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) - return y + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, result) def gemm_with_dynamic_quant_fake( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, @@ -85,18 +84,110 @@ def gemm_with_dynamic_quant_fake( rocm_use_aiter_fp4_asm_gemm: bool = False, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: - return torch.empty((*x.shape[:-1], weight.shape[0]), - dtype=out_dtype, - device=x.device) + return direct_register_custom_op( op_name="gemm_with_dynamic_quant", op_func=gemm_with_dynamic_quant, - mutates_args=[], + mutates_args=['result'], fake_impl=gemm_with_dynamic_quant_fake, dispatch_key=current_platform.dispatch_key, ) + def silu_and_mul_mxfp4_gemm( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + x_fp4, blockscale_e8m0 = act_mul_and_mxfp4_quant(x, 'silu') + gemm_with_dynamic_quant(result, x_fp4, weight, weight_scale, blockscale_e8m0, rocm_use_aiter_fp4_asm_gemm, out_dtype) + + def silu_and_mul_mxfp4_gemm_fake( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return + + direct_register_custom_op( + op_name="silu_and_mul_mxfp4_gemm", + op_func=silu_and_mul_mxfp4_gemm, + mutates_args=['result'], + fake_impl=silu_and_mul_mxfp4_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + def add_rmsnorm_mxfp4_gemm( + result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + MXFP4_QUANT_BLOCK_SIZE = 32 + M, N1 = input.shape + BLOCK_SIZE = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE) + BLOCK_SIZE = max(BLOCK_SIZE, MXFP4_QUANT_BLOCK_SIZE) + res_row_stride = residual.stride(0) + out_res_row_stride = residual_out.stride(0) + rms_out_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=input.device) + rms_out_bs = torch.empty( + ((N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE, M), + dtype=torch.uint8, + device=input.device, + ).T + _fused_rms_mxfp4_quant_kernel[(M,)]( + input, + weight_rms, + None, + None, + residual, + rms_out_fp4, + rms_out_bs, + None, + residual_out, + epsilon, + 0.0, + M, + N1, + 0, + input.stride(0), + 0, + res_row_stride, + rms_out_fp4.stride(0), + *rms_out_bs.stride(), + 0, + out_res_row_stride, + BLOCK_SIZE=BLOCK_SIZE, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + SKIP_SECOND_INPUT=True, + FIRST_INPUT_RES=True, + ) + gemm_with_dynamic_quant(result, rms_out_fp4, weight_gemm, scale, rms_out_bs, rocm_use_aiter_fp4_asm_gemm, out_dtype) + + def add_rmsnorm_mxfp4_gemm_fake( + result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return + + direct_register_custom_op( + op_name="add_rmsnorm_mxfp4_gemm", + op_func=add_rmsnorm_mxfp4_gemm, + mutates_args=['result', 'residual_out'], + fake_impl=add_rmsnorm_mxfp4_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + except ImportError: dynamic_mxfp4_quant = gemm_afp4wfp4 = None @@ -234,6 +325,7 @@ def apply_weights(self, x = quant_dequant_mxfp4(x) return F.linear(x, dq_w, bias) else: - return torch.ops.vllm.gemm_with_dynamic_quant( - x, layer.weight, layer.weight_scale, - self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype) + result = torch.empty((*x.shape[:-1], layer.weight.shape[0]), dtype=self.out_dtype, device=x.device) + torch.ops.vllm.gemm_with_dynamic_quant( + result, x, layer.weight, layer.weight_scale, None, self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype) + return result