Skip to content

Conversation

alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Jul 25, 2024

FIX #5793
FIX #6258

This PR increases the precision of Marlin kernel by modifying its "global_reduce" algorithm to use FP32 full-precision reduction, instead of the FP16 half-precision reduction that was used originally. We were able to implement the new FP32 global reduce efficiently, so that it introduces negligible overhead vs the original FP16 reduce.

The key idea is to introduce a temporary FP32 C buffer for the FP32 reduction (and not use the original FP16 C buffer as before). The new FP32 C buffer is limited in size based on the batch size (M dimension) and the potential "max_par" that can be achieved for each specific execution. Internally, each kernel thread-block detects on which thread-column-block it operates, and based on that accesses the appropriate chunk of the temporary C buffer in a fully thread-aligned way (to avoid any bank conflicts or non-contiguous memory reads/stores).

Here are micro-benchmark results for the gptq_marlin_gemm_fp16 vs gptq_marlin_gemm_fp32 (compared vs pytorch_gemm):

image

End-to-end performance verification on A100 shows max 5% penalty for 8b llama3 and no-penalty for 70b llama3.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@alexm-redhat
Copy link
Collaborator Author

/ready

@casper-hansen
Copy link
Contributor

Do you have any data showcasing that this fixes the described accuracy issues? I'm guessing you could look at the examples referenced in the issue? And potentially run an evaluation of perplexity?

@alexm-redhat
Copy link
Collaborator Author

@casper-hansen We had run internally a test on A10 for GSM dataset to verify accuracy, and we saw that with the old fp16 reduce the total accuracy is 58% and with the new fp32 reduce the accuracy is 73% (as it is supposed to be). Also, in the unit tests, we saw that the max difference is improved from e-3 to e-6 (almost double basically)

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Alex, I'm glad it didn't require paring down split-k or any drastic changes.

Do you think this should be added to fp8 marlin as well? We could also just wait for @LucasWilkinson 's type refactor

@alexm-redhat
Copy link
Collaborator Author

I think it is not critical to add to fp8 marlin, since we did not had so much reports of accuracy issues like we had for GPTQ and especially AWQ (which is even more sensitive than GPTQ).

@mgoin mgoin merged commit 75acdaa into vllm-project:main Jul 27, 2024
@mgoin mgoin deleted the marlin_high_precision branch July 27, 2024 21:52
danieldk added a commit to huggingface/text-generation-inference that referenced this pull request Jul 29, 2024
Sync with upstream change that improves the precision of the
'global_reduce' algorithm from FP16 to FP32. This solves some
reported generation quality issues.

Upstream issue/PR:

vllm-project/vllm#6795
tjohnson31415 added a commit to tjohnson31415/vllm that referenced this pull request Jul 30, 2024
* upstream/main: (66 commits)
  [Bugfix] Fix PaliGemma MMP (vllm-project#6930)
  [TPU] Fix greedy decoding (vllm-project#6933)
  [Kernel] Tuned int8 kernels for Ada Lovelace (vllm-project#6848)
  [Kernel] Fix marlin divide-by-zero warnings (vllm-project#6904)
  [ci] GHA workflow to remove ready label upon "/notready" comment (vllm-project#6921)
  [Kernel] Remove unused variables in awq/gemm_kernels.cu (vllm-project#6908)
  [Frontend] New `allowed_token_ids` decoding request parameter (vllm-project#6753)
  [Bugfix] Allow vllm to still work if triton is not installed. (vllm-project#6786)
  [TPU] Support tensor parallelism in async llm engine (vllm-project#6891)
  [Kernel] Fix deprecation function warnings squeezellm quant_cuda_kernel (vllm-project#6901)
  [Core] Reduce unnecessary compute when logprobs=None (vllm-project#6532)
  [Kernel] Tuned FP8 Kernels for Ada Lovelace (vllm-project#6677)
  [Model] Initialize support for InternVL2 series models (vllm-project#6514)
  [Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (vllm-project#6871)
  Add Nemotron to PP_SUPPORTED_MODELS (vllm-project#6863)
  [Kernel] Increase precision of GPTQ/AWQ Marlin kernel (vllm-project#6795)
  [TPU] Reduce compilation time & Upgrade PyTorch XLA version  (vllm-project#6856)
  [Docs] Add RunLLM chat widget (vllm-project#6857)
  [Model] Initial support for BLIP-2 (vllm-project#5920)
  [CI/Build][Doc] Update CI and Doc for VLM example changes (vllm-project#6860)
  ...
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

3 participants