Skip to content

Conversation

@baonudesifeizhai
Copy link
Contributor

@baonudesifeizhai baonudesifeizhai commented Sep 28, 2025

Key Changes

  • Preserve user-specified splitting_ops: When use_inductor_graph_partition=True, user-provided splitting_ops are now preserved in _user_specified_splitting_ops instead of being cleared
  • Dynamic partition rules: Implement _setup_dynamic_partition_rules() using PyTorch 2.9+'s register_should_partition_rule API to register custom partition points
  • Robust fallback mechanism: If dynamic rule registration fails for any user-specified operation, the system falls back to traditional splitting behavior with preserved splitting_ops
  • Comprehensive logging: Add detailed debug logging to track partition rule setup, registration success/failure, and fallback behavior
  • PyTorch version compatibility: Update requirements to support PyTorch 2.10.0.dev20250927+cu128 for the new partition rule API

Technical Implementation

  1. Operation mapping: Map user-provided splitting_ops (including aliases like "flash_attention") to torch._ops.OpOverload objects
  2. Rule registration: Use register_should_partition_rule(op_overload, partition_function) for each resolved operation
  3. Duplicate prevention: Track registered overloads globally to prevent duplicate registrations
  4. Alias resolution: Handle common operation aliases (e.g., "flash_attention" → "vllm.unified_attention")

Test Plan

Unit Tests

  • Updated test_splitting_ops_dynamic() to verify splitting_ops preservation behavior
  • Added comprehensive test coverage for dynamic partition rule setup
  • Updated existing compilation tests to reflect new default behavior

Integration Tests

  • Test with various splitting_ops configurations:
    python -m vllm.entrypoints.openai.api_server \
      --model Qwen/Qwen2.5-7B-Instruct \
      --compilation-config '{"use_inductor_graph_partition": true, "splitting_ops": ["flash_attention", "addmm", "aten.bmm.default"]}' \
      --host 0.0.0.0 --port 8000
  • Verify fallback behavior when PyTorch version doesn't support dynamic rules
  • Test with mixed resolvable/unresolvable operations

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

baonudesifeizhai and others added 19 commits September 26, 2025 20:15
- Add _user_specified_splitting_ops field to store user configuration
- Modify set_splitting_ops_for_inductor_graph_partition to respect user settings
- Add debug logging to track splitting_ops handling
- Addresses issue vllm-project#25691 - partial implementation for dynamic partitioning

This change preserves user-specified splitting_ops when use_inductor_graph_partition=True,
laying groundwork for future PyTorch 2.9+ register_should_partition_rule integration.
- Add _setup_dynamic_partition_rules() method
- Implement register_should_partition_rule integration
- Support both attention ops and user-specified splitting_ops
- Add comprehensive debug logging for partition decisions
- Graceful fallback if PyTorch API not available

This completes the implementation for issue vllm-project#25691
@baonudesifeizhai baonudesifeizhai force-pushed the feature/dynamic-inductor-partition-rules branch from 263f0b6 to 56ae27d Compare October 10, 2025 07:37
@mergify
Copy link

mergify bot commented Oct 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @baonudesifeizhai.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 10, 2025
@mergify mergify bot removed the needs-rebase label Oct 10, 2025
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: baonudesifeizhai <[email protected]>
@ProExpertProg ProExpertProg enabled auto-merge (squash) October 10, 2025 12:53
@ProExpertProg ProExpertProg enabled auto-merge (squash) October 10, 2025 12:54
@ProExpertProg ProExpertProg changed the title Feature/dynamic inductor partition rules #25691 [torch.compile] Make inductor partition rules respect splitting_ops #25691 Oct 10, 2025
@ProExpertProg ProExpertProg merged commit cddce79 into vllm-project:main Oct 10, 2025
49 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in torch.compile integration Oct 10, 2025
huydhn added a commit to pytorch/pytorch that referenced this pull request Oct 11, 2025
huydhn added a commit to huydhn/pytorch that referenced this pull request Oct 12, 2025
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…llm-project#25691 (vllm-project#25845)

Signed-off-by: baonudesifeizhai <[email protected]>
Signed-off-by: baonudesifeizhai <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: Dhruvil Bhatt <[email protected]>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…llm-project#25691 (vllm-project#25845)

Signed-off-by: baonudesifeizhai <[email protected]>
Signed-off-by: baonudesifeizhai <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: bbartels <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Oct 24, 2025
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
vllm-project/vllm@17c540a

1. refactor deepseek to the latest code arch as of
vllm-project/vllm@17c540a
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
vllm-project/vllm#25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by vllm-project/vllm#26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
vllm-project/vllm#23485
- Fix `MLAAttention` import,caused by
vllm-project/vllm#25103
- Fix `SharedFusedMoE` import, caused by
vllm-project/vllm#26145
- Fix `LazyLoader` improt, caused by
vllm-project/vllm#27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
vllm-project/vllm#26990
- Fix `Backend` enum import, caused by
vllm-project/vllm#25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by vllm-project/vllm#26355
- Fix fused_moe ops, caused by
vllm-project/vllm#24097
- Fix bert model because of `inputs_embeds`, caused by
vllm-project/vllm#25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
vllm-project/vllm#24172
- Fix `splitting_ops` changes introduced by
vllm-project/vllm#25845
- Fix multi-modality changes introduced by
vllm-project/vllm#16229
- Fix lora bias dropping issue introduced by
vllm-project/vllm#25807
- Fix structured ouput break introduced by
vllm-project/vllm#26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: Icey <[email protected]>
Co-authored-by: Icey <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…llm-project#25691 (vllm-project#25845)

Signed-off-by: baonudesifeizhai <[email protected]>
Signed-off-by: baonudesifeizhai <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed torch.compile

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Feature]: Inductor partitioning should decide what ops to partition on dynamically

5 participants