Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 13, 2025

Based on these calculations:

https://docs.google.com/spreadsheets/d/17eoqEbhblvtNsRRlFSjCQnEXZiBxtLgZGKD4IgZUz38/edit?usp=sharing

It's actually better to just not materialize the absorbed W_Q_UK and W_UV_O as it reduces memory usage (and total flops) and instead compute using sequential matmuls. One issue is that we do not have an FP8 bmm (which is needed if not materializing the absorbed matrix, materializing absorbing allowed us to bypass this), so we instead decompress the matrices involved in the bmm to fp16/bf16. This also has the added benefit of dramatically reducing complexity.

This PR is needed for DP attention since without it the weight materialization eats up too much of the GPU memory to make DP beneficial.

This PR (minor regression in short context but seems worth it given the saved memory boosts long-context and enables DP-attention, also the short context measurements are a bit noisy)

  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000    1323.397915  1.323398      29.755860     2307.676603
2    vllm          5000           1000    1041.455043  1.041455      33.205620     5491.457423
4    vllm         10000           1000     874.563079  0.874563      36.871404     8508.498527
1    vllm         32000           1000     190.195698  0.190196      35.948243   108055.433401

Baseline (#14769)

  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000    1380.973383  1.380973      30.029079     2223.776098
2    vllm          5000           1000    1047.717832  1.047718      33.303024     5499.557093
4    vllm         10000           1000     586.460476  0.586460      36.855329     8512.637936
1    vllm         32000           1000     162.816157  0.162816      42.981104   115515.806204

correctness tests:

VLLM_USE_V1=1  vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

lm_eval --model local-completions --tasks gsm8k --model_args model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_request
s=False --limit 100

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  | 0.96|±  |0.0197|
VLLM_USE_V1=0 vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

lm_eval --model local-completions --tasks gsm8k --model_args model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_request
s=False --limit 100

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  | 0.96|±  |0.0197|

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 13, 2025
@LucasWilkinson LucasWilkinson changed the title MLA get rid of materialize [Attention] MLA get rid of materialize Mar 13, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/get-rid-of-materialize branch from d02b6b7 to 98cdf57 Compare March 13, 2025 21:37
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson changed the title [Attention] MLA get rid of materialize [Attention] MLA get rid of materialization Mar 13, 2025
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense for the memory savings

@simon-mo simon-mo enabled auto-merge (squash) March 14, 2025 03:54
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2025
@simon-mo simon-mo added this to the v0.8.0 milestone Mar 14, 2025
@vllm-bot vllm-bot merged commit 9532c49 into main Mar 14, 2025
51 of 55 checks passed
@vllm-bot vllm-bot deleted the lwilkinson/get-rid-of-materialize branch March 14, 2025 06:39
jikunshang added a commit to jikunshang/vllm that referenced this pull request Mar 14, 2025
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
@erictanjn
Copy link

Based on these calculations:

https://docs.google.com/spreadsheets/d/17eoqEbhblvtNsRRlFSjCQnEXZiBxtLgZGKD4IgZUz38/edit?usp=sharing

It's actually better to just not materialize the absorbed W_Q_UK and W_UV_O as it reduces memory usage (and total flops) and instead compute using sequential matmuls. One issue is that we do not have an FP8 bmm (which is needed if not materializing the absorbed matrix, materializing absorbing allowed us to bypass this), so we instead decompress the matrices involved in the bmm to fp16/bf16. This also has the added benefit of dramatically reducing complexity.

This PR is needed for DP attention since without it the weight materialization eats up too much of the GPU memory to make DP beneficial.

This PR (minor regression in short context but seems worth it given the saved memory boosts long-context and enables DP-attention, also the short context measurements are a bit noisy)

  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000    1323.397915  1.323398      29.755860     2307.676603
2    vllm          5000           1000    1041.455043  1.041455      33.205620     5491.457423
4    vllm         10000           1000     874.563079  0.874563      36.871404     8508.498527
1    vllm         32000           1000     190.195698  0.190196      35.948243   108055.433401

Baseline (#14769)

  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000    1380.973383  1.380973      30.029079     2223.776098
2    vllm          5000           1000    1047.717832  1.047718      33.303024     5499.557093
4    vllm         10000           1000     586.460476  0.586460      36.855329     8512.637936
1    vllm         32000           1000     162.816157  0.162816      42.981104   115515.806204

correctness tests:

VLLM_USE_V1=1  vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

lm_eval --model local-completions --tasks gsm8k --model_args model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_request
s=False --limit 100

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  | 0.96|±  |0.0197|
VLLM_USE_V1=0 vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

lm_eval --model local-completions --tasks gsm8k --model_args model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_request
s=False --limit 100

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  | 0.96|±  |0.0197|

hello,may I know what the parallel configuration of the baseline and ur test is like. I am also exploring the comparison before and after enabling matrix absorption preconditioning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants