Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/nightly_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ jobs:

- name: Checkout vllm-project/vllm-ascend repo
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Checkout vllm-project/vllm repo
uses: actions/checkout@v4
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ jobs:
pytest -sv tests/singlecard/test_scheduler.py
# guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
pytest -sv tests/singlecard/test_camem.py
pytest -sv tests/singlecard/ \
--ignore=tests/singlecard/test_offline_inference.py \
--ignore=tests/singlecard/test_scheduler.py \
--ignore=tests/singlecard/test_guided_decoding.py \
--ignore=tests/singlecard/test_camem.py
else
pytest -sv tests/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.
Expand Down
2 changes: 2 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import os

import pytest
import vllm # noqa: F401

from tests.conftest import VllmRunner
Expand All @@ -46,6 +47,7 @@ def test_models_distributed_QwQ():
vllm_model.generate_greedy(example_prompts, max_tokens)


@pytest.mark.skipif(True, reason="wait for mla issue fixed on v1")
def test_models_distributed_DeepSeek():
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
Expand Down
7 changes: 6 additions & 1 deletion tests/singlecard/test_camem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os

import pytest
import torch
from vllm import LLM, SamplingParams
Expand All @@ -24,7 +26,11 @@
from tests.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator

if os.getenv("VLLM_USE_V1") == "1":
pytest.skip("Skip in vllm v1", allow_module_level=True)


@fork_new_process_for_each_test
def test_basic_camem():
# some tensors from default memory pool
shape = (1024, 1024)
Expand Down Expand Up @@ -57,7 +63,6 @@ def test_basic_camem():
assert torch.allclose(output, torch.ones_like(output) * 3)


@pytest.mark.skipif(True, reason="test failed, should be fixed later")
@fork_new_process_for_each_test
def test_end_to_end():
free, total = torch.npu.mem_get_info()
Expand Down
31 changes: 22 additions & 9 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer

if TYPE_CHECKING:
Expand Down Expand Up @@ -1265,15 +1266,27 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}

self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
# Remove this after we drop 0.9.0 support
if vllm_version_is("0.9.0"):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
)

for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
Expand Down