Skip to content

Commit aac743e

Browse files
committed
[https://nvbugs/5534705][fix] Update inference state for draft model
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 2db22fb commit aac743e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ def release_batch(result: ScheduledRequests | None):
719719
logger.info(
720720
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
721721
)
722+
# The draft model has draft_len = 0, so we need to check either draft_len > 0 or is_draft_model.
722723
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
723724

724725
def _update_draft_inference_state(is_first_draft: bool,
@@ -732,7 +733,8 @@ def _update_draft_inference_state(is_first_draft: bool,
732733
# Reset the draft tokens for the first draft inference
733734
req.py_draft_tokens = []
734735

735-
_update_draft_inference_state(draft_len > 0, batch)
736+
_update_draft_inference_state(self.enable_spec_decode,
737+
batch)
736738

737739
self.forward(batch,
738740
new_tensors_device=None,

0 commit comments

Comments
 (0)