-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm][FEAT] Support AITER RMSNorm quantization fusion pass #26575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Signed-off-by: vllmellm <[email protected]>
… error for other platforms Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
This PR supports fusion pass for ROCM AITER by fusing
+rms_norm, aiter rmsnorm ops, and+quant_fp8, vllm quantization custom ops.Benchmark Result
benchmark setting
vllm bench serve \ --backend vllm \ --model "RedHatAI/Qwen3-14B-FP8-dynamic" \ --dataset-name random \ --num-prompts 500 \ --random-input-len 1000 \ --random-output-len 1000 \ --endpoint /v1/completions \ --random-range-ratio 0.9 \IMPORTANT NOTE
use
--compilation-config '{"pass_config": {"enable_fusion": true, "enable_noop": true, "enable-attn-fusion": false} \, "custom_ops": ["+rms_norm", "+quant_fp8"]to enable fusion pass.Test Plan
vllm/tests/compile/test_rocm_aiter_fusion.pythat verifies accuracy, replacement of the ops in cuda graph.RedHatAI/Qwen3-14B-FP8-dynamicmodelenvironment setting
Step 1: run vllm serve
VLLM_ROCM_USE_AITER=1 \ VLLM_USE_V1=1 \ vllm serve RedHatAI/Qwen3-14B-FP8-dynamic \ --compilation-config '{"pass_config": {"enable_fusion": true, "enable_noop": true, "enable-attn-fusion": false} \, "custom_ops": ["+rms_norm", "+quant_fp8"], "cudagraph_capture_sizes": [1,2,4,8,16,24,32,256]}' \ --port 9090 \ --trust-remote-code --swap-space 16 --distributed-executor-backend mpStep 2: run lm_eval
lm_eval --model local-completions --tasks gsm8k \ --model_args model=RedHatAI/Qwen3-14B-FP8-dynamic,base_url=http://localhost:9090/v1/completions \ --trust_remote_code \ --num_fewshot 5 \ --batch_size 128Test Results
RedHatAI/Qwen3-14B-FP8-dynamic fusion pass
RedHatAI/Qwen3-14B-FP8-dynamic without fusion pass
Unit test result
INFO 10-10 08:39:08 [init.py:224] Automatically detected platform rocm.
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /app/norm/vllm
configfile: pyproject.toml
plugins: anyio-4.10.0, asyncio-1.2.0
asyncio: mode=strict, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... WARNING 10-10 08:39:11 [interface.py:518] Current platform cuda does not have 'test' attribute.
WARNING 10-10 08:39:11 [interface.py:518] Current platform cuda does not have 'bases' attribute.
WARNING 10-10 08:39:11 [interface.py:518] Current platform cuda does not have 'test' attribute.
collected 2 items
compile/test_rocm_aiter_fusion.py::test_fusion_rmsnorm_quant[1e-05-257-64-dtype0] Matched count: 2
PASSED
compile/test_rocm_aiter_fusion.py::test_fusion_rmsnorm_quant[1e-06-257-64-dtype0] Matched count: 2
PASSED
======================== 2 passed, 2 warnings in 25.65s ========================
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.