Skip to content

Conversation

666even666
Copy link
Contributor

@666even666 666even666 commented Aug 29, 2025

Purpose

This PR enables DP for InternVL vision encoder

FIX #23876

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@666even666 666even666 requested a review from hmellor as a code owner August 29, 2025 07:57
@mergify mergify bot added the documentation Improvements or additions to documentation label Aug 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces data parallelism for the InternVL vision encoder, a valuable optimization for multi-modal models. The implementation correctly uses a use_data_parallel flag to switch between tensor-parallel and replicated linear layers within the vision transformer's attention and MLP blocks. The forward pass for the data-parallel mode is handled by run_dp_sharded_vision_model, which correctly shards the input batch across the available GPUs. The changes are well-contained, follow existing design patterns within the vLLM codebase, and include the necessary updates to the model configuration and documentation. The implementation appears solid and I have no major concerns.

Signed-off-by: Yiwen Chen <[email protected]>
Signed-off-by: Yiwen Chen <[email protected]>
@DarkLight1337
Copy link
Member

Thanks, can you attach benchmark results to show the performance improvement?

Signed-off-by: Yiwen Chen <[email protected]>
Signed-off-by: Yiwen Chen <[email protected]>
@666even666
Copy link
Contributor Author

666even666 commented Aug 31, 2025

Thanks, can you attach benchmark results to show the performance improvement?

Hi @DarkLight1337 I tested with InternVL2-1B and dp size 2 but did not see a significant improvement. Should I try a larger dp size? Let me know if you have any insights :)

Test Plan

TP:

vllm serve OpenGVLab/InternVL2-1B     --gpu-memory-utilization 0.9     --trust-remote-code     --tensor-parallel-size 2     --host 0.0.0.0     --port 20001"

DP:

vllm serve OpenGVLab/InternVL2-1B     --gpu-memory-utilization 0.9     --trust-remote-code     --tensor-parallel-size 2     --host 0.0.0.0     --port 20001 --mm-encoder-tp-mode "data"

Benchmark:

python3 benchmark_serving.py --backend openai-chat     --base-url http://0.0.0.0:20001  --endpoint /v1/chat/completions   --dataset-name hf   --dataset-path lmarena-ai/VisionArena-Chat   --hf-split train   --num-prompts 1000 --max-concurrency 64   --model OpenGVLab/InternVL2-1B     --seed 12345

Test Result

TP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  593.97    
Total input tokens:                      91122     
Total generated tokens:                  107579    
Request throughput (req/s):              1.68      
Output token throughput (tok/s):         181.12    
Total Token throughput (tok/s):          334.53    
---------------Time to First Token----------------
Mean TTFT (ms):                          36396.88  
Median TTFT (ms):                        36344.09  
P99 TTFT (ms):                           49753.34  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.81      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           48.89     
---------------Inter-token Latency----------------
Mean ITL (ms):                           467.36    
Median ITL (ms):                         0.05      
P99 ITL (ms):                            4658.09   
==================================================

DP:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  593.08    
Total input tokens:                      91122     
Total generated tokens:                  106559    
Request throughput (req/s):              1.69      
Output token throughput (tok/s):         179.67    
Total Token throughput (tok/s):          333.31    
---------------Time to First Token----------------
Mean TTFT (ms):                          36225.52  
Median TTFT (ms):                        35993.62  
P99 TTFT (ms):                           50118.88  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.12      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           56.93     
---------------Inter-token Latency----------------
Mean ITL (ms):                           501.50    
Median ITL (ms):                         0.06      
P99 ITL (ms):                            4504.30   
==================================================


@ZJY0516
Copy link
Contributor

ZJY0516 commented Aug 31, 2025

@666even666 Could you reduce the max-concurrency value and try again?

@666even666
Copy link
Contributor Author

@666even666 Could you reduce the max-concurrency value and try again?

I lowered 'max-concurrency' to 32 and here are the results

TP(baseline)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             32        
Benchmark duration (s):                  71.24     
Total input tokens:                      91122     
Total generated tokens:                  107756    
Request throughput (req/s):              14.04     
Output token throughput (tok/s):         1512.65   
Total Token throughput (tok/s):          2791.80   
---------------Time to First Token----------------
Mean TTFT (ms):                          996.54    
Median TTFT (ms):                        645.73    
P99 TTFT (ms):                           5228.27   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.87     
Median TPOT (ms):                        10.95     
P99 TPOT (ms):                           56.80     
---------------Inter-token Latency----------------
Mean ITL (ms):                           46.46     
Median ITL (ms):                         4.09      
P99 ITL (ms):                            489.11    
==================================================

DP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             32        
Benchmark duration (s):                  68.22     
Total input tokens:                      91122     
Total generated tokens:                  107063    
Request throughput (req/s):              14.66     
Output token throughput (tok/s):         1569.45   
Total Token throughput (tok/s):          2905.22   
---------------Time to First Token----------------
Mean TTFT (ms):                          997.19    
Median TTFT (ms):                        602.48    
P99 TTFT (ms):                           5671.35   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.90     
Median TPOT (ms):                        9.95      
P99 TPOT (ms):                           60.27     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.24     
Median ITL (ms):                         5.44      
P99 ITL (ms):                            583.22    
==================================================

Total Token throughput slightly improved. I wonder is it because a tp size of 2 is too small to see any significant improvement?

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 2, 2025

Here is my benchmark result and I did not see a significant improvement too. It's a little weird that the Time per Output Token are so different.

vllm bench serve \
--backend openai-chat \
--endpoint-type openai-chat \
--endpoint /v1/chat/completions \
--model internvl \
--tokenizer /data/datasets/models-hf/InternVL2-1B/ \
--dataset-name hf \
--dataset-path /data/datasets/datasets-hf/VisionArena-Chat/ \
--hf-name lmarena-ai/VisionArena-Chat \
--hf-split train \
--num-prompts 1000 \
--max-concurrency 64 \
--seed 12345

DP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  95.58     
Total input tokens:                      91122     
Total generated tokens:                  106912    
Request throughput (req/s):              10.46     
Output token throughput (tok/s):         1118.56   
Total Token throughput (tok/s):          2071.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          2852.64   
Median TTFT (ms):                        1984.55   
P99 TTFT (ms):                           10975.64  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.94     
Median TPOT (ms):                        33.56     
P99 TPOT (ms):                           83.64     
---------------Inter-token Latency----------------
Mean ITL (ms):                           83.17     
Median ITL (ms):                         45.35     
P99 ITL (ms):                            915.63    
==================================================

TP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  140.93    
Total input tokens:                      91122     
Total generated tokens:                  107282    
Request throughput (req/s):              7.10      
Output token throughput (tok/s):         761.24    
Total Token throughput (tok/s):          1407.81   
---------------Time to First Token----------------
Mean TTFT (ms):                          2841.10   
Median TTFT (ms):                        1993.46   
P99 TTFT (ms):                           9014.42   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          61.77     
Median TPOT (ms):                        61.12     
P99 TPOT (ms):                           191.64    
---------------Inter-token Latency----------------
Mean ITL (ms):                           101.10    
Median ITL (ms):                         61.68     
P99 ITL (ms):                            727.56    
==================================================

When decreasing the max concurrency, we can see TTFT improvemet.

Maximum request concurrency: 8

DP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             8         
Benchmark duration (s):                  133.97    
Total input tokens:                      91122     
Total generated tokens:                  107498    
Request throughput (req/s):              7.46      
Output token throughput (tok/s):         802.40    
Total Token throughput (tok/s):          1482.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          371.40    
Median TTFT (ms):                        216.69    
P99 TTFT (ms):                           5281.09   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.89      
Median TPOT (ms):                        6.12      
P99 TPOT (ms):                           20.37     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.56     
Median ITL (ms):                         3.24      
P99 ITL (ms):                            92.97     
==================================================

TP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             8         
Benchmark duration (s):                  183.17    
Total input tokens:                      91122     
Total generated tokens:                  107101    
Request throughput (req/s):              5.46      
Output token throughput (tok/s):         584.71    
Total Token throughput (tok/s):          1082.19   
---------------Time to First Token----------------
Mean TTFT (ms):                          506.53    
Median TTFT (ms):                        350.34    
P99 TTFT (ms):                           3524.97   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.89      
Median TPOT (ms):                        8.37      
P99 TPOT (ms):                           56.36     
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.04     
Median ITL (ms):                         3.23      
P99 ITL (ms):                            152.18    
==================================================

Copy link

mergify bot commented Sep 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @666even666.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 2, 2025
@mergify mergify bot removed the needs-rebase label Sep 3, 2025
@666even666
Copy link
Contributor Author

666even666 commented Sep 3, 2025

Here is my benchmark result and I did not see a significant improvement too. It's a little weird that the Time per Output Token are so different.

vllm bench serve \
--backend openai-chat \
--endpoint-type openai-chat \
--endpoint /v1/chat/completions \
--model internvl \
--tokenizer /data/datasets/models-hf/InternVL2-1B/ \
--dataset-name hf \
--dataset-path /data/datasets/datasets-hf/VisionArena-Chat/ \
--hf-name lmarena-ai/VisionArena-Chat \
--hf-split train \
--num-prompts 1000 \
--max-concurrency 64 \
--seed 12345

DP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  95.58     
Total input tokens:                      91122     
Total generated tokens:                  106912    
Request throughput (req/s):              10.46     
Output token throughput (tok/s):         1118.56   
Total Token throughput (tok/s):          2071.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          2852.64   
Median TTFT (ms):                        1984.55   
P99 TTFT (ms):                           10975.64  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.94     
Median TPOT (ms):                        33.56     
P99 TPOT (ms):                           83.64     
---------------Inter-token Latency----------------
Mean ITL (ms):                           83.17     
Median ITL (ms):                         45.35     
P99 ITL (ms):                            915.63    
==================================================

TP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  140.93    
Total input tokens:                      91122     
Total generated tokens:                  107282    
Request throughput (req/s):              7.10      
Output token throughput (tok/s):         761.24    
Total Token throughput (tok/s):          1407.81   
---------------Time to First Token----------------
Mean TTFT (ms):                          2841.10   
Median TTFT (ms):                        1993.46   
P99 TTFT (ms):                           9014.42   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          61.77     
Median TPOT (ms):                        61.12     
P99 TPOT (ms):                           191.64    
---------------Inter-token Latency----------------
Mean ITL (ms):                           101.10    
Median ITL (ms):                         61.68     
P99 ITL (ms):                            727.56    
==================================================

When decreasing the max concurrency, we can see TTFT improvemet.

Maximum request concurrency: 8 DP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             8         
Benchmark duration (s):                  133.97    
Total input tokens:                      91122     
Total generated tokens:                  107498    
Request throughput (req/s):              7.46      
Output token throughput (tok/s):         802.40    
Total Token throughput (tok/s):          1482.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          371.40    
Median TTFT (ms):                        216.69    
P99 TTFT (ms):                           5281.09   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.89      
Median TPOT (ms):                        6.12      
P99 TPOT (ms):                           20.37     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.56     
Median ITL (ms):                         3.24      
P99 ITL (ms):                            92.97     
==================================================

TP

============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             8         
Benchmark duration (s):                  183.17    
Total input tokens:                      91122     
Total generated tokens:                  107101    
Request throughput (req/s):              5.46      
Output token throughput (tok/s):         584.71    
Total Token throughput (tok/s):          1082.19   
---------------Time to First Token----------------
Mean TTFT (ms):                          506.53    
Median TTFT (ms):                        350.34    
P99 TTFT (ms):                           3524.97   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.89      
Median TPOT (ms):                        8.37      
P99 TPOT (ms):                           56.36     
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.04     
Median ITL (ms):                         3.23      
P99 ITL (ms):                            152.18    
==================================================

@ZJY0516 Thanks a lot for running the test!! This looks much better than what I had with dp size of 2. Looks like the performance improvement is aligned with other multimodal models (#23168). @DarkLight1337 If the changes look good to you, we can probably merge it?

BTW @ZJY0516 what's the dp size you are using?

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 3, 2025

dp = 2. It seems GitHub's renderer isn't displaying the content inside the blockquotes correctly. @666even666

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 3, 2025

I will try to profile this today.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 3, 2025

Benchmark Settings

vllm serve /data/datasets/models-hf/InternVL2-1B/ --served-model-name internvl --trust-remote-code -tp 2 --mm-encoder-tp-mode "data"

vllm bench serve \
--backend openai-chat \
--endpoint-type openai-chat \
--endpoint /v1/chat/completions \
--model internvl \
--tokenizer /data/datasets/models-hf/InternVL2-1B/ \
--dataset-name hf \
--dataset-path /data/datasets/datasets-hf/VisionArena-Chat/ \
--hf-name lmarena-ai/VisionArena-Chat \
--hf-split train \
--num-prompts 128 \
--max-concurrency 64 \
--seed 12345

Result

encoder dp

============ Serving Benchmark Result ============
Successful requests:                     128       
Maximum request concurrency:             64        
Benchmark duration (s):                  14.96     
Total input tokens:                      12855     
Total generated tokens:                  13979     
Request throughput (req/s):              8.56      
Output token throughput (tok/s):         934.68    
Total Token throughput (tok/s):          1794.20   
---------------Time to First Token----------------
Mean TTFT (ms):                          3863.60   
Median TTFT (ms):                        2970.99   
P99 TTFT (ms):                           10435.49  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          36.93     
Median TPOT (ms):                        29.94     
P99 TPOT (ms):                           161.00    
---------------Inter-token Latency----------------
Mean ITL (ms):                           67.99     
Median ITL (ms):                         8.72      
P99 ITL (ms):                            491.97    
==================================================

normal

============ Serving Benchmark Result ============
Successful requests:                     128       
Maximum request concurrency:             64        
Benchmark duration (s):                  18.23     
Total input tokens:                      12855     
Total generated tokens:                  14216     
Request throughput (req/s):              7.02      
Output token throughput (tok/s):         779.73    
Total Token throughput (tok/s):          1484.81   
---------------Time to First Token----------------
Mean TTFT (ms):                          3850.05   
Median TTFT (ms):                        4559.03   
P99 TTFT (ms):                           9275.80   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          45.68     
Median TPOT (ms):                        43.43     
P99 TPOT (ms):                           121.66    
---------------Inter-token Latency----------------
Mean ITL (ms):                           68.14     
Median ITL (ms):                         25.66     
P99 ITL (ms):                            306.18    
==================================================

The profile result is a little weird.
First encoder dp mode truly have performance improvement and better Median TTFT(4559.03 vs 2970.99).
However, it performs worse in P99 TTFT. What I find puzzling is why we observe improvement at a larger scale (num-prompts = 100).

According profile data below, encoder dp mode have better performane in language model forward.

Additionally, my benchmark results indicate that encoder DP mode leads to an end-to-end speedup (i.e., a decrease in benchmark duration), which contradicts the findings from @666even666 . Could you verify it? @666even666

CC @DarkLight1337 @ywang96

vision model part

dp

Name
/home/zjy/code/vllm-src/vllm/model_executor/models/internvl.py(1299): get_multimodal_embeddings
Category
python_function
Start time
00:00:00.314140142
Duration
19ms 575µs 812ns

Name
nn.Module: InternParallelAttention_5
Category
python_function
Start time
00:00:00.320102836
Duration
331µs 936ns

Name
nn.Module: InternMLP_5
Category
python_function
Start time
00:00:00.320507377
Duration
101µs 979ns

tp

Name
/home/zjy/code/vllm-src/vllm/model_executor/models/internvl.py(1299): get_multimodal_embeddings
Category
python_function
Start time
00:00:00.366545328
Duration
22ms 963µs 501ns

Name
nn.Module: InternParallelAttention_5
Category
python_function
Start time
00:00:00.372874457
Duration
420µs 999ns

Name
nn.Module: InternMLP_5
Category
python_function
Start time
00:00:00.373366766
Duration
173µs 110ns

language part

dp

# some are slow
Name
nn.Module: InternVLChatModel_0
Category
python_function
Start time
00:00:01.825283774
Duration
16ms 528µs 300ns
# some are fast
Name
nn.Module: InternVLChatModel_0
Category
python_function
Start time
00:00:02.160551260
Duration
4ms 665µs 84ns

tp

# some are slow
Name
nn.Module: InternVLChatModel_0
Category
python_function
Start time
00:00:04.275143076
Duration
22ms 845µs 40ns
# some are fast
Name
nn.Module: InternVLChatModel_0
Category
python_function
Start time
00:00:04.625940359
Duration
6ms 606µs 901ns

@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 3, 2025

Can you try increasing the number of prompts? Usually I use 500 or 1000. I think 128 might not be enough to get reliable results.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 3, 2025

The profile data will be quite large, but I'll try tomorrow

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 3, 2025

Random-mm dataset: each request has 3 images.

command

vllm bench serve \
--endpoint-type openai-chat \
--endpoint /v1/chat/completions \
--model internvl \
--dataset-name random-mm \
--random-mm-base-items-per-request 3 \
--tokenizer /data/datasets/models-hf/InternVL2-1B/ --trust-remote-code \
--num-prompts 1024 \
--max-concurrency 64

dp

============ Serving Benchmark Result ============
Successful requests:                     1024      
Maximum request concurrency:             64        
Benchmark duration (s):                  549.43    
Total input tokens:                      1045858   
Total generated tokens:                  57097     
Request throughput (req/s):              1.86      
Output token throughput (tok/s):         103.92    
Total Token throughput (tok/s):          2007.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          33417.88  
Median TTFT (ms):                        34372.67  
P99 TTFT (ms):                           47424.23  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.20      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           1.96      
---------------Inter-token Latency----------------
Mean ITL (ms):                           18.85     
Median ITL (ms):                         0.02      
P99 ITL (ms):                            71.55     
==================================================

tp

============ Serving Benchmark Result ============
Successful requests:                     1024      
Maximum request concurrency:             64        
Benchmark duration (s):                  547.59    
Total input tokens:                      1045858   
Total generated tokens:                  57052     
Request throughput (req/s):              1.87      
Output token throughput (tok/s):         104.19    
Total Token throughput (tok/s):          2014.12   
---------------Time to First Token----------------
Mean TTFT (ms):                          33228.48  
Median TTFT (ms):                        33879.10  
P99 TTFT (ms):                           44226.49  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1.50      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           63.73     
---------------Inter-token Latency----------------
Mean ITL (ms):                           107.25    
Median ITL (ms):                         0.02      
P99 ITL (ms):                            4806.17   
==================================================

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 4, 2025

Can you try increasing the number of prompts? Usually I use 500 or 1000. I think 128 might not be enough to get reliable results.

However, shouldn't the execution time of each operator remain consistent regardless of the number of prompts? The benchmark results I obtained with 128 prompts are similar with those from 1024 prompts.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 4, 2025

A more detailed benchmark result

lmarena-ai/VisionArena-Chat

vllm bench serve \
--backend openai-chat \
--endpoint-type openai-chat \
--endpoint /v1/chat/completions \
--model internvl \
--tokenizer /data/datasets/models-hf/InternVL2-1B/ \
--dataset-name hf \
--dataset-path /data/datasets/datasets-hf/VisionArena-Chat/ \
--hf-name lmarena-ai/VisionArena-Chat \
--hf-split train \
--num-prompts 1000 \
--max-concurrency 64 \
--seed 12345 \
--percentile-metrics ttft \
--metric-percentiles "10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,99"

encoder dp

Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  98.81     
Total input tokens:                      91122     
Total generated tokens:                  107034    
Request throughput (req/s):              10.12     
Output token throughput (tok/s):         1083.20   
Total Token throughput (tok/s):          2005.37

encder tp

Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  135.96    
Total input tokens:                      91122     
Total generated tokens:                  106727    
Request throughput (req/s):              7.36      
Output token throughput (tok/s):         784.98    
Total Token throughput (tok/s):          1455.18
ttft_comparison_tp_dp_mode_VisionArena-Chat

random-mm: 3 image

encoder tp

Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  531.70    
Total input tokens:                      1021282   
Total generated tokens:                  55629     
Request throughput (req/s):              1.88      
Output token throughput (tok/s):         104.62    
Total Token throughput (tok/s):          2025.39

encoder dp mode

Successful requests:                     1000      
Maximum request concurrency:             64        
Benchmark duration (s):                  539.03    
Total input tokens:                      1021282   
Total generated tokens:                  55877     
Request throughput (req/s):              1.86      
Output token throughput (tok/s):         103.66    
Total Token throughput (tok/s):          1998.34
ttft_comparison_tp_dp_mode_random-mm

@ywang96
Copy link
Member

ywang96 commented Sep 4, 2025

@ywang96 Thanks for replying. I am not familiar with vLLM scheduler policy. How to determine max QPS for a vLLM instance?

From the results @ZJY0516
Request throughput (req/s): 1.88

This is not related to scheduler policy actually, but more of the benchmark script. If you don't set --request-rate but --max-concurrency 64, the benchmark client will simply keep sending requests and max out the 64 concurrency to the server, but obviously not all requests are being processed by the server and some will have to wait in the queue.

I suggest running this benchmark without --max-concurrency 64 but simply set --request-rate to something like 1.8

@ZJY0516
Copy link
Contributor

ZJY0516 commented Sep 4, 2025

You are right
Here is the benchmark result of 1.8 request rate

encoder dp

============ Serving Benchmark Result ============
Successful requests:                     100       
Request rate configured (RPS):           1.80      
Benchmark duration (s):                  56.22     
Total input tokens:                      11514     
Total generated tokens:                  11068     
Request throughput (req/s):              1.78      
Output token throughput (tok/s):         196.86    
Total Token throughput (tok/s):          401.66    
---------------Time to First Token----------------
Mean TTFT (ms):                          328.13    
Median TTFT (ms):                        130.79    
P99 TTFT (ms):                           2294.19   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          3.59      
Median TPOT (ms):                        3.19      
P99 TPOT (ms):                           7.10      
---------------Inter-token Latency----------------
Mean ITL (ms):                           3.88      
Median ITL (ms):                         2.98      
P99 ITL (ms):                            42.90     
==================================================

encoder tp

============ Serving Benchmark Result ============
Successful requests:                     100       
Request rate configured (RPS):           1.80      
Benchmark duration (s):                  56.41     
Total input tokens:                      11514     
Total generated tokens:                  10901     
Request throughput (req/s):              1.77      
Output token throughput (tok/s):         193.24    
Total Token throughput (tok/s):          397.34    
---------------Time to First Token----------------
Mean TTFT (ms):                          441.91    
Median TTFT (ms):                        186.23    
P99 TTFT (ms):                           2849.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.84      
Median TPOT (ms):                        3.35      
P99 TPOT (ms):                           27.16     
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.93      
Median ITL (ms):                         2.94      
P99 ITL (ms):                            56.81     
==================================================

@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 4, 2025

Oh, I usually just set infinite QPS for the benchmark.

@ywang96
Copy link
Member

ywang96 commented Sep 4, 2025

@666even666 Could you run some eval benchmarks to make sure this change does not affect model quality?

@gongshanchong
Copy link

I am also implementing data parallelism for InternVL. My code is similar to yours, but I found that your code seems to be missing one content, which is this processing def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = run_dp_sharded_vision_model(pixel_values, self.vision_model) if self.use_data_parallel else self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :]. When my tp is also 2, I added this data slicing processing, and the performance improvement seems to be more obvious.

@666even666
Copy link
Contributor Author

I am also implementing data parallelism for InternVL. My code is similar to yours, but I found that your code seems to be missing one content, which is this processing def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = run_dp_sharded_vision_model(pixel_values, self.vision_model) if self.use_data_parallel else self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :]. When my tp is also 2, I added this data slicing processing, and the performance improvement seems to be more obvious.

Interesting! thanks for letting us know. I put the same logic under forward() of InternVisionModel https://github.com/vllm-project/vllm/pull/23909/files#diff-2df945bfeea35cb93797dc3b4f8a393a27b06de6d60e2f9b0334920d40721e10R506, which should be executed every time self.vision_model() is called. @DarkLight1337 @ZJY0516 Do you think this could cause the performance difference?

@gongshanchong
Copy link

I tested InternVL3-8B on L20 with reference to (https://github.com/mistralai/mistral-evals), and this is my result:
---------dp2--------------
"explicit_prompt_relaxed_correctness": 0.5711111111111111,
"anywhere_in_answer_relaxed_correctness": 0.5733333333333334

---------tp2--------------
"explicit_prompt_relaxed_correctness": 0.5744444444444444,
"anywhere_in_answer_relaxed_correctness": 0.5766666666666667

@DarkLight1337
Copy link
Member

Can you run the performance benchmark with infinite QPS and show the output as well?

@666even666
Copy link
Contributor Author

@666even666 Could you run some eval benchmarks to make sure this change does not affect model quality?

Sorry for the delay! Here is the result of InternVL3-8B tested against chartqa

TP
chartqa

DP
dp

@gongshanchong
Copy link

gongshanchong commented Sep 9, 2025

@666even666 Could you run some eval benchmarks to make sure this change does not affect model quality?

Sorry for the delay! Here is the result of InternVL3-8B tested against chartqa

hello,Did you get this result by running that program file?

@666even666
Copy link
Contributor Author

yeah the numbers are exactly the same. I will double check today.

Signed-off-by: Yiwen Chen <[email protected]>
auto-merge was automatically disabled September 10, 2025 04:47

Head branch was pushed to by a user without write access

@666even666
Copy link
Contributor Author

666even666 commented Sep 10, 2025

I double checked that dp was enabled
logging and still got the same result
dp

@DarkLight1337
Copy link
Member

Sorry for the delay, let me merge this

@vllm-bot vllm-bot merged commit 52bc9d5 into vllm-project:main Sep 18, 2025
37 of 41 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 18, 2025
…litPR into model_register

* 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits)
  Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085)
  [Docs] Fix API Reference (vllm-project#25140)
  [Kernel] Better inf handling for grouped topk cu (vllm-project#24886)
  [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769)
  [benchmark] add peak throughput metrics and plot (vllm-project#23867)
  [Spec Decode] Efficient padded speculation (vllm-project#24539)
  [V0 Deprecation] Remove more V0 tests (vllm-project#25117)
  [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078)
  [XPU] Whisper model support on XPU Platform (vllm-project#25123)
  Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077)
  [Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
  [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
  [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)
  [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006)
  [Docs] Clean up the contributing README (vllm-project#25099)
  [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955)
  [Kernels] Enable DeepGEMM by default (vllm-project#24462)
  [V0 Deprecation] Skip PP test (vllm-project#25128)
  [V0 Deprecation] Remove misc V0 tests (vllm-project#25118)
  [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115)
  ...
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…t#23909)

Signed-off-by: Yiwen Chen <[email protected]>
Signed-off-by: YiwenC <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: charlifu <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…t#23909)

Signed-off-by: Yiwen Chen <[email protected]>
Signed-off-by: YiwenC <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MM Encoder] Add Encoder DP to InternVL

6 participants