@@ -241,6 +241,10 @@ def __init__(self,
241241 self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
242242 self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
243243 self .stream_interval = model_engine .pytorch_backend_config .stream_interval
244+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
245+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
246+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
247+
244248 self .num_fetch_requests_cur_rank = 0
245249 self .num_fetch_requests = 0
246250 self .shutdown_event = threading .Event ()
@@ -287,6 +291,9 @@ def __init__(self,
287291 self .draft_model_engine .warmup (self .resource_manager )
288292
289293 self .is_shutdown = False
294+ self .max_batch_size = max_batch_size
295+ self .adp_ctx_waiting_iters = 0
296+ self .adp_ctx_batching_wait_iters = 0
290297
291298 self .stats_lock = threading .Lock ()
292299 self .stats = []
@@ -1228,7 +1235,15 @@ def _broadcast_new_requests(
12281235 def _fetch_new_requests (self ) -> List [RequestQueueItem ]:
12291236 if self .enable_attention_dp :
12301237 all_ranks_num_active_requests = []
1231- responses_list = self .dist .tp_allgather (len (self .active_requests ))
1238+ num_active_requests = len (self .active_requests )
1239+ responses_list = self .dist .tp_allgather (num_active_requests )
1240+ # Debug check - remove after verification
1241+ if not all (isinstance (x , int ) for x in responses_list ):
1242+ raise RuntimeError (
1243+ f"tp_allgather returned non-integer values: { responses_list } " +
1244+ f"Expected all ranks to return int from { num_active_requests } and { self .active_requests } ."
1245+ )
1246+
12321247 for num_active_requests in responses_list :
12331248 all_ranks_num_active_requests .append (num_active_requests )
12341249 total_num_active_requests = sum (all_ranks_num_active_requests )
@@ -1518,8 +1533,66 @@ def _schedule(self):
15181533 scheduler_output = self .scheduler .schedule_request (
15191534 self .active_requests , self .inflight_req_ids )
15201535 scheduled_requests = ScheduledRequests ()
1536+ context_requests = scheduler_output .context_requests
1537+ if self .enable_attention_dp :
1538+ num_scheduled_context_requests = len (
1539+ scheduler_output .context_requests )
1540+ num_scheduled_generation_requests = len (
1541+ scheduler_output .generation_requests )
1542+ num_scheduled_tokens = sum ([
1543+ len (req .get_tokens (0 )) for req in context_requests
1544+ ]) + num_scheduled_generation_requests
1545+ responses_list = self .dist .tp_allgather ([
1546+ num_scheduled_context_requests ,
1547+ num_scheduled_generation_requests , num_scheduled_tokens
1548+ ])
1549+ all_ranks_num_scheduled_context_requests = [
1550+ response [0 ] for response in responses_list
1551+ ]
1552+ all_ranks_num_scheduled_generation_requests = [
1553+ response [1 ] for response in responses_list
1554+ ]
1555+ all_ranks_num_scheduled_tokens = [
1556+ response [2 ] for response in responses_list
1557+ ]
1558+
1559+ all_ranks_have_free_ctx_slots = all ([
1560+ num_gen < self .max_batch_size
1561+ for num_gen in all_ranks_num_scheduled_generation_requests
1562+ ])
1563+ all_ranks_have_multi_gen = all ([
1564+ num_gen > 1
1565+ for num_gen in all_ranks_num_scheduled_generation_requests
1566+ ])
1567+ all_ranks_have_ctx_requests = all ([
1568+ num_ctx > 0
1569+ for num_ctx in all_ranks_num_scheduled_context_requests
1570+ ])
1571+
1572+ all_ranks_have_gen_requests = all ([
1573+ num_gen > 0
1574+ for num_gen in all_ranks_num_scheduled_generation_requests
1575+ ])
1576+ if self .use_attention_dp_config :
1577+ # wait for all ranks have context requests
1578+ if all_ranks_have_multi_gen :
1579+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1580+ self .adp_ctx_waiting_iters = 0
1581+ else :
1582+ self .adp_ctx_waiting_iters += 1
1583+ context_requests = []
1584+ if self .adp_ctx_waiting_iters >= self .attention_dp_time_out_iters :
1585+ self .adp_ctx_waiting_iters = 0
1586+ context_requests = scheduler_output .context_requests
1587+ # balance number of context requests across ranks
1588+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests :
1589+ if self .adp_ctx_batching_wait_iters <= self .attention_dp_batching_wait_iters :
1590+ self .adp_ctx_batching_wait_iters += 1
1591+ context_requests = []
1592+ else :
1593+ self .adp_ctx_batching_wait_iters = 0
15211594
1522- scheduled_requests .context_requests = scheduler_output . context_requests
1595+ scheduled_requests .context_requests = context_requests
15231596 scheduled_requests .generation_requests = scheduler_output .generation_requests
15241597 scheduled_requests .paused_requests = scheduler_output .paused_requests
15251598 return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments