Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -561,7 +562,7 @@ def setup_sampler_step(self, requests):
self.model_config, self.world_config, self.decoding_config,
requests, self.store["buffer_manager"], self.logits_datatype,
self.store["decoder_input_buffers"],
self.algs.decoder.decoder_state, self.beam_width,
self.algs.decoder.decoder_state, self.beam_width(requests),
self.store["cuda_stream"])

if len(decoder_requests):
Expand All @@ -578,15 +579,15 @@ def setup_sampler_step(self, requests):
decoder_requests)

@staticmethod
def beam_width(scheduled_requests: ScheduledRequests) -> int:
for req in scheduled_requests.all_requests:
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
for req in scheduled_requests:
return req.sampling_config.beam_width
raise ValueError("No beam width found")
return 0

def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateTRTLLM:
batch_size = scheduled_requests.batch_size
beam_width = self.beam_width(scheduled_requests)
beam_width = self.beam_width(scheduled_requests.all_requests)

logits = model_outputs["logits"].reshape((batch_size, beam_width, -1))

Expand Down Expand Up @@ -659,7 +660,7 @@ def update_requests(self, state: SampleStateTRTLLM):

scheduled_requests = state.scheduled_requests
assert scheduled_requests.batch_size > 0
beam_width = self.beam_width(scheduled_requests)
beam_width = self.beam_width(scheduled_requests.all_requests)
sampler_event = state.sampler_event

if sampler_event:
Expand Down