1414from vllm .config import CUDAGraphMode , ParallelConfig , VllmConfig
1515from vllm .logger import init_logger
1616from vllm .platforms import current_platform
17+ from vllm .v1 .worker .ubatch_utils import UBatchSlices , is_second_ubatch_empty
1718
1819if TYPE_CHECKING :
1920 from vllm .attention .backends .abstract import AttentionMetadata
@@ -97,6 +98,53 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
9798 dist .all_reduce (num_tokens_tensor , group = group )
9899 return num_tokens_tensor .cpu ()
99100
101+ @staticmethod
102+ def should_ubatch_across_dp (
103+ should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
104+ padded_num_tokens_per_ubatch : int , dp_size : int ,
105+ dp_rank : int ) -> tuple [bool , Optional [torch .Tensor ]]:
106+ """
107+ 1. Decides if each DP rank is going to microbatch. Either all ranks
108+ run with microbatching or none of them do. If this function decides
109+ not to run with microbatching. It will "abort" meaning that no padding
110+ information will be returned to the caller. It will return (False, None)
111+
112+ 2. Determines the total number of tokens that each rank will run.
113+ All ranks will be padded out so that the run with the same number
114+ of tokens
115+
116+ Returns: tuple[
117+ should_ubatch: Are all DP ranks going to microbatch
118+ num_tokens_after_padding: A tensor containing the total number of
119+ tokens per-microbatch for each DP rank including padding. Will be
120+ None if should_ubatch if False
121+ ]
122+ """
123+
124+ device = current_platform .device_type
125+ tensor = torch .zeros (3 , dp_size , device = device , dtype = torch .int32 )
126+ tensor [0 ][dp_rank ] = orig_num_tokens_per_ubatch
127+ tensor [1 ][dp_rank ] = padded_num_tokens_per_ubatch
128+ tensor [2 ][dp_rank ] = 1 if should_ubatch else 0
129+
130+ from vllm .distributed .parallel_state import get_dp_group
131+ dist .all_reduce (tensor , group = get_dp_group ().device_group )
132+
133+ result : bool = bool (torch .all (tensor [2 ] == 1 ).item ())
134+ if not result :
135+ return result , None
136+
137+ orig_num_tokens_tensor = tensor [0 , :]
138+ padded_num_tokens_tensor = tensor [1 , :]
139+
140+ orig_min_num_tokens = int (orig_num_tokens_tensor .min ().item ())
141+ padded_max_num_tokens = int (padded_num_tokens_tensor .max ().item ())
142+ if is_second_ubatch_empty (orig_min_num_tokens , padded_max_num_tokens ):
143+ logger .debug ("Aborting ubatching %s %s" , orig_min_num_tokens ,
144+ padded_max_num_tokens )
145+ return False , None
146+ return result , padded_num_tokens_tensor .cpu ()
147+
100148 @staticmethod
101149 def make (
102150 parallel_config : ParallelConfig ,
@@ -119,14 +167,15 @@ def make(
119167
120168 # If num_tokens_across_dp is None, it will be computed by all_reduce
121169 # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
122- assert (num_tokens_across_dp is None
123- or num_tokens_across_dp [dp_rank ] == batchsize )
170+ assert (num_tokens_across_dp is None or num_tokens_across_dp [ dp_rank ]
171+ == batchsize ), f" { num_tokens_across_dp [dp_rank ]} { batchsize } "
124172 if num_tokens_across_dp is None :
125173 num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
126174 batchsize , dp_size , dp_rank )
127175 max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp )
128176 cu_tokens_across_dp_cpu = torch .cumsum (num_tokens_across_dp , dim = 0 )
129- return DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu )
177+ return DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu ,
178+ num_tokens_across_dp )
130179
131180 @contextmanager
132181 def chunked_sizes (self , max_chunk_size_per_rank : int , chunk_idx : int ):
@@ -179,9 +228,12 @@ class ForwardContext:
179228 Type AttentionMetadata for v0,
180229 Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
181230 attention layer to its attention metadata
182- set dynamically for each forward pass
231+ Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
232+ for each microbatch.
233+ Set dynamically for each forward pass
183234 """
184- attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ]]
235+ attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ],
236+ list [dict [str , "AttentionMetadata" ]]]
185237 # TODO: remove after making all virtual_engines share the same kv cache
186238 virtual_engine : int # set dynamically for each forward pass
187239 # set dynamically for each forward pass
@@ -191,6 +243,8 @@ class ForwardContext:
191243 cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE
192244 batch_descriptor : Optional [BatchDescriptor ] = None
193245
246+ ubatch_slices : Optional [UBatchSlices ] = None
247+
194248 def __post_init__ (self ):
195249 assert self .cudagraph_runtime_mode in [
196250 CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL ], \
@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
208262 return _forward_context
209263
210264
265+ def create_forward_context (
266+ attn_metadata : Any ,
267+ vllm_config : VllmConfig ,
268+ virtual_engine : int = 0 ,
269+ dp_metadata : Optional [DPMetadata ] = None ,
270+ cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE ,
271+ batch_descriptor : Optional [BatchDescriptor ] = None ,
272+ ubatch_slices : Optional [UBatchSlices ] = None ):
273+ return ForwardContext (no_compile_layers = vllm_config .compilation_config .
274+ static_forward_context ,
275+ virtual_engine = virtual_engine ,
276+ attn_metadata = attn_metadata ,
277+ dp_metadata = dp_metadata ,
278+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
279+ batch_descriptor = batch_descriptor ,
280+ ubatch_slices = ubatch_slices )
281+
282+
283+ @contextmanager
284+ def override_forward_context (forward_context : Optional [ForwardContext ]):
285+ """A context manager that overrides the current forward context.
286+ This is used to override the forward context for a specific
287+ forward pass.
288+ """
289+ global _forward_context
290+ prev_context = _forward_context
291+ _forward_context = forward_context
292+ try :
293+ yield
294+ finally :
295+ _forward_context = prev_context
296+
297+
211298@contextmanager
212299def set_forward_context (
213300 attn_metadata : Any ,
@@ -216,7 +303,8 @@ def set_forward_context(
216303 num_tokens : Optional [int ] = None ,
217304 num_tokens_across_dp : Optional [torch .Tensor ] = None ,
218305 cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE ,
219- batch_descriptor : Optional [BatchDescriptor ] = None ):
306+ batch_descriptor : Optional [BatchDescriptor ] = None ,
307+ ubatch_slices : Optional [UBatchSlices ] = None ):
220308 """A context manager that stores the current forward context,
221309 can be attention metadata, etc.
222310 Here we can inject common logic for every model forward pass.
@@ -225,27 +313,22 @@ def set_forward_context(
225313 need_to_track_batchsize = track_batchsize and attn_metadata is not None
226314 if need_to_track_batchsize :
227315 forward_start_time = time .perf_counter ()
316+
228317 dp_metadata : Optional [DPMetadata ] = None
229318 if vllm_config .parallel_config .data_parallel_size > 1 and (
230319 attn_metadata is not None or num_tokens is not None ):
231320 dp_metadata = DPMetadata .make (vllm_config .parallel_config ,
232321 attn_metadata , num_tokens or 0 ,
233322 num_tokens_across_dp )
234323
235- global _forward_context
236- prev_context = _forward_context
237- _forward_context = ForwardContext (
238- no_compile_layers = vllm_config .compilation_config .
239- static_forward_context ,
240- virtual_engine = virtual_engine ,
241- attn_metadata = attn_metadata ,
242- dp_metadata = dp_metadata ,
243- cudagraph_runtime_mode = cudagraph_runtime_mode ,
244- batch_descriptor = batch_descriptor ,
245- )
324+ forward_context = create_forward_context (attn_metadata , vllm_config ,
325+ virtual_engine , dp_metadata ,
326+ cudagraph_runtime_mode ,
327+ batch_descriptor , ubatch_slices )
246328
247329 try :
248- yield
330+ with override_forward_context (forward_context ):
331+ yield
249332 finally :
250333 global last_logging_time , batchsize_logging_interval
251334 if need_to_track_batchsize :
@@ -282,5 +365,3 @@ def set_forward_context(
282365 logger .info (("Batchsize forward time stats "
283366 "(batchsize, count, median_time(ms)): %s" ),
284367 forward_stats )
285-
286- _forward_context = prev_context
0 commit comments