-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[ROCm] Add aiter tkw1 kernel for Llama4 fp8 #16727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: kliuae <[email protected]>
88e60fb
to
6659b99
Compare
Signed-off-by: kliuae <[email protected]>
Co-authored-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]>
vllm/envs.py
Outdated
VLLM_ROCM_USE_AITER_LINEAR: bool = True | ||
VLLM_ROCM_USE_AITER_MOE: bool = True | ||
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False | ||
VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the env name more align with the kernel name , in this case, to include tkw1 in the name?
|
||
|
||
def is_rocm_aiter_channel_scaled_moe_enabled() -> bool: | ||
return is_rocm_aiter_moe_enabled() and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this tkw1 enablement need to depend on is_rocm_aiter_moe_enabled() ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this enablement we are following the block_scaled_moe case in using VLLM_ROCM_USE_AITER_MOE as a master switch for enabling MoE ops, to stay consistent with the other aiter kernels.
if activation_str == "silu": | ||
activation = ActivationType.Silu | ||
elif activation_str == "gelu": | ||
activation = ActivationType.Gelu | ||
else: | ||
activation = ActivationType.Silu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified to one-liner ?
if activation_str == "silu": | |
activation = ActivationType.Silu | |
elif activation_str == "gelu": | |
activation = ActivationType.Gelu | |
else: | |
activation = ActivationType.Silu | |
activation = ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an additional wrapper for the _tkw1 kernel, given that it’s just a kernel call plus an activation type conversion? the activation type can also used by other branches / kernel calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are wrapping the kernel call because in our future PR addressing the enablement of torch compile for aiter MoE kernels, we will be using wrappers to register the aiter ops, and so we thought to leave it here for now.
# # All AITER Fused MoE kernels are expecting the following datatypes | ||
# topk_weights = topk_weights.to(torch.float32) | ||
# topk_ids = topk_ids.to(torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# # All AITER Fused MoE kernels are expecting the following datatypes | |
# topk_weights = topk_weights.to(torch.float32) | |
# topk_ids = topk_ids.to(torch.int32) |
# topk_weights = topk_weights.to(torch.float32) | ||
# topk_ids = topk_ids.to(torch.int32) | ||
|
||
return rocm_aiter_asm_moe_tkw1(hidden_states, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assert apply_router_weight_on_input=True
or do the if branch check when calling the _tkw1
kernel? btw, we should have some comments to illustrate the difference between _tkw1
kernel and other aiter kernels. The difference is on applying topk_weights on the output of the first GEMM or the second GEMM
if activation_str == "silu": | ||
activation = ActivationType.Silu | ||
elif activation_str == "gelu": | ||
activation = ActivationType.Gelu | ||
else: | ||
activation = ActivationType.Silu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an additional wrapper for the _tkw1 kernel, given that it’s just a kernel call plus an activation type conversion? the activation type can also used by other branches / kernel calls?
and layer.activation == "silu" and layer.expert_map is None): | ||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) | ||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): | ||
if is_rocm_aiter_channel_scaled_moe_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tkw1
is not a general support of FP8 FMOE channel / rowwise scaling, it only supports the case when apply_router_weight_on_input =True
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
…E_AITER_FP8_BLOCK_SCALED_MOE and VLLM_ROCM_USE_AITER_FP8_TKW1_MOE Co-authored-by: kliuae <[email protected]> Co-authored-by: vllmellm <[email protected]> Signed-off-by: tjtanaa <[email protected]>
…E_AITER_FP8_BLOCK_SCALED_MOE and VLLM_ROCM_USE_AITER_FP8_TKW1_MOE Co-authored-by: kliuae <[email protected]> Co-authored-by: vllmellm <[email protected]> Signed-off-by: tjtanaa <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
) | ||
|
||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: | ||
# TODO: verify this code path for DeepSeekV3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we verify before landing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verified: Will remove the comment.
2025-04-18:10:35:16 INFO [loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=30000,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks | Version | Filter | n-shot | Metric | Value | Stderr | ||
---|---|---|---|---|---|---|---|---|
gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.9492 | ± | 0.006 |
strict-match | 5 | exact_match | ↑ | 0.9500 | ± | 0.006 |
Signed-off-by: tjtanaa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Just a few nits.
layer.w2_weight = torch.nn.Parameter(shuffled_w2, | ||
requires_grad=False) | ||
|
||
if self.use_rocm_aiter_moe: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you merge these into one if statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do. Thanks for pointing this out.
is_rocm_aiter_moe_enabled) | ||
|
||
# Property to determine if AITER is used | ||
self.use_rocm_aiter_moe = is_rocm_aiter_moe_enabled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Do you need to store this in the class? It doesn't look like you are using it outside of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. Updated this along with the merged if statement.
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: vllmellm <[email protected]> Signed-off-by: Frieda (Jingying) Huang <[email protected]>
Signed-off-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: vllmellm <[email protected]>
Signed-off-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: vllmellm <[email protected]>
Signed-off-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: vllmellm <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
Signed-off-by: kliuae <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: vllmellm <[email protected]> Signed-off-by: Mu Huai <[email protected]>
This PR enables aiter's tkw1 quantized MoE kernel to improve inferencing performance of compressed tensor Llama4 quantized with FP8. We have also revamped the aiter's MoE kernel dispatching to automatically choose the suitable AITER Fused MoE kernel without needing to set flags for kernel selection. Users only need to specify
VLLM_ROCM_USE_AITER=1
andVLLM_ROCM_USE_AITER_MOE=1
to activate aiter's MoE kernels, and theVLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
flag is removed.Note: torch.compile isn't supported in this PR yet, and the performance numbers are attained with V1 eager mode. The enablement of V1 torch compile for aiter MoE kernels will be addressed in a separate PR.
Llama4 Maverick FP8 throughput benchmarks
Llama4 Maverick FP8 latency benchmarks
Text Generation Response
lm_eval Results
V1 without aiter, eager mode
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=30000,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
V1 with aiter, eager mode
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=30000,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Reduce complexity of selecting AITER Fused MoE kernel
As the number of AITER Flags have increased, we have revamped the condition to pick the AITER Fused MoE kernel without the need of any flags. So
VLLM_ROCM_USE_AITER=1
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
. User only need to specifyand
VLLM_ROCM_USE_AITER_MOE=1`We have validated the code path of other models with the latest AITER fused moe selection logic:
mistralai_Mixtral-8x7B-Instruct-v0.1_V0
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=1,max_model_len=30000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
mistralai_Mixtral-8x7B-Instruct-v0.1_FP8_V0
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=1,max_model_len=30000,quantization=fp8,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
despseek-ai_DeepSeek-V3
vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=30000,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto