diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index f5b42bcfb81..ce67d25cc6b 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: tensorrt_llm-dev: - image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420 + image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792 network_mode: host ipc: host diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index 6f19602efb2..a8a7ff38940 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -72,6 +72,10 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/ RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir +# Install DeepEP +COPY docker/common/install_deep_ep.sh install_deep_ep.sh +RUN bash ./install_deep_ep.sh && rm install_deep_ep.sh + # WARs against security issues inherited from pytorch:25.04 # * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj # * https://github.com/advisories/GHSA-7cx3-6m66-7c5m diff --git a/docker/common/install_deep_ep.sh b/docker/common/install_deep_ep.sh new file mode 100644 index 00000000000..c6c572eff9e --- /dev/null +++ b/docker/common/install_deep_ep.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +set -euxo pipefail + +GITHUB_URL=${GITHUB_MIRROR:-https://github.com} +DEEP_EP_COMMIT=2b266cf6452134f993ab0fcb3ef2d5de7683c561 + +if [ "$(. /etc/os-release && echo $ID)" == "rocky" ]; then + echo "Skipping DeepEP installation in the Rocky distribution." + exit 0 +fi +libmlx5_dir=$(dirname $(ldconfig -p | grep libmlx5.so.1 | head -n1 | awk '{print $NF}')) + +export NVCC_APPEND_FLAGS="--threads 4" + +# Custom NVSHMEM +curl -fsSL https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz | tar xz +pushd nvshmem_src +curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/raw/$DEEP_EP_COMMIT/third-party/nvshmem.patch | patch -p1 +sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i src/CMakeLists.txt +ln -s libmlx5.so.1 "$libmlx5_dir/libmlx5.so" +cmake -S . -B build \ + -DCMAKE_INSTALL_PREFIX=/opt/custom_nvshmem \ + -DGDRCOPY_HOME=/usr/include \ + -DNVSHMEM_SHMEM_SUPPORT=0 \ + -DNVSHMEM_UCX_SUPPORT=0 \ + -DNVSHMEM_USE_NCCL=0 \ + -DNVSHMEM_MPI_SUPPORT=0 \ + -DNVSHMEM_IBGDA_SUPPORT=1 \ + -DNVSHMEM_PMIX_SUPPORT=0 \ + -DNVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + -DNVSHMEM_USE_GDRCOPY=1 \ + -DCMAKE_CUDA_ARCHITECTURES="90-real;100-real;120-real" \ + -DNVSHMEM_BUILD_TESTS=0 \ + -DNVSHMEM_BUILD_EXAMPLES=0 +cmake --build build -j`nproc` +make -C build install +popd + +# DeepEP +curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/archive/$DEEP_EP_COMMIT.tar.gz | tar xz +TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" NVSHMEM_DIR=/opt/custom_nvshmem pip install -v --no-cache-dir ./DeepEP-$DEEP_EP_COMMIT + +# Clean up +rm -r nvshmem_src +rm "$libmlx5_dir/libmlx5.so" +rm -r DeepEP-$DEEP_EP_COMMIT diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 39df9c5fa65..37d20dea4f8 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac // Container configuration // available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/ // [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id] -LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" -LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420" -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420" +LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792" +LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792" +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506111045-4792" +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506111045-4792" // TODO: Move common variables to an unified location BUILD_CORES_REQUEST = "8" diff --git a/jenkins/controlCCache.groovy b/jenkins/controlCCache.groovy index c41e09967e6..30788b251ab 100644 --- a/jenkins/controlCCache.groovy +++ b/jenkins/controlCCache.groovy @@ -1,7 +1,7 @@ import java.lang.InterruptedException -DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" +DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792" def createKubernetesPodConfig(image, arch = "amd64") { diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f5a3d5f4199..24802ec2e0f 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -84,6 +84,9 @@ class ModelConfig(Generic[TConfig]): # If true, enable min-latency mode. Currently only used for Llama4. enable_min_latency: bool = False + # Allow models to select op according to whether CUDA Graphs are used. + use_cuda_graph: bool = False + extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) _frozen: bool = field(default=False, init=False, repr=False) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index f5d3417f88f..6659f7c1aed 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -38,7 +38,6 @@ from tqdm import tqdm from transformers import PretrainedConfig -from tensorrt_llm._mnnvl_utils import MnnvlMemory from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.models.modeling_utils import QuantConfig @@ -351,10 +350,6 @@ def __init__(self, config = model_config.pretrained_config self.top_k = top_k self.use_dp = model_config.mapping.enable_attention_dp - self.enable_alltoall = Deepseekv3MoE.should_enable_alltoall( - model_config, top_k) - if self.enable_alltoall: - MnnvlMemory.initialize() self.gate = DeepseekV3Gate( hidden_size, num_experts, @@ -377,7 +372,6 @@ def __init__(self, model_config=model_config, override_quant_config=override_quant_config, aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], - enable_alltoall=self.enable_alltoall, layer_idx=layer_idx) self.mapping = model_config.mapping @@ -443,25 +437,6 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int, return shared_tp_size, shared_output_scale - @staticmethod - def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool: - if not model_config.mapping.enable_attention_dp: - return False - - if model_config.mapping.tp_size == 1: - return False - - if not MnnvlMemory.supports_mnnvl(): - return False - - if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": - return False - - if model_config.mapping.moe_ep_size <= top_k: - return False - - return True - def compute_routed_output(self, hidden_states, hidden_states_fp4, all_rank_num_tokens, do_finalize): # max-throughput @@ -469,7 +444,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, if self.use_dp and self.mapping.tp_size > 1: # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # to reduce allreduce BW - if disable_fp4_allgather() and not self.enable_alltoall: + if disable_fp4_allgather() and not self.experts.enable_alltoall: hidden_states = allgather(hidden_states, self.mapping, dim=0, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 5e6f67a8d42..e63c293c0b1 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -6,8 +6,6 @@ from tqdm import tqdm from transformers import Qwen3MoeConfig -from tensorrt_llm._mnnvl_utils import MnnvlMemory - from ..attention_backend import AttentionMetadata from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, allgather) @@ -91,10 +89,6 @@ def __init__( self.mapping = model_config.mapping self.allreduce = AllReduce(mapping=model_config.mapping, strategy=model_config.allreduce_strategy) - self.enable_alltoall = Qwen3MoE.should_enable_alltoall( - model_config, self.top_k) - if self.enable_alltoall: - MnnvlMemory.initialize() self.gate = Qwen3Gate( hidden_size=self.hidden_dim, @@ -117,25 +111,6 @@ def __init__( model_config=model_config, ) - @staticmethod - def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool: - if not model_config.mapping.enable_attention_dp: - return False - - if model_config.mapping.tp_size == 1: - return False - - if not MnnvlMemory.supports_mnnvl(): - return False - - if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": - return False - - if model_config.mapping.moe_ep_size <= top_k: - return False - - return True - def forward( self, hidden_states: torch.Tensor, @@ -151,7 +126,7 @@ def forward( if self.enable_attention_dp and self.mapping.tp_size > 1: # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # to reduce allreduce BW - if disable_fp4_allgather() and not self.enable_alltoall: + if disable_fp4_allgather() and not self.experts.enable_alltoall: hidden_states = allgather(hidden_states, self.mapping, dim=0, diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 98f59026dbc..c81f04affb5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -52,7 +52,6 @@ def create_moe( aux_stream: Optional[torch.cuda.Stream] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, apply_router_weight_on_input: bool = False, - enable_alltoall: bool = False, layer_idx: Optional[int] = None, ) -> MoE: moe_cls = get_moe_cls(model_config, override_quant_config) @@ -63,7 +62,6 @@ def create_moe( if moe_cls == TRTLLMGenFusedMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE." - assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE." return moe_cls( routing_method=routing_method, @@ -88,12 +86,10 @@ def create_moe( aux_stream=aux_stream, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, - enable_alltoall=enable_alltoall, layer_idx=layer_idx, ) elif moe_cls == VanillaMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE." - assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE." return moe_cls( routing_method=routing_method, diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py new file mode 100644 index 00000000000..e0c7c67748f --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -0,0 +1,209 @@ +# Adapted from +# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md +import weakref +from typing import List, Tuple, Union + +import torch + +from tensorrt_llm._utils import local_mpi_size, mpi_comm +from tensorrt_llm.mapping import Mapping + +try: + from deep_ep import Buffer + deep_ep_installed = True +except ModuleNotFoundError: + deep_ep_installed = False + + +class VariableLengthBuffer: + """ A wrapper of deep_ep.Buffer that accepts future size change + """ + + def __init__(self, mapping: Mapping): + self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank) + self.buffer = None + + def __del__(self): + self.comm.Free() + + def reserve(self, hidden_size: int, hidden_dtype: torch.dtype): + """ Ensure the buffer capacity is large enough. + + Reserve is a collective operation that requires all EP ranks to be sync + """ + # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests + num_nvl_bytes, num_rdma_bytes = 0, 0 + hidden_bytes = hidden_size * max(hidden_dtype.itemsize, + torch.bfloat16.itemsize) + world_size = self.comm.Get_size() + for config in (Buffer.get_dispatch_config(world_size), + Buffer.get_combine_config(world_size)): + num_nvl_bytes = max( + config.get_nvl_buffer_size_hint(hidden_bytes, world_size), + num_nvl_bytes) + num_rdma_bytes = max( + config.get_rdma_buffer_size_hint(hidden_bytes, world_size), + num_rdma_bytes) + + # Allocate a buffer if not existed or not enough buffer size + if self.buffer is None or self.buffer.num_nvl_bytes < num_nvl_bytes or self.buffer.num_rdma_bytes < num_rdma_bytes: + if self.buffer is not None: + num_nvl_bytes = max(num_nvl_bytes, self.buffer.num_nvl_bytes) + num_rdma_bytes = max(num_rdma_bytes, self.buffer.num_rdma_bytes) + del self.buffer # Destruct before Construct + self.buffer = Buffer(None, + num_nvl_bytes, + num_rdma_bytes, + num_nvl_peers=local_mpi_size(), + comm=self.comm) + + def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + topk_idx: torch.Tensor, topk_weights: torch.Tensor, + num_experts: int) -> \ + Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]: + # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency + # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please + # refer to the docs of `Buffer.dispatch` + + # Calculate layout before actual dispatch + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \ + self.buffer.get_dispatch_layout(topk_idx, num_experts) + assert event.event is None + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph + # For more advanced usages, please refer to the docs of the `dispatch` function + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ + self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert) + assert event.event is None + + # For event management, please refer to the docs of the `EventOverlap` class + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle + + def combine(self, x: torch.Tensor, handle: Tuple) -> torch.Tensor: + # Do MoE combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_x, _, event = self.buffer.combine(x, handle) + assert event.event is None + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_x + + +class VariableLengthLowLatencyBuffer: + """ A wrapper of deep_ep.Buffer that accepts future size change + """ + + def __init__(self, mapping: Mapping): + self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank) + self.buffer = None + self.num_max_dispatch_tokens_per_rank = None + + def __del__(self): + self.comm.Free() + + def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int, + num_experts: int): + """ Ensure the buffer capacity is large enough. + + Reserve is a collective operation that requires all EP ranks to be sync + """ + # NOTES: the low-latency mode will consume much more space than the normal mode + # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 + world_size = self.comm.Get_size() + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank, hidden_size, world_size, + num_experts) + + # Allocate a buffer if not existed or not enough buffer size + if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes: + # NOTES: for best performance, the QP number **must** be equal to the number of the local experts + assert num_experts % world_size == 0 + del self.buffer # Destruct before Construct + self.buffer = Buffer(None, + 0, + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // world_size, + comm=self.comm) + + def low_latency_dispatch(self, hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int): + if self.num_max_dispatch_tokens_per_rank is None: + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank: + raise NotImplementedError( + "There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values" + ) + + # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) + recv_hidden_states, recv_expert_count, handle, event, hook = \ + self.buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=False) + assert event.event is None + assert hook is None + + # NOTES: the actual tensor will not be received only if you call `hook()`, + # it is useful for double-batch overlapping, but **without any SM occupation** + # If you don't want to overlap, please set `return_recv_hook=False` + # Later, you can use our GEMM library to do the computation with this specific format + return recv_hidden_states, recv_expert_count, handle + + def low_latency_combine(self, hidden_states: torch.Tensor, + topk_idx: torch.Tensor, topk_weights: torch.Tensor, + handle: Tuple): + # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay) + combined_hidden_states, event, hook = \ + self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle) + assert event.event is None + assert hook is None + + # NOTES: the same behavior as described in the dispatch kernel + return combined_hidden_states + + +class BufferPool: + """ A pool that allocates buffers on demand. + + Although the pool interface allows creating multiple buffers, the + current version of DeepEP supports at most one `deep_ep.Buffer` at a + time. Please ensure that all references to `VariableLengthBuffer` are + released before getting another buffer. + """ + + def __init__(self): + self.buffers: Map[Mapping, + weakref.ReferenceType[VariableLengthBuffer]] = {} + self.low_latency_buffers: Map[ + Mapping, + weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {} + + def get_buffer(self, mapping: Mapping) -> VariableLengthBuffer: + """ Get_buffer is a collective operation that requires all ranks to be sync + """ + if mapping in self.buffers and self.buffers[mapping]() is not None: + buffer = self.buffers[mapping]() + else: + buffer = VariableLengthBuffer(mapping) + self.buffers[mapping] = weakref.ref(buffer) + return buffer + + def get_low_latency_buffer( + self, mapping: Mapping) -> VariableLengthLowLatencyBuffer: + """ Get_low_latency_buffer is a collective operation that requires all ranks to be sync + """ + if mapping in self.low_latency_buffers and self.low_latency_buffers[ + mapping]() is not None: + buffer = self.low_latency_buffers[mapping]() + else: + buffer = VariableLengthLowLatencyBuffer(mapping) + self.low_latency_buffers[mapping] = weakref.ref(buffer) + return buffer + + +# The default pool +# You may create own pools for better resource management. +buffer_pool = BufferPool() diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index fac302b899f..291cb17fba2 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,16 +1,19 @@ import os +from enum import IntEnum from typing import Dict, List, Optional, Tuple, Union import torch -from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo +from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo from tensorrt_llm._utils import logger +from tensorrt_llm.mapping import Mapping from ...distributed import allgather, reducescatter from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather, reswizzle_sf, swizzle_sf, unswizzle_sf) +from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import MoE from .moe_load_balancer import get_moe_load_balancer from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, @@ -20,6 +23,18 @@ from .routing import BaseMoeRoutingMethod +# The type of alltoall method +class AlltoallMethodType(IntEnum): + # Not available + NotEnabled = 0 + # MNNVL + MNNVL = 1 + # DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode + DeepEP = 2 + # DeepEP low latency: CUDA Graphs are supported, IBGDA is required + DeepEPLowLatency = 3 + + class CutlassFusedMoE(MoE): """ Fused Mixture of Experts (MoE) Layer with performance tuning. @@ -33,7 +48,6 @@ class CutlassFusedMoE(MoE): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. - enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter MoE torch custom op: In min-latency mode: @@ -82,7 +96,6 @@ def __init__( weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, apply_router_weight_on_input: bool = False, - enable_alltoall: bool = False, layer_idx: Optional[int] = None, ): @@ -176,7 +189,12 @@ def __init__( self.has_been_profiled = False self.has_been_profiled_min_latency = False - self.enable_alltoall = enable_alltoall + self.alltoall_method_type = self.select_alltoall_method_type( + model_config.mapping, routing_method.experts_per_token, dtype, + model_config.use_cuda_graph) + logger.info_once( + f"CutlassFusedMoE selects alltoall_method_type {self.alltoall_method_type!r}", + key="alltoall_method_type") self.use_postquant_alltoall = False if self.enable_alltoall: assert self.use_dp and self.parallel_size > 1,\ @@ -185,8 +203,25 @@ def __init__( self.use_postquant_alltoall = (os.environ.get( "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1") and qm.has_nvfp4() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) if enable_alltoall else None + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP: + self.deep_ep_buffer = buffer_pool.get_buffer( + model_config.mapping) + self.deep_ep_buffer.reserve(hidden_size, dtype) + elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + self.deep_ep_max_num_tokens = min(model_config.max_num_tokens, + self.moe_max_num_tokens) + self.deep_ep_buffer = buffer_pool.get_low_latency_buffer( + model_config.mapping) + self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens, + hidden_size, self.num_slots) + else: + raise NotImplementedError( + f"Not available alltoall method type: {alltoall_method_type!r}" + ) # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input @@ -215,12 +250,48 @@ def _check_configs(self): f"unsupported quantization mode: {self.quant_config.quant_mode}" ) + @staticmethod + def select_alltoall_method_type(mapping: Mapping, top_k: int, + dtype: torch.dtype, + use_cuda_graph: bool) -> AlltoallMethodType: + if not mapping.enable_attention_dp: + return AlltoallMethodType.NotEnabled + + if mapping.tp_size == 1: + return AlltoallMethodType.NotEnabled + + if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": + return AlltoallMethodType.NotEnabled + + if mapping.moe_ep_size <= top_k: + return AlltoallMethodType.NotEnabled + + if MnnvlMemory.supports_mnnvl(): + return AlltoallMethodType.MNNVL + + if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1": + if deep_ep_installed and dtype == torch.bfloat16: + if use_cuda_graph: + # Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs. + return AlltoallMethodType.DeepEPLowLatency + else: + # Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster. + return AlltoallMethodType.DeepEP + + return AlltoallMethodType.NotEnabled + @property def has_w4afp8(self): assert self._weights_created return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group( ) + @property + def enable_alltoall(self): + """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter + """ + return self.alltoall_method_type != AlltoallMethodType.NotEnabled + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -311,10 +382,6 @@ def forward_chunk( # TODO: remove this once we have correct fusedmoe kernel ready token_final_scales = None - token_count = x.shape[0] - - alltoall_info = None - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( ) and is_first_call: self.layer_load_balancer.maybe_cudagraph_done_wait() @@ -334,13 +401,58 @@ def forward_chunk( ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) token_selected_experts_for_statistic = token_selected_experts if need_statistic else None + if self.enable_alltoall: - x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \ - self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, - x, - token_selected_slots, - token_final_scales, - token_selected_experts_for_statistic) + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + token_count = x.shape[0] + alltoall_info = None + x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \ + self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, + x, + token_selected_slots, + token_final_scales, + token_selected_experts_for_statistic) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP: + if not self.use_postquant_alltoall: + x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ + self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots) + elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + if not self.use_postquant_alltoall: + deep_ep_topk_idx = token_selected_slots.to(torch.int64) + deep_ep_topk_weights = token_final_scales + x, recv_expert_count, deep_ep_handle = \ + self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots) + # x shape: [#local experts, #max recv tokens, hidden_size] + # recv_expert_count shape: [#local experts] + + # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP + # TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API + mask = torch.arange( + x.shape[1], dtype=torch.int32, device=x.device).expand( + x.shape[0], + x.shape[1]) < recv_expert_count.unsqueeze(1) + token_selected_slots = torch.full( + (x.shape[0], x.shape[1], self.routing_method.top_k), + self.num_slots, + dtype=torch.int32, + device=x.device) + token_selected_slots[:, :, 0] = torch.where( + mask, + torch.arange( + x.shape[0] * self.mapping.moe_ep_rank, + x.shape[0] * (self.mapping.moe_ep_rank + 1), + dtype=torch.int32, + device=x.device).unsqueeze(1), self.num_slots) + x = x.view(x.shape[0] * x.shape[1], x.shape[2]) + token_selected_slots = token_selected_slots.view( + x.shape[0], self.routing_method.top_k) + token_final_scales = torch.ones_like( + token_selected_slots, dtype=token_final_scales.dtype) + else: + raise NotImplementedError( + f"Not available alltoall method type: {alltoall_method_type!r}" + ) + x_sf = None if self.has_any_quant: if self.has_fp8_qdq: @@ -414,8 +526,56 @@ def forward_chunk( quant_scales = self.quant_scales if self.use_postquant_alltoall: - x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col, - alltoall_info) + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + x, x_sf = self.alltoall_postquant_dispatch( + x, x_sf, x_row, x_col, alltoall_info) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP: + if x_sf is not None: + if self.has_nvfp4: + x_sf = unswizzle_sf(x_sf, x_row, x_col, + self.scaling_vector_size) + # Adapter between `x_sf` and DeepEP + # TODO: remove the adapter by adding dtype support to DeepEP + x_sf_dtype = x_sf.dtype + x_sf = x_sf.view(torch.float32) + (x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ + self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots) + if x_sf is not None: + x_sf = x_sf.view(x_sf_dtype) + if self.has_nvfp4: + x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, + self.scaling_vector_size) + elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + raise NotImplementedError( + "Not implemented postquant for DeepEPLowLatency, please set TRTLLM_MOE_POST_QUANT_ALLTOALLV=0" + ) + else: + raise NotImplementedError( + f"Not available alltoall method type: {alltoall_method_type!r}" + ) + + if self.enable_alltoall: + # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP + # TODO: remove the adapter by changing APIs + if self.alltoall_method_type == AlltoallMethodType.DeepEP: + token_selected_slots = recv_topk_idx.to(torch.int32) + mask = token_selected_slots == -1 + token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank + token_selected_slots[mask] = self.num_slots + num_recv_token_is_zero = x.shape[0] == 0 + if x.shape[0] == 0: + x = torch.zeros((1, x.shape[1]), + dtype=x.dtype, + device=x.device) + token_selected_slots = torch.full( + (1, token_selected_slots.shape[1]), + self.num_slots, + dtype=token_selected_slots.dtype, + device=token_selected_slots.device) + token_final_scales = torch.ones( + (1, token_final_scales.shape[1]), + dtype=token_final_scales.dtype, + device=token_final_scales.device) final_hidden_states = torch.ops.trtllm.fused_moe( x, @@ -452,9 +612,25 @@ def forward_chunk( final_hidden_states = final_hidden_states[0] if self.enable_alltoall: - final_hidden_states = self.alltoall_combine(final_hidden_states, - alltoall_info, - token_count) + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + final_hidden_states = self.alltoall_combine( + final_hidden_states, alltoall_info, token_count) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP: + if num_recv_token_is_zero: + final_hidden_states = final_hidden_states[:0] + final_hidden_states = self.deep_ep_buffer.combine( + final_hidden_states, deep_ep_handle) + elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + final_hidden_states = self.deep_ep_buffer.low_latency_combine( + final_hidden_states.view( + self.expert_size_per_partition, + self.deep_ep_max_num_tokens * self.mapping.moe_ep_size, + final_hidden_states.shape[1]), deep_ep_topk_idx, + deep_ep_topk_weights, deep_ep_handle) + else: + raise NotImplementedError( + f"Not available alltoall method type: {alltoall_method_type!r}" + ) if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( ) and is_last_call: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 616141dfd04..52d0af71d3c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. - enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter MoE torch custom op: Only support min-latency mode now (SM100 Blackwell only). diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index f87647ce511..3249bac979b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -89,8 +89,6 @@ def __init__( if model_config.moe_max_num_tokens is not None else max_num_tokens) - self.enable_alltoall = False - self._weights_created = False if not model_config.skip_create_weights_in_init: self.create_weights() @@ -458,7 +456,7 @@ def reducescatter_or_allreduce( use_dp_padding: Optional[bool] = None, ): outputs = inputs - if self.parallel_size > 1 and not self.enable_alltoall: + if self.parallel_size > 1: if self.use_dp: outputs = reducescatter( inputs, diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index d305a3b763e..b90cdbe3001 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -27,7 +27,6 @@ class MoE(nn.Module): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. - enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter """ def __init__( @@ -123,3 +122,9 @@ def has_nvfp4(self): assert self._weights_created return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( ) + + @property + def enable_alltoall(self): + """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter + """ + return False diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f6a3d1e420b..acf5f037ec3 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1003,6 +1003,7 @@ def _load_model(self, checkpoint_dir, trust_remote_code=True, enable_min_latency=self.pytorch_backend_config.enable_min_latency, + use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, spec_config=self.spec_config, max_num_tokens=max_num_tokens, moe_max_num_tokens=moe_max_num_tokens, diff --git a/tensorrt_llm/logger.py b/tensorrt_llm/logger.py index 1773682d3d9..1229b7a198d 100644 --- a/tensorrt_llm/logger.py +++ b/tensorrt_llm/logger.py @@ -75,6 +75,9 @@ def __init__(self): self._polygraphy_logger.module_severity = severity_map[ min_severity][2] + # For log_once + self._appeared_keys = set() + if invalid_severity: self.warning( f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead" @@ -109,23 +112,44 @@ def log(self, severity, *msg): parts.extend(map(str, msg)) self._func_wrapper(severity)(" ".join(parts)) + def log_once(self, severity, *msg, key): + if key not in self._appeared_keys: + self._appeared_keys.add(key) + self.log(severity, *msg) + def critical(self, *msg): self.log(self.INTERNAL_ERROR, *msg) + def critical_once(self, *msg, key): + self.log_once(self.INTERNAL_ERROR, *msg, key=key) + fatal = critical + fatal_once = critical_once def error(self, *msg): self.log(self.ERROR, *msg) + def error_once(self, *msg, key): + self.log_once(self.ERROR, *msg, key=key) + def warning(self, *msg): self.log(self.WARNING, *msg) + def warning_once(self, *msg, key): + self.log_once(self.WARNING, *msg, key=key) + def info(self, *msg): self.log(self.INFO, *msg) + def info_once(self, *msg, key): + self.log_once(self.INFO, *msg, key=key) + def debug(self, *msg): self.log(self.VERBOSE, *msg) + def debug_once(self, *msg, key): + self.log_once(self.VERBOSE, *msg, key=key) + @property def level(self) -> str: return self._min_severity diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index c827b5a5657..1d1785282c2 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -54,6 +54,8 @@ l0_dgx_h100: auto_trigger: deepseek tests: - unittest/_torch/multi_gpu_modeling -k "deepseek" + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 49aac4ad717..074a25e901a 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -2,6 +2,7 @@ import sys from itertools import product from typing import Dict, List, Optional +from unittest import mock import cloudpickle import pytest @@ -19,6 +20,8 @@ DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, VanillaMoE) +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import \ + AlltoallMethodType from tensorrt_llm._torch.modules.gated_mlp import GatedMLP from tensorrt_llm._utils import mpi_rank from tensorrt_llm.mapping import Mapping @@ -123,6 +126,111 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size): assert r is None +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize("alltoall_method_type", [ + AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP, + AlltoallMethodType.DeepEPLowLatency +], + ids=lambda s: s.name) +def test_fused_moe_alltoall(alltoall_method_type): + world_size = 4 + dtype = torch.bfloat16 + HIDDEN_SIZE = 2560 + INTERMEDIATE_SIZE = 1536 + NUM_EXPERTS = 72 + TOP_K = 6 + MAX_NUM_TOKENS = 2048 + + def per_rank_test_fused_moe_alltoall(job_id): + routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + mapping = Mapping(world_size=world_size, + rank=mpi_rank(), + tp_size=world_size, + moe_ep_size=world_size, + moe_tp_size=1, + enable_attention_dp=True) + torch.cuda.set_device(mapping.rank) + torch.manual_seed(mapping.rank) + + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype) + w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype) + w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype) + torch.nn.init.xavier_uniform_(w1_weight) + torch.nn.init.xavier_uniform_(w2_weight) + torch.nn.init.xavier_uniform_(w3_weight) + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + with mock.patch.object(CutlassFusedMoE, + "select_alltoall_method_type", + return_value=alltoall_method_type): + alltoall_model = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS), + ) + alltoall_model.to("cuda") + alltoall_model.load_weights([weights]) + with mock.patch.object(CutlassFusedMoE, + "select_alltoall_method_type", + return_value=AlltoallMethodType.NotEnabled): + ref_model = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS), + ) + ref_model.to("cuda") + ref_model.load_weights([weights]) + + # Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods + m = MAX_NUM_TOKENS + while m >= 1: + x = torch.randn((m, HIDDEN_SIZE), dtype=dtype).cuda() + router_logits = torch.randn((m, NUM_EXPERTS), dtype=dtype).cuda() + all_rank_num_tokens = [m] * mapping.world_size + + with torch.inference_mode(): + output = alltoall_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=False) + ref_output = ref_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=False) + + # Evaluate outputs + torch.testing.assert_close(output, + ref_output, + rtol=0.05, + atol=0.003) + m //= 2 + + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map(per_rank_test_fused_moe_alltoall, + range(world_size)) + for r in results: + assert r is None + + @skip_pre_hopper @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_moe_fp8(dtype):