Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 33 additions & 45 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,28 +683,24 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):

def _executor_loop_pp(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
microbatch_id = 0
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not got_finish_signal or len(self.active_requests) > 0:
while not self.is_shutdown or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
if self.is_shutdown and len(self.active_requests) == 0:
break

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.new_active_requests_queue_latency_ms)

if not got_finish_signal:
self._pad_attention_dp_dummy_request()
self._pad_attention_dp_dummy_request()

scheduled_batch, _, _ = self._schedule()

Expand Down Expand Up @@ -839,22 +835,19 @@ def _executor_loop_pp(self):

def _executor_loop(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
is_ngram = hasattr(
self.model_engine, "spec_config"
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
)
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not got_finish_signal or len(self.active_requests) > 0:
while not self.is_shutdown or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
if self.is_shutdown and len(self.active_requests) == 0:
break

if self.kv_cache_transceiver:
Expand All @@ -865,8 +858,7 @@ def _executor_loop(self):
len(new_requests),
self.new_active_requests_queue_latency_ms)

if not got_finish_signal:
self._pad_attention_dp_dummy_request()
self._pad_attention_dp_dummy_request()

if self.draft_model_engine is not None or is_ngram:
self._prepare_draft_requests()
Expand Down Expand Up @@ -985,18 +977,15 @@ def _prepare_draft_requests(self):

def _executor_loop_overlap(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not got_finish_signal or len(self.active_requests) > 0:
while not self.is_shutdown or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
if self.is_shutdown and len(self.active_requests) == 0:
break

if self.kv_cache_transceiver:
Expand All @@ -1007,8 +996,7 @@ def _executor_loop_overlap(self):
len(new_requests),
self.new_active_requests_queue_latency_ms)

if not got_finish_signal:
self._pad_attention_dp_dummy_request()
self._pad_attention_dp_dummy_request()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -1214,6 +1202,16 @@ def _fetch_new_requests(self):
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)

# drop requests arriving after shutdown
valid_new_requests = []
for req_item in new_requests:
if req_item.is_shutdown_request():
self.is_shutdown = True
break
else:
valid_new_requests.append(req_item)
new_requests = valid_new_requests

if py_request_objects and (self.dist.tp_size > 1
or self.dist.has_pp) and self.dist.rank > 0:
attr_name, req_obj_dict = py_request_objects
Expand All @@ -1222,6 +1220,8 @@ def _fetch_new_requests(self):

if not self.enable_attention_dp:
self._update_new_active_requests_queue_latency(new_requests)
new_requests = self._merge_requests(new_requests)
self.active_requests.extend(new_requests)
return new_requests

num_new_requests_all_ranks = len(new_requests)
Expand All @@ -1233,8 +1233,7 @@ def _fetch_new_requests(self):

self.has_context_request = False
new_requests_cur_rank = []
if new_requests != [] and not new_requests[0].is_shutdown_request(
) and self.expected_num_active_requests > all_ranks_num_active_requests[
if new_requests != [] and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
# Balance context tokens across ranks
HeapVal = namedtuple(
Expand Down Expand Up @@ -1289,8 +1288,8 @@ def _fetch_new_requests(self):
self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len(
new_requests_cur_rank)

if len(new_requests) == 1 and new_requests[0].is_shutdown_request():
new_requests_cur_rank = new_requests
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
self.active_requests.extend(new_requests_cur_rank)
return new_requests_cur_rank

def _add_kv_cache_events(self):
Expand All @@ -1302,16 +1301,6 @@ def _add_kv_cache_events(self):
# to be transferred to main thread when user needs them.
kv_cache_manager.flush_iteration_events()

def _merge_tp_requests(self, new_requests: List[RequestQueueItem]):
for req_item in new_requests:
if req_item.is_shutdown_request():
return True
for req_item in new_requests:
req = executor_request_to_llm_request(req_item.id, req_item.request)
self.active_requests.append(req)

return False

def _collect_py_objects_from_requests(
self, requests: list[RequestQueueItem],
attribute_name: str) -> Optional[tuple[str, dict]]:
Expand All @@ -1337,8 +1326,6 @@ def _attach_py_objects_to_requests(self, requests: list[RequestQueueItem],
to each request.
"""
for item in requests:
if item.is_shutdown_request():
continue
py_obj = py_request_objects.get(item.id)
if py_obj is not None:
setattr(item.request, attribute_name, py_obj)
Expand Down Expand Up @@ -1383,9 +1370,7 @@ def _partition_context(self, ctx_ids_list):

def _merge_star_attention_requests(self,
new_requests: list[RequestQueueItem]):
for req_item in new_requests:
if req_item.is_shutdown_request():
return True
result = []
for req_item in new_requests:
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
ctx_len0 = len(exe_req.input_token_ids)
Expand Down Expand Up @@ -1434,24 +1419,27 @@ def _merge_star_attention_requests(self,
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
self.active_requests.append(req)

return False
result.append(req)

return result

@nvtx_range("_merge_requests")
def _merge_requests(self, new_requests: list[RequestQueueItem]):
cp_config = self.dist.cp_config
if 'cp_type' in cp_config:
cp_type = cp_config['cp_type']
if cp_type == 'star_attention':
ret = self._merge_star_attention_requests(new_requests)
return self._merge_star_attention_requests(new_requests)
elif cp_type == 'ring_attention':
raise NotImplementedError("ring attention not implemented yet")
else:
raise NotImplementedError(f'unsupport cp type {cp_type}')
else:
ret = self._merge_tp_requests(new_requests)
return ret
return [
executor_request_to_llm_request(req_item.id, req_item.request)
for req_item in new_requests
]

@nvtx_range("_schedule")
def _schedule(self):
Expand Down