Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/singlecard/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@ def test_empty_input_batch(k: int, batch_size: int,

@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
Expand Down Expand Up @@ -646,7 +645,6 @@ def test_initialize_cache(acceptance_sampler_method):
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
Expand Down Expand Up @@ -685,7 +683,6 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
Expand Down
4 changes: 1 addition & 3 deletions vllm_ascend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@

def register():
"""Register the NPU platform."""

return "vllm_ascend.platform.NPUPlatform"


def register_model():
# fix pytorch schema check error, remove this line after pytorch
# is upgraded to 2.7.0
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401

from .models import register_model
from vllm_ascend.models import register_model
register_model()
13 changes: 8 additions & 5 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@


def register_model():
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
from .qwen2_5_vl import \
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from vllm_ascend.models.deepseek_v2 import \
CustomDeepseekV2ForCausalLM # noqa: F401
from vllm_ascend.models.deepseek_v2 import \
CustomDeepseekV3ForCausalLM # noqa: F401
from vllm_ascend.models.qwen2_5_vl import \
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
from vllm_ascend.models.qwen2_vl import \
AscendQwen2VLForConditionalGeneration # noqa: F401

ModelRegistry.register_model(
"DeepSeekMTPModel",
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .deepseek_v2 import CustomDeepseekV2DecoderLayer
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2DecoderLayer


class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
Expand Down
39 changes: 31 additions & 8 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,39 @@
#
# * Worker Patch:
# ===============
# ** File: worker/patch_common/patch_utils.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.utils.direct_register_custom_op`
# Why:
# direct_register_custom_op requires pytorch version >= 2.7.0,
# but vllm-ascend only support pytorch version 2.5.1
# How:
# Convert annotation type to typing type for 2.5.1 backward compatibility
# Related PR (if no, explain why):
# No related PR, it's the change in vllm-ascend.
# Future Plan:
# Update pytorch and torch-npu to 2.7.0 in the future.
# ** File: worker/patch_common/patch_cache_engine.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.worker.cache_engine.CacheEngine._allocate_kv_cache`
# Why:
# Add graph_mode optimization for kv cache allocation.
# How:
# If graph_mode is enabled, add layer_kv_cache_nope and layer_kv_cache_pe to the kv_cache.
# Related PR (if no, explain why):
# Need a PR to vllm to fix the issue.
# Future Plan:
# Revert it when the related pr is merged in vllm.
# ** File: worker/patch_common/patch_metrics.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
# How:
# Change to use `current_platform.Event` to determine whether to return None
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Related PR (if no, explain why):
# Need a PR to vllm to fix the issue.
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
Expand All @@ -110,7 +133,7 @@
# However float32 is not supported in cann rope op, thus we keep this patch
# How:
# Removed the dtype convert operations in forward
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# Related PR (if no, explain why):
# NO, only for npu due to rope op.
# Future Plan:
# Keep this patch in vllm-ascend.
Expand All @@ -126,7 +149,7 @@
# - support attention metadata register to the set supported spec decode
# - offer a api in platform to determine whether spec decode is supported,
# and deprecate is_cuda_alike in it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
Expand All @@ -138,7 +161,7 @@
# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change.
# How:
# Use vLLM 0.8.4 method to patch it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
Expand All @@ -153,7 +176,7 @@
# `FlashAttentionMetadata`
# How:
# ditto
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# patch_utils should be the first import, because it will be used by other
# patch files.
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_cache_engine # noqa
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
4 changes: 0 additions & 4 deletions vllm_ascend/patch/worker/patch_common/patch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
# limitations under the License.
#

from typing import Callable

import torch
from vllm.spec_decode.metrics import AsyncMetricsCollector

Timer = Callable[[], float]


def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
"""
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod

from .quantizer import AscendQuantizer
from vllm_ascend.quantization.quantizer import AscendQuantizer


@register_quantization_config("ascend")
Expand Down
11 changes: 6 additions & 5 deletions vllm_ascend/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

from vllm.logger import logger

from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from .w8a8 import AscendW8A8LinearMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from vllm_ascend.quantization.func_wrapper import (wrapper_load_model,
wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.w8a8_dynamic import (
AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod)

CUSTOMIZED_QUANTIZER_TYPE: List[str] = []

Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from packaging.version import InvalidVersion, Version
from vllm.logger import logger

import vllm_ascend.envs as envs
from vllm_ascend import envs

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import vllm_ascend.worker.cache_engine # noqa