From 4f5ae8d9113c3e048cb413d4a86437f7e2535b17 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 24 Sep 2025 21:03:57 +0000 Subject: [PATCH 1/8] Hybrid DeepEP Signed-off-by: Bill Nell --- .../device_communicators/all2all.py | 28 +- .../device_communicators/cuda_communicator.py | 4 + .../model_executor/layers/fused_moe/config.py | 9 + .../deepep_hybrid_prepare_finalze.py | 337 ++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 32 +- 5 files changed, 406 insertions(+), 4 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 661ed939608a..95c5d2bffb51 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -248,6 +248,32 @@ def set_num_sms(self, num_sms: int): deep_ep.Buffer.set_num_sms(num_sms) +class DeepEPHybridAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP Hybrid kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs(self, kwargs) -> dict[Any, Any]: + extra_kwargs = dict(group=self.cpu_group, + num_of_ranks_per_node = 32, + num_sms_preprocessing_api = 32, + num_sms_dispatch_api = 32, + num_sms_combine_api = 32 + ) + return {**kwargs, **extra_kwargs} + + def get_handle(self, kwargs): + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.HybridEpBuffer) + return handle + + class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): """ All2All communication based on DeepEP Low-Latency kernels. @@ -395,4 +421,4 @@ def cleanup(self): self.workspace_tensor = None self.prepare_workspace_tensor = None self.mapping = None - self.initialized = False \ No newline at end of file + self.initialized = False diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index bab372b722db..27b296f40ebd 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -114,6 +114,10 @@ def __init__(self, from .all2all import DeepEPLLAll2AllManager self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) logger.info("Using DeepEP Low-Latency all2all manager.") + elif all2all_backend == "deepep_hybrid": + from .all2all import DeepEPHybridAll2AllManager + self.all2all_manager = DeepEPHybridAll2AllManager(self.cpu_group) + logger.info("Using DeepEP Hybrid all2all manager.") elif all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager self.all2all_manager = FlashInferAllToAllManager( diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 34bfe1c16aac..cdb314ead8c3 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -618,6 +618,11 @@ def use_deepep_ll_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + @property + def use_deepep_hybrid_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_hybrid") + @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -794,6 +799,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_deepep_hybrid_kernels(self): + return self.moe_parallel_config.use_deepep_hybrid_kernels + @property def use_flashinfer_cutlass_kernels(self): """ diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py new file mode 100644 index 000000000000..df84b78d2e77 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional, Union + +import deep_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) +from vllm.utils import round_up +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm, + dbo_switch_to_compute, dbo_switch_to_compute_sync, + dbo_yield_and_switch_from_comm_to_compute, + dbo_yield_and_switch_from_compute_to_comm) + + +class DeepEPHybridPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using DeepEP High-Throughput kernels. + """ + + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int, + dtype: torch.dtype) -> int: + # Round up hidden size so it is compatible with DeepEP High Throughput + # kernels. + # DeepEP intranode kernels make copies in units of, + # 32(warp-size) int4 elements. Round up hidden size to respect this. + # For example, an input hidden size of 2880 with dtype torch.bfloat16 + # will be rounded up to 3072. + hidden_size_bytes = hidden_size * dtype.itemsize + xfer_atom_size = 512 # 32 * 16 (size(int4)) + if hidden_size_bytes % xfer_atom_size == 0: + return hidden_size + + hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) + return hidden_size_bytes // dtype.itemsize + + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, + dp_size: int, rank_expert_offset: int): + super().__init__() + self.buffer = buffer + self.num_dispatchers_ = num_dispatchers + self.dp_size = dp_size + self.rank_expert_offset = rank_expert_offset + self.async_prepare = True + + # The dispatch function returns a handle that the combine function + # requires. Under DBO microbatching we must track one handle per + # micro-batch to avoid races between threads. + self.handles = [None, None] + + # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 + self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int64 + + def _get_dispatch_config(self) -> Optional[deep_ep.Config]: + if self.num_dispatchers_ not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_) + + def _get_combine_config(self) -> Optional[deep_ep.Config]: + if self.num_dispatchers_ not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) + + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: + + has_scales = token_scales is not None + + # We yield before launching the dispatch kernel since the dispatch + # kernel will block the CPU so we want to queue up all the compute + # for the other ubatch before the dispatch kernel starts. + dbo_yield_and_switch_from_compute_to_comm() + + (num_tokens_per_rank, num_tokens_per_rdma_rank, + dispatch_expert_num_tokens, is_token_in_rank, + event) = self.buffer.get_dispatch_layout( + topk_idx=rank_topk_ids, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + + # dispatched_token, + # dispatched_probs, + # dispatched_scaling_factor, + # num_of_tokens_for_experts_tensor, + # local_expert_routing_map, + # handle, + + ( + token_data, expert_probs, token_scales, + expert_num_tokens_per_expert, local_expert_routing_map, + handle + ) = self.buffer.dispatch( + tokens=tokens, + scaling_factor=token_scales, + topk_idx=rank_topk_ids, + topk_weights=rank_topk_weights, + routing_map=None, # None = generated dynamically + handle=None, + num_of_tokens_for_experts=-1, #?? + async_mode=self.async_prepare and not dbo_enabled(), + ) + + # record the handle for this ubatch + a2a_idx = dbo_current_ubatch_id() + self.handles[a2a_idx] = handle + + dbo_switch_to_compute_sync() + + return lambda: self._receiver( + event, + token_data, + token_scales, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if event.event is not None: + event.current_stream_wait() + + if has_scales: + expert_x, expert_x_scale = token_data + else: + expert_x, expert_x_scale = token_data, None + + # The existing MOE kernels assume that all entries of topk_ids are + # valid. To that effect, set the -1s in expert_topk_ids to some expert + # outside this rank so the expert_map can remap it to -1 when safe. + # With Expert Parallel, the experts are divided amongst the rank + # sequentially. For rank 0, set it to num_experts - 1 and for all other + # ranks set it to 0 as we know that expert_map will have a -1 in those + # regions for those ranks. + # + # DeepEP's topk_ids output refers to the local experts directly. Offset + # the topk_ids to move it back to the global experts space so it aligns + # with existing vLLM interfaces. + assert expert_topk_ids is not None + expert_topk_ids = torch.where( + expert_topk_ids == -1, + num_experts - 1 if self.rank_expert_offset == 0 else 0, + expert_topk_ids + self.rank_expert_offset) + + # Makes a GPU-CPU copy. + # TODO (varun): Maybe it is better to re-compute the expert_num_tokens + # on GPU. + expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( + expert_num_tokens_per_expert_list, device=expert_x.device) + + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: + # Quantize after dispatch. + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, + expert_topk_weights) + + def supports_async(self) -> bool: + return False # combine async not supported + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.ReceiverType: + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + if quant_config.is_block_quantized: + # Quant and Dispatch + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + a1_post_scale = None + else: + a1q = a1 + a1q_scale = None + a1_post_scale = quant_config.a1_scale + + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, + expert_map, apply_router_weight_on_input, + quant_config) + return receiver() + + def _finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + do_async: bool, + ) -> Optional[Callable]: + handle = self.handle + assert handle is not None + + # fused_expert_output can have 0 tokens - This happens when none of the + # tokens from the all2all reach this EP rank. + if fused_expert_output.numel() != 0: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + fused_expert_output = weight_and_reduce_impl.apply( + output=None, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + dbo_yield_and_switch_from_compute_to_comm() + combined_x, _, event = self.buffer.combine( + tensor=fused_expert_output, + probs=probs, + handle=handle, + ) + + # TODO(lucas): support this case with the refactored modular kernel + # Respect inplace outputs. + # apply weights??? + output.copy_(combined_x, non_blocking=True) + return None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + receiver = self._finalize(output, fused_expert_output, topk_weights, + topk_ids, apply_router_weight_on_input, + weight_and_reduce_impl, True) + assert receiver is not None + return receiver + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize(output, fused_expert_output, topk_weights, topk_ids, + apply_router_weight_on_input, weight_and_reduce_impl, + False) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b68190e5d1c1..0ea5bd222ba3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -50,6 +50,7 @@ pplx_hidden_dim_scale_bytes) if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize + from .deepep_hybrid_prepare_finalize import DeepEPHybridPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) else: @@ -203,6 +204,27 @@ def _maybe_make_prepare_finalize( num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, ) + elif moe.use_deepep_hybrid_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + + use_fp8 = quant_config.use_fp8_w8a8 if quant_config is not None else False + + all_to_all_args = dict( + hidden_dim=moe.hidden_dim, + max_num_of_tokens_per_dp_rank=moe.max_num_tokens, + num_local_experts=(moe.num_experts // all2all_manager.world_size), + num_experts=moe.num_experts, + use_fp8=use_fp8, + ) + + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = DeepEPHybridPrepareAndFinalize( + handle, + num_dispatchers=all2all_manager.world_size, + dp_size=all2all_manager.dp_world_size, + rank_expert_offset=all2all_manager.rank * + moe.num_local_experts, + ) return prepare_finalize @@ -1145,6 +1167,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_deepep_hybrid_kernels(self): + return self.moe_parallel_config.use_deepep_hybrid_kernels + @property def use_flashinfer_cutlass_kernels(self): return (self.moe_quant_config is not None @@ -1693,7 +1719,7 @@ def must_reduce_shared_expert_outputs(self) -> bool: Therefore it is required that we reduce the shared_experts output early. """ - return (self.use_pplx_kernels or self.use_deepep_ht_kernels + return (self.use_pplx_kernels or self.use_deepep_ht_kernels or self.use_deepep_hybrid_kernels or self.use_deepep_ll_kernels) def maybe_all_reduce_tensor_model_parallel( @@ -1701,8 +1727,7 @@ def maybe_all_reduce_tensor_model_parallel( """ The pplx combine kernel reduces across GPU ranks by default. """ - if (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels): + if self.must_reduce_shared_expert_outputs(): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -1895,6 +1920,7 @@ def forward_impl( do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_deepep_hybrid_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) # If there are shared experts but we are not using a modular kernel, the From af7c901f40ebb2424e0c2e7bd9db5b0bf042e2bb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 25 Sep 2025 19:08:33 -0400 Subject: [PATCH 2/8] wip Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/mk_objects.py | 10 + .../deepep_hybrid_prepare_finalze.py | 251 ++++-------------- .../layers/fused_moe/modular_kernel.py | 16 +- 3 files changed, 71 insertions(+), 206 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 57a1da7b4b1a..38db3268013d 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -189,6 +189,8 @@ def expert_info(kind) -> ExpertInfo: DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_hybrid_prepare_finalize import ( # noqa: E501 + DeepEPHybridPrepareAndFinalize) register_prepare_and_finalize( DeepEPHTPrepareAndFinalize, @@ -206,6 +208,14 @@ def expert_info(kind) -> ExpertInfo: backend="deepep_low_latency", ) + register_prepare_and_finalize( + DeepEPHybridPrepareAndFinalize, + batched_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_hybrid", + ) + if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py index df84b78d2e77..e0954d12ba3c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py @@ -41,19 +41,14 @@ def maybe_roundup_layer_hidden_size(hidden_size: int, hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) return hidden_size_bytes // dtype.itemsize - def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, + def __init__(self, buffer: deep_ep.HybridBuffer, num_dispatchers: int, dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset - self.async_prepare = True - - # The dispatch function returns a handle that the combine function - # requires. Under DBO microbatching we must track one handle per - # micro-batch to avoid races between threads. - self.handles = [None, None] + self.handle = None # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] @@ -81,138 +76,10 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) - def _do_dispatch( - self, - tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, - num_experts: int, - a1_scale: Optional[torch.Tensor], - quant_config: FusedMoEQuantConfig, - ) -> Callable: - - has_scales = token_scales is not None - - # We yield before launching the dispatch kernel since the dispatch - # kernel will block the CPU so we want to queue up all the compute - # for the other ubatch before the dispatch kernel starts. - dbo_yield_and_switch_from_compute_to_comm() - - (num_tokens_per_rank, num_tokens_per_rdma_rank, - dispatch_expert_num_tokens, is_token_in_rank, - event) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, - num_experts=num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False) - - # dispatched_token, - # dispatched_probs, - # dispatched_scaling_factor, - # num_of_tokens_for_experts_tensor, - # local_expert_routing_map, - # handle, - - ( - token_data, expert_probs, token_scales, - expert_num_tokens_per_expert, local_expert_routing_map, - handle - ) = self.buffer.dispatch( - tokens=tokens, - scaling_factor=token_scales, - topk_idx=rank_topk_ids, - topk_weights=rank_topk_weights, - routing_map=None, # None = generated dynamically - handle=None, - num_of_tokens_for_experts=-1, #?? - async_mode=self.async_prepare and not dbo_enabled(), - ) - - # record the handle for this ubatch - a2a_idx = dbo_current_ubatch_id() - self.handles[a2a_idx] = handle - - dbo_switch_to_compute_sync() - - return lambda: self._receiver( - event, - token_data, - token_scales, - expert_topk_ids, - num_experts, - expert_num_tokens_per_expert_list, - expert_topk_weights, - a1_scale, - quant_config, - ) - - def _receiver( - self, - event: deep_ep.EventOverlap, - has_scales: bool, - token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], - expert_topk_ids: Optional[torch.Tensor], - num_experts: int, - expert_num_tokens_per_expert_list: list[int], - expert_topk_weights: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - if event.event is not None: - event.current_stream_wait() - - if has_scales: - expert_x, expert_x_scale = token_data - else: - expert_x, expert_x_scale = token_data, None - - # The existing MOE kernels assume that all entries of topk_ids are - # valid. To that effect, set the -1s in expert_topk_ids to some expert - # outside this rank so the expert_map can remap it to -1 when safe. - # With Expert Parallel, the experts are divided amongst the rank - # sequentially. For rank 0, set it to num_experts - 1 and for all other - # ranks set it to 0 as we know that expert_map will have a -1 in those - # regions for those ranks. - # - # DeepEP's topk_ids output refers to the local experts directly. Offset - # the topk_ids to move it back to the global experts space so it aligns - # with existing vLLM interfaces. - assert expert_topk_ids is not None - expert_topk_ids = torch.where( - expert_topk_ids == -1, - num_experts - 1 if self.rank_expert_offset == 0 else 0, - expert_topk_ids + self.rank_expert_offset) - - # Makes a GPU-CPU copy. - # TODO (varun): Maybe it is better to re-compute the expert_num_tokens - # on GPU. - expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( - expert_num_tokens_per_expert_list, device=expert_x.device) - - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 and quantize afterwards - if not quant_config.is_block_quantized: - # Quantize after dispatch. - expert_x_scale = None - if expert_x.numel() != 0: - expert_x, expert_x_scale = moe_kernel_quantize_input( - expert_x, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=False, - block_shape=quant_config.block_shape) - - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) - def supports_async(self) -> bool: return False # combine async not supported - def prepare_async( + def prepare( self, a1: torch.Tensor, topk_weights: torch.Tensor, @@ -221,7 +88,7 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.ReceiverType: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -247,30 +114,51 @@ def prepare_async( a1q_scale = None a1_post_scale = quant_config.a1_scale - return self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config) + ( + expert_x, expert_probs, expert_x_scale, + num_tokens_per_expert, local_expert_routing_map, + self.handle + ) = self.buffer.dispatch( + tokens=a1, + scaling_factor=a1q_scale, + topk_idx=topk_ids, + topk_weights=topk_weights, + routing_map=None, # None = generated dynamically + handle=None, + num_of_tokens_for_experts=-1, #?? + async_mode=False, + ) + + # Makes a GPU-CPU copy. + # TODO (varun): Maybe it is better to re-compute the expert_num_tokens + # on GPU. + expert_tokens_meta = mk.ExpertTokensMetadata( + num_tokens_per_expert, + None, #? num_tokens_per_expert.cpu(), + ) - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, - expert_map, apply_router_weight_on_input, - quant_config) - return receiver() + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: + # Quantize after dispatch. + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_post_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + + self.expert_probs = expert_probs + + return (expert_x, expert_x_scale, expert_tokens_meta, + topk_ids[local_expert_routing_map], + expert_probs) - def _finalize( + def finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -278,14 +166,10 @@ def _finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - do_async: bool, - ) -> Optional[Callable]: - handle = self.handle - assert handle is not None - + ) -> None: # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. - if fused_expert_output.numel() != 0: + if False and fused_expert_output.numel() != 0: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceContiguous() fused_expert_output = weight_and_reduce_impl.apply( @@ -295,43 +179,14 @@ def _finalize( topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, ) - dbo_yield_and_switch_from_compute_to_comm() - combined_x, _, event = self.buffer.combine( + + combined_x, _ = self.buffer.combine( tensor=fused_expert_output, - probs=probs, - handle=handle, + probs=self.expert_probs, + handle=self.handle, ) # TODO(lucas): support this case with the refactored modular kernel # Respect inplace outputs. # apply weights??? output.copy_(combined_x, non_blocking=True) - return None - - def finalize_async( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> Callable: - receiver = self._finalize(output, fused_expert_output, topk_weights, - topk_ids, apply_router_weight_on_input, - weight_and_reduce_impl, True) - assert receiver is not None - return receiver - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - self._finalize(output, fused_expert_output, topk_weights, topk_ids, - apply_router_weight_on_input, weight_and_reduce_impl, - False) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4ba14196682a..78e17796c460 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -240,16 +240,16 @@ def prepare_async( - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a callback or a hook callback pair that when invoked waits for - results from other workers and has the same return signature as + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as `prepare`, if a hook is returned this is more lightweight check that - the recv is complete without doing extra work (used by DBO, will be + the recv is complete without doing extra work (used by DBO, will be refactored in the very near future) - + e.g. ret = obj.prepare_async(...) - + if isinstance(ret, tuple): hook, receiver = ret hook() @@ -310,10 +310,10 @@ def finalize_async( - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation. - Returns a callback or a hook callback pair that when invoked waits for - results from other workers and has the same return signature as + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as `finalize`, if a hook is returned this is more lightweight check that - the recv is complete without doing extra work (used by DBO, will be + the recv is complete without doing extra work (used by DBO, will be refactored in the very near future) ret = obj.finalize_async(output, ...) From c7fc685abd650260f9d93a6c25ec2e768a04ca4d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 25 Sep 2025 19:27:12 -0400 Subject: [PATCH 3/8] wip Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/mk_objects.py | 32 +++++++++---------- ...e.py => deepep_hybrid_prepare_finalize.py} | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) rename vllm/model_executor/layers/fused_moe/{deepep_hybrid_prepare_finalze.py => deepep_hybrid_prepare_finalize.py} (98%) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 38db3268013d..ecfeee713223 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -192,21 +192,21 @@ def expert_info(kind) -> ExpertInfo: from vllm.model_executor.layers.fused_moe.deepep_hybrid_prepare_finalize import ( # noqa: E501 DeepEPHybridPrepareAndFinalize) - register_prepare_and_finalize( - DeepEPHTPrepareAndFinalize, - standard_format, - common_float_types, - blocked_quantization_support=True, - backend="deepep_high_throughput", - ) - - register_prepare_and_finalize( - DeepEPLLPrepareAndFinalize, - batched_format, - common_float_types, - blocked_quantization_support=True, - backend="deepep_low_latency", - ) + # register_prepare_and_finalize( + # DeepEPHTPrepareAndFinalize, + # standard_format, + # common_float_types, + # blocked_quantization_support=True, + # backend="deepep_high_throughput", + # ) + + # register_prepare_and_finalize( + # DeepEPLLPrepareAndFinalize, + # batched_format, + # common_float_types, + # blocked_quantization_support=True, + # backend="deepep_low_latency", + # ) register_prepare_and_finalize( DeepEPHybridPrepareAndFinalize, @@ -216,7 +216,7 @@ def expert_info(kind) -> ExpertInfo: backend="deepep_hybrid", ) -if has_pplx(): +if False and has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) register_prepare_and_finalize( diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py similarity index 98% rename from vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py rename to vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py index e0954d12ba3c..988a75168f56 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalze.py +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py @@ -41,7 +41,7 @@ def maybe_roundup_layer_hidden_size(hidden_size: int, hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) return hidden_size_bytes // dtype.itemsize - def __init__(self, buffer: deep_ep.HybridBuffer, num_dispatchers: int, + def __init__(self, buffer: deep_ep.HybridEpBuffer, num_dispatchers: int, dp_size: int, rank_expert_offset: int): super().__init__() self.buffer = buffer From 4da863f79691e50038c372219025507aa46e167e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 26 Sep 2025 18:25:24 +0000 Subject: [PATCH 4/8] fixes Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/mk_objects.py | 16 ++++++++-------- vllm/distributed/device_communicators/all2all.py | 8 ++++---- vllm/envs.py | 10 ++++++---- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index ecfeee713223..e62e49459753 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -146,13 +146,13 @@ def expert_info(kind) -> ExpertInfo: return info -register_prepare_and_finalize( - MoEPrepareAndFinalizeNoEP, - standard_format, - common_float_types, - blocked_quantization_support=True, - backend=None, -) +# register_prepare_and_finalize( +# MoEPrepareAndFinalizeNoEP, +# standard_format, +# common_float_types, +# blocked_quantization_support=True, +# backend=None, +# ) register_experts( BatchedTritonExperts, @@ -227,7 +227,7 @@ def expert_info(kind) -> ExpertInfo: backend="pplx", ) -if (has_flashinfer_cutlass_fused_moe() +if False and (has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100)): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 FlashInferExperts) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 95c5d2bffb51..17eace5f0d73 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -256,12 +256,12 @@ class DeepEPHybridAll2AllManager(DeepEPAll2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def _make_all2all_kwargs(self, kwargs) -> dict[Any, Any]: + def _make_all2all_kwargs(self, **kwargs) -> dict[Any, Any]: extra_kwargs = dict(group=self.cpu_group, - num_of_ranks_per_node = 32, - num_sms_preprocessing_api = 32, num_sms_dispatch_api = 32, - num_sms_combine_api = 32 + num_sms_combine_api = 32, + num_sms_preprocessing_api = 128, + nvlink_domain_size = None, ) return {**kwargs, **extra_kwargs} diff --git a/vllm/envs.py b/vllm/envs.py index 4797d96bb899..220c97268b46 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -156,6 +156,7 @@ VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput", "deepep_low_latency", + "deepep_hybrid", "allgather_reducescatter", "flashinfer_all2allv"] = \ "allgather_reducescatter" @@ -1214,10 +1215,11 @@ def get_vllm_port() -> Optional[int]: "VLLM_ALL2ALL_BACKEND": env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", ["naive", "pplx", - "deepep_high_throughput", - "deepep_low_latency", - "allgather_reducescatter", - "flashinfer_all2allv"]), + "deepep_high_throughput", + "deepep_low_latency", + "deepep_hybrid", + "allgather_reducescatter", + "flashinfer_all2allv"]), # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. # Both require compute capability 10.0 or above. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0ea5bd222ba3..63c2296beec2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -211,9 +211,9 @@ def _maybe_make_prepare_finalize( all_to_all_args = dict( hidden_dim=moe.hidden_dim, - max_num_of_tokens_per_dp_rank=moe.max_num_tokens, + max_num_of_tokens_per_rank=moe.max_num_tokens, num_local_experts=(moe.num_experts // all2all_manager.world_size), - num_experts=moe.num_experts, + num_of_experts=moe.num_experts, use_fp8=use_fp8, ) From a6eeb4b48e6847cabea470213266e66e14a28732 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 26 Sep 2025 20:18:22 +0000 Subject: [PATCH 5/8] hacking Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/mk_objects.py | 2 +- .../device_communicators/all2all.py | 5 +++-- .../deepep_hybrid_prepare_finalize.py | 22 +++++-------------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index e62e49459753..85febb8a712a 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -210,7 +210,7 @@ def expert_info(kind) -> ExpertInfo: register_prepare_and_finalize( DeepEPHybridPrepareAndFinalize, - batched_format, + standard_format, common_float_types, blocked_quantization_support=True, backend="deepep_hybrid", diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 17eace5f0d73..4b0e2b676e6b 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -261,16 +261,17 @@ def _make_all2all_kwargs(self, **kwargs) -> dict[Any, Any]: num_sms_dispatch_api = 32, num_sms_combine_api = 32, num_sms_preprocessing_api = 128, - nvlink_domain_size = None, + nvlink_domain_size = 2, # hack for now. dp world_size ) return {**kwargs, **extra_kwargs} def get_handle(self, kwargs): import deep_ep buffer_kwargs = self._make_all2all_kwargs(**kwargs) - logger.debug("DeepEP all2all args %s", buffer_kwargs) + logger.debug("DeepEP Hybrid all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.HybridEpBuffer) + logger.debug("DeepEP Hybrid constructed.") return handle diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py index 988a75168f56..0be954cb5e48 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py @@ -114,28 +114,20 @@ def prepare( a1q_scale = None a1_post_scale = quant_config.a1_scale + self.handle = None ( - expert_x, expert_probs, expert_x_scale, - num_tokens_per_expert, local_expert_routing_map, - self.handle + expert_x, expert_probs, expert_x_scale, _, _, self.handle ) = self.buffer.dispatch( - tokens=a1, + tensor=a1, scaling_factor=a1q_scale, topk_idx=topk_ids, topk_weights=topk_weights, routing_map=None, # None = generated dynamically handle=None, num_of_tokens_for_experts=-1, #?? - async_mode=False, ) - # Makes a GPU-CPU copy. - # TODO (varun): Maybe it is better to re-compute the expert_num_tokens - # on GPU. - expert_tokens_meta = mk.ExpertTokensMetadata( - num_tokens_per_expert, - None, #? num_tokens_per_expert.cpu(), - ) + expert_tokens_meta = None # Dispatch and Quant # DeepEP kernels only support dispatching block-quantized @@ -154,9 +146,7 @@ def prepare( self.expert_probs = expert_probs - return (expert_x, expert_x_scale, expert_tokens_meta, - topk_ids[local_expert_routing_map], - expert_probs) + return (expert_x, expert_x_scale, expert_tokens_meta, None, None) def finalize( self, @@ -182,7 +172,7 @@ def finalize( combined_x, _ = self.buffer.combine( tensor=fused_expert_output, - probs=self.expert_probs, + probs=self.expert_probs, # None? handle=self.handle, ) From 973bdf01b36517d10082c353b4e0b7d387558dda Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 26 Sep 2025 20:47:49 +0000 Subject: [PATCH 6/8] wip Signed-off-by: Bill Nell --- .../layers/fused_moe/deepep_hybrid_prepare_finalize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py index 0be954cb5e48..c59129784b06 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py @@ -49,6 +49,7 @@ def __init__(self, buffer: deep_ep.HybridEpBuffer, num_dispatchers: int, self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset self.handle = None + self.expert_probs = None # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] @@ -114,9 +115,8 @@ def prepare( a1q_scale = None a1_post_scale = quant_config.a1_scale - self.handle = None ( - expert_x, expert_probs, expert_x_scale, _, _, self.handle + expert_x, expert_probs, expert_x_scale, handle ) = self.buffer.dispatch( tensor=a1, scaling_factor=a1q_scale, @@ -126,7 +126,7 @@ def prepare( handle=None, num_of_tokens_for_experts=-1, #?? ) - + self.handle = handle expert_tokens_meta = None # Dispatch and Quant From b966d04159fb9c4f779ea6451c77f791910d0b7b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 29 Sep 2025 11:51:44 -0400 Subject: [PATCH 7/8] tweak test Signed-off-by: Bill Nell --- tests/kernels/moe/modular_kernel_tools/mk_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 85febb8a712a..c2eb64d3052c 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -184,7 +184,7 @@ def expert_info(kind) -> ExpertInfo: ) # Disable on blackwell for now -if has_deep_ep() and not current_platform.has_device_capability(100): +if has_deep_ep(): # and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 From 5a681204d43f0c7fdd9de0ee4c26d79c06d463bd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 1 Oct 2025 21:08:16 -0400 Subject: [PATCH 8/8] hacking Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/common.py | 6 +-- .../moe/modular_kernel_tools/mk_objects.py | 3 +- .../fused_moe/deepep_ht_prepare_finalize.py | 7 +++ .../deepep_hybrid_prepare_finalize.py | 48 ++++++++++++++++--- vllm/model_executor/layers/fused_moe/layer.py | 4 ++ .../layers/fused_moe/modular_kernel.py | 4 +- 6 files changed, 60 insertions(+), 12 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index b5fcc4cd70bf..57f5befa36f2 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -522,11 +522,11 @@ def make_modular_kernel( quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: - def next_power_of_2(x): + def next_power_of_2(x) -> int: import math if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return int(2**math.ceil(math.log2(x))) # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( @@ -542,7 +542,7 @@ def next_power_of_2(x): num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, - max_num_tokens=next_power_of_2(config.M), + max_num_tokens=max(128, next_power_of_2(config.M)), ) # make modular kernel diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index c2eb64d3052c..5a3e35540d01 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -76,6 +76,7 @@ class ExpertInfo: common_float_and_int_types = common_float_types + [torch.int8] nvfp4_types = ["nvfp4"] fp8_types = [torch.float8_e4m3fn] +fp8_bf16_types = [torch.float8_e4m3fn, torch.bfloat16] def register_prepare_and_finalize( @@ -211,7 +212,7 @@ def expert_info(kind) -> ExpertInfo: register_prepare_and_finalize( DeepEPHybridPrepareAndFinalize, standard_format, - common_float_types, + fp8_bf16_types, blocked_quantization_support=True, backend="deepep_hybrid", ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 9e9a9afc18a0..e9e85f304f6a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -132,6 +132,13 @@ def _do_dispatch( async_finish=self.async_prepare and not dbo_enabled(), allocate_on_comm_stream=False) + print(f"HT STUFF\n" + f"a1 = {tokens.shape} -> {token_data.shape}\n" + f"topk_ids={expert_topk_ids.shape}\n" + f"probs={expert_topk_weights.shape}\n" + f"lem shape={expert_num_tokens_per_expert_list}\n" + ) + # record the handle for this ubatch a2a_idx = dbo_current_ubatch_id() self.handles[a2a_idx] = handle diff --git a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py index c59129784b06..a87ff7c5489b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py @@ -4,6 +4,7 @@ import deep_ep import torch +import torch.nn.functional as F import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -47,7 +48,7 @@ def __init__(self, buffer: deep_ep.HybridEpBuffer, num_dispatchers: int, self.buffer = buffer self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size - self.rank_expert_offset = rank_expert_offset + self.rank_expert_offset = rank_expert_offset #? self.handle = None self.expert_probs = None @@ -112,13 +113,13 @@ def prepare( a1_post_scale = None else: a1q = a1 - a1q_scale = None + a1q_scale = torch.ones(1, device=a1.device, dtype=torch.float32) # hack a1_post_scale = quant_config.a1_scale ( expert_x, expert_probs, expert_x_scale, handle ) = self.buffer.dispatch( - tensor=a1, + tensor=a1q, scaling_factor=a1q_scale, topk_idx=topk_ids, topk_weights=topk_weights, @@ -127,6 +128,39 @@ def prepare( num_of_tokens_for_experts=-1, #?? ) self.handle = handle + self.expert_probs = expert_probs + + (sparse_to_dense_map, + rdma_to_attn_map, + attn_to_rdma_map, + num_of_tokens_for_experts, + local_expert_routing_map, + num_tokens) = self.handle + + num_of_tokens_for_experts = num_of_tokens_for_experts.cpu() + + print(f"STUFF\n" + f"rank_exp_offset = {self.rank_expert_offset}\n" + f"a={a1q.shape}/{a1q.dtype} -> {expert_x.shape}/{expert_x.dtype}\n" + f"topk_ids={topk_ids.shape}\n" + f"tok_for_exp={num_of_tokens_for_experts}\n" + f"probs={expert_probs.shape}\n" + f"lem shape={local_expert_routing_map.shape}, {local_expert_routing_map[:num_of_tokens_for_experts].shape}\n" + f"lem numel={local_expert_routing_map.nonzero().numel()}\n" + #f"lem={local_expert_routing_map}\n" + f"lem sum={local_expert_routing_map.sum(dim=1).shape}\n" + f"sparse_to_dense_map={sparse_to_dense_map.shape} {sparse_to_dense_map.dtype} {sparse_to_dense_map}\n" + f"rdma_to_attn_map={rdma_to_attn_map.shape} {rdma_to_attn_map.dtype} {rdma_to_attn_map}\n" + f"attn_to_rdma_map={attn_to_rdma_map.shape} {attn_to_rdma_map.dtype}\n" + f"num_tokens={num_tokens}\n" + ) + + local_expert_routing_map = local_expert_routing_map[:num_of_tokens_for_experts.item()] + + # TBD + new_topk_ids = None + + # N/A expert_tokens_meta = None # Dispatch and Quant @@ -144,9 +178,7 @@ def prepare( per_act_token_quant=False, block_shape=quant_config.block_shape) - self.expert_probs = expert_probs - - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return (expert_x, expert_x_scale, expert_tokens_meta, new_topk_ids, expert_probs) def finalize( self, @@ -170,12 +202,16 @@ def finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) + print(f"\nCOMBINE START({self.rank_expert_offset})\n") + combined_x, _ = self.buffer.combine( tensor=fused_expert_output, probs=self.expert_probs, # None? handle=self.handle, ) + print(f"\nCOMBINE END({self.rank_expert_offset}) {combined_x.shape}/{combined_x.dtype}\n") + # TODO(lucas): support this case with the refactored modular kernel # Respect inplace outputs. # apply weights??? diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 63c2296beec2..7627aca4b49e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -217,6 +217,8 @@ def _maybe_make_prepare_finalize( use_fp8=use_fp8, ) + print(f"MAX NUM TOKENS = {moe.max_num_tokens}") + handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHybridPrepareAndFinalize( handle, @@ -1045,6 +1047,8 @@ def __init__( max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, ) + print(f"VLLM_MOE_DP_CHUNK_SIZE={envs.VLLM_MOE_DP_CHUNK_SIZE}") + self.moe_config = moe self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 78e17796c460..35622e595eb9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -81,8 +81,8 @@ def _moe_problem_size( if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), \ - f"{topk_ids.size(0)} != {a1.size(0)}" +# assert topk_ids.size(0) == a1.size(0), \ +# f"{topk_ids.size(0)} != {a1.size(0)}" M = a1.size(0) else: assert a1.dim() == 3