- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[Model] Pooling model activation supports per request control by PoolingParams #20538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            37 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      42352ba
              
                +test
              
              
                noooop 3f887ab
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop f27031c
              
                + using_normalize
              
              
                noooop 593df3e
              
                + using_activation
              
              
                noooop 3473b05
              
                conflicts
              
              
                noooop b90f8a7
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop 8c5f744
              
                + pooling_params
              
              
                noooop e3bc35a
              
                fix
              
              
                noooop 1a370b5
              
                fix
              
              
                noooop 9d75628
              
                + test_pooling_params.py
              
              
                noooop 5a131d1
              
                + merge_default_parameters
              
              
                noooop 5385b76
              
                Remove unnecessary changes
              
              
                noooop 1946596
              
                Remove unnecessary changes
              
              
                noooop beead6b
              
                fix
              
              
                noooop 32e4533
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop 889570a
              
                fix
              
              
                noooop fa1367e
              
                mypy
              
              
                noooop efcf72e
              
                fix
              
              
                noooop 8526e2a
              
                + test_reward_models_using_softmax
              
              
                noooop f644bc2
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop 2ab4d55
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop 684a2d9
              
                + test_reward
              
              
                noooop d5e30e6
              
                - default_normalize & default_softmax
              
              
                noooop cfa1a3d
              
                + JambaForSequenceClassificationConfig
              
              
                noooop d0488e7
              
                fix
              
              
                noooop 6668c47
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop 5274e2f
              
                fix
              
              
                noooop 2ffa834
              
                - merge_default_parameters
              
              
                noooop 9e69222
              
                fix
              
              
                noooop bd83ada
              
                fix
              
              
                noooop dab5b55
              
                fix
              
              
                noooop 9129093
              
                fix
              
              
                noooop 0973e6b
              
                fix
              
              
                noooop f0d6190
              
                using tomaarsen/Qwen3-Reranker-0.6B-seq-cls
              
              
                noooop e55a342
              
                fix
              
              
                noooop b3624e1
              
                ci bug ?
              
              
                noooop a42938c
              
                Merge branch 'vllm-project:main' into pooler_config
              
              
                noooop File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| 
     | 
||
| import weakref | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| 
     | 
||
| from vllm import LLM, PoolingParams | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| 
     | 
||
| from ...models.utils import softmax | ||
| 
     | 
||
| MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" | ||
| 
     | 
||
| prompts = ["The chef prepared a delicious meal."] | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(autouse=True) | ||
| def v1(run_with_both_engines): | ||
| # Simple autouse wrapper to run both engines for each test | ||
| # This can be promoted up to conftest.py to run for every | ||
| # test in a package | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def llm(): | ||
| # pytest caches the fixture so we use weakref.proxy to | ||
| # enable garbage collection | ||
| llm = LLM(model=MODEL_NAME, | ||
| max_num_batched_tokens=32768, | ||
| tensor_parallel_size=1, | ||
| gpu_memory_utilization=0.75, | ||
| enforce_eager=True, | ||
| seed=0) | ||
| 
     | 
||
| with llm.deprecate_legacy_api(): | ||
| yield weakref.proxy(llm) | ||
| 
     | 
||
| del llm | ||
| 
     | 
||
| cleanup_dist_env_and_memory() | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skip_global_cleanup | ||
| def test_pooling_params(llm: LLM): | ||
| 
     | 
||
| def get_outputs(activation): | ||
| outputs = llm.classify( | ||
| prompts, | ||
| pooling_params=PoolingParams(activation=activation), | ||
| use_tqdm=False) | ||
| return torch.tensor([x.outputs.probs for x in outputs]) | ||
| 
     | 
||
| default = get_outputs(activation=None) | ||
| w_activation = get_outputs(activation=True) | ||
| wo_activation = get_outputs(activation=False) | ||
| 
     | 
||
| assert torch.allclose(default, w_activation, | ||
| atol=1e-2), "Default should use activation." | ||
| assert not torch.allclose( | ||
| w_activation, wo_activation, | ||
| atol=1e-2), "wo_activation should not use activation." | ||
| assert torch.allclose( | ||
| softmax(wo_activation), w_activation, atol=1e-2 | ||
| ), "w_activation should be close to activation(wo_activation)." | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| 
     | 
||
| import weakref | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| 
     | 
||
| from vllm import LLM, PoolingParams | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| 
     | 
||
| MODEL_NAME = "intfloat/multilingual-e5-small" | ||
| 
     | 
||
| prompts = ["The chef prepared a delicious meal."] | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def llm(): | ||
| # pytest caches the fixture so we use weakref.proxy to | ||
| # enable garbage collection | ||
| llm = LLM(model=MODEL_NAME, | ||
| max_num_batched_tokens=32768, | ||
| tensor_parallel_size=1, | ||
| gpu_memory_utilization=0.75, | ||
| enforce_eager=True, | ||
| seed=0) | ||
| 
     | 
||
| with llm.deprecate_legacy_api(): | ||
| yield weakref.proxy(llm) | ||
| 
     | 
||
| del llm | ||
| 
     | 
||
| cleanup_dist_env_and_memory() | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skip_global_cleanup | ||
| def test_pooling_params(llm: LLM): | ||
| 
     | 
||
| def get_outputs(normalize): | ||
| outputs = llm.embed(prompts, | ||
| pooling_params=PoolingParams(normalize=normalize), | ||
| use_tqdm=False) | ||
| return torch.tensor([x.outputs.embedding for x in outputs]) | ||
| 
     | 
||
| default = get_outputs(normalize=None) | ||
| w_normal = get_outputs(normalize=True) | ||
| wo_normal = get_outputs(normalize=False) | ||
| 
     | 
||
| assert torch.allclose(default, w_normal, | ||
| atol=1e-2), "Default should use normal." | ||
| assert not torch.allclose(w_normal, wo_normal, | ||
| atol=1e-2), "wo_normal should not use normal." | ||
| assert torch.allclose( | ||
| w_normal, F.normalize(wo_normal, p=2, dim=-1), | ||
| atol=1e-2), "w_normal should be close to normal(wo_normal)." | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| 
     | 
||
| import weakref | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| 
     | 
||
| from vllm import LLM, PoolingParams | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| 
     | 
||
| from ...models.utils import softmax | ||
| 
     | 
||
| MODEL_NAME = "internlm/internlm2-1_8b-reward" | ||
| 
     | 
||
| prompts = ["The chef prepared a delicious meal."] | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(autouse=True) | ||
| def v1(run_with_both_engines): | ||
| # Simple autouse wrapper to run both engines for each test | ||
| # This can be promoted up to conftest.py to run for every | ||
| # test in a package | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def llm(): | ||
| # pytest caches the fixture so we use weakref.proxy to | ||
| # enable garbage collection | ||
| llm = LLM(model=MODEL_NAME, | ||
| max_num_batched_tokens=32768, | ||
| tensor_parallel_size=1, | ||
| gpu_memory_utilization=0.75, | ||
| enforce_eager=True, | ||
| trust_remote_code=True, | ||
| seed=0) | ||
| 
     | 
||
| with llm.deprecate_legacy_api(): | ||
| yield weakref.proxy(llm) | ||
| 
     | 
||
| del llm | ||
| 
     | 
||
| cleanup_dist_env_and_memory() | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skip_global_cleanup | ||
| def test_pooling_params(llm: LLM): | ||
| 
     | 
||
| def get_outputs(softmax): | ||
| outputs = llm.reward(prompts, | ||
| pooling_params=PoolingParams(softmax=softmax), | ||
| use_tqdm=False) | ||
| return torch.cat([x.outputs.data for x in outputs]) | ||
| 
     | 
||
| default = get_outputs(softmax=None) | ||
| w_softmax = get_outputs(softmax=True) | ||
| wo_softmax = get_outputs(softmax=False) | ||
| 
     | 
||
| assert torch.allclose(default, w_softmax, | ||
| atol=1e-2), "Default should use softmax." | ||
| assert not torch.allclose(w_softmax, wo_softmax, | ||
| atol=1e-2), "wo_softmax should not use softmax." | ||
| assert torch.allclose( | ||
| softmax(wo_softmax), w_softmax, | ||
| atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| 
     | 
||
| import weakref | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| 
     | 
||
| from vllm import LLM, PoolingParams | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| 
     | 
||
| from ...models.utils import softmax | ||
| 
     | 
||
| MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(autouse=True) | ||
| def v1(run_with_both_engines): | ||
| # Simple autouse wrapper to run both engines for each test | ||
| # This can be promoted up to conftest.py to run for every | ||
| # test in a package | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def llm(): | ||
| # pytest caches the fixture so we use weakref.proxy to | ||
| # enable garbage collection | ||
| llm = LLM(model=MODEL_NAME, | ||
| max_num_batched_tokens=32768, | ||
| tensor_parallel_size=1, | ||
| gpu_memory_utilization=0.75, | ||
| enforce_eager=True, | ||
| seed=0) | ||
| 
     | 
||
| with llm.deprecate_legacy_api(): | ||
| yield weakref.proxy(llm) | ||
| 
     | 
||
| del llm | ||
| 
     | 
||
| cleanup_dist_env_and_memory() | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skip_global_cleanup | ||
| def test_pooling_params(llm: LLM): | ||
| 
     | 
||
| def get_outputs(activation): | ||
| text_1 = "What is the capital of France?" | ||
| text_2 = "The capital of France is Paris." | ||
| 
     | 
||
| outputs = llm.score( | ||
| text_1, | ||
| text_2, | ||
| pooling_params=PoolingParams(activation=activation), | ||
| use_tqdm=False) | ||
| return torch.tensor([x.outputs.score for x in outputs]) | ||
| 
     | 
||
| default = get_outputs(activation=None) | ||
| w_activation = get_outputs(activation=True) | ||
| wo_activation = get_outputs(activation=False) | ||
| 
     | 
||
| assert torch.allclose(default, w_activation, | ||
| atol=1e-2), "Default should use activation." | ||
| assert not torch.allclose( | ||
| w_activation, wo_activation, | ||
| atol=1e-2), "wo_activation should not use activation." | ||
| assert torch.allclose( | ||
| softmax(wo_activation), w_activation, atol=1e-2 | ||
| ), "w_activation should be close to activation(wo_activation)." | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.