Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@
return info


register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend=None,
)
# register_prepare_and_finalize(
# MoEPrepareAndFinalizeNoEP,
# standard_format,
# common_float_types,
# blocked_quantization_support=True,
# backend=None,
# )

register_experts(
BatchedTritonExperts,
Expand Down Expand Up @@ -189,35 +189,45 @@
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_hybrid_prepare_finalize import ( # noqa: E501
DeepEPHybridPrepareAndFinalize)

# register_prepare_and_finalize(
# DeepEPHTPrepareAndFinalize,
# standard_format,
# common_float_types,
# blocked_quantization_support=True,
# backend="deepep_high_throughput",
# )

# register_prepare_and_finalize(
# DeepEPLLPrepareAndFinalize,
# batched_format,
# common_float_types,
# blocked_quantization_support=True,
# backend="deepep_low_latency",
# )

register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
DeepEPHybridPrepareAndFinalize,
standard_format,
common_float_types,
blocked_quantization_support=True,

Check failure on line 215 in tests/kernels/moe/modular_kernel_tools/mk_objects.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM223)

tests/kernels/moe/modular_kernel_tools/mk_objects.py:215:4: SIM223 Use `False` instead of `False and ...`
backend="deepep_high_throughput",
)

register_prepare_and_finalize(
DeepEPLLPrepareAndFinalize,
batched_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_low_latency",
backend="deepep_hybrid",
)

if has_pplx():
if False and has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
backend="pplx",

Check failure on line 227 in tests/kernels/moe/modular_kernel_tools/mk_objects.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM223)

tests/kernels/moe/modular_kernel_tools/mk_objects.py:226:4: SIM223 Use `False` instead of `False and ...`
)

if (has_flashinfer_cutlass_fused_moe()
if False and (has_flashinfer_cutlass_fused_moe()
and current_platform.has_device_capability(100)):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
Expand Down
29 changes: 28 additions & 1 deletion vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ def set_num_sms(self, num_sms: int):
deep_ep.Buffer.set_num_sms(num_sms)


class DeepEPHybridAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP Hybrid kernels.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)

def _make_all2all_kwargs(self, **kwargs) -> dict[Any, Any]:
extra_kwargs = dict(group=self.cpu_group,
num_sms_dispatch_api = 32,
num_sms_combine_api = 32,
num_sms_preprocessing_api = 128,
nvlink_domain_size = 2, # hack for now. dp world_size
)
return {**kwargs, **extra_kwargs}

def get_handle(self, kwargs):
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP Hybrid all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.HybridEpBuffer)
logger.debug("DeepEP Hybrid constructed.")
return handle


class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP Low-Latency kernels.
Expand Down Expand Up @@ -395,4 +422,4 @@ def cleanup(self):
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
self.initialized = False
4 changes: 4 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __init__(self,
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "deepep_hybrid":
from .all2all import DeepEPHybridAll2AllManager
self.all2all_manager = DeepEPHybridAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Hybrid all2all manager.")
elif all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(
Expand Down
10 changes: 6 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
"deepep_high_throughput",
"deepep_low_latency",
"deepep_hybrid",
"allgather_reducescatter",
"flashinfer_all2allv"] = \
"allgather_reducescatter"
Expand Down Expand Up @@ -1214,10 +1215,11 @@ def get_vllm_port() -> Optional[int]:
"VLLM_ALL2ALL_BACKEND":
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
["naive", "pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv"]),
"deepep_high_throughput",
"deepep_low_latency",
"deepep_hybrid",
"allgather_reducescatter",
"flashinfer_all2allv"]),

# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
# Both require compute capability 10.0 or above.
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")

@property
def use_deepep_hybrid_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_hybrid")

@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
Expand Down Expand Up @@ -794,6 +799,10 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels

@property
def use_deepep_hybrid_kernels(self):
return self.moe_parallel_config.use_deepep_hybrid_kernels

@property
def use_flashinfer_cutlass_kernels(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional, Union

import deep_ep
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
dbo_switch_to_compute, dbo_switch_to_compute_sync,
dbo_yield_and_switch_from_comm_to_compute,
dbo_yield_and_switch_from_compute_to_comm)


class DeepEPHybridPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""

@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int,
dtype: torch.dtype) -> int:
# Round up hidden size so it is compatible with DeepEP High Throughput
# kernels.
# DeepEP intranode kernels make copies in units of,
# 32(warp-size) int4 elements. Round up hidden size to respect this.
# For example, an input hidden size of 2880 with dtype torch.bfloat16
# will be rounded up to 3072.
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size

hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize

def __init__(self, buffer: deep_ep.HybridEpBuffer, num_dispatchers: int,
dp_size: int, rank_expert_offset: int):
super().__init__()
self.buffer = buffer
self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.handle = None
self.expert_probs = None

# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]

def num_dispatchers(self) -> int:
return self.num_dispatchers_

@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard

def max_num_tokens_per_rank(self) -> Optional[int]:
return None

def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int64

def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)

def _get_combine_config(self) -> Optional[deep_ep.Config]:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)

def supports_async(self) -> bool:
return False # combine async not supported

def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:

if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * topk_weights.to(a1.dtype)

if quant_config.is_block_quantized:
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)

Check failure on line 109 in vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py:109:13: F841 Local variable `a1q` is assigned to but never used
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
a1_post_scale = None
else:
a1q = a1
a1q_scale = None
a1_post_scale = quant_config.a1_scale

(
expert_x, expert_probs, expert_x_scale, handle
) = self.buffer.dispatch(
tensor=a1,
scaling_factor=a1q_scale,
topk_idx=topk_ids,
topk_weights=topk_weights,
routing_map=None, # None = generated dynamically
handle=None,
num_of_tokens_for_experts=-1, #??
)
self.handle = handle
expert_tokens_meta = None

# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_post_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)

self.expert_probs = expert_probs

return (expert_x, expert_x_scale, expert_tokens_meta, None, None)

def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,

Check failure on line 156 in vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM223)

vllm/model_executor/layers/fused_moe/deepep_hybrid_prepare_finalize.py:156:12: SIM223 Use `False` instead of `False and ...`
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if False and fused_expert_output.numel() != 0:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
fused_expert_output = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)

combined_x, _ = self.buffer.combine(
tensor=fused_expert_output,
probs=self.expert_probs, # None?
handle=self.handle,
)

# TODO(lucas): support this case with the refactored modular kernel
# Respect inplace outputs.
# apply weights???
output.copy_(combined_x, non_blocking=True)
Loading
Loading