Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void initBindings(nb::module_& m)
using GenLlmReq = tb::GenericLlmRequest<runtime::ITensor::SharedPtr>;

// Create and register exceptions in module scope
nb::exception<tb::PeftTaskNotCachedException>(m, "PeftTaskNotCachedException");
nb::exception<tr::LoraCacheFullException>(m, "LoraCacheFullException");
static nb::object peft_exc = nb::exception<tb::PeftTaskNotCachedException>(m, "PeftTaskNotCachedException");
static nb::object lora_exc = nb::exception<tr::LoraCacheFullException>(m, "LoraCacheFullException");

// Register with no captures
nb::register_exception_translator(
Expand All @@ -71,11 +71,11 @@ void initBindings(nb::module_& m)
}
catch (const tb::PeftTaskNotCachedException& e)
{
PyErr_SetString(nb::type<tb::PeftTaskNotCachedException>().ptr(), e.what());
PyErr_SetString(peft_exc.ptr(), e.what());
}
catch (const tr::LoraCacheFullException& e)
{
PyErr_SetString(nb::type<tr::LoraCacheFullException>().ptr(), e.what());
PyErr_SetString(lora_exc.ptr(), e.what());
}
});

Expand Down
19 changes: 15 additions & 4 deletions cpp/tensorrt_llm/nanobind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,21 @@ void initRequestBindings(nb::module_& m)
nb::cast<std::optional<std::vector<tle::AdditionalModelOutput>>>(state[6]));
};
nb::class_<tle::OutputConfig>(m, "OutputConfig")
.def(nb::init<bool, bool, bool, bool, bool, bool, std::optional<std::vector<tle::AdditionalModelOutput>>>(),
nb::arg("return_log_probs").none() = false, nb::arg("return_context_logits") = false,
nb::arg("return_generation_logits") = false, nb::arg("exclude_input_from_output") = false,
nb::arg("return_encoder_output") = false, nb::arg("return_perf_metrics") = false,
.def(
"__init__",
[](tle::OutputConfig& self, std::optional<bool> return_log_probs, std::optional<bool> return_context_logits,
std::optional<bool> return_generation_logits, std::optional<bool> exclude_input_from_output,
std::optional<bool> return_encoder_output, std::optional<bool> return_perf_metrics,
std::optional<std::vector<tle::AdditionalModelOutput>> additional_model_outputs)
{
new (&self) tle::OutputConfig(return_log_probs.value_or(false), return_context_logits.value_or(false),
return_generation_logits.value_or(false), exclude_input_from_output.value_or(false),
return_encoder_output.value_or(false), return_perf_metrics.value_or(false),
additional_model_outputs);
},
nb::arg("return_log_probs") = nb::none(), nb::arg("return_context_logits") = nb::none(),
nb::arg("return_generation_logits") = nb::none(), nb::arg("exclude_input_from_output") = nb::none(),
nb::arg("return_encoder_output") = nb::none(), nb::arg("return_perf_metrics") = nb::none(),
nb::arg("additional_model_outputs") = nb::none())
.def_rw("return_log_probs", &tle::OutputConfig::returnLogProbs)
.def_rw("return_context_logits", &tle::OutputConfig::returnContextLogits)
Expand Down
6 changes: 2 additions & 4 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void initConfigBindings(pybind11::module_& m)
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
.def(py::init(
.def("from_string",
[](std::string const& str)
{
if (str == "DEFAULT" || str == "default")
Expand All @@ -436,9 +436,7 @@ void initConfigBindings(pybind11::module_& m)
if (str == "NIXL" || str == "nixl")
return tle::CacheTransceiverConfig::BackendType::NIXL;
throw std::runtime_error("Invalid backend type: " + str);
}));

py::implicitly_convertible<std::string, tle::CacheTransceiverConfig::BackendType>();
});

py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# isort: off
from ..bindings.executor import (
BatchingType as _BatchingType,
CacheTransceiverBackendType as _CacheTransceiverBackendType,
CacheTransceiverConfig as _CacheTransceiverConfig,
CapacitySchedulerPolicy as _CapacitySchedulerPolicy,
ContextChunkingPolicy as _ContextChunkingPolicy,
Expand Down Expand Up @@ -871,7 +872,7 @@ class CacheTransceiverConfig(BaseModel, PybindMirror):

def _to_pybind(self):
return _CacheTransceiverConfig(
backend=self.backend,
backend=_CacheTransceiverBackendType.from_string(self.backend),
max_tokens_in_buffer=self.max_tokens_in_buffer)


Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def to_sampling_params(self) -> SamplingParams:
add_special_tokens=self.add_special_tokens,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=self.logprobs,
_return_log_probs=bool(self.logprobs),
)
return sampling_params

Expand Down Expand Up @@ -543,7 +543,7 @@ def to_sampling_params(self) -> SamplingParams:
add_special_tokens=self.add_special_tokens,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=self.logprobs,
_return_log_probs=bool(self.logprobs),
)
return sampling_params

Expand Down