| 
 | 1 | +#  | 
 | 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.  | 
 | 3 | +# This file is a part of the vllm-ascend project.  | 
 | 4 | +#  | 
 | 5 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 6 | +# you may not use this file except in compliance with the License.  | 
 | 7 | +# You may obtain a copy of the License at  | 
 | 8 | +#  | 
 | 9 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 10 | +#  | 
 | 11 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 12 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 14 | +# See the License for the specific language governing permissions and  | 
 | 15 | +# limitations under the License.  | 
 | 16 | +#  | 
 | 17 | +# This file is used to monkey patch communicator in vllm to support ascend.  | 
 | 18 | +# Remove this file when vllm support by  | 
 | 19 | +# https://github.com/vllm-project/vllm/pull/11324.  | 
 | 20 | + | 
 | 21 | +import torch  | 
 | 22 | +from vllm.distributed.parallel_state import GroupCoordinator  | 
 | 23 | +from vllm.utils import resolve_obj_by_qualname  | 
 | 24 | + | 
 | 25 | + | 
 | 26 | +class GroupCoordinatorPatch(GroupCoordinator):  | 
 | 27 | + | 
 | 28 | +    def __init__(self, *args, **kwargs):  | 
 | 29 | +        super().__init__(*args, **kwargs)  | 
 | 30 | +        self.device = torch.device(f"npu:{self.local_rank}")  | 
 | 31 | + | 
 | 32 | +        from vllm.platforms import current_platform  | 
 | 33 | +        device_comm_cls = resolve_obj_by_qualname(  | 
 | 34 | +            current_platform.get_device_communicator_cls())  | 
 | 35 | +        # we have checked and ensure that reusing tpu tag here is fine.  | 
 | 36 | +        use_custom_device = kwargs.get("use_tpu_communicator", False)  | 
 | 37 | +        if use_custom_device and self.world_size > 1:  | 
 | 38 | +            self.communicator = device_comm_cls(group=self.device_group,  | 
 | 39 | +                                                unique_name=self.unique_name)  | 
 | 40 | + | 
 | 41 | +    def all_reduce(self, input_):  | 
 | 42 | +        # Bypass the function if we are using only 1 device.  | 
 | 43 | +        if self.world_size == 1:  | 
 | 44 | +            return input_  | 
 | 45 | + | 
 | 46 | +        return self.communicator.all_reduce(input_)  | 
 | 47 | + | 
 | 48 | +    def gather(self, input_, dst=0, dim=-1):  | 
 | 49 | +        # Bypass the function if we are using only 1 device.  | 
 | 50 | +        if self.world_size == 1:  | 
 | 51 | +            return input_  | 
 | 52 | +        assert -input_.dim() <= dim < input_.dim(), (  | 
 | 53 | +            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")  | 
 | 54 | +        if dim < 0:  | 
 | 55 | +            # Convert negative dim to positive.  | 
 | 56 | +            dim += input_.dim()  | 
 | 57 | + | 
 | 58 | +        return self.communicator.gather(input_, dst, dim)  | 
 | 59 | + | 
 | 60 | +    def all_gather(self, input_, dim=-1):  | 
 | 61 | +        # Bypass the function if we are using only 1 device.  | 
 | 62 | +        if self.world_size == 1:  | 
 | 63 | +            return input_  | 
 | 64 | +        assert -input_.dim() <= dim < input_.dim(), (  | 
 | 65 | +            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")  | 
 | 66 | +        return self.communicator.all_gather(input_, dim)  | 
 | 67 | + | 
 | 68 | + | 
 | 69 | +GroupCoordinator = GroupCoordinatorPatch  | 
0 commit comments