From ba1c76b4aee7338d0dd1e1c7b253855366dc083e Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 6 May 2025 15:10:22 +0800 Subject: [PATCH 1/3] fix: fix a typo in setup.py Signed-off-by: linfeng-yuan <1102311262@qq.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ad468ca34b..631e55c1b6 100644 --- a/setup.py +++ b/setup.py @@ -143,7 +143,7 @@ def configure(self, ext: CMakeExtension) -> None: sys.executable) # find PYTHON_INCLUDE_PATH - check_or_set_default_env(cmake_args, "PYHTON_INCLUDE_PATH", + check_or_set_default_env(cmake_args, "PYTHON_INCLUDE_PATH", get_paths()["include"]) # ccache and ninja can not be applied at ascendc kernels now From c9b78851dad677dc62f30cf2c293ca6ed6b0650b Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 6 May 2025 18:57:04 +0800 Subject: [PATCH 2/3] fix: fix an accuracy problem for quantized deepseek models Signed-off-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/quantization/w8a8_dynamic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index bcd313d22d..136af3037b 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -285,8 +285,10 @@ def fused_experts(hidden_states: torch.Tensor, valid_token_mask = torch.arange( 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < num_valid_tokens - down_out_list.mul_(valid_token_mask) - final_hidden_states.index_add_(0, sorted_token_indices, down_out_list) + valid_output = torch.where( + valid_token_mask, down_out_list, + torch.zeros_like(down_out_list)).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. From e2e575b8c1f8c1c88f693b5d3db18c0ed6ba7ee7 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Thu, 15 May 2025 19:32:07 +0800 Subject: [PATCH 3/3] adapt: make the cache engine for npu graph mode compatible with original one Signed-off-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/attention/attention.py | 2 +- .../kv_transfer/simple_connector.py | 26 ++++-- vllm_ascend/worker/cache_engine.py | 82 ------------------- vllm_ascend/worker/model_runner.py | 23 ++++-- 4 files changed, 38 insertions(+), 95 deletions(-) delete mode 100644 vllm_ascend/worker/cache_engine.py diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index d598822080..ae69c694d1 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -1166,7 +1166,7 @@ def forward( # TODO: Replace the env with more flexible expressions if self.enable_graph_mode: - if len(kv_cache) > 0 and kv_cache[0].numel( + if kv_cache is not None and len(kv_cache) > 0 and kv_cache[0].numel( ) > 0 and attn_metadata.num_prefills > 0: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues diff --git a/vllm_ascend/distributed/kv_transfer/simple_connector.py b/vllm_ascend/distributed/kv_transfer/simple_connector.py index 7b05052d08..c1a6c79fe2 100644 --- a/vllm_ascend/distributed/kv_transfer/simple_connector.py +++ b/vllm_ascend/distributed/kv_transfer/simple_connector.py @@ -302,13 +302,25 @@ def recv_kv_caches_and_hidden_states( layer = model_executable.model.layers[i] if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - key_cache = kv_cache - slots = slot_mapping[start_pos:end_pos] - sliced_key = keys[i - model_executable.model.start_layer] - torch_npu._npu_reshape_and_cache_siso(key=sliced_key, - key_cache=key_cache, - slot_indices=slots) + if self.enable_graph_mode: + num_blocks, block_size, num_head, head_dim = kv_cache.size() + kv_cache = kv_cache.view(-1) + receive_key = keys[i - model_executable.model.start_layer].view(-1) + slice_receive_size = num_tokens * num_head * 512 + slice_paged_size = num_blocks * block_size * num_head * 512 + receive_nope = receive_key[:slice_receive_size].view(num_tokens, num_head, 512) + receive_rope = receive_key[slice_receive_size:].view(num_tokens, num_head, 64) + paged_nope = kv_cache[:slice_paged_size].view(num_blocks, block_size, num_head, 512) + paged_rope = kv_cache[slice_paged_size:].view(num_blocks, block_size, num_head, 64) + torch_npu._npu_reshape_and_cache(key=receive_nope, value=receive_rope, key_cache=paged_nope, value_cache=paged_rope, slot_indices=slot_mapping[start_pos:end_pos]) + else: + layer.self_attn.attn = layer.self_attn.mla_attn + key_cache = kv_cache + slots = slot_mapping[start_pos:end_pos] + sliced_key = keys[i - model_executable.model.start_layer] + torch_npu._npu_reshape_and_cache_siso(key=sliced_key, + key_cache=key_cache, + slot_indices=slots) else: key_cache, value_cache = kv_cache[0], kv_cache[1] sliced_key = keys[i - model_executable.model.start_layer] diff --git a/vllm_ascend/worker/cache_engine.py b/vllm_ascend/worker/cache_engine.py deleted file mode 100644 index 72de201f1d..0000000000 --- a/vllm_ascend/worker/cache_engine.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/vllm/worker/model_runner.py -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, List - -import torch -from vllm.config import get_current_vllm_config -from vllm.utils import is_pin_memory_available -from vllm.worker.cache_engine import CacheEngine - - -def allocate_kv_cache( - self, - num_blocks: int, - device: str, -) -> List[Any]: - """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - pin_memory = is_pin_memory_available() if device == "cpu" else False - kv_cache: List[Any] = [] - - additional_config = get_current_vllm_config().additional_config - if additional_config and additional_config.get("enable_graph_mode", False): - # Align entries so they are 256 byte aligned for better performance - # Primarily targets MLA as this typically only ends up having entries - # be 128 byte aligned. - alloc_shape = kv_cache_shape - - for _ in range(self.num_attention_layers): - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - layer_kv_cache_nope = torch.zeros( - alloc_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=pin_memory, - device=device) - layer_kv_cache_pe = torch.zeros( - alloc_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, ), - dtype=self.dtype, - pin_memory=pin_memory, - device=device) - - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases - # when entry_shape is higher than 1D - kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe)) - else: - for _ in range(self.num_attention_layers): - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device) - - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases - # when entry_shape is higher than 1D - kv_cache.append(layer_kv_cache) - return kv_cache - - -CacheEngine._allocate_kv_cache = allocate_kv_cache diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 49c221e6b5..88d8c3217f 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -1314,10 +1314,6 @@ def execute_model( torch._dynamo.mark_static(model_input.input_positions) torch._dynamo.mark_static(model_input.attn_metadata.block_tables) torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping) - for kv in kv_caches: - if isinstance(kv, tuple): - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. @@ -1397,7 +1393,24 @@ def execute_model( if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions if self.enable_graph_mode: - model_kwargs["kv_caches"] = kv_caches + if kv_caches[0].numel() == 0: + pe_caches = None + else: + pe_caches = [] + for (i, kv) in enumerate(kv_caches): + if self.enable_graph_mode and len(kv.shape) == 4: + (num_blocks, block_size, num_kv_heads, head_size) = kv.shape + flatten_cache = kv.view(-1) + split_index = num_blocks * block_size * num_kv_heads * 512 + nope_cache = flatten_cache[:split_index] + rope_cache = flatten_cache[split_index:] + nope_cache = nope_cache.view(num_blocks, block_size, num_kv_heads, 512) + rope_cache = rope_cache.view(num_blocks, block_size, num_kv_heads, 64) + # kv_caches[i] = (nope_cache, rope_cache) + pe_caches.append((nope_cache, rope_cache)) + torch._dynamo.mark_static(pe_caches[i][0]) + torch._dynamo.mark_static(pe_caches[i][1]) + model_kwargs["kv_caches"] = pe_caches model_kwargs["attn_metadata"] = model_input.attn_metadata hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens,