Skip to content

Commit f567db4

Browse files
committed
update logic for IFB
Signed-off-by: Yue Weng <[email protected]>
1 parent b2e0e4b commit f567db4

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ def previous_seq_slots_device():
13151315

13161316
num_tokens = len(input_ids)
13171317
num_draft_tokens = len(draft_tokens)
1318-
num_requests = len(request_ids)
1318+
len(request_ids)
13191319
total_num_tokens = len(position_ids)
13201320
assert total_num_tokens <= self.max_num_tokens, (
13211321
"total_num_tokens should be less than or equal to max_num_tokens")
@@ -1358,34 +1358,34 @@ def previous_seq_slots_device():
13581358
previous_pos_indices_host = torch.tensor(previous_pos_indices,
13591359
dtype=torch.int,
13601360
pin_memory=True)
1361-
new_tokens_len = new_tokens.shape[0]
13621361
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
13631362
previous_pos_indices_host, non_blocking=True)
13641363

1365-
# The order of request in Batch: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
1364+
# The order of requests in a batch: [context requests, generation requests]
1365+
# generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
1366+
# 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1367+
# 2) 'requests that already have previous batch': previous iteration's requests.
1368+
# 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
13661369
# Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments.
1367-
# 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1368-
# Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
1369-
# self.previous_pos_id_offsets_cuda[0 : new_tokens_len] *= 0
1370-
# self.previous_kv_lens_offsets_cuda[0 : num_requests - previous_batch_len - len(extend_dummy_requests)] *= 0
1371-
# Already set to '0' during initialization.
1372-
# 2) 'requests that already have previous batch': enable overlap scheduler.
1373-
# Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
1374-
# self.previous_pos_id_offsets_cuda[new_tokens_len : new_tokens_len + previous_batch_tokens]
1375-
# self.previous_kv_lens_offsets_cuda[num_requests - previous_batch_len - len(extend_dummy_requests) : num_requests - len(extend_dummy_requests)]
1376-
# 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1377-
# self.previous_pos_id_offsets_cuda[new_tokens_len + previous_batch_tokens : num_requests * (1 + max_draft_len)]
1378-
# self.previous_kv_lens_offsets_cuda[num_requests - len(extend_dummy_requests) : num_requests]
1379-
# Already set to '0' during initialization.
1370+
# For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1371+
# Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
1372+
# Already set to '0' during initialization.
1373+
# For 2) 'requests that already have previous batch': enable overlap scheduler.
1374+
# Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
1375+
# For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1376+
# Already set to '0' during initialization.
13801377
self.previous_pos_id_offsets_cuda[
1381-
new_tokens_len:new_tokens_len +
1378+
(len(extend_requests) - len(extend_dummy_requests) -
1379+
previous_batch_len) * (1 + self.max_draft_len):
1380+
(len(extend_requests) - len(extend_dummy_requests) -
1381+
previous_batch_len) * (1 + self.max_draft_len) +
13821382
previous_batch_tokens].copy_(
13831383
new_tokens_lens_device[self.previous_pos_indices_cuda[
13841384
0:previous_batch_tokens]],
13851385
non_blocking=True)
13861386
self.previous_kv_lens_offsets_cuda[
1387-
num_requests - previous_batch_len -
1388-
len(extend_dummy_requests):num_requests -
1387+
len(extend_requests) - previous_batch_len -
1388+
len(extend_dummy_requests):len(extend_requests) -
13891389
len(extend_dummy_requests)].copy_(
13901390
kv_len_offsets_device[previous_slots],
13911391
non_blocking=True)

0 commit comments

Comments
 (0)