Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/neuron/test_comm_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
import functools
from typing import Callable
from unittest.mock import patch

import pytest
import torch
import torch_xla.distributed.xla_multiprocessing as xmp
from typing_extensions import ParamSpec

from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_distributed_init_method, get_open_port

_P = ParamSpec("_P")


def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""

@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.unsafe_close()

f(*args, **kwargs)
runtime.initialize()

return wrapper


def all_gather_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)

num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s

all_gather_dimension = -1
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="xla").reshape(tensor_size) * (r + 1)
for r in range(tp_degree)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
torch.testing.assert_close(t, expected)


def all_reduce_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)

num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1)
for r in range(tp_degree)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_reduce(t)
torch.testing.assert_close(t, expected)


@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
@reinitialize_neuron_runtime
def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size,
test_target):

with patch('torch_xla._XLAC._xla_runtime_is_initialized',
return_value=False):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())

monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size))
monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES",
','.join(['1' for _ in range(tp_size)]))

xmp.spawn(test_target, args=(tp_size, distributed_init_method))
19 changes: 19 additions & 0 deletions vllm/distributed/device_communicators/neuron_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
import torch

from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform

if current_platform.is_neuron():
import torch_xla.core.xla_model as xm


class NeuronCommunicator(DeviceCommunicatorBase):

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
8 changes: 8 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Optional

from vllm import envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum
Expand Down Expand Up @@ -56,6 +57,13 @@ def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False

@classmethod
def get_device_communicator_cls(cls) -> str:
if envs.VLLM_USE_V1:
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
else:
return Platform.get_device_communicator_cls()

@classmethod
def use_all_gather(cls) -> bool:
return True