Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ See [Building and Testing](./CONTRIBUTING.md) for more details, which is a step-
| Feature | Supported | Note |
|---------|-----------|------|
| Chunked Prefill | ✗ | Plan in 2025 Q1 |
| Automatic Prefix Caching | ✅ | Imporve performance in 2025 Q1 |
| Automatic Prefix Caching | ✅ | Improve performance in 2025 Q1 |
| LoRA | ✗ | Plan in 2025 Q1 |
| Prompt adapter | ✅ ||
| Speculative decoding | ✅ | Impore accuracy in 2025 Q1|
| Speculative decoding | ✅ | Improve accuracy in 2025 Q1|
| Pooling | ✗ | Plan in 2025 Q1 |
| Enc-dec | ✗ | Plan in 2025 Q1 |
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
Expand Down
56 changes: 53 additions & 3 deletions vllm_ascend/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,62 @@

import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_communicator import \
CommunicatorBase


class NPUCommunicator(CommunicatorBase):
class NPUCommunicator:

def __init__(self, group, unique_name=""):
self.group = group
self.unique_name = unique_name
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(self.group)
self.ranks = dist.get_process_group_ranks(self.group)
global_rank = dist.get_rank()
self.rank_in_group = dist.get_group_rank(self.group, global_rank)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any UT to check the functionality?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

communicator test need more than one NPU card which is not supported by current CI. We're working on multi card support for CI system.

In this comment, we need test this PR by hand locally and be careful to merge it.

# NOTE: We assume that the input tensor is on the same device across
# all the ranks.
# NOTE: `dst` is the local rank of the destination rank.
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [
torch.empty_like(input_) for _ in range(self.world_size)
]
else:
gather_list = None
# Gather.
dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
18 changes: 18 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from vllm_ascend.patch import patch_commnicator # noqa
67 changes: 67 additions & 0 deletions vllm_ascend/patch/patch_commnicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is used to monkey patch communicator in vllm to support ascend.
# Remove this file when vllm support by
# https://github.com/vllm-project/vllm/pull/11324.

from vllm.distributed.parallel_state import GroupCoordinator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but just curious: should vllm be a dependency of vllm-ascend as oneline in requriement and pyproject?

Copy link
Collaborator Author

@wangxiyuan wangxiyuan Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm. Let's have a try. we can add it.

While IMO, it maybe raises error because there is no CPU version of pytorch on pypi.

Once it's added, the install step in the future from my sight is:

  1. install cpu version of Pytorch by hand. (torch==2.5.1+cpu)
  2. pip install vllm-ascend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no warries, we can do it in followup

from vllm.utils import resolve_obj_by_qualname


class GroupCoordinatorPatch(GroupCoordinator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

from vllm.platforms import current_platform
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
# we have checked and ensure that reusing tpu tag here is fine.
use_custom_device = kwargs.get("use_tpu_communicator", False)
if use_custom_device and self.world_size > 1:
self.communicator = device_comm_cls(group=self.device_group,
unique_name=self.unique_name)

def all_reduce(self, input_):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_

return self.communicator.all_reduce(input_)

def gather(self, input_, dst=0, dim=-1):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

return self.communicator.gather(input_, dst, dim)

def all_gather(self, input_, dim=-1):
# Bypass the function if we are using only 1 device.
if self.world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.communicator.all_gather(input_, dim)


GroupCoordinator = GroupCoordinatorPatch
3 changes: 2 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def mem_get_info(cls) -> Tuple[int, int]:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Register ops when setup.
# Register ops and patch when setup.
from vllm_ascend import ops # noqa: F401
from vllm_ascend import patch # noqa: F401

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
Expand Down