Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def wait_on_pp_send_handles(self, microbatch_id):
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id] = None

def _prepare_and_schedule_batch(self):
def _prepare_and_schedule_batch(self, overlap_mode: bool = False):
new_requests = self._fetch_and_activate_new_requests()
if self.should_stop_processing:
return None, None
Expand Down Expand Up @@ -973,6 +973,16 @@ def _prepare_and_schedule_batch(self):

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
if overlap_mode:
new_generation_requests = []
for req in scheduled_batch.generation_requests:
# Generation logits are hard to deal with and distinguish between EOS and length
# Currently, last logits are skipped with overlap, but then we splinter cases based on stopping conditions.
# For now, just treat req with gen logits like before and let them overschedule.
# Does not account for draft tokens (py_decoding only every goes up by 1)
if req.is_dummy or req.py_return_generation_logits or req.py_decoding_iter + 1 < req.py_max_new_tokens:
new_generation_requests.append(req)
scheduled_batch.generation_requests = new_generation_requests

if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
Expand Down Expand Up @@ -1158,7 +1168,8 @@ def _executor_loop_overlap(self):
if self.enable_iter_perf_stats:
iter_start_time = time.time()

scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
scheduled_batch, iter_stats = self._prepare_and_schedule_batch(
overlap_mode=True)
if scheduled_batch is None:
break

Expand Down Expand Up @@ -1199,9 +1210,10 @@ def _executor_loop_overlap(self):
batch_outputs = self._forward_step(scheduled_batch,
previous_tensors_device)

if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)

if scheduled_batch.batch_size > 0:
if self.guided_decoder is not None:
# add_batch must be called again to have updated new tokens.
self.guided_decoder.add_batch(scheduled_batch)
Expand All @@ -1217,9 +1229,11 @@ def _executor_loop_overlap(self):
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []

if self.previous_batch is not None:
self._process_previous_batch()
self.previous_batch: Optional[BatchState] = None
if self.previous_batch is not None:
self._process_previous_batch()
self.previous_batch: Optional[BatchState] = None

if scheduled_batch.batch_size > 0:

if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
Expand Down
18 changes: 6 additions & 12 deletions tests/integration/defs/llmapi/test_llm_api_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1

# With the overlap scheduler, we generate one extra token.
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
assert scheduler.build_connector_meta.call_count == NUM_TOKENS

# We should have a single `SchedulerOutput` per forward pass.
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
Expand All @@ -108,8 +107,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
assert len(scheduler_output.cached_requests[0].new_tokens) == 1

# We call `start_load_kv` once at the beginning of each forward pass.
assert worker.start_load_kv.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
assert worker.start_load_kv.call_count == NUM_TOKENS

# Only called once when the request is received.
assert scheduler.get_num_new_matched_tokens.call_count == 1
Expand All @@ -118,19 +116,16 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
for call in worker.wait_for_layer_load.call_args_list) + 1

# Called num_layers * num_forward_passes times.
assert worker.wait_for_layer_load.call_count == num_layers * (
NUM_TOKENS + int(use_overlap_scheduler))
assert worker.save_kv_layer.call_count == num_layers * (
NUM_TOKENS + int(use_overlap_scheduler))
assert worker.wait_for_layer_load.call_count == num_layers * (NUM_TOKENS)
assert worker.save_kv_layer.call_count == num_layers * (NUM_TOKENS)

for i, call in enumerate(worker.wait_for_layer_load.call_args_list):
assert call.args[0] == i % num_layers

for i, call in enumerate(worker.save_kv_layer.call_args_list):
assert call.args[0] == i % num_layers

assert worker.wait_for_save.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
assert worker.wait_for_save.call_count == NUM_TOKENS

assert scheduler.request_finished.call_count == 1

Expand Down Expand Up @@ -238,8 +233,7 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector,
NUM_INPUT_TOKENS / BLOCK_SIZE)

# Additional token when using the overlap scheduler.
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
assert scheduler.build_connector_meta.call_count == NUM_TOKENS

for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
sched_output = call.args[0]
Expand Down
5 changes: 2 additions & 3 deletions tests/unittest/llmapi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,9 +2204,8 @@ async def task1():
results.append(stats)

assert results
if not use_overlap:
validate_stats(results, pytorch_backend, max_tokens,
enable_iter_req_stats)
validate_stats(results, pytorch_backend, max_tokens,
enable_iter_req_stats)

async def main():
for i in range(2): # test recurrent usage
Expand Down
8 changes: 6 additions & 2 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def test_llm_reward_model():


def test_llm_perf_metrics():
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
disable_overlap_scheduler = False
llm = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config,
disable_overlap_scheduler=disable_overlap_scheduler)
sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True)
outputs = llm.generate(prompts, sampling_params)
assert outputs[0].outputs[0].request_perf_metrics is not None
Expand All @@ -194,7 +197,8 @@ def test_llm_perf_metrics():
assert kv_cache_metrics.kv_cache_hit_rate == 0

assert perf_metrics.first_iter is not None
assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - 1
assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - (
1 if disable_overlap_scheduler else 2)
assert perf_metrics.last_iter == perf_metrics.iter


Expand Down