Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass

if current_platform.is_rocm():
from .rocm_fusion import ROCmFusionPass

from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
Expand Down Expand Up @@ -100,6 +103,8 @@ def configure(self, config: VllmConfig):
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if current_platform.is_rocm():
self.passes += [ROCmFusionPass(config)]

if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
Expand Down
175 changes: 175 additions & 0 deletions vllm/compilation/rocm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, NamedTuple

import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform

from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass

logger = init_logger(__name__)


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")


def empty_fp4(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda")


class SiluMulMXFP4GemmPattern:
def __init__(self):
pass

def register(self, pm_pass: PatternMatcherPass):

def pattern(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default,
result=result,
x=at1[1],
weight=weight,
weight_scale=scale,
x_scales=None)
return at2[1]


def replacement(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default,
result=result,
x=input,
weight=weight,
weight_scale=scale)
return at[1]

inputs = [
empty_bf16(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp4(5, 4), # weight
empty_fp4(1, 1), # scale
]

pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
)


ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default


class AddRMSNormMXFP4GemmPattern:
def __init__(self, epsilon: float):
self.epsilon = epsilon
self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default
self.QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default

def register(self, pm_pass: PatternMatcherPass):

def pattern(
result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(ADD_RMS_OP,
input=input,
residual=residual,
weight=weight_rms,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_F4GEMM_OP,
result=result,
x=at1[1],
weight=weight_gemm,
weight_scale=scale,
x_scales=None)
return at2[1], at1[2]

def replacement(
result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
residual=residual,
residual_out=residual,
weight_rms=weight_rms,
weight_gemm=weight_gemm,
scale=scale,
epsilon=self.epsilon)
return at[1], at[2]

inputs = [
empty_bf16(4, 4), # result
empty_bf16(4, 4), # input
empty_bf16(4, 4), # residual
empty_bf16(1, 4), # weight_rms
empty_fp4(4, 4), # weight_gemm
empty_fp4(1, 1), # scale
]

pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass)


class ROCmFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
"""

@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_fusion_pass")

SiluMulMXFP4GemmPattern().register(self.patterns)

for epsilon in [1e-5, 1e-6]:
AddRMSNormMXFP4GemmPattern(epsilon).register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self) -> Any:
return self.hash_source(self, SiluMulMXFP4GemmPattern)
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,28 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:


try:
import triton
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.ops.triton.activation import act_mul_and_mxfp4_quant
from aiter.ops.triton.fused_mxfp4_quant import _fused_rms_mxfp4_quant_kernel

from vllm.utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

def gemm_with_dynamic_quant(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: Optional[torch.Tensor] = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
x_scales: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = x.shape[0]
) -> None:
if rocm_use_aiter_fp4_asm_gemm:
M = x.shape[0]
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
Expand All @@ -62,41 +66,128 @@ def gemm_with_dynamic_quant(
weight_scale.view(x_s.dtype),
y,
bpreshuffle=True)
return y[:M]
result.copy_(y[:M])
else:
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(x_q.shape[0],
weight.shape[0],
device=x_q.device,
dtype=out_dtype)

gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, result)

def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
return
Comment on lines 80 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hints for gemm_with_dynamic_quant_fake are incorrect and do not match the function's implementation or its real counterpart gemm_with_dynamic_quant.

  1. The x_scales parameter is type-hinted as torch.Tensor but has a default value of None. The type hint should be Optional[torch.Tensor] to reflect this.
  2. The function is type-hinted to return torch.Tensor, but it implicitly returns None. The return type hint should be None.

These inconsistencies can cause issues with static type checkers and torch.compile's fake tensor propagation, which relies on correct function signatures for its analysis.

Suggested change
def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
return
def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: Optional[torch.Tensor] = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> None:
return


direct_register_custom_op(
op_name="gemm_with_dynamic_quant",
op_func=gemm_with_dynamic_quant,
mutates_args=[],
mutates_args=['result'],
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

def silu_and_mul_mxfp4_gemm(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
x_fp4, blockscale_e8m0 = act_mul_and_mxfp4_quant(x, 'silu')
gemm_with_dynamic_quant(result, x_fp4, weight, weight_scale, blockscale_e8m0, rocm_use_aiter_fp4_asm_gemm, out_dtype)

def silu_and_mul_mxfp4_gemm_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
return

direct_register_custom_op(
op_name="silu_and_mul_mxfp4_gemm",
op_func=silu_and_mul_mxfp4_gemm,
mutates_args=['result'],
fake_impl=silu_and_mul_mxfp4_gemm_fake,
dispatch_key=current_platform.dispatch_key,
)

def add_rmsnorm_mxfp4_gemm(
result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
MXFP4_QUANT_BLOCK_SIZE = 32
M, N1 = input.shape
BLOCK_SIZE = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE)
BLOCK_SIZE = max(BLOCK_SIZE, MXFP4_QUANT_BLOCK_SIZE)
res_row_stride = residual.stride(0)
out_res_row_stride = residual_out.stride(0)
rms_out_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=input.device)
rms_out_bs = torch.empty(
((N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE, M),
dtype=torch.uint8,
device=input.device,
).T
_fused_rms_mxfp4_quant_kernel[(M,)](
input,
weight_rms,
None,
None,
residual,
rms_out_fp4,
rms_out_bs,
None,
residual_out,
epsilon,
0.0,
M,
N1,
0,
input.stride(0),
0,
res_row_stride,
rms_out_fp4.stride(0),
*rms_out_bs.stride(),
0,
out_res_row_stride,
BLOCK_SIZE=BLOCK_SIZE,
MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE,
SKIP_SECOND_INPUT=True,
FIRST_INPUT_RES=True,
)
gemm_with_dynamic_quant(result, rms_out_fp4, weight_gemm, scale, rms_out_bs, rocm_use_aiter_fp4_asm_gemm, out_dtype)

def add_rmsnorm_mxfp4_gemm_fake(
result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
return

direct_register_custom_op(
op_name="add_rmsnorm_mxfp4_gemm",
op_func=add_rmsnorm_mxfp4_gemm,
mutates_args=['result', 'residual_out'],
fake_impl=add_rmsnorm_mxfp4_gemm_fake,
dispatch_key=current_platform.dispatch_key,
)

except ImportError:
dynamic_mxfp4_quant = gemm_afp4wfp4 = None

Expand Down Expand Up @@ -234,6 +325,7 @@ def apply_weights(self,
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x, layer.weight, layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
result = torch.empty((*x.shape[:-1], layer.weight.shape[0]), dtype=self.out_dtype, device=x.device)
torch.ops.vllm.gemm_with_dynamic_quant(
result, x, layer.weight, layer.weight_scale, None, self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
return result
Loading