Skip to content

Conversation

@thechaos16
Copy link
Contributor

@thechaos16 thechaos16 commented Jul 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

  • Fix an issue that classify no more available for LlamaForSequenceClassification model.
  • load any LlamaForSequenceClassification model
import torch

from vllm import LLM
from vllm.engine.arg_utils import PoolerConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer

mpath = "Skywork/Skywork-Reward-V2-Llama-3.1-8B"
llm = LLM(model=mpath, task="classify")
  • error message
ERROR 07-11 17:05:32 [core.py:586] EngineCore failed to start.
ERROR 07-11 17:05:32 [core.py:586] Traceback (most recent call last):
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
ERROR 07-11 17:05:32 [core.py:586]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 07-11 17:05:32 [core.py:586]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 404, in __init__
ERROR 07-11 17:05:32 [core.py:586]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 75, in __init__
ERROR 07-11 17:05:32 [core.py:586]     self.model_executor = executor_class(vllm_config)
ERROR 07-11 17:05:32 [core.py:586]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/executor/executor_base.py", line 53, in __init__
ERROR 07-11 17:05:32 [core.py:586]     self._init_executor()
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py", line 48, in _init_executor
ERROR 07-11 17:05:32 [core.py:586]     self.collective_rpc("load_model")
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-11 17:05:32 [core.py:586]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-11 17:05:32 [core.py:586]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/utils/__init__.py", line 2736, in run_method
ERROR 07-11 17:05:32 [core.py:586]     return func(*args, **kwargs)
ERROR 07-11 17:05:32 [core.py:586]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/gpu_worker.py", line 185, in load_model
ERROR 07-11 17:05:32 [core.py:586]     self.model_runner.load_model()
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1776, in load_model
ERROR 07-11 17:05:32 [core.py:586]     self.model = model_loader.load_model(
ERROR 07-11 17:05:32 [core.py:586]                  ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/model_loader/base_loader.py", line 41, in load_model
ERROR 07-11 17:05:32 [core.py:586]     self.load_weights(model, model_config)
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/model_loader/default_loader.py", line 269, in load_weights
ERROR 07-11 17:05:32 [core.py:586]     loaded_weights = model.load_weights(
ERROR 07-11 17:05:32 [core.py:586]                      ^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/transformers.py", line 509, in load_weights
ERROR 07-11 17:05:32 [core.py:586]     return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
ERROR 07-11 17:05:32 [core.py:586]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/utils.py", line 291, in load_weights
ERROR 07-11 17:05:32 [core.py:586]     autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 07-11 17:05:32 [core.py:586]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-11 17:05:32 [core.py:586]   File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/utils.py", line 277, in _load_module
ERROR 07-11 17:05:32 [core.py:586]     raise ValueError(msg)
ERROR 07-11 17:05:32 [core.py:586] ValueError: There is no module or parameter named 'score' in TransformersForCausalLM

Test Plan

Test Result

  • After applying these codes
스크린샷 2025-07-11 오후 5 23 34

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @thechaos16, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request extends VLLM's capabilities by introducing full support for LlamaForSequenceClassification models. It resolves an issue where these models could not be properly loaded or used for classification tasks, by adapting the existing Llama causal language model implementation and integrating it into VLLM's model registry. This enhancement allows users to seamlessly load and run Llama-based sequence classification models for their specific needs.

Highlights

  • New Model Support: Added comprehensive support for LlamaForSequenceClassification models, enabling them to be loaded and utilized within the VLLM framework for sequence classification tasks. This addresses a previous limitation where the classify function was not available for these models.
  • Model Adaptation: Implemented an adapter mechanism (as_seq_cls_model) to allow the existing LlamaForCausalLM implementation to function correctly as a LlamaForSequenceClassification model. This reuses the core Llama architecture while providing the necessary classification capabilities.
  • Registry Integration: Registered LlamaForSequenceClassification within VLLM's model registry, making it discoverable and loadable when users specify a Llama-based sequence classification model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added llama Related to Llama models new-model Requests to new models labels Jul 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for LlamaForSequenceClassification by using the as_seq_cls_model adapter, which is a clean and idiomatic way to extend model functionality in this codebase. The changes are straightforward and well-contained. I have one suggestion regarding code style to improve maintainability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and consistency with other model dictionaries in this file (e.g., _TEXT_GENERATION_MODELS), it's recommended to keep the _SEQUENCE_CLASSIFICATION_MODELS dictionary sorted alphabetically by key.

While I can only comment on the changed lines, I'd suggest sorting the entire dictionary. This would make it easier to find models and add new ones in the future.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

cc @noooop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep this with the other text models. JinaVLForRanking is specially put at the end because it's multimodal, maybe add a code comment like in the other sections

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I rearranged as you suggested.

@noooop
Copy link
Collaborator

noooop commented Jul 11, 2025

@thechaos16

You need to add LlamaForSequenceClassification in tests/models/registry.py

Fix an issue that classify no more available for LlamaForSequenceClassification model.

Why could it run successfully before? <- working on it

Anyhow, we should automatically support all ForSequenceClassification models, instead of adding them one by one to registry.py

@thechaos16 thechaos16 requested a review from ywang96 as a code owner July 11, 2025 11:35
@thechaos16
Copy link
Contributor Author

@noooop

Thank you for your comment. I've just added LlamaForSequenceClassfication to tests/models/registry.py.

Anyhow, we should automatically support all ForSequenceClassification models, instead of adding them one by one to registry.py

I couldn't agree more. It would be perfect if we could use any kind of SequenceClassification model for classify task.

@noooop
Copy link
Collaborator

noooop commented Jul 12, 2025

@thechaos16 Please fix pre-commit, it is required.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 13, 2025 04:24
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 13, 2025
auto-merge was automatically disabled July 13, 2025 04:36

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 13, 2025 04:48
@thechaos16
Copy link
Contributor Author

Thank you for the comment. I've fixed pre-commit and squashed all commits to one.

@vllm-bot vllm-bot merged commit bd4c1e6 into vllm-project:main Jul 13, 2025
64 of 69 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants