@@ -177,6 +177,9 @@ def __init__(self,
177177 self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
178178 self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
179179 self .stream_interval = model_engine .pytorch_backend_config .stream_interval
180+ self .attention_dp_enable_balance = model_engine .pytorch_backend_config .attention_dp_enable_balance
181+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
182+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
180183 self .num_fetch_requests_cur_rank = 0
181184 self .num_fetch_requests = 0
182185 self .shutdown_event = threading .Event ()
@@ -215,6 +218,9 @@ def __init__(self,
215218 self .draft_model_engine .warmup (self .resource_manager )
216219
217220 self .is_shutdown = False
221+ self .max_batch_size = max_batch_size
222+ self .adp_ctx_waiting_iters_count = 0
223+ self .adp_ctx_batching_wait_iters_count = 0
218224
219225 # request fetcher initialization
220226 self .executor_request_queue = ExecutorRequestQueue (
@@ -1131,13 +1137,68 @@ def _add_kv_cache_events(self):
11311137 # to be transferred to main thread when user needs them.
11321138 kv_cache_manager .flush_iteration_events ()
11331139
1140+ def _balance_adp_requests (self , context_requests : list [LlmRequest ],
1141+ generation_requests : list [LlmRequest ]):
1142+ balanced_context_requests = context_requests
1143+ num_scheduled_context_requests = len (context_requests )
1144+ num_scheduled_generation_requests = len (generation_requests )
1145+ num_scheduled_tokens = sum (
1146+ [len (req .get_tokens (0 ))
1147+ for req in context_requests ]) + num_scheduled_generation_requests
1148+ responses_list = self .dist .tp_allgather ([
1149+ num_scheduled_context_requests , num_scheduled_generation_requests ,
1150+ num_scheduled_tokens
1151+ ])
1152+ all_ranks_num_scheduled_context_requests = [
1153+ response [0 ] for response in responses_list
1154+ ]
1155+ all_ranks_num_scheduled_generation_requests = [
1156+ response [1 ] for response in responses_list
1157+ ]
1158+ all_ranks_have_free_ctx_slots = all ([
1159+ num_gen < self .max_batch_size
1160+ for num_gen in all_ranks_num_scheduled_generation_requests
1161+ ])
1162+ all_ranks_have_ctx_requests = all ([
1163+ num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests
1164+ ])
1165+ all_ranks_have_gen_requests = all ([
1166+ num_gen > 0
1167+ for num_gen in all_ranks_num_scheduled_generation_requests
1168+ ])
1169+
1170+ if self .attention_dp_enable_balance :
1171+ # wait for all ranks have context requests
1172+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1173+ self .adp_ctx_waiting_iters_count = 0
1174+ # balance number of context requests across ranks
1175+ if all_ranks_have_gen_requests :
1176+ if self .adp_ctx_batching_wait_iters_count < self .attention_dp_batching_wait_iters :
1177+ self .adp_ctx_batching_wait_iters_count += 1
1178+ balanced_context_requests = []
1179+ else :
1180+ self .adp_ctx_batching_wait_iters_count = 0
1181+ else :
1182+ self .adp_ctx_waiting_iters_count += 1
1183+ balanced_context_requests = []
1184+ timeout_reached = self .adp_ctx_waiting_iters_count >= self .attention_dp_time_out_iters
1185+ if timeout_reached or not all_ranks_have_gen_requests :
1186+ self .adp_ctx_waiting_iters_count = 0
1187+ balanced_context_requests = context_requests
1188+ return balanced_context_requests
1189+
11341190 @nvtx_range ("_schedule" )
11351191 def _schedule (self ):
11361192 scheduler_output = self .scheduler .schedule_request (
11371193 self .active_requests , self .inflight_req_ids )
1138- scheduled_requests = ScheduledRequests ()
1194+ scheduled_context_requests = scheduler_output .context_requests
1195+ if self .enable_attention_dp and self .attention_dp_enable_balance :
1196+ scheduled_context_requests = self ._balance_adp_requests (
1197+ scheduler_output .context_requests ,
1198+ scheduler_output .generation_requests )
11391199
1140- scheduled_requests .context_requests = scheduler_output .context_requests
1200+ scheduled_requests = ScheduledRequests ()
1201+ scheduled_requests .context_requests = scheduled_context_requests
11411202 scheduled_requests .generation_requests = scheduler_output .generation_requests
11421203 scheduled_requests .paused_requests = scheduler_output .paused_requests
11431204 return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments