Skip to content

Commit 0159d90

Browse files
committed
Adjusting to binding changes
Signed-off-by: Linda-Stadter <[email protected]>
1 parent bc9828b commit 0159d90

File tree

5 files changed

+48
-30
lines changed

5 files changed

+48
-30
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,16 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
131131

132132
nb::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
133133
.def(nb::init<>())
134-
.def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("context_requests"),
135-
nb::arg("generation_requests"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"),
136-
nb::arg("model_config"), nb::arg("max_num_sequences"), nb::arg("fused_runtime_buffers") = std::nullopt)
134+
.def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("decoder_input_buffers"),
135+
nb::arg("decoder_state"), nb::arg("model_config"), nb::arg("max_num_sequences"),
136+
nb::arg("fused_runtime_buffers") = std::nullopt)
137137
.def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; });
138138

139139
nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
140140
.def(nb::init<>())
141-
.def("__call__", &LogitsPostProcessor::operator(), nb::arg("context_requests"), nb::arg("generation_requests"),
142-
nb::arg("replicate_logits_post_processor"), nb::arg("decoder_buffers"), nb::arg("world_config"),
143-
nb::arg("runtime"), nb::arg("logits_post_processor_batched") = std::nullopt)
141+
.def("__call__", &LogitsPostProcessor::operator(), nb::arg("decoder_input_buffers"),
142+
nb::arg("replicate_logits_post_processor"), nb::arg("world_config"), nb::arg("stream"),
143+
nb::arg("logits_post_processor_batched") = std::nullopt)
144144
.def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; });
145145

146146
nb::class_<CreateNewDecoderRequests>(m, CreateNewDecoderRequests::name)

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,16 @@ void initBindings(nb::module_& m)
387387
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"));
388388

389389
nb::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
390-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(),
391-
nb::arg("max_num_sequences"), nb::arg("max_batch_size"), nb::arg("max_tokens_per_engine_step"),
392-
nb::arg("manager"))
390+
.def(nb::init<runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(), nb::arg("max_batch_size"),
391+
nb::arg("max_tokens_per_engine_step"), nb::arg("manager"))
393392
.def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
394393
.def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
395394
.def_rw("fill_values", &tb::DecoderInputBuffers::fillValues)
396395
.def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
397396
.def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
398397
.def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
399-
.def_rw("logits", &tb::DecoderInputBuffers::logits);
398+
.def_rw("logits", &tb::DecoderInputBuffers::logits)
399+
.def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests);
400400

401401
nb::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
402402
.def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)

cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,21 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
8484
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus)
8585
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete);
8686

87-
nb::enum_<tb::CacheTransceiver::CommType>(m, "CommType")
88-
.value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN)
89-
.value("MPI", tb::CacheTransceiver::CommType::MPI)
90-
.value("UCX", tb::CacheTransceiver::CommType::UCX)
91-
.value("NIXL", tb::CacheTransceiver::CommType::NIXL);
92-
9387
nb::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
9488
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
9589
.value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA);
9690

9791
nb::class_<tb::CacheTransceiver, tb::BaseCacheTransceiver>(m, "CacheTransceiver")
98-
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, tb::CacheTransceiver::CommType,
99-
std::vector<SizeType32>, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType,
100-
executor::kv_cache::CacheState::AttentionType, std::optional<executor::CacheTransceiverConfig>>(),
101-
nb::arg("cache_manager"), nb::arg("comm_type"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"),
92+
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::vector<SizeType32>, SizeType32, SizeType32,
93+
runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType,
94+
std::optional<executor::CacheTransceiverConfig>>(),
95+
nb::arg("cache_manager"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"),
10296
nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"),
10397
nb::arg("cache_transceiver_config") = std::nullopt);
10498

10599
nb::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
106100
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), nb::arg("cache_manager"),
107101
nb::arg("max_num_tokens") = std::nullopt)
108102
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
109-
nb::arg("max_num_tokens") = std::nullopt);
103+
nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("cache_transceiver_config") = nb::none());
110104
}

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -424,21 +424,44 @@ void initConfigBindings(nb::module_& m)
424424
.def("__getstate__", guidedDecodingConfigGetstate)
425425
.def("__setstate__", guidedDecodingConfigSetstate);
426426

427-
auto cacheTransceiverConfigGetstate
428-
= [](tle::CacheTransceiverConfig const& self) { return nb::make_tuple(self.getMaxNumTokens()); };
427+
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
428+
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
429429
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
430430
{
431-
if (state.size() != 1)
431+
if (state.size() != 2)
432432
{
433433
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
434434
}
435-
new (&self) tle::CacheTransceiverConfig(nb::cast<std::optional<size_t>>(state[0]));
435+
new (&self) tle::CacheTransceiverConfig(
436+
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
436437
};
437438

439+
nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
440+
.value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT)
441+
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
442+
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
443+
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
444+
.def("from_string",
445+
[](std::string const& str)
446+
{
447+
if (str == "DEFAULT" || str == "default")
448+
return tle::CacheTransceiverConfig::BackendType::DEFAULT;
449+
if (str == "MPI" || str == "mpi")
450+
return tle::CacheTransceiverConfig::BackendType::MPI;
451+
if (str == "UCX" || str == "ucx")
452+
return tle::CacheTransceiverConfig::BackendType::UCX;
453+
if (str == "NIXL" || str == "nixl")
454+
return tle::CacheTransceiverConfig::BackendType::NIXL;
455+
throw std::runtime_error("Invalid backend type: " + str);
456+
});
457+
438458
nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
439-
.def(nb::init<std::optional<size_t>>(), nb::arg("max_num_tokens") = nb::none())
440-
.def_prop_rw("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens,
441-
&tle::CacheTransceiverConfig::setMaxNumTokens)
459+
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
460+
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
461+
.def_prop_rw(
462+
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
463+
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
464+
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
442465
.def("__getstate__", cacheTransceiverConfigGetstate)
443466
.def("__setstate__", cacheTransceiverConfigSetstate);
444467

tests/unittest/bindings/test_executor_bindings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,8 +2478,9 @@ def test_guided_decoding_config_pickle():
24782478

24792479

24802480
def test_cache_transceiver_config_pickle():
2481-
config = trtllm.CacheTransceiverConfig(backend="UCX",
2482-
max_tokens_in_buffer=1024)
2481+
config = trtllm.CacheTransceiverConfig(
2482+
backend=trtllm.CacheTransceiverBackendType.UCX,
2483+
max_tokens_in_buffer=1024)
24832484
config_copy = pickle.loads(pickle.dumps(config))
24842485
assert config_copy.backend == config.backend
24852486
assert config_copy.max_tokens_in_buffer == config.max_tokens_in_buffer

0 commit comments

Comments
 (0)