Skip to content

Commit 6659b99

Browse files
committed
Add aiter tkw1 kernel for fp8
Signed-off-by: kliuae <[email protected]>
1 parent fdcb850 commit 6659b99

File tree

4 files changed

+184
-2
lines changed

4 files changed

+184
-2
lines changed

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="8970b25b"
15+
ARG AITER_BRANCH="5a77249"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
VLLM_ROCM_USE_AITER_LINEAR: bool = True
7979
VLLM_ROCM_USE_AITER_MOE: bool = True
8080
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
81+
VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE: bool = False
8182
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
8283
VLLM_ROCM_FP8_PADDING: bool = True
8384
VLLM_ROCM_MOE_PADDING: bool = True
@@ -553,6 +554,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
553554
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
554555
("true", "1")),
555556

557+
# Whether to use aiter channel scaled moe kernel.
558+
# By default this is disabled.
559+
"VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE":
560+
lambda:
561+
(os.getenv("VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE", "false").lower() in
562+
("true", "1")),
563+
556564
# use aiter rms norm op if aiter ops are enabled.
557565
"VLLM_ROCM_USE_AITER_RMSNORM":
558566
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,50 @@ def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
1818
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
1919

2020

21+
def is_rocm_aiter_channel_scaled_moe_enabled() -> bool:
22+
return is_rocm_aiter_moe_enabled() and \
23+
envs.VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE
24+
25+
26+
def asm_moe_tkw1_impl(sorted_ids: torch.Tensor,
27+
sorted_weights: torch.Tensor,
28+
sorted_expert_ids: torch.Tensor,
29+
num_valid_ids: torch.Tensor,
30+
moe_buf: torch.Tensor,
31+
hidden_states: torch.Tensor,
32+
w1: torch.Tensor,
33+
w2: torch.Tensor,
34+
topk_weight: torch.Tensor,
35+
topk_ids: torch.Tensor,
36+
fc1_scale: Optional[torch.Tensor] = None,
37+
fc2_scale: Optional[torch.Tensor] = None,
38+
fc1_smooth_scale: Optional[torch.Tensor] = None,
39+
fc2_smooth_scale: Optional[torch.Tensor] = None,
40+
activation_str: str = "silu") -> None:
41+
import aiter as rocm_aiter
42+
43+
if activation_str == "silu":
44+
activation = rocm_aiter.ActivationType.Silu
45+
elif activation_str == "gelu":
46+
activation = rocm_aiter.ActivationType.Gelu
47+
else:
48+
activation = rocm_aiter.ActivationType.Silu
49+
50+
E, model_dim, _ = w2.shape
51+
M, topk = topk_ids.shape
52+
device = topk_ids.device
53+
54+
a8_type = (w1.dtype if w1.dtype != torch.int32 and w1.dtype != torch.uint32
55+
else torch.float8_e4m3fnuz)
56+
a8 = torch.empty((M, model_dim), dtype=a8_type, device=device)
57+
a8_scale = torch.empty(M, dtype=torch.float, device=device)
58+
rocm_aiter.dynamic_per_token_scaled_fp8_quant(a8, hidden_states, a8_scale)
59+
fmoe_func = rocm_aiter.fmoe_g1u1_tkw1
60+
fmoe_func(moe_buf, a8, w1, w2, sorted_ids, sorted_weights,
61+
sorted_expert_ids, num_valid_ids, topk, a8_scale, fc1_scale,
62+
fc2_scale, fc2_smooth_scale, activation)
63+
64+
2165
def rocm_aiter_fused_experts(
2266
*,
2367
hidden_states: torch.Tensor,
@@ -26,10 +70,12 @@ def rocm_aiter_fused_experts(
2670
topk_weights: torch.Tensor,
2771
topk_ids: torch.Tensor,
2872
use_fp8_w8a8: bool = False,
73+
apply_router_weight_on_input: bool = False,
2974
w1_scale: Optional[torch.Tensor] = None,
3075
w2_scale: Optional[torch.Tensor] = None,
3176
block_shape: Optional[List[int]] = None,
3277
expert_mask: Optional[torch.Tensor] = None,
78+
activation: str = "silu",
3379
**kwagrs # Ignore additional keyword arguments
3480
) -> torch.Tensor:
3581

@@ -38,8 +84,22 @@ def rocm_aiter_fused_experts(
3884

3985
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
4086
per_token_group_quant_fp8)
87+
88+
if apply_router_weight_on_input:
89+
_, topk = topk_weights.shape
90+
assert (
91+
topk == 1
92+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
93+
94+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
95+
topk_ids = topk_ids.to(torch.int32)
96+
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
97+
98+
if is_rocm_aiter_block_scaled_moe_enabled() and use_fp8_w8a8:
99+
assert not apply_router_weight_on_input, (
100+
"apply_router_weight_on_input is not supported for block scaled moe"
101+
)
41102

42-
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
43103
assert w1_scale is not None
44104
assert w2_scale is not None
45105

@@ -88,8 +148,53 @@ def rocm_aiter_fused_experts(
88148
None,
89149
)
90150
return out_asm
151+
152+
elif is_rocm_aiter_channel_scaled_moe_enabled() and use_fp8_w8a8:
153+
topk_weights = topk_weights.to(torch.float32)
154+
topk_ids = topk_ids.to(torch.int32)
155+
156+
E, model_dim, _ = w2.shape
157+
dtype = hidden_states.dtype
158+
159+
if expert_mask is not None:
160+
E = expert_mask.numel()
161+
162+
(
163+
sorted_token_ids,
164+
sorted_weight_buf,
165+
sorted_expert_ids,
166+
num_valid_ids,
167+
out_asm,
168+
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
169+
topk_weights,
170+
E,
171+
model_dim,
172+
dtype,
173+
expert_mask=expert_mask)
174+
175+
asm_moe_tkw1_impl(
176+
sorted_ids=sorted_token_ids,
177+
sorted_weights=sorted_weight_buf,
178+
sorted_expert_ids=sorted_expert_ids,
179+
num_valid_ids=num_valid_ids,
180+
moe_buf=out_asm,
181+
hidden_states=hidden_states,
182+
w1=w1,
183+
w2=w2,
184+
topk_weight=topk_weights,
185+
topk_ids=topk_ids,
186+
fc1_scale=w1_scale,
187+
fc2_scale=w2_scale,
188+
fc1_smooth_scale=None,
189+
fc2_smooth_scale=None,
190+
activation_str=activation)
191+
192+
return out_asm
91193

92194
elif use_fp8_w8a8:
195+
assert not apply_router_weight_on_input, (
196+
"apply_router_weight_on_input is not supported for fp8_w8a8")
197+
93198
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
94199
w1=w1,
95200
w2=w2,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from vllm.logger import init_logger
1515
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1616
FusedMoeWeightScaleSupported)
17+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
18+
is_rocm_aiter_channel_scaled_moe_enabled, rocm_aiter_fused_experts,
19+
shuffle_weights)
1720
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1821
WNA16_SUPPORTED_BITS)
1922
from vllm.model_executor.layers.quantization.utils import replace_parameter
@@ -36,6 +39,7 @@ class GPTQMarlinState(Enum):
3639
"CompressedTensorsW8A8Fp8MoECutlassMethod",
3740
"CompressedTensorsWNA16MarlinMoEMethod",
3841
"CompressedTensorsWNA16MoEMethod",
42+
"CompressedTensorsW8A8Fp8MoEAiterMethod",
3943
]
4044

4145

@@ -70,6 +74,8 @@ def get_moe_method(
7074
and layer.activation == "silu" and layer.expert_map is None):
7175
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
7276
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
77+
if is_rocm_aiter_channel_scaled_moe_enabled():
78+
return CompressedTensorsW8A8Fp8MoEAiterMethod(quant_config)
7379
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
7480
else:
7581
raise RuntimeError(
@@ -302,6 +308,69 @@ def apply(
302308
a2_scale=layer.w2_input_scale)
303309

304310

311+
class CompressedTensorsW8A8Fp8MoEAiterMethod(CompressedTensorsW8A8Fp8MoEMethod
312+
):
313+
314+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
315+
super().process_weights_after_loading(layer)
316+
317+
# reshaping weights is required for aiter moe kernel.
318+
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
319+
layer.w2_weight.data)
320+
321+
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
322+
requires_grad=False)
323+
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
324+
325+
def apply(
326+
self,
327+
layer: torch.nn.Module,
328+
x: torch.Tensor,
329+
router_logits: torch.Tensor,
330+
top_k: int,
331+
renormalize: bool,
332+
use_grouped_topk: bool = False,
333+
topk_group: Optional[int] = None,
334+
num_expert_group: Optional[int] = None,
335+
global_num_experts: int = -1,
336+
expert_map: Optional[torch.Tensor] = None,
337+
custom_routing_function: Optional[Callable] = None,
338+
scoring_func: str = "softmax",
339+
e_score_correction_bias: Optional[torch.Tensor] = None,
340+
apply_router_weight_on_input: bool = False,
341+
activation: str = "silu",
342+
) -> torch.Tensor:
343+
344+
assert activation in ["silu", "gelu"]
345+
assert global_num_experts == layer.w13_weight.shape[0]
346+
assert expert_map is None
347+
348+
topk_weights, topk_ids = FusedMoE.select_experts(
349+
hidden_states=x,
350+
router_logits=router_logits,
351+
use_grouped_topk=use_grouped_topk,
352+
top_k=top_k,
353+
renormalize=renormalize,
354+
topk_group=topk_group,
355+
num_expert_group=num_expert_group,
356+
custom_routing_function=custom_routing_function,
357+
scoring_func=scoring_func,
358+
e_score_correction_bias=e_score_correction_bias)
359+
360+
return rocm_aiter_fused_experts(
361+
hidden_states=x,
362+
w1=layer.w13_weight,
363+
w2=layer.w2_weight,
364+
topk_weights=topk_weights,
365+
topk_ids=topk_ids,
366+
use_fp8_w8a8=True,
367+
w1_scale=layer.w13_weight_scale,
368+
w2_scale=layer.w2_weight_scale,
369+
activation=activation,
370+
expert_map=expert_map,
371+
apply_router_weight_on_input=apply_router_weight_on_input)
372+
373+
305374
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
306375

307376
def __init__(

0 commit comments

Comments
 (0)