@@ -241,6 +241,10 @@ def __init__(self,
241241 self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
242242 self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
243243 self .stream_interval = model_engine .pytorch_backend_config .stream_interval
244+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
245+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
246+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
247+
244248 self .num_fetch_requests_cur_rank = 0
245249 self .num_fetch_requests = 0
246250 self .shutdown_event = threading .Event ()
@@ -287,6 +291,9 @@ def __init__(self,
287291 self .draft_model_engine .warmup (self .resource_manager )
288292
289293 self .is_shutdown = False
294+ self .max_batch_size = max_batch_size
295+ self .adp_ctx_waiting_iters = 0
296+ self .adp_ctx_batching_wait_iters = 0
290297
291298 self .stats_lock = threading .Lock ()
292299 self .stats = []
@@ -1228,7 +1235,16 @@ def _broadcast_new_requests(
12281235 def _fetch_new_requests (self ) -> List [RequestQueueItem ]:
12291236 if self .enable_attention_dp :
12301237 all_ranks_num_active_requests = []
1231- responses_list = self .dist .tp_allgather (len (self .active_requests ))
1238+ num_active_requests = len (self .active_requests )
1239+ responses_list = self .dist .tp_allgather (num_active_requests )
1240+ # Debug check - remove after verification
1241+ if not all (isinstance (x , int ) for x in responses_list ):
1242+ raise RuntimeError (
1243+ f"tp_allgather returned non-integer values: { responses_list } "
1244+ +
1245+ f"Expected all ranks to return int from { num_active_requests } and { self .active_requests } ."
1246+ )
1247+
12321248 for num_active_requests in responses_list :
12331249 all_ranks_num_active_requests .append (num_active_requests )
12341250 total_num_active_requests = sum (all_ranks_num_active_requests )
@@ -1518,8 +1534,66 @@ def _schedule(self):
15181534 scheduler_output = self .scheduler .schedule_request (
15191535 self .active_requests , self .inflight_req_ids )
15201536 scheduled_requests = ScheduledRequests ()
1537+ context_requests = scheduler_output .context_requests
1538+ if self .enable_attention_dp :
1539+ num_scheduled_context_requests = len (
1540+ scheduler_output .context_requests )
1541+ num_scheduled_generation_requests = len (
1542+ scheduler_output .generation_requests )
1543+ num_scheduled_tokens = sum ([
1544+ len (req .get_tokens (0 )) for req in context_requests
1545+ ]) + num_scheduled_generation_requests
1546+ responses_list = self .dist .tp_allgather ([
1547+ num_scheduled_context_requests ,
1548+ num_scheduled_generation_requests , num_scheduled_tokens
1549+ ])
1550+ all_ranks_num_scheduled_context_requests = [
1551+ response [0 ] for response in responses_list
1552+ ]
1553+ all_ranks_num_scheduled_generation_requests = [
1554+ response [1 ] for response in responses_list
1555+ ]
1556+ all_ranks_num_scheduled_tokens = [
1557+ response [2 ] for response in responses_list
1558+ ]
1559+
1560+ all_ranks_have_free_ctx_slots = all ([
1561+ num_gen < self .max_batch_size
1562+ for num_gen in all_ranks_num_scheduled_generation_requests
1563+ ])
1564+ all_ranks_have_multi_gen = all ([
1565+ num_gen > 1
1566+ for num_gen in all_ranks_num_scheduled_generation_requests
1567+ ])
1568+ all_ranks_have_ctx_requests = all ([
1569+ num_ctx > 0
1570+ for num_ctx in all_ranks_num_scheduled_context_requests
1571+ ])
1572+
1573+ all_ranks_have_gen_requests = all ([
1574+ num_gen > 0
1575+ for num_gen in all_ranks_num_scheduled_generation_requests
1576+ ])
1577+ if self .use_attention_dp_config :
1578+ # wait for all ranks have context requests
1579+ if all_ranks_have_multi_gen :
1580+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1581+ self .adp_ctx_waiting_iters = 0
1582+ else :
1583+ self .adp_ctx_waiting_iters += 1
1584+ context_requests = []
1585+ if self .adp_ctx_waiting_iters >= self .attention_dp_time_out_iters :
1586+ self .adp_ctx_waiting_iters = 0
1587+ context_requests = scheduler_output .context_requests
1588+ # balance number of context requests across ranks
1589+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests :
1590+ if self .adp_ctx_batching_wait_iters <= self .attention_dp_batching_wait_iters :
1591+ self .adp_ctx_batching_wait_iters += 1
1592+ context_requests = []
1593+ else :
1594+ self .adp_ctx_batching_wait_iters = 0
15211595
1522- scheduled_requests .context_requests = scheduler_output . context_requests
1596+ scheduled_requests .context_requests = context_requests
15231597 scheduled_requests .generation_requests = scheduler_output .generation_requests
15241598 scheduled_requests .paused_requests = scheduler_output .paused_requests
15251599 return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments