Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ class GenericLlmRequest
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
}

GenericLlmRequest(GenericLlmRequest&& request) = default;
GenericLlmRequest(GenericLlmRequest const& request) = default;

void setExcludeInputFromOutput(bool exclude)
{
mExcludeInputFromOutput = exclude;
Expand Down Expand Up @@ -2318,6 +2321,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
mKvCacheRetentionConfig = request.getKvCacheRetentionConfig();
}

LlmRequest(LlmRequest&& request) = default;
LlmRequest(LlmRequest const& request) = default;

/// @brief Create a Response from the current state of the request
/// @details Note that there is some dependency on the order of operations in this method. Modify with care!
/// @return An optional Response
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ void initBindings(nb::module_& m)
.def_prop_ro("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
.def_prop_ro("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
.def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
.def_prop_ro("is_child", &GenLlmReq::isChild)
.def_prop_ro("multimodal_hashes",
[](GenLlmReq& self)
{
Expand Down Expand Up @@ -351,11 +353,13 @@ void initBindings(nb::module_& m)
nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt,
nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt,
nb::arg("context_phase_params") = std::nullopt)
.def(nb::init<tb::LlmRequest const&>())
.def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"),
nb::arg("max_draft_len"), nb::arg("vocab_size_padded"), nb::arg("max_endocer_input_len") = std::nullopt,
nb::arg("enable_kv_cache_reuse") = false)
.def("create_response", &tb::LlmRequest::createResponse, nb::arg("use_fast_logits") = false,
nb::arg("mpi_world_rank") = 0)
.def("create_child_request", &tb::LlmRequest::createChildRequest, nb::arg("child_id"))
.def("create_result", &tb::LlmRequest::createResult, nb::arg("use_fast_logits") = false,
nb::arg("mpi_world_rank") = 0)
.def("create_serialized_result",
Expand Down
6 changes: 5 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
.def_property_readonly("is_child", &GenLlmReq::isChild)
.def_property_readonly("multimodal_hashes",
[](GenLlmReq& self)
{
Expand Down Expand Up @@ -254,7 +256,7 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);

py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init(
.def(py::init<>(
[](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens,
std::vector<tb::LlmRequest::TokenIdType> input_tokens, runtime::SamplingConfig sampling_config,
bool is_streaming, std::optional<tb::LlmRequest::SizeType32> end_id,
Expand Down Expand Up @@ -357,11 +359,13 @@ void initBindings(pybind11::module_& m)
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
py::arg("context_phase_params") = std::nullopt)
.def(py::init<tb::LlmRequest const&>())
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
py::arg("enable_kv_cache_reuse") = false)
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id"))
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_serialized_result",
Expand Down
31 changes: 22 additions & 9 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def add_llm_args(parser):
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument('--load_format', type=str, default='auto')
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--best_of', type=int, default=None)
parser.add_argument('--max_beam_width', type=int, default=1)

# Speculative decoding
Expand Down Expand Up @@ -193,6 +195,7 @@ def setup_llm(args, **kwargs):
batch_sizes=args.cuda_graph_batch_sizes,
enable_padding=args.cuda_graph_padding_enabled,
) if args.use_cuda_graph else None

llm = LLM(
model=args.model_dir,
backend='pytorch',
Expand Down Expand Up @@ -228,6 +231,15 @@ def setup_llm(args, **kwargs):
**kwargs,
)

use_beam_search = args.max_beam_width > 1
best_of = args.best_of or args.n
if use_beam_search:
if args.n == 1 and args.best_of is None:
args.n = args.max_beam_width
assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}"

assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be less or equal to n: {args.n}"

sampling_params = SamplingParams(
max_tokens=args.max_tokens,
temperature=args.temperature,
Expand All @@ -236,8 +248,9 @@ def setup_llm(args, **kwargs):
return_context_logits=args.return_context_logits,
return_generation_logits=args.return_generation_logits,
logprobs=args.logprobs,
n=args.max_beam_width,
use_beam_search=args.max_beam_width > 1)
n=args.n,
best_of=best_of,
use_beam_search=use_beam_search)
return llm, sampling_params


Expand All @@ -250,23 +263,23 @@ def main():

for i, output in enumerate(outputs):
prompt = output.prompt
for beam_idx, beam in enumerate(output.outputs):
generated_text = beam.text
for sequence_idx, sequence in enumerate(output.outputs):
generated_text = sequence.text
# Skip printing the beam_idx if no beam search was used
beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else ""
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
print(
f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
)
if args.return_context_logits:
print(
f"[{i}]{beam_id_text} Context logits: {output.context_logits}"
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
)
if args.return_generation_logits:
print(
f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}"
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
)
if args.logprobs:
print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}")
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")


if __name__ == '__main__':
Expand Down
82 changes: 65 additions & 17 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class RequestQueueItem:
id: int
request: Optional[ExecutorRequest] = None
child_req_ids: Optional[list] = None
is_canceled_request: bool = False
query: Optional[list] = None # only used in `StarAttention`

Expand Down Expand Up @@ -83,6 +84,12 @@ def _get_from_request_queue(
pass
return items

@staticmethod
def _get_num_child_requests(request: ExecutorRequest) -> int:
sampling_config = request.sampling_config
return 0 if sampling_config.beam_width > 1 else (
sampling_config.num_return_sequences or 1) - 1

def _get_from_waiting_queue(
self,
waiting_queue: deque[RequestQueueItem],
Expand Down Expand Up @@ -111,14 +118,19 @@ def _get_from_waiting_queue(
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
) if enable_attention_dp else None
while req_count < max_req_count and waiting_queue:
req_item = waiting_queue[0]
num_children = len(
req_item.child_req_ids) if req_item.child_req_ids else 0
if (req_count + 1 + num_children) > max_req_count:
break
req_item = waiting_queue.popleft()
can_process = self._can_process_attention_dp_request(
req_item, scheduling_all_ranks_num_active_requests
) if enable_attention_dp else True

if can_process:
items.append(req_item)
req_count += 1
req_count += 1 + num_children
else:
pending_requests.append(req_item)

Expand Down Expand Up @@ -149,17 +161,43 @@ def _can_process_attention_dp_request(

return False

def _get_request_id(self):
# (next_request_id + 1) % UINT64_MAX
current_id = self.next_request_id
self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1)
return current_id

def _generate_child_request_ids(
self, request: ExecutorRequest) -> List[int] | None:
""" Generate child request IDs if needed. """
child_req_ids = None
num_children = self._get_num_child_requests(request)
if num_children > 0:
child_req_ids = []
for _ in range(num_children):
child_req_id = self._get_request_id()
if self.enable_iter_perf_stats:
self.start_times[child_req_id] = time.time()
child_req_ids.append(child_req_id)

return child_req_ids

def enqueue_requests(self, requests: List[ExecutorRequest]):
req_ids = []
try:
self.enqueue_lock.acquire()
start_time = time.time()
for request in requests:
self.start_times[self.next_request_id] = start_time
req_id = self._get_request_id()

if self.enable_iter_perf_stats:
self.start_times[req_id] = time.time()

child_req_ids = self._generate_child_request_ids(request)
self.request_queue.put(
RequestQueueItem(self.next_request_id, request))
req_ids.append(self.next_request_id)
self.next_request_id += 1
RequestQueueItem(req_id, request, child_req_ids,
query=None))

req_ids.append(req_id)
finally:
self.enqueue_lock.release()
return req_ids
Expand All @@ -186,15 +224,18 @@ def enqueue_request(self,
try:
self.enqueue_lock.acquire()
assert self.active, "PyExecutor has already been shutdown."
req_id = self.next_request_id
req_id = self._get_request_id()
if self.enable_iter_perf_stats:
self.start_times[req_id] = time.time()

if query is not None:
self.request_queue.put(RequestQueueItem(req_id, request, query))
else:
self.request_queue.put(RequestQueueItem(req_id, request))
self.next_request_id += 1
child_req_ids = self._generate_child_request_ids(request)
self.request_queue.put(
RequestQueueItem(
req_id,
request,
child_req_ids=child_req_ids,
query=query,
))
finally:
self.enqueue_lock.release()

Expand Down Expand Up @@ -530,6 +571,10 @@ def _update_new_active_requests_queue_latency(
if req_item.id in self.start_times:
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
req_item.id)
if req_item.child_req_ids:
for child_id in req_item.child_req_ids:
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
child_id)

@nvtx_range("_merge_requests")
def _merge_requests(self, new_requests: list[RequestQueueItem]):
Expand All @@ -543,12 +588,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
else:
raise NotImplementedError(f'unsupport cp type {cp_type}')
else:
return [
executor_request_to_llm_request(
req_item.id, req_item.request,
req_with_children = []
for req_item in new_requests:
req = executor_request_to_llm_request(
req_item.id, req_item.request, req_item.child_req_ids,
self._should_exclude_last_generation_logits())
for req_item in new_requests
]
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
return req_with_children

def _merge_star_attention_requests(self,
new_requests: list[RequestQueueItem]):
Expand Down
Loading