Skip to content

Conversation

@MengqingCao
Copy link
Collaborator

@MengqingCao MengqingCao commented May 15, 2025

What this PR does / why we need it?

this PR fix CI failure broken by vllm.

  1. add moe_config for fused_moe
  2. adjust the change for kv cache group from vllm. currently vllm-ascend doesn't support this feature. this is just a quick fix for backward compatibility

fix: #872

Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pin_memory=True

cache_config.cache_dtype]

self.attn_metadata_builders: list[AscendAttentionMetadataBuilder] = []
self.attn_backends: list[type[AscendAttentionBackend]] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless 2L

self.scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device
self.pin_memory = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless


self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
self.max_model_len = self.model_config.max_model_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless, use self.model_config.max_model_len for InputBatch

Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
@wangxiyuan
Copy link
Collaborator

LGTM. let's merge this to unblock CI once the CI passed. Thanks for the fix.

@MengqingCao
Copy link
Collaborator Author

LGTM. let's merge this to unblock CI once the CI passed. Thanks for the fix.

Thanks, I make a small change in the latest commit, plz help to review it.

@MengqingCao MengqingCao changed the title [Bugfix][Model] Fix deepseek [Bugfix][Model] Fix fusedmoe and make modelrunner_v1 compatible with latest vllm May 16, 2025
@Yikun Yikun added the ready read for review label May 16, 2025
@Yikun
Copy link
Collaborator

Yikun commented May 16, 2025

@jianzs @ApsarasX Please take a look

self.local_num_experts = self.global_num_experts
self.expert_map = None

if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part of the code may not be needed, refer to the modification of this part in PR 863

However, the most urgent thing at present is to fix CI, which can be considered later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes let' make ci happy first then solve the bug later

@wangxiyuan wangxiyuan merged commit 7a325b2 into vllm-project:main May 16, 2025
16 checks passed
@MengqingCao MengqingCao deleted the fixds branch May 20, 2025 06:36
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Oct 16, 2025
…latest vllm (vllm-project#867)

### What this PR does / why we need it?
this PR fix CI failure broken by vllm.
1. add moe_config for fused_moe
2. adjust the change for kv cache group from vllm. currently vllm-ascend
doesn't support this feature. this is just a quick fix for backward
compatibility

fix: vllm-project#872

---------

Signed-off-by: MengqingCao <[email protected]>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
…latest vllm (vllm-project#867)

### What this PR does / why we need it?
this PR fix CI failure broken by vllm.
1. add moe_config for fused_moe
2. adjust the change for kv cache group from vllm. currently vllm-ascend
doesn't support this feature. this is just a quick fix for backward
compatibility

fix: vllm-project#872

---------

Signed-off-by: MengqingCao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: InputBatch.__init__() got an unexpected keyword argument 'max_num_blocks_per_req'

4 participants