Skip to content

Conversation

@DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Nov 13, 2025

Purpose

  • Move default values of SchedulerConfig from utils to SchedulerConfig itself.
  • In normal usage of vLLM, max_num_batched_tokens and max_num_seqs are actually set by UsageContext when constructing EngineArgs, so there is no point in making the default values of SchedulerConfig dynamic (the defaults for pooling and multimodal models are actually being overwritten). This simplifies the initialization code a lot.
  • Split up the logic of EngineArgs._set_default_args to be more modular.
  • Make SchedulerConfig.chunked_prefill_enabled a property-based alias of SchedulerConfig.enable_chunked_prefill to avoid having to set both when overriding the config in each platform.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2025
@mergify mergify bot added the tpu Related to Google TPUs label Nov 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a solid cleanup of the SchedulerConfig initialization. Moving the default value logic from SchedulerConfig.__post_init__ to EngineArgs and refactoring it into new helper methods significantly improves modularity and readability. The use of class variables for default values in SchedulerConfig is also a good practice. I've identified one minor logging issue that should be addressed.

Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member Author

/gemini review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a good refactoring that cleans up SchedulerConfig initialization by centralizing default values and improving the modularity of EngineArgs. The use of a property for chunked_prefill_enabled is a clean way to handle aliasing and maintain backward compatibility.

However, I've identified a critical issue: the refactoring appears to have removed the specific default max_num_batched_tokens logic for pooling and multimodal models. This could lead to a performance regression for these model types. I've added a detailed comment regarding this. Once this is addressed, the PR will be a strong improvement.

Comment on lines 196 to 223
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192

if self.max_num_seqs is None:
self.max_num_seqs = 128

if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both."
)

if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)

if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len,
self.max_num_batched_tokens,
)

# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len,
self.max_num_batched_tokens,
)

self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This refactoring simplifies the initialization, but it seems to have removed the special default logic for max_num_batched_tokens for pooling and multimodal models.

Previously, if max_num_batched_tokens was not set by the user or a UsageContext-specific default, there was fallback logic to increase it for pooling models (to 32768) and multimodal models (to 5120) for better throughput. This logic was triggered if max_num_batched_tokens was None when __post_init__ was called.

This logic has now been removed. The justification in the PR description suggests this was dead code, but it appears it would have been triggered if no UsageContext default was found. The new implementation in EngineArgs ensures max_num_batched_tokens is always set, but the specific, higher defaults for pooling/multimodal models are no longer applied anywhere.

Removing this could lead to a significant performance regression for these model types. Could you please confirm if this change is intended? If it's a mistake, this logic should be restored, perhaps within EngineArgs.get_batch_defaults.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case of no UsageContext is not normal usage of vLLM

Copy link
Member Author

@DarkLight1337 DarkLight1337 Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @WoosukKwon @robertgshaw2-redhat correct me if I'm wrong about this

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really nice change, I've left a few comments for clarification

Comment on lines +51 to +56
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).
The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this entirely?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in some other places like vllm.v1.core.sched.Scheduler. We can try to refactor this in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would need a small refactor. A follow up PR sounds good.

Comment on lines +235 to +242
@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill

@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just remove this? It used to be init=False so it's not part of the normal API of SchedulerConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
max_num_batched_tokens: int | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you group these fields which don't copy the defaults from their respective config and add a comment saying why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that in the next cleanup

@DarkLight1337 DarkLight1337 merged commit 511a6b6 into vllm-project:main Nov 14, 2025
50 checks passed
@DarkLight1337 DarkLight1337 deleted the clean-sched-defaults branch November 14, 2025 14:41
@ZJY0516 ZJY0516 mentioned this pull request Nov 15, 2025
5 tasks
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
@WoosukKwon
Copy link
Collaborator

@DarkLight1337 I think this PR changes the default values for max_num_batched_tokens and max_num_seqs unexpectedly.

@WoosukKwon
Copy link
Collaborator

IIUC, the two if statements here are not executed because self.max_num_batched_tokens and self.max_num_seqs are already set to the default values in SchedulerConfig:

vllm/vllm/engine/arg_utils.py

Lines 1988 to 1998 in 3380ed5

if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)

@hmellor
Copy link
Member

hmellor commented Nov 17, 2025

If the behaviour has changed I don't think it's because of the changed defaults in SchedulerConfig. The if blocks linked will be triggered if the values of EngineArgs are None, which still are in the default case:

max_num_batched_tokens: int | None = None

max_num_seqs: int | None = None

@DarkLight1337
Copy link
Member Author

Yeah, tests/v1/engine/test_engine_args.py::test_defaults_with_usage_context ensures that the defaults stay the same.

@WoosukKwon
Copy link
Collaborator

@hmellor @DarkLight1337 In B200, the correct default values (1024 seqs, 8K tokens) are not used. vllm serve uses 128 seqs & 2K tokens instead, which limits the performance a lot.

@DarkLight1337
Copy link
Member Author

Can you run tests/v1/engine/test_engine_args.py::test_defaults_with_usage_context on the B200 and see if the test passes?

@WoosukKwon
Copy link
Collaborator

@DarkLight1337 It passes the test, but it's probably because the usage context is considered in the test (while it doesn't in vllm serve). And I've confirmed that this PR caused the issue.

@DarkLight1337
Copy link
Member Author

Hmm, shouldn't UsageContext.OPENAI_API_SERVER be used in vllm serve? Or did I misunderstand how UsageContext works?

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Nov 17, 2025

@DarkLight1337 vllm serve supposed to use usage_context too. However, this PR introduces a bug that it doesn't use the usage context.

 if self.max_num_batched_tokens is None: 
     self.max_num_batched_tokens = default_max_num_batched_tokens.get( 
         usage_context, 
         SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, 
     ) 

Here, self.max_num_batched_tokens is never None after this PR for some reason. It is somehow set to 2048 (default value in SchedulerConfig).

@DarkLight1337
Copy link
Member Author

Ok I figured out the issue, the CLI defaults are still using the ones from SchedulerConfig.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants