-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Kernel] Increase precision of GPTQ/AWQ Marlin kernel #6795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
/ready |
Do you have any data showcasing that this fixes the described accuracy issues? I'm guessing you could look at the examples referenced in the issue? And potentially run an evaluation of perplexity? |
@casper-hansen We had run internally a test on A10 for GSM dataset to verify accuracy, and we saw that with the old fp16 reduce the total accuracy is 58% and with the new fp32 reduce the accuracy is 73% (as it is supposed to be). Also, in the unit tests, we saw that the max difference is improved from e-3 to e-6 (almost double basically) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work Alex, I'm glad it didn't require paring down split-k or any drastic changes.
Do you think this should be added to fp8 marlin as well? We could also just wait for @LucasWilkinson 's type refactor
I think it is not critical to add to fp8 marlin, since we did not had so much reports of accuracy issues like we had for GPTQ and especially AWQ (which is even more sensitive than GPTQ). |
Sync with upstream change that improves the precision of the 'global_reduce' algorithm from FP16 to FP32. This solves some reported generation quality issues. Upstream issue/PR: vllm-project/vllm#6795
* upstream/main: (66 commits) [Bugfix] Fix PaliGemma MMP (vllm-project#6930) [TPU] Fix greedy decoding (vllm-project#6933) [Kernel] Tuned int8 kernels for Ada Lovelace (vllm-project#6848) [Kernel] Fix marlin divide-by-zero warnings (vllm-project#6904) [ci] GHA workflow to remove ready label upon "/notready" comment (vllm-project#6921) [Kernel] Remove unused variables in awq/gemm_kernels.cu (vllm-project#6908) [Frontend] New `allowed_token_ids` decoding request parameter (vllm-project#6753) [Bugfix] Allow vllm to still work if triton is not installed. (vllm-project#6786) [TPU] Support tensor parallelism in async llm engine (vllm-project#6891) [Kernel] Fix deprecation function warnings squeezellm quant_cuda_kernel (vllm-project#6901) [Core] Reduce unnecessary compute when logprobs=None (vllm-project#6532) [Kernel] Tuned FP8 Kernels for Ada Lovelace (vllm-project#6677) [Model] Initialize support for InternVL2 series models (vllm-project#6514) [Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (vllm-project#6871) Add Nemotron to PP_SUPPORTED_MODELS (vllm-project#6863) [Kernel] Increase precision of GPTQ/AWQ Marlin kernel (vllm-project#6795) [TPU] Reduce compilation time & Upgrade PyTorch XLA version (vllm-project#6856) [Docs] Add RunLLM chat widget (vllm-project#6857) [Model] Initial support for BLIP-2 (vllm-project#5920) [CI/Build][Doc] Update CI and Doc for VLM example changes (vllm-project#6860) ...
) Signed-off-by: Alvant <[email protected]>
) Signed-off-by: LeiWang1999 <[email protected]>
FIX #5793
FIX #6258
This PR increases the precision of Marlin kernel by modifying its "global_reduce" algorithm to use FP32 full-precision reduction, instead of the FP16 half-precision reduction that was used originally. We were able to implement the new FP32 global reduce efficiently, so that it introduces negligible overhead vs the original FP16 reduce.
The key idea is to introduce a temporary FP32 C buffer for the FP32 reduction (and not use the original FP16 C buffer as before). The new FP32 C buffer is limited in size based on the batch size (M dimension) and the potential "max_par" that can be achieved for each specific execution. Internally, each kernel thread-block detects on which thread-column-block it operates, and based on that accesses the appropriate chunk of the temporary C buffer in a fully thread-aligned way (to avoid any bank conflicts or non-contiguous memory reads/stores).
Here are micro-benchmark results for the gptq_marlin_gemm_fp16 vs gptq_marlin_gemm_fp32 (compared vs pytorch_gemm):
End-to-end performance verification on A100 shows max 5% penalty for 8b llama3 and no-penalty for 70b llama3.