diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py new file mode 100644 index 00000000000..f99387e594e --- /dev/null +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -0,0 +1,241 @@ +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +import torch + +from sglang.srt import two_batch_overlap +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + + +class TboAttnBackend(AttentionBackend): + def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]): + super().__init__() + self.primary = primary + self.children = children + + @classmethod + def init_new(cls, creator: Callable[[], AttentionBackend]): + return cls( + primary=creator(), + children=[creator() for _ in range(2)], + ) + + def init_forward_metadata(self, forward_batch: "ForwardBatch"): + self.primary.init_forward_metadata(forward_batch=forward_batch) + if forward_batch.tbo_children is not None: + for child, forward_batch_child in zip( + self.children, forward_batch.tbo_children, strict=True + ): + if forward_batch_child.batch_size > 0: + child.init_forward_metadata(forward_batch=forward_batch_child) + + def init_cuda_graph_state(self, max_bs: int): + self.primary.init_cuda_graph_state(max_bs=max_bs) + for item in self.children: + # TODO for children, maybe can provide *smaller* max_bs to optimize + item.init_cuda_graph_state(max_bs=max_bs) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: "ForwardMode", + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + self.primary.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=num_tokens, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=encoder_lens, + forward_mode=forward_mode, + spec_info=spec_info, + ) + + self._init_forward_metadata_cuda_graph_children( + fn_name="init_forward_metadata_capture_cuda_graph", + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=encoder_lens, + forward_mode=forward_mode, + spec_info=spec_info, + capture_num_tokens=num_tokens, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: "ForwardMode", + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + self.primary.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, + forward_mode=forward_mode, + spec_info=spec_info, + seq_lens_cpu=seq_lens_cpu, + ) + + self._init_forward_metadata_cuda_graph_children( + fn_name="init_forward_metadata_replay_cuda_graph", + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=encoder_lens, + forward_mode=forward_mode, + spec_info=spec_info, + replay_seq_lens_sum=seq_lens_sum, + replay_seq_lens_cpu=seq_lens_cpu, + ) + + def _init_forward_metadata_cuda_graph_children( + self, + fn_name: str, + # common args + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: "ForwardMode", + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + # capture args + capture_num_tokens: int = None, + # replay args + replay_seq_lens_sum: int = None, + replay_seq_lens_cpu: Optional[torch.Tensor] = None, + ): + from sglang.srt.model_executor.forward_batch_info import ForwardMode + + if fn_name == "init_forward_metadata_capture_cuda_graph": + assert capture_num_tokens == bs, "Only support num_tokens==bs currently" + num_tokens = bs + + forward_mode_for_tbo_split = ( + forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE + ) + tbo_split_seq_index = two_batch_overlap.compute_split_seq_index( + forward_mode=forward_mode_for_tbo_split, + num_tokens=num_tokens, + extend_lens=None, + ) + tbo_split_token_index = two_batch_overlap.compute_split_token_index( + split_seq_index=tbo_split_seq_index, + forward_mode=forward_mode_for_tbo_split, + extend_seq_lens=None, + ) + + num_tokens_child_left = tbo_split_token_index + num_tokens_child_right = num_tokens - tbo_split_token_index + bs_child_left = num_tokens_child_left + bs_child_right = num_tokens_child_right + + assert ( + num_tokens_child_left > 0 and num_tokens_child_right > 0 + ), f"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}" + + common_pre_split_args = dict( + fn_name=fn_name, + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=encoder_lens, + forward_mode=forward_mode, + spec_info=spec_info, + capture_num_tokens=capture_num_tokens, + replay_seq_lens_sum=replay_seq_lens_sum, + replay_seq_lens_cpu=replay_seq_lens_cpu, + ) + + args_left = _init_forward_metadata_cuda_graph_split( + output_bs=bs_child_left, + seq_slice=slice(None, tbo_split_seq_index), + **common_pre_split_args, + ) + args_right = _init_forward_metadata_cuda_graph_split( + output_bs=bs_child_right, + seq_slice=slice(tbo_split_seq_index, None), + **common_pre_split_args, + ) + + child_left, child_right = self.children + getattr(child_left, fn_name)(**args_left) + getattr(child_right, fn_name)(**args_right) + + def get_cuda_graph_seq_len_fill_value(self): + ans = self.primary.get_cuda_graph_seq_len_fill_value() + for child in self.children: + assert ans == child.get_cuda_graph_seq_len_fill_value() + return ans + + def forward_extend(self, *args, **kwargs): + return self.primary.forward_extend(*args, **kwargs) + + def forward_decode(self, *args, **kwargs): + return self.primary.forward_decode(*args, **kwargs) + + +def _init_forward_metadata_cuda_graph_split( + fn_name: str, + seq_slice: slice, + output_bs: int, + # common args + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: "ForwardMode", + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + # capture args + capture_num_tokens: int = None, + # replay args + replay_seq_lens_sum: int = None, + replay_seq_lens_cpu: Optional[torch.Tensor] = None, +): + assert encoder_lens is None, "encoder_lens is not supported yet" + assert spec_info is None, "spec_info is not supported yet" + + ans = dict( + bs=output_bs, + req_pool_indices=req_pool_indices[seq_slice], + seq_lens=seq_lens[seq_slice], + # directly forward + forward_mode=forward_mode, + # ignore + encoder_lens=None, + spec_info=None, + ) + + if fn_name == "init_forward_metadata_capture_cuda_graph": + assert capture_num_tokens == bs, "Only support num_tokens==bs currently" + ans.update( + dict( + num_tokens=output_bs, + ) + ) + elif fn_name == "init_forward_metadata_replay_cuda_graph": + output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice] + ans.update( + dict( + seq_lens_sum=output_seq_lens_cpu.sum().item(), + seq_lens_cpu=output_seq_lens_cpu, + ) + ) + else: + raise NotImplementedError + + return ans diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index 0292b21aa2f..ef454bc0fd3 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -391,3 +391,16 @@ def __patched_func(self, *args, **kwargs): RuntimeCache.get = __patched_func yield RuntimeCache.get = origin_func + + +@contextmanager +def configure_deep_gemm_num_sms(num_sms): + if num_sms is None: + yield + else: + original_num_sms = deep_gemm.get_num_sms() + deep_gemm.set_num_sms(num_sms) + try: + yield + finally: + deep_gemm.set_num_sms(original_num_sms) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index be6c3e99f2b..ee613c573da 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -78,6 +78,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap, "enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_ep_moe": ServerArgs.enable_ep_moe, "deepep_config": ServerArgs.deepep_config, @@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None can_run_dp_cuda_graph: bool = False + tbo_split_seq_index: Optional[int] = None + global_forward_mode: Optional[ForwardMode] = None # For processing logprobs return_logprob: bool = False @@ -1624,6 +1627,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "fa3" or global_server_args_dict["attention_backend"] == "cutlass_mla" + or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = self.seq_lens.cpu() else: @@ -1651,6 +1655,8 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + tbo_split_seq_index=self.tbo_split_seq_index, + global_forward_mode=self.global_forward_mode, seq_lens_cpu=seq_lens_cpu, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, @@ -1729,6 +1735,8 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]] can_run_dp_cuda_graph: bool + tbo_split_seq_index: Optional[int] + global_forward_mode: Optional[ForwardMode] # For extend extend_num_tokens: Optional[int] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3c449fbea85..e1e7f62df37 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -34,6 +34,7 @@ from torch.distributed import barrier from sglang.global_config import global_config +from sglang.srt import two_batch_overlap from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.disaggregation.decode import ( @@ -132,7 +133,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.utils import ( + DeepEPMode, DynamicGradMode, broadcast_pyobj, configure_logger, @@ -1648,6 +1651,9 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): disable_cuda_graph=self.server_args.disable_cuda_graph, spec_algorithm=self.spec_algorithm, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, + enable_two_batch_overlap=self.server_args.enable_two_batch_overlap, + enable_deepep_moe=self.server_args.enable_deepep_moe, + deepep_mode=DeepEPMode[self.server_args.deepep_mode], ) @staticmethod @@ -1661,6 +1667,9 @@ def prepare_dp_attn_batch_raw( disable_cuda_graph: bool, spec_algorithm, speculative_num_draft_tokens, + enable_two_batch_overlap: bool, + enable_deepep_moe: bool, + deepep_mode: DeepEPMode, ): # Check if other DP workers have running batches if local_batch is None: @@ -1696,17 +1705,26 @@ def prepare_dp_attn_batch_raw( is_extend_in_batch = ( local_batch.forward_mode.is_extend() if local_batch else False ) + + tbo_preparer = TboDPAttentionPreparer() + local_info = torch.tensor( [ num_tokens, can_cuda_graph, num_tokens_for_logprob, is_extend_in_batch, + *tbo_preparer.prepare_all_gather( + local_batch, + deepep_mode, + enable_deepep_moe, + enable_two_batch_overlap, + ), ], dtype=torch.int64, ) global_info = torch.empty( - (dp_size, attn_tp_size, 4), + (dp_size, attn_tp_size, 6), dtype=torch.int64, ) torch.distributed.all_gather_into_tensor( @@ -1719,6 +1737,10 @@ def prepare_dp_attn_batch_raw( global_num_tokens_for_logprob = global_info[:, 0, 2].tolist() is_extend_in_batch = global_info[:, 0, 3].tolist() + tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output( + global_info[:, :, 4:6] + ) + if local_batch is None and max(global_num_tokens) > 0: local_batch = get_idle_batch() @@ -1732,6 +1754,8 @@ def prepare_dp_attn_batch_raw( local_batch.global_num_tokens_for_logprob = ( global_num_tokens_for_logprob ) + local_batch.tbo_split_seq_index = tbo_split_seq_index + local_batch.global_forward_mode = global_forward_mode # Check forward mode for cuda graph if not disable_cuda_graph: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 308bf92ddd9..74f45fb093d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -24,6 +24,7 @@ import torch import tqdm +from sglang.srt import two_batch_overlap from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture @@ -38,6 +39,10 @@ PPProxyTensors, ) from sglang.srt.patch_torch import monkey_patch_torch_compile +from sglang.srt.two_batch_overlap import ( + TboCudaGraphRunnerUtils, + TboForwardBatchPreparer, +) from sglang.srt.utils import ( get_available_gpu_memory, get_device_memory_capacity, @@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): model_runner.req_to_token_pool.size ] + if server_args.enable_two_batch_overlap: + capture_bs = [bs for bs in capture_bs if bs >= 2] + if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] if max(capture_bs) < server_args.cuda_graph_max_bs: @@ -349,7 +357,14 @@ def can_run(self, forward_batch: ForwardBatch): if self.is_encoder_decoder else True ) - return is_bs_supported and is_encoder_lens_supported + + is_tbo_supported = ( + forward_batch.can_run_tbo + if self.model_runner.server_args.enable_two_batch_overlap + else True + ) + + return is_bs_supported and is_encoder_lens_supported and is_tbo_supported def capture(self): with graph_capture() as graph_capture_context: @@ -466,7 +481,12 @@ def capture_one_batch_size(self, bs: int, forward: Callable): capture_hidden_mode=self.capture_hidden_mode, lora_paths=lora_paths, num_token_non_padded=self.num_token_non_padded, + tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index( + self, num_tokens + ), + global_forward_mode=self.capture_forward_mode, ) + TboForwardBatchPreparer.prepare(forward_batch) if lora_paths is not None: self.model_runner.lora_manager.prepare_lora_batch(forward_batch) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index ea64199a5d3..de462e45da8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -29,9 +29,10 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import triton @@ -239,6 +240,7 @@ class ForwardBatch: dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime gathered_buffer: Optional[torch.Tensor] = None can_run_dp_cuda_graph: bool = False + global_forward_mode: Optional[ForwardMode] = None # Speculative decoding spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None @@ -252,12 +254,18 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None + tbo_split_seq_index: Optional[int] = None + tbo_parent_token_range: Optional[Tuple[int, int]] = None + tbo_children: Optional[List["ForwardBatch"]] = None + @classmethod def init_new( cls, batch: ModelWorkerBatch, model_runner: ModelRunner, ): + from sglang.srt.two_batch_overlap import TboForwardBatchPreparer + device = model_runner.device extend_input_logprob_token_ids_gpu = None if batch.extend_input_logprob_token_ids is not None: @@ -281,6 +289,7 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, + global_forward_mode=batch.global_forward_mode, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, req_to_token_pool=model_runner.req_to_token_pool, @@ -294,6 +303,7 @@ def init_new( num_token_non_padded=torch.tensor( len(batch.input_ids), dtype=torch.int32 ).to(device, non_blocking=True), + tbo_split_seq_index=batch.tbo_split_seq_index, ) # For DP attention @@ -316,6 +326,7 @@ def init_new( ) if ret.forward_mode.is_idle(): ret.positions = torch.empty((0,), device=device) + TboForwardBatchPreparer.prepare(ret) return ret # Override the positions with spec_info @@ -364,6 +375,8 @@ def init_new( if model_runner.server_args.lora_paths is not None: model_runner.lora_manager.prepare_lora_batch(ret) + TboForwardBatchPreparer.prepare(ret) + return ret def merge_mm_inputs(self) -> Optional[MultimodalInputs]: @@ -588,6 +601,10 @@ def prepare_chunked_prefix_cache_info(self, device: torch.device): # Precompute the kv indices for each chunk self.prepare_chunked_kv_indices(device) + @property + def can_run_tbo(self): + return self.tbo_split_seq_index is not None + class PPProxyTensors: # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index acb69fa9a77..3fad97bd6f1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -37,6 +37,7 @@ set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, @@ -198,6 +199,7 @@ def __init__( "disable_radix_cache": server_args.disable_radix_cache, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, + "enable_two_batch_overlap": server_args.enable_two_batch_overlap, "enable_dp_lm_head": server_args.enable_dp_lm_head, "enable_ep_moe": server_args.enable_ep_moe, "enable_deepep_moe": server_args.enable_deepep_moe, @@ -994,6 +996,13 @@ def init_cublas(self): def init_attention_backend(self): """Init attention kernel backend.""" + if self.server_args.enable_two_batch_overlap: + self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend) + else: + self.attn_backend = self._get_attention_backend() + + # TODO unify with 6338 + def _get_attention_backend(self): if self.server_args.attention_backend == "flashinfer": if not self.use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( @@ -1003,17 +1012,17 @@ def init_attention_backend(self): # Init streams if self.server_args.speculative_algorithm == "EAGLE": self.plan_stream_for_flashinfer = torch.cuda.Stream() - self.attn_backend = FlashInferAttnBackend(self) + return FlashInferAttnBackend(self) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, ) - self.attn_backend = FlashInferMLAAttnBackend(self) + return FlashInferMLAAttnBackend(self) elif self.server_args.attention_backend == "aiter": from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend - self.attn_backend = AiterAttnBackend(self) + return AiterAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( "Window attention is not supported in the triton attention backend. " @@ -1028,21 +1037,21 @@ def init_attention_backend(self): DoubleSparseAttnBackend, ) - self.attn_backend = DoubleSparseAttnBackend(self) + return DoubleSparseAttnBackend(self) else: from sglang.srt.layers.attention.triton_backend import TritonAttnBackend - self.attn_backend = TritonAttnBackend(self) + return TritonAttnBackend(self) elif self.server_args.attention_backend == "torch_native": from sglang.srt.layers.attention.torch_native_backend import ( TorchNativeAttnBackend, ) - self.attn_backend = TorchNativeAttnBackend(self) + return TorchNativeAttnBackend(self) elif self.server_args.attention_backend == "flashmla": from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend - self.attn_backend = FlashMLABackend(self) + return FlashMLABackend(self) elif self.server_args.attention_backend == "fa3": assert ( torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend @@ -1054,13 +1063,13 @@ def init_attention_backend(self): FlashAttentionBackend, ) - self.attn_backend = FlashAttentionBackend(self) + return FlashAttentionBackend(self) elif self.server_args.attention_backend == "cutlass_mla": from sglang.srt.layers.attention.cutlass_mla_backend import ( CutlassMLABackend, ) - self.attn_backend = CutlassMLABackend(self) + return CutlassMLABackend(self) else: raise ValueError( f"Invalid attention backend: {self.server_args.attention_backend}" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c0088facaf7..7d80794fc6e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -83,8 +83,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.operations import execute_operations -from sglang.srt.operations_strategy import compute_layer_operations +from sglang.srt.two_batch_overlap import ( + MaybeTboDeepEPDispatcher, + model_forward_maybe_tbo, +) from sglang.srt.utils import ( BumpAllocator, DeepEPMode, @@ -226,6 +228,7 @@ def __init__( self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + self.config = config self.layer_id = layer_id if self.tp_size > config.n_routed_experts: @@ -300,7 +303,7 @@ def __init__( else None ) - self.deepep_dispatcher = DeepEPDispatcher( + self.deepep_dispatcher = MaybeTboDeepEPDispatcher( group=parallel_state.get_tp_group().device_group, router_topk=self.top_k, permute_fusion=True, @@ -309,13 +312,11 @@ def __init__( hidden_size=config.hidden_size, params_dtype=config.torch_dtype, deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - async_finish=True, # TODO + async_finish=True, return_recv_hook=True, ) - @property - def _enable_deepep_moe(self): - return global_server_args_dict["enable_deepep_moe"] + self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"] def get_moe_weights(self): return [ @@ -423,7 +424,7 @@ def _forward_shared_experts(self, hidden_states): return None def op_gate(self, state): - if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( + if is_non_idle_and_non_empty( state.forward_batch.forward_mode, state.hidden_states_mlp_input ): # router_logits: (num_tokens, n_experts) @@ -432,115 +433,105 @@ def op_gate(self, state): state.router_logits = None def op_shared_experts(self, state): - if (self.n_share_experts_fusion == 0) and ( - (not self._enable_deepep_moe) - or is_non_idle_and_non_empty( - state.forward_batch.forward_mode, state.hidden_states_mlp_input - ) + hidden_states_mlp_input = state.pop("hidden_states_mlp_input") + if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty( + state.forward_batch.forward_mode, hidden_states_mlp_input ): - state.shared_output = self.shared_experts(state.hidden_states_mlp_input) + state.shared_output = self.shared_experts(hidden_states_mlp_input) else: state.shared_output = None def op_select_experts(self, state): - router_logits = state.router_logits + router_logits = state.pop("router_logits") hidden_states = state.hidden_states_mlp_input - if self._enable_deepep_moe: - if router_logits is not None: - state.topk_weights_local, state.topk_idx_local = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( - layer_id=self.layer_id, - ), - ) - else: - state.topk_idx_local = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - state.topk_weights_local = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) + if router_logits is not None: + state.topk_weights_local, state.topk_idx_local = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) + else: + state.topk_idx_local = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + state.topk_weights_local = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) def op_dispatch_a(self, state): - if self._enable_deepep_moe and (self.ep_size > 1): + if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value self.deepep_dispatcher.dispatch_a( - hidden_states=state.pop("hidden_states_mlp_input"), + hidden_states=state.hidden_states_mlp_input, topk_idx=state.pop("topk_idx_local"), topk_weights=state.pop("topk_weights_local"), forward_mode=state.forward_batch.forward_mode, + tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_dispatch_b(self, state): - if self._enable_deepep_moe and (self.ep_size > 1): - ( - state.hidden_states_experts_input, - state.topk_idx_dispatched, - state.topk_weights_dispatched, - state.reorder_topk_ids, - state.num_recv_tokens_per_expert, - state.seg_indptr, - state.masked_m, - state.expected_m, - ) = self.deepep_dispatcher.dispatch_b() + if self.ep_size > 1: + with get_global_expert_distribution_recorder().with_current_layer( + self.layer_id + ): + ( + state.hidden_states_experts_input, + state.topk_idx_dispatched, + state.topk_weights_dispatched, + state.reorder_topk_ids, + state.num_recv_tokens_per_expert, + state.seg_indptr, + state.masked_m, + state.expected_m, + ) = self.deepep_dispatcher.dispatch_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) def op_experts(self, state): - if self._enable_deepep_moe: - state.pop("router_logits") - state.hidden_states_experts_output = self.experts( - hidden_states=state.pop("hidden_states_experts_input"), - topk_idx=state.topk_idx_dispatched, - topk_weights=state.topk_weights_dispatched, - reorder_topk_ids=state.pop("reorder_topk_ids"), - seg_indptr=state.pop("seg_indptr"), - masked_m=state.pop("masked_m"), - expected_m=state.pop("expected_m"), - num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), - forward_mode=state.forward_batch.forward_mode, - ) - else: - state.hidden_states_experts_output = self.experts( - hidden_states=state.pop("hidden_states_mlp_input"), - router_logits=state.pop("router_logits"), - ) + state.hidden_states_experts_output = self.experts( + hidden_states=state.pop("hidden_states_experts_input"), + topk_idx=state.topk_idx_dispatched, + topk_weights=state.topk_weights_dispatched, + reorder_topk_ids=state.pop("reorder_topk_ids"), + seg_indptr=state.pop("seg_indptr"), + masked_m=state.pop("masked_m"), + expected_m=state.pop("expected_m"), + num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), + forward_mode=state.forward_batch.forward_mode, + ) def op_combine_a(self, state): - if self._enable_deepep_moe and (self.ep_size > 1): + if self.ep_size > 1: self.deepep_dispatcher.combine_a( - state.pop("hidden_states_experts_output"), + hidden_states=state.pop("hidden_states_experts_output"), topk_idx=state.pop("topk_idx_dispatched"), topk_weights=state.pop("topk_weights_dispatched"), forward_mode=state.forward_batch.forward_mode, + tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_combine_b(self, state): - if self._enable_deepep_moe and (self.ep_size > 1): - state.hidden_states_after_combine = self.deepep_dispatcher.combine_b() + if self.ep_size > 1: + state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) def op_output(self, state): - final_hidden_states = ( - state.pop("hidden_states_after_combine") - if self._enable_deepep_moe - else state.pop("hidden_states_experts_output") - ) - + final_hidden_states = state.pop("hidden_states_after_combine") final_hidden_states *= self.routed_scaling_factor - if (s := state.pop("shared_output")) is not None: final_hidden_states = final_hidden_states + s - if (not self._enable_deepep_moe) and (self.tp_size > 1): - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - state.hidden_states_mlp_output = final_hidden_states @@ -1482,6 +1473,7 @@ def op_comm_prepare_attn( forward_batch: ForwardBatch, residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, + tbo_subbatch_index: Optional[int] = None, ): state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) @@ -1491,6 +1483,7 @@ def op_comm_prepare_attn( forward_batch=forward_batch, positions=positions, zero_allocator=zero_allocator, + tbo_subbatch_index=tbo_subbatch_index, ) ) @@ -1523,8 +1516,24 @@ def op_comm_postprocess_layer(self, state): state.forward_batch, ) - state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"}) - return hidden_states, residual + output = dict( + positions=state.positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=state.forward_batch, + zero_allocator=state.zero_allocator, + tbo_subbatch_index=state.tbo_subbatch_index, + ) + + state.clear( + expect_keys={ + "positions", + "forward_batch", + "zero_allocator", + "tbo_subbatch_index", + } + ) + return output class DeepseekV2Model(nn.Module): @@ -1539,6 +1548,7 @@ def __init__( super().__init__() self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -1572,13 +1582,12 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + total_num_layers = len(self.layers) + device = input_embeds.device if input_embeds is not None else input_ids.device zero_allocator = BumpAllocator( - # TODO for two-batch-overlap, we need a larger buffer size - buffer_size=len(self.layers) * 2, + buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), dtype=torch.float32, - device=( - input_embeds.device if input_embeds is not None else input_ids.device - ), + device=device, ) if input_embeds is None: @@ -1587,12 +1596,30 @@ def forward( hidden_states = input_embeds residual = None - for i in range(len(self.layers)): + + normal_num_layers = ( + self.first_k_dense_replace + if forward_batch.can_run_tbo + else total_num_layers + ) + for i in range(normal_num_layers): with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, zero_allocator ) + + if normal_num_layers != total_num_layers: + hidden_states, residual = model_forward_maybe_tbo( + layers=self.layers[normal_num_layers:], + enable_tbo=True, + positions=positions, + forward_batch=forward_batch, + hidden_states=hidden_states, + residual=residual, + zero_allocator=zero_allocator, + ) + if not forward_batch.forward_mode.is_idle(): if residual is None: hidden_states = self.norm(hidden_states) @@ -1674,7 +1701,6 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index 0ef9fb9c404..0a8c118dfe1 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -12,7 +12,7 @@ def execute_operations(inputs, operations): - stages = _convert_operations_to_stages(decorate_operations(operations)) + stages = _convert_operations_to_stages(operations) executor = _StageExecutor("primary", stages, inputs=inputs) for _ in range(executor.num_stages): executor.next() @@ -20,6 +20,37 @@ def execute_operations(inputs, operations): return executor.output +def execute_overlapped_operations( + inputs_arr: Sequence, + operations_arr: Sequence, + delta_stages: Sequence[int], +) -> Sequence: + # Make it explicit for clarity; if we need multi-batch overlap, this can be generalized + inputs_a, inputs_b = inputs_arr + operations_a, operations_b = operations_arr + delta_stage_a, delta_stage_b = delta_stages + assert delta_stage_a == 0 + delta_stage = delta_stage_b + + stages_a = _convert_operations_to_stages(operations_a) + stages_b = _convert_operations_to_stages(operations_b) + executor_a = _StageExecutor("a", stages_a, inputs=inputs_a) + executor_b = _StageExecutor("b", stages_b, inputs=inputs_b) + + for _ in range(delta_stage): + executor_a.next() + + for _ in range(executor_a.num_stages - delta_stage): + executor_a.next() + executor_b.next() + + for _ in range(delta_stage): + executor_b.next() + + assert executor_a.done and executor_b.done + return [executor_a.output, executor_b.output] + + class YieldOperation: pass @@ -109,6 +140,9 @@ def update(self, values: Dict[str, Any]): for k, v in values.items(): setattr(self, k, v) + def get(self, item): + return self._data.get(item) + def clear(self, expect_keys: Sequence[str]): if set(self._data.keys()) != set(expect_keys): raise Exception( @@ -119,6 +153,7 @@ def clear(self, expect_keys: Sequence[str]): def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]: + operations = _decorate_operations(operations) operation_chunks = list( _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation)) ) @@ -140,7 +175,7 @@ def _chunk_by_separator( yield pending_items -def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""): +def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""): return [_decorate_operation(op, debug_name_prefix) for op in operations] diff --git a/python/sglang/srt/operations_strategy.py b/python/sglang/srt/operations_strategy.py index be0577ce295..b8e0eaef0aa 100644 --- a/python/sglang/srt/operations_strategy.py +++ b/python/sglang/srt/operations_strategy.py @@ -1,33 +1,116 @@ +from dataclasses import dataclass +from typing import List, Optional + import torch +from sglang.srt import operations +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.operations import Operation + + +@dataclass +class OperationsStrategy: + operations: List[Operation] + deep_gemm_num_sms: Optional[int] = None + tbo_delta_stages: Optional[int] = None + + @classmethod + def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy": + return OperationsStrategy( + operations=[x for item in items for x in item.operations], + deep_gemm_num_sms=_assert_all_same( + [item.deep_gemm_num_sms for item in items] + ), + tbo_delta_stages=_assert_all_same( + [item.tbo_delta_stages for item in items] + ), + ) + + @staticmethod + def init_new_tbo( + layers: torch.nn.ModuleList, + forward_mode: ForwardMode, + ) -> "OperationsStrategy": + return OperationsStrategy.concat( + [ + _compute_layer_operations_strategy_tbo(layer, forward_mode) + for layer in layers + ] + ) -def compute_layer_operations( + +def _assert_all_same(items: List): + assert all(item == items[0] for item in items) + return items[0] + + +# TODO can refactor to make it more fancy if we have more complex strategies +def _compute_layer_operations_strategy_tbo( layer: torch.nn.Module, -): - if not layer.is_layer_sparse: - return [ + forward_mode: ForwardMode, +) -> OperationsStrategy: + assert layer.is_layer_sparse, "dense layer TBO not yet implemented" + if forward_mode == ForwardMode.EXTEND: + return _compute_moe_deepseek_blog_prefill(layer) + elif forward_mode == ForwardMode.DECODE: + return _compute_moe_deepseek_blog_decode(layer) + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + + +def _compute_moe_deepseek_blog_prefill(layer): + device_properties = torch.cuda.get_device_properties(device="cuda") + total_num_sms = device_properties.multi_processor_count + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + + return OperationsStrategy( + deep_gemm_num_sms=deep_gemm_num_sms, + tbo_delta_stages=0, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_shared_experts, + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +def _compute_moe_deepseek_blog_decode(layer): + return OperationsStrategy( + deep_gemm_num_sms=None, + tbo_delta_stages=2, + operations=[ layer.op_comm_prepare_attn, layer.self_attn.op_prepare, + operations.YieldOperation(), layer.self_attn.op_core, layer.op_comm_prepare_mlp, - layer.op_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_a, + layer.mlp.op_shared_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, layer.op_comm_postprocess_layer, - ] - - # Will add TBO operation orders here - return [ - layer.op_comm_prepare_attn, - layer.self_attn.op_prepare, - layer.self_attn.op_core, - layer.op_comm_prepare_mlp, - layer.mlp.op_gate, - layer.mlp.op_shared_experts, - layer.mlp.op_select_experts, - layer.mlp.op_dispatch_a, - layer.mlp.op_dispatch_b, - layer.mlp.op_experts, - layer.mlp.op_combine_a, - layer.mlp.op_combine_b, - layer.mlp.op_output, - layer.op_comm_postprocess_layer, - ] + operations.YieldOperation(), + ], + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ab28e5abea5..f00aa11ac5a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -167,6 +167,7 @@ class ServerArgs: enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_dp_lm_head: bool = False + enable_two_batch_overlap: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -1144,6 +1145,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) + parser.add_argument( + "--enable-two-batch-overlap", + action="store_true", + help="Enabling two micro batches to overlap.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py new file mode 100644 index 00000000000..0fbc3c8e73f --- /dev/null +++ b/python/sglang/srt/two_batch_overlap.py @@ -0,0 +1,462 @@ +import dataclasses +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence + +import torch + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher +from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.operations import execute_operations, execute_overlapped_operations +from sglang.srt.operations_strategy import OperationsStrategy +from sglang.srt.utils import BumpAllocator, DeepEPMode + +if TYPE_CHECKING: + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + + +# -------------------------------- Compute Basic Info --------------------------------------- + + +# TODO: may smartly disable TBO when batch size is too small b/c it will slow down +def compute_split_seq_index( + forward_mode: "ForwardMode", + num_tokens: int, + extend_lens: Optional[Sequence[int]], +) -> Optional[int]: + if forward_mode.is_extend(): + assert extend_lens is not None + return _split_array_by_half_sum(extend_lens) + elif forward_mode.is_decode(): + return num_tokens // 2 + elif forward_mode.is_idle(): + assert num_tokens == 0 + return 0 + else: + raise NotImplementedError + + +def _split_array_by_half_sum(arr: Sequence[int]) -> int: + overall_sum = sum(arr) + accumulator, split_index = 0, 0 + for value in arr[:-1]: + accumulator += value + split_index += 1 + if accumulator >= overall_sum // 2: + break + return split_index + + +def compute_split_token_index( + split_seq_index: int, + forward_mode: "ForwardMode", + extend_seq_lens: Optional[Sequence[int]], +) -> int: + if forward_mode.is_extend(): + assert extend_seq_lens is not None + return sum(extend_seq_lens[:split_seq_index]) + elif forward_mode.is_decode(): + return split_seq_index + elif forward_mode.is_idle(): + assert split_seq_index == 0 + return 0 + else: + raise NotImplementedError + + +# -------------------------------- Preparation --------------------------------------- + + +class TboCudaGraphRunnerUtils: + @staticmethod + def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int): + if that.model_runner.server_args.enable_two_batch_overlap: + tbo_split_seq_index = compute_split_seq_index( + forward_mode=that.capture_forward_mode, + num_tokens=num_tokens, + extend_lens=None, + ) + # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true + assert ( + tbo_split_seq_index is not None + ), f"{that.capture_forward_mode=} {num_tokens=}" + else: + tbo_split_seq_index = None + return tbo_split_seq_index + + +class TboDPAttentionPreparer: + def prepare_all_gather( + self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap + ): + self.enable_two_batch_overlap = enable_two_batch_overlap + + if local_batch is not None: + self.local_tbo_split_seq_index = compute_split_seq_index( + forward_mode=local_batch.forward_mode, + num_tokens=local_batch.input_ids.shape[0], + extend_lens=local_batch.extend_lens, + ) + resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode) + local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not ( + local_batch.forward_mode.is_extend() + and enable_deepep_moe + and (resolved_deepep_mode == DeepEPMode.low_latency) + ) + else: + self.local_tbo_split_seq_index = 0 + local_can_run_tbo = True + + local_forward_mode = self._compute_local_forward_mode(local_batch) + + return local_can_run_tbo, local_forward_mode + + def compute_output(self, partial_global_info): + local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist()) + forward_modes = partial_global_info[:, 0, 1].tolist() + + global_forward_mode, forward_mode_agree = self._compute_global_forward_mode( + forward_modes + ) + + can_run_tbo = ( + self.enable_two_batch_overlap + and local_can_run_tbo_aggregated + and forward_mode_agree + ) + + tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None + global_forward_mode = global_forward_mode if can_run_tbo else None + return tbo_split_seq_index, global_forward_mode + + @staticmethod + def _compute_local_forward_mode(local_batch): + return ( + local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE + ).value + + @staticmethod + def _compute_global_forward_mode(forward_modes): + converted_forward_modes = [ + ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x + for x in forward_modes + ] + forward_mode_agree = TboDPAttentionPreparer._is_all_same( + converted_forward_modes + ) + global_forward_mode = ( + ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None + ) + return global_forward_mode, forward_mode_agree + + @staticmethod + def _is_all_same(x): + return all(value == x[0] for value in x) + + +class TboForwardBatchPreparer: + @classmethod + def prepare(cls, batch: ForwardBatch): + from sglang.srt.layers.attention.tbo_backend import TboAttnBackend + + if batch.tbo_split_seq_index is None: + return + + tbo_split_token_index = compute_split_token_index( + split_seq_index=batch.tbo_split_seq_index, + forward_mode=batch.forward_mode, + extend_seq_lens=batch.extend_seq_lens_cpu, + ) + + assert isinstance(batch.attn_backend, TboAttnBackend) + attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children + + child_a = cls.filter_batch( + batch, + start_token_index=0, + end_token_index=tbo_split_token_index, + start_seq_index=0, + end_seq_index=batch.tbo_split_seq_index, + output_attn_backend=attn_backend_child_a, + ) + child_b = cls.filter_batch( + batch, + start_token_index=tbo_split_token_index, + end_token_index=batch.input_ids.shape[0], + start_seq_index=batch.tbo_split_seq_index, + end_seq_index=batch.batch_size, + output_attn_backend=attn_backend_child_b, + ) + + assert batch.tbo_children is None + batch.tbo_children = [child_a, child_b] + + @classmethod + def filter_batch( + cls, + batch: ForwardBatch, + *, + start_token_index: int, + end_token_index: int, + start_seq_index: int, + end_seq_index: int, + output_attn_backend: AttentionBackend, + ): + from sglang.srt.managers.schedule_batch import global_server_args_dict + + num_tokens = batch.input_ids.shape[0] + num_seqs = batch.batch_size + + output_dict = dict() + + for key in [ + "input_ids", + "positions", + "out_cache_loc", + ]: + old_value = getattr(batch, key) + assert ( + old_value.shape[0] == num_tokens + ), f"{key=} {old_value=} {num_tokens=} {batch=}" + output_dict[key] = old_value[start_token_index:end_token_index] + + for key in [ + "req_pool_indices", + "seq_lens", + "seq_lens_cpu", + "extend_seq_lens", + "extend_prefix_lens", + "extend_start_loc", + "extend_prefix_lens_cpu", + "extend_seq_lens_cpu", + "extend_logprob_start_lens_cpu", + "lora_paths", + ]: + old_value = getattr(batch, key) + if old_value is None: + continue + assert ( + len(old_value) == num_seqs + ), f"{key=} {old_value=} {num_seqs=} {batch=}" + output_dict[key] = old_value[start_seq_index:end_seq_index] + + for key in [ + "forward_mode", + "return_logprob", + "req_to_token_pool", + "token_to_kv_pool", + "can_run_dp_cuda_graph", + "global_forward_mode", + "spec_info", + "spec_algorithm", + "capture_hidden_mode", + "padded_static_len", + "mrope_positions", # only used by qwen2-vl, thus not care + ]: + output_dict[key] = getattr(batch, key) + + assert ( + _compute_extend_num_tokens(batch.input_ids, batch.forward_mode) + == batch.extend_num_tokens + ), f"{batch=}" + extend_num_tokens = _compute_extend_num_tokens( + output_dict["input_ids"], output_dict["forward_mode"] + ) + + # TODO improve, e.g. unify w/ `init_raw` + if global_server_args_dict["moe_dense_tp_size"] == 1: + sum_len = end_token_index - start_token_index + gathered_buffer = torch.zeros( + (sum_len, batch.gathered_buffer.shape[1]), + dtype=batch.gathered_buffer.dtype, + device=batch.gathered_buffer.device, + ) + else: + gathered_buffer = None + + output_dict.update( + dict( + batch_size=end_seq_index - start_seq_index, + seq_lens_sum=( + output_dict["seq_lens_cpu"].sum() + if "seq_lens_cpu" in output_dict + else None + ), + extend_num_tokens=extend_num_tokens, + attn_backend=output_attn_backend, + tbo_split_seq_index=None, + tbo_parent_token_range=(start_token_index, end_token_index), + tbo_children=None, + global_num_tokens_gpu=None, + global_num_tokens_cpu=None, + gathered_buffer=gathered_buffer, + global_num_tokens_for_logprob_gpu=None, + global_num_tokens_for_logprob_cpu=None, + sampling_info=None, + # For logits and logprobs post processing, thus we do not care + temp_scaled_logprobs=False, + temperature=None, + top_p_normalized_logprobs=False, + top_p=None, + mm_inputs=None, + num_token_non_padded=None, + ) + ) + + errors = [] + for field in dataclasses.fields(ForwardBatch): + if getattr(batch, field.name) is not None and field.name not in output_dict: + errors.append( + f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})" + ) + if len(errors) > 0: + raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors)) + + return ForwardBatch(**output_dict) + + +def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): + if forward_mode.is_extend(): + return input_ids.shape[0] + elif forward_mode.is_decode() or forward_mode.is_idle(): + return None + raise NotImplementedError + + +# -------------------------------- Execution --------------------------------------- + + +def model_forward_maybe_tbo( + layers, + enable_tbo: bool, + positions: torch.Tensor, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, +): + inputs = dict( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=residual, + zero_allocator=zero_allocator, + ) + operations_strategy = OperationsStrategy.init_new_tbo( + layers, forward_batch.global_forward_mode + ) + if enable_tbo: + return _model_forward_tbo(inputs, operations_strategy) + else: + return _model_forward_non_tbo(inputs, operations_strategy) + + +def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy): + # The attn_tp_size!=1 case is not yet extracted to master + assert get_attention_tp_size() == 1 + + inputs_arr = _model_forward_tbo_split_inputs(**inputs) + del inputs + + with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms): + outputs_arr = execute_overlapped_operations( + inputs_arr=inputs_arr, + operations_arr=[operations_strategy.operations] * 2, + delta_stages=[0, operations_strategy.tbo_delta_stages], + ) + + return _model_forward_tbo_merge_outputs(*outputs_arr) + + +def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy): + outputs = execute_operations(inputs, operations_strategy.operations) + return outputs["hidden_states"], outputs["residual"] + + +def _model_forward_tbo_split_inputs( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, +) -> List[Dict]: + return [ + dict( + **_model_forward_filter_inputs( + hidden_states=hidden_states, + residual=residual, + positions=positions, + output_forward_batch=output_forward_batch, + tbo_subbatch_index=tbo_subbatch_index, + ), + zero_allocator=zero_allocator, + ) + for tbo_subbatch_index, output_forward_batch in enumerate( + forward_batch.tbo_children + ) + ] + + +def _model_forward_filter_inputs( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + output_forward_batch: ForwardBatch, + tbo_subbatch_index: int, +) -> Dict: + token_slice = slice(*output_forward_batch.tbo_parent_token_range) + return dict( + hidden_states=hidden_states[token_slice], + residual=None if residual is None else residual[token_slice], + positions=positions[token_slice], + forward_batch=output_forward_batch, + tbo_subbatch_index=tbo_subbatch_index, + ) + + +def _model_forward_tbo_merge_outputs(output_a, output_b): + def _handle_key(name): + value_a = output_a[name] + value_b = output_b[name] + assert (value_a is None) == (value_b is None) + if value_a is None: + return None + return torch.concat([value_a, value_b], dim=0) + + return _handle_key("hidden_states"), _handle_key("residual") + + +# -------------------------------- Utilities and wrappers --------------------------------------- + + +class MaybeTboDeepEPDispatcher: + def __init__(self, **kwargs): + num_inner_dispatchers = ( + 2 if global_server_args_dict["enable_two_batch_overlap"] else 1 + ) + self._inners = [ + DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] + + def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): + return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) + + def dispatch(self, **kwargs): + return self._execute("dispatch", **kwargs) + + def dispatch_a(self, **kwargs): + return self._execute("dispatch_a", **kwargs) + + def dispatch_b(self, **kwargs): + return self._execute("dispatch_b", **kwargs) + + def combine(self, **kwargs): + return self._execute("combine", **kwargs) + + def combine_a(self, **kwargs): + return self._execute("combine_a", **kwargs) + + def combine_b(self, **kwargs): + return self._execute("combine_b", **kwargs) diff --git a/test/srt/test_two_batch_overlap.py b/test/srt/test_two_batch_overlap.py new file mode 100644 index 00000000000..89e793ca62c --- /dev/null +++ b/test/srt/test_two_batch_overlap.py @@ -0,0 +1,72 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTwoBatchOverlap(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--enable-deepep-moe", + "--deepep-mode", + "normal", + "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph + "--enable-two-batch-overlap", + ], + env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_generate_single_prompt(self): + response = requests.post( + self.base_url + "/generate", + # we use an uncommon start to minimise the chance that the cache is hit by chance + json={ + "text": "_ 1+1=2, 1+2=3, 1+3=4, 1+4=", + "sampling_params": {"temperature": 0, "max_new_tokens": 8}, + }, + ) + print(f"{response.json()=}") + self.assertEquals(response.json()["text"], "5, 1+5=6") + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +if __name__ == "__main__": + unittest.main()