diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 54ccc556504..128996d93e0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -687,16 +687,20 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests): def _executor_loop_pp(self): torch.cuda.set_device(self.device_id) + got_finish_signal = False + num_dummy_request = 0 microbatch_id = 0 with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None - while not self.is_shutdown or len(self.active_requests) > 0: + while not got_finish_signal or len(self.active_requests) > 0: profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_new_requests() - if self.is_shutdown and len(self.active_requests) == 0: + got_finish_signal = self._merge_requests( + new_requests) or got_finish_signal + if got_finish_signal and len(self.active_requests) == 0: break if self.enable_iter_perf_stats: @@ -704,7 +708,8 @@ def _executor_loop_pp(self): len(new_requests), self.new_active_requests_queue_latency_ms) - self._pad_attention_dp_dummy_request() + if not got_finish_signal: + num_dummy_request = self._pad_attention_dp_dummy_request() scheduled_batch, _, _ = self._schedule() @@ -758,6 +763,9 @@ def _executor_loop_pp(self): microbatch_id=microbatch_id, ) + if num_dummy_request > 0: + self._finish_dummy_request( + sample_state.scheduled_requests) self.micro_batches[microbatch_id] = batch_state # Stage 2: Communicate new tokens for previous batch between ranks @@ -838,6 +846,8 @@ def _executor_loop_pp(self): def _executor_loop(self): torch.cuda.set_device(self.device_id) + got_finish_signal = False + num_dummy_request = 0 is_ngram = hasattr( self.model_engine, "spec_config" ) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram( @@ -845,12 +855,14 @@ def _executor_loop(self): with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None - while not self.is_shutdown or len(self.active_requests) > 0: + while not got_finish_signal or len(self.active_requests) > 0: profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_new_requests() - if self.is_shutdown and len(self.active_requests) == 0: + got_finish_signal = self._merge_requests( + new_requests) or got_finish_signal + if got_finish_signal and len(self.active_requests) == 0: break if self.kv_cache_transceiver: @@ -861,7 +873,8 @@ def _executor_loop(self): len(new_requests), self.new_active_requests_queue_latency_ms) - self._pad_attention_dp_dummy_request() + if not got_finish_signal: + num_dummy_request = self._pad_attention_dp_dummy_request() if self.draft_model_engine is not None or is_ngram: self._prepare_draft_requests() @@ -936,6 +949,9 @@ def _executor_loop(self): scheduled_batch.context_requests ) if self.kv_cache_transceiver else [] + if num_dummy_request > 0: + self._finish_dummy_request(scheduled_batch) + if self.kv_cache_transceiver: # For context only req in transmission, we reset the state since sampler might have changed it for req in ctx_transmission_reqs: @@ -986,15 +1002,19 @@ def _prepare_draft_requests(self): def _executor_loop_overlap(self): torch.cuda.set_device(self.device_id) + got_finish_signal = False + num_dummy_request = 0 with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None - while not self.is_shutdown or len(self.active_requests) > 0: + while not got_finish_signal or len(self.active_requests) > 0: profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_new_requests() - if self.is_shutdown and len(self.active_requests) == 0: + got_finish_signal = self._merge_requests( + new_requests) or got_finish_signal + if got_finish_signal and len(self.active_requests) == 0: break if self.kv_cache_transceiver: @@ -1005,7 +1025,8 @@ def _executor_loop_overlap(self): len(new_requests), self.new_active_requests_queue_latency_ms) - self._pad_attention_dp_dummy_request() + if not got_finish_signal: + num_dummy_request = self._pad_attention_dp_dummy_request() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) @@ -1076,6 +1097,9 @@ def _executor_loop_overlap(self): scheduled_batch.context_requests ) if self.kv_cache_transceiver else [] + if num_dummy_request > 0: + self._finish_dummy_request(scheduled_batch) + has_previous_batch = self.previous_batch is not None if has_previous_batch: previous_batch_size = self.previous_batch.sample_state.scheduled_requests.batch_size @@ -1211,16 +1235,6 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: new_requests, py_request_objects = self._broadcast_new_requests( new_requests, py_request_objects) - # drop requests arriving after shutdown - valid_new_requests = [] - for req_item in new_requests: - if req_item.is_shutdown_request(): - self.is_shutdown = True - break - else: - valid_new_requests.append(req_item) - new_requests = valid_new_requests - if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp) and self.dist.rank > 0: attr_name, req_obj_dict = py_request_objects @@ -1229,8 +1243,6 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: if not self.enable_attention_dp: self._update_new_active_requests_queue_latency(new_requests) - new_requests = self._merge_requests(new_requests) - self.active_requests.extend(new_requests) return new_requests num_new_requests_all_ranks = len(new_requests) @@ -1242,7 +1254,8 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: self.has_context_request = False new_requests_cur_rank = [] - if new_requests != [] and self.expected_num_active_requests > all_ranks_num_active_requests[ + if new_requests != [] and not new_requests[0].is_shutdown_request( + ) and self.expected_num_active_requests > all_ranks_num_active_requests[ self.dist.tp_rank]: # Balance context tokens across ranks HeapVal = namedtuple( @@ -1268,14 +1281,14 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: new_requests = sorted(new_requests, key=lambda x: len(x.request.input_token_ids), reverse=True) - for req_item in new_requests: + for request_item in new_requests: val = heapq.heappop(all_ranks_new_requests_heap) val = val._replace( num_tokens=val.num_tokens + - len(req_item.request.input_token_ids), + len(request_item.request.input_token_ids), num_requests=val.num_requests - 1, ) - val.request_list.append(req_item) + val.request_list.append(request_item) if val.num_requests > 0: heapq.heappush(all_ranks_new_requests_heap, val) elif val.rank == self.dist.tp_rank: @@ -1297,8 +1310,8 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len( new_requests_cur_rank) - new_requests_cur_rank = self._merge_requests(new_requests_cur_rank) - self.active_requests.extend(new_requests_cur_rank) + if len(new_requests) == 1 and new_requests[0].is_shutdown_request(): + new_requests_cur_rank = new_requests return new_requests_cur_rank def _add_kv_cache_events(self): @@ -1310,6 +1323,29 @@ def _add_kv_cache_events(self): # to be transferred to main thread when user needs them. kv_cache_manager.flush_iteration_events() + def _merge_tp_requests(self, new_requests: List[RequestQueueItem]): + for req_item in new_requests: + if req_item.is_shutdown_request(): + return True + for req_item in new_requests: + req = executor_request_to_llm_request(req_item.id, req_item.request) + self.active_requests.append(req) + + return False + + def _finish_dummy_request(self, scheduled_requests: ScheduledRequests): + for req in scheduled_requests.context_requests: + if req.is_attention_dp_dummy: + req.state = LlmRequestState.GENERATION_COMPLETE + for req in scheduled_requests.generation_requests: + if req.is_attention_dp_dummy: + req.state = LlmRequestState.GENERATION_COMPLETE + for req in self.active_requests[:]: + if req.is_attention_dp_dummy: + self.inflight_req_ids.erase(req.request_id) + self._terminate_request(req) + self.active_requests.remove(req) + def _collect_py_objects_from_requests( self, requests: list[RequestQueueItem], attribute_name: str) -> Optional[tuple[str, dict]]: @@ -1335,6 +1371,8 @@ def _attach_py_objects_to_requests(self, requests: list[RequestQueueItem], to each request. """ for item in requests: + if item.is_shutdown_request(): + continue py_obj = py_request_objects.get(item.id) if py_obj is not None: setattr(item.request, attribute_name, py_obj) @@ -1379,7 +1417,9 @@ def _partition_context(self, ctx_ids_list): def _merge_star_attention_requests(self, new_requests: list[RequestQueueItem]): - result = [] + for req_item in new_requests: + if req_item.is_shutdown_request(): + return True for req_item in new_requests: req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query ctx_len0 = len(exe_req.input_token_ids) @@ -1429,10 +1469,9 @@ def _merge_star_attention_requests(self, req.ctx_blocks = ctx_blocks req.ctx_position_blocks = position_blocks req.query_id = query_token_ids + self.active_requests.append(req) - result.append(req) - - return result + return False @nvtx_range("_merge_requests") def _merge_requests(self, new_requests: list[RequestQueueItem]): @@ -1440,18 +1479,14 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]): if 'cp_type' in cp_config: cp_type = cp_config['cp_type'] if cp_type == 'star_attention': - return self._merge_star_attention_requests(new_requests) + ret = self._merge_star_attention_requests(new_requests) elif cp_type == 'ring_attention': raise NotImplementedError("ring attention not implemented yet") else: raise NotImplementedError(f'unsupport cp type {cp_type}') else: - return [ - executor_request_to_llm_request( - req_item.id, req_item.request, - self._should_exclude_last_generation_logits()) - for req_item in new_requests - ] + ret = self._merge_tp_requests(new_requests) + return ret @nvtx_range("_schedule") def _schedule(self): @@ -1485,10 +1520,10 @@ def _check_disagg_gen_transfer_status(self): @nvtx_range("_pad_attention_dp_dummy_request") def _pad_attention_dp_dummy_request(self): """ - Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request. + Pad dummy requests to ensure each attention_dp rank has the same number of active requests """ if not self.enable_attention_dp: - return + return 0 assert self.expected_num_active_requests >= len(self.active_requests) if self.kv_cache_transceiver is None: @@ -1500,15 +1535,19 @@ def _pad_attention_dp_dummy_request(self): for req in self.active_requests ]) - if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0: - llm_request = self.kv_cache_manager.add_dummy_requests( - request_ids=[0], + num_dummy_request = self.expected_num_active_requests - num_active_request + if num_dummy_request > 0: + llm_request_list = self.kv_cache_manager.add_dummy_requests( + request_ids=list(range(num_dummy_request)), is_gen=not self.has_context_request, prepare_resource=not self.has_context_request, - max_num_draft_tokens=self.max_draft_tokens, - )[0] - llm_request.is_attention_dp_dummy = True - self.active_requests.append(llm_request) + max_num_draft_tokens=0 + if self.has_context_request else self.max_draft_tokens, + ) + for llm_request in llm_request_list: + llm_request.is_attention_dp_dummy = True + self.active_requests += llm_request_list + return num_dummy_request @nvtx_range("_prepare_disagg_gen_init") def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): @@ -1624,20 +1663,13 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, return None def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): - # handle potential attention dp dummy request - if self.active_requests and self.active_requests[ - -1].is_attention_dp_dummy: - request = self.active_requests[-1] - request.state = LlmRequestState.GENERATION_COMPLETE - self.inflight_req_ids.erase(request.py_request_id) - self._terminate_request(request) - self.active_requests.remove(request) - for request in scheduled_requests.context_requests: if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests request.move_to_next_context_chunk() if request.context_remaining_length == 0: request.state = LlmRequestState.GENERATION_IN_PROGRESS + if request.is_attention_dp_dummy: + request.state = LlmRequestState.GENERATION_COMPLETE def _update_request_states_star_attention( self, scheduled_requests: ScheduledRequests):