Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 88 additions & 56 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,24 +687,29 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):

def _executor_loop_pp(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
microbatch_id = 0
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not got_finish_signal or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
break

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.new_active_requests_queue_latency_ms)

self._pad_attention_dp_dummy_request()
if not got_finish_signal:
num_dummy_request = self._pad_attention_dp_dummy_request()

scheduled_batch, _, _ = self._schedule()

Expand Down Expand Up @@ -758,6 +763,9 @@ def _executor_loop_pp(self):
microbatch_id=microbatch_id,
)

if num_dummy_request > 0:
self._finish_dummy_request(
sample_state.scheduled_requests)
self.micro_batches[microbatch_id] = batch_state

# Stage 2: Communicate new tokens for previous batch between ranks
Expand Down Expand Up @@ -838,19 +846,23 @@ def _executor_loop_pp(self):

def _executor_loop(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
is_ngram = hasattr(
self.model_engine, "spec_config"
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
)
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not got_finish_signal or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
break

if self.kv_cache_transceiver:
Expand All @@ -861,7 +873,8 @@ def _executor_loop(self):
len(new_requests),
self.new_active_requests_queue_latency_ms)

self._pad_attention_dp_dummy_request()
if not got_finish_signal:
num_dummy_request = self._pad_attention_dp_dummy_request()

if self.draft_model_engine is not None or is_ngram:
self._prepare_draft_requests()
Expand Down Expand Up @@ -936,6 +949,9 @@ def _executor_loop(self):
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []

if num_dummy_request > 0:
self._finish_dummy_request(scheduled_batch)

if self.kv_cache_transceiver:
# For context only req in transmission, we reset the state since sampler might have changed it
for req in ctx_transmission_reqs:
Expand Down Expand Up @@ -986,15 +1002,19 @@ def _prepare_draft_requests(self):

def _executor_loop_overlap(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not got_finish_signal or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
break

if self.kv_cache_transceiver:
Expand All @@ -1005,7 +1025,8 @@ def _executor_loop_overlap(self):
len(new_requests),
self.new_active_requests_queue_latency_ms)

self._pad_attention_dp_dummy_request()
if not got_finish_signal:
num_dummy_request = self._pad_attention_dp_dummy_request()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -1076,6 +1097,9 @@ def _executor_loop_overlap(self):
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []

if num_dummy_request > 0:
self._finish_dummy_request(scheduled_batch)

has_previous_batch = self.previous_batch is not None
if has_previous_batch:
previous_batch_size = self.previous_batch.sample_state.scheduled_requests.batch_size
Expand Down Expand Up @@ -1211,16 +1235,6 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)

# drop requests arriving after shutdown
valid_new_requests = []
for req_item in new_requests:
if req_item.is_shutdown_request():
self.is_shutdown = True
break
else:
valid_new_requests.append(req_item)
new_requests = valid_new_requests

if py_request_objects and (self.dist.tp_size > 1
or self.dist.has_pp) and self.dist.rank > 0:
attr_name, req_obj_dict = py_request_objects
Expand All @@ -1229,8 +1243,6 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:

if not self.enable_attention_dp:
self._update_new_active_requests_queue_latency(new_requests)
new_requests = self._merge_requests(new_requests)
self.active_requests.extend(new_requests)
return new_requests

num_new_requests_all_ranks = len(new_requests)
Expand All @@ -1242,7 +1254,8 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:

self.has_context_request = False
new_requests_cur_rank = []
if new_requests != [] and self.expected_num_active_requests > all_ranks_num_active_requests[
if new_requests != [] and not new_requests[0].is_shutdown_request(
) and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
# Balance context tokens across ranks
HeapVal = namedtuple(
Expand All @@ -1268,14 +1281,14 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
new_requests = sorted(new_requests,
key=lambda x: len(x.request.input_token_ids),
reverse=True)
for req_item in new_requests:
for request_item in new_requests:
val = heapq.heappop(all_ranks_new_requests_heap)
val = val._replace(
num_tokens=val.num_tokens +
len(req_item.request.input_token_ids),
len(request_item.request.input_token_ids),
num_requests=val.num_requests - 1,
)
val.request_list.append(req_item)
val.request_list.append(request_item)
if val.num_requests > 0:
heapq.heappush(all_ranks_new_requests_heap, val)
elif val.rank == self.dist.tp_rank:
Expand All @@ -1297,8 +1310,8 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len(
new_requests_cur_rank)

new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
self.active_requests.extend(new_requests_cur_rank)
if len(new_requests) == 1 and new_requests[0].is_shutdown_request():
new_requests_cur_rank = new_requests
return new_requests_cur_rank

def _add_kv_cache_events(self):
Expand All @@ -1310,6 +1323,29 @@ def _add_kv_cache_events(self):
# to be transferred to main thread when user needs them.
kv_cache_manager.flush_iteration_events()

def _merge_tp_requests(self, new_requests: List[RequestQueueItem]):
for req_item in new_requests:
if req_item.is_shutdown_request():
return True
for req_item in new_requests:
req = executor_request_to_llm_request(req_item.id, req_item.request)
self.active_requests.append(req)

return False

def _finish_dummy_request(self, scheduled_requests: ScheduledRequests):
for req in scheduled_requests.context_requests:
if req.is_attention_dp_dummy:
req.state = LlmRequestState.GENERATION_COMPLETE
for req in scheduled_requests.generation_requests:
if req.is_attention_dp_dummy:
req.state = LlmRequestState.GENERATION_COMPLETE
for req in self.active_requests[:]:
if req.is_attention_dp_dummy:
self.inflight_req_ids.erase(req.request_id)
self._terminate_request(req)
self.active_requests.remove(req)

def _collect_py_objects_from_requests(
self, requests: list[RequestQueueItem],
attribute_name: str) -> Optional[tuple[str, dict]]:
Expand All @@ -1335,6 +1371,8 @@ def _attach_py_objects_to_requests(self, requests: list[RequestQueueItem],
to each request.
"""
for item in requests:
if item.is_shutdown_request():
continue
py_obj = py_request_objects.get(item.id)
if py_obj is not None:
setattr(item.request, attribute_name, py_obj)
Expand Down Expand Up @@ -1379,7 +1417,9 @@ def _partition_context(self, ctx_ids_list):

def _merge_star_attention_requests(self,
new_requests: list[RequestQueueItem]):
result = []
for req_item in new_requests:
if req_item.is_shutdown_request():
return True
for req_item in new_requests:
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
ctx_len0 = len(exe_req.input_token_ids)
Expand Down Expand Up @@ -1429,29 +1469,24 @@ def _merge_star_attention_requests(self,
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
self.active_requests.append(req)

result.append(req)

return result
return False

@nvtx_range("_merge_requests")
def _merge_requests(self, new_requests: list[RequestQueueItem]):
cp_config = self.dist.cp_config
if 'cp_type' in cp_config:
cp_type = cp_config['cp_type']
if cp_type == 'star_attention':
return self._merge_star_attention_requests(new_requests)
ret = self._merge_star_attention_requests(new_requests)
elif cp_type == 'ring_attention':
raise NotImplementedError("ring attention not implemented yet")
else:
raise NotImplementedError(f'unsupport cp type {cp_type}')
else:
return [
executor_request_to_llm_request(
req_item.id, req_item.request,
self._should_exclude_last_generation_logits())
for req_item in new_requests
]
ret = self._merge_tp_requests(new_requests)
return ret

@nvtx_range("_schedule")
def _schedule(self):
Expand Down Expand Up @@ -1485,10 +1520,10 @@ def _check_disagg_gen_transfer_status(self):
@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""
Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request.
Pad dummy requests to ensure each attention_dp rank has the same number of active requests
"""
if not self.enable_attention_dp:
return
return 0

assert self.expected_num_active_requests >= len(self.active_requests)
if self.kv_cache_transceiver is None:
Expand All @@ -1500,15 +1535,19 @@ def _pad_attention_dp_dummy_request(self):
for req in self.active_requests
])

if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
llm_request = self.kv_cache_manager.add_dummy_requests(
request_ids=[0],
num_dummy_request = self.expected_num_active_requests - num_active_request
if num_dummy_request > 0:
llm_request_list = self.kv_cache_manager.add_dummy_requests(
request_ids=list(range(num_dummy_request)),
is_gen=not self.has_context_request,
prepare_resource=not self.has_context_request,
max_num_draft_tokens=self.max_draft_tokens,
)[0]
llm_request.is_attention_dp_dummy = True
self.active_requests.append(llm_request)
max_num_draft_tokens=0
if self.has_context_request else self.max_draft_tokens,
)
for llm_request in llm_request_list:
llm_request.is_attention_dp_dummy = True
self.active_requests += llm_request_list
return num_dummy_request

@nvtx_range("_prepare_disagg_gen_init")
def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
Expand Down Expand Up @@ -1624,20 +1663,13 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
return None

def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
# handle potential attention dp dummy request
if self.active_requests and self.active_requests[
-1].is_attention_dp_dummy:
request = self.active_requests[-1]
request.state = LlmRequestState.GENERATION_COMPLETE
self.inflight_req_ids.erase(request.py_request_id)
self._terminate_request(request)
self.active_requests.remove(request)

for request in scheduled_requests.context_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
request.move_to_next_context_chunk()
if request.context_remaining_length == 0:
request.state = LlmRequestState.GENERATION_IN_PROGRESS
if request.is_attention_dp_dummy:
request.state = LlmRequestState.GENERATION_COMPLETE

def _update_request_states_star_attention(
self, scheduled_requests: ScheduledRequests):
Expand Down