Skip to content

Commit ac7a987

Browse files
Linda-Stadtertimlee0212
authored andcommitted
feat: nanobind bindings (NVIDIA#6185)
Signed-off-by: Linda-Stadter <[email protected]>
1 parent a3d0a55 commit ac7a987

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+5811
-21
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ set(TRT_LIB TensorRT::NvInfer)
199199
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
200200

201201
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
202-
if(BINDING_TYPE STREQUAL "pybind")
202+
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
203203
add_subdirectory(${3RDPARTY_DIR}/pybind11
204204
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
205205
endif()
@@ -218,7 +218,7 @@ include_directories(
218218
${3RDPARTY_DIR}/cutlass/tools/util/include
219219
${3RDPARTY_DIR}/NVTX/include
220220
${3RDPARTY_DIR}/json/include)
221-
if(BINDING_TYPE STREQUAL "pybind")
221+
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
222222
include_directories(${3RDPARTY_DIR}/pybind11/include)
223223
endif()
224224
if(BINDING_TYPE STREQUAL "nanobind")

cpp/tensorrt_llm/nanobind/CMakeLists.txt

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,22 @@ set(TRTLLM_NB_MODULE
33
${TRTLLM_NB_MODULE}
44
PARENT_SCOPE)
55

6-
set(SRCS ../runtime/ipcNvlsMemory.cu bindings.cpp)
6+
set(SRCS
7+
batch_manager/algorithms.cpp
8+
batch_manager/bindings.cpp
9+
batch_manager/cacheTransceiver.cpp
10+
batch_manager/kvCacheManager.cpp
11+
batch_manager/llmRequest.cpp
12+
executor/bindings.cpp
13+
executor/executor.cpp
14+
executor/executorConfig.cpp
15+
executor/request.cpp
16+
runtime/bindings.cpp
17+
testing/modelSpecBinding.cpp
18+
runtime/moeBindings.cpp
19+
userbuffers/bindings.cpp
20+
../runtime/ipcNvlsMemory.cu
21+
bindings.cpp)
722

823
include_directories(${PROJECT_SOURCE_DIR}/include)
924

@@ -14,20 +29,29 @@ set_property(TARGET ${TRTLLM_NB_MODULE} PROPERTY POSITION_INDEPENDENT_CODE ON)
1429
target_link_directories(${TRTLLM_NB_MODULE} PUBLIC
1530
"${TORCH_INSTALL_PREFIX}/lib")
1631

32+
if(ENABLE_NVSHMEM)
33+
target_link_libraries(${TRTLLM_NB_MODULE} PUBLIC nvshmem::nvshmem_host
34+
nvshmem::nvshmem_device)
35+
endif()
36+
1737
target_link_libraries(
1838
${TRTLLM_NB_MODULE}
19-
PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG}
20-
${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
21-
39+
PUBLIC ${SHARED_TARGET}
40+
${UNDEFINED_FLAG}
41+
${NO_AS_NEEDED_FLAG}
42+
${Python3_LIBRARIES}
43+
${TORCH_LIBRARIES}
44+
torch_python
45+
${CUDA_NVML_LIB})
2246
target_compile_definitions(
2347
${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE}
24-
NB_DETAILED_ERROR_MESSAGES=1)
48+
PYBIND11_DETAILED_ERROR_MESSAGES=1)
2549

2650
if(NOT WIN32)
2751
set_target_properties(
2852
${TRTLLM_NB_MODULE}
2953
PROPERTIES
3054
LINK_FLAGS
31-
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
55+
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
3256
)
3357
endif()
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "algorithms.h"
19+
#include "tensorrt_llm/batch_manager/allocateKvCache.h"
20+
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
21+
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
22+
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23+
#include "tensorrt_llm/batch_manager/handleContextLogits.h"
24+
#include "tensorrt_llm/batch_manager/handleGenerationLogits.h"
25+
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
26+
#include "tensorrt_llm/batch_manager/llmRequest.h"
27+
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
28+
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
29+
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
30+
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
31+
#include "tensorrt_llm/batch_manager/pauseRequests.h"
32+
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
33+
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
34+
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
35+
#include "tensorrt_llm/nanobind/common/customCasters.h"
36+
#include "tensorrt_llm/runtime/decoderState.h"
37+
#include "tensorrt_llm/runtime/torch.h"
38+
#include "tensorrt_llm/runtime/torchView.h"
39+
40+
#include <ATen/core/TensorBody.h>
41+
#include <nanobind/nanobind.h>
42+
#include <nanobind/stl/list.h>
43+
#include <nanobind/stl/shared_ptr.h>
44+
#include <nanobind/stl/tuple.h>
45+
#include <nanobind/stl/vector.h>
46+
#include <torch/extension.h>
47+
48+
#include <optional>
49+
50+
namespace nb = nanobind;
51+
52+
namespace tr = tensorrt_llm::runtime;
53+
using namespace tensorrt_llm::batch_manager;
54+
55+
void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_& m)
56+
{
57+
nb::class_<CapacityScheduler>(m, CapacityScheduler::name)
58+
.def(nb::init<SizeType32, executor::CapacitySchedulerPolicy, bool, bool, LlmRequestState, LlmRequestState>(),
59+
nb::arg("max_num_requests"), nb::arg("capacity_scheduler_policy"), nb::arg("has_kv_cache_manager"),
60+
nb::arg("two_step_lookahead") = false, nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
61+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
62+
.def("__call__", &CapacityScheduler::operator(), nb::arg("active_requests"),
63+
nb::arg("kv_cache_manager") = nullptr, nb::arg("peft_cache_manager") = nullptr,
64+
nb::arg("cross_kv_cache_manager") = nullptr)
65+
.def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; });
66+
67+
nb::class_<MicroBatchScheduler>(m, MicroBatchScheduler::name)
68+
.def(nb::init<std::optional<batch_scheduler::ContextChunkingConfig>, std::optional<SizeType32>, LlmRequestState,
69+
LlmRequestState>(),
70+
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
71+
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
72+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
73+
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
74+
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
75+
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
76+
77+
nb::class_<PauseRequests>(m, PauseRequests::name)
78+
.def(nb::init<SizeType32>(), nb::arg("max_input_len"))
79+
.def("__call__", &PauseRequests::operator(), nb::arg("requests_to_pause"), nb::arg("inflight_req_ids"),
80+
nb::arg("req_ids_to_pause"), nb::arg("pause_flagged"), nb::arg("seq_slot_manager"),
81+
nb::arg("kv_cache_manager") = std::nullopt, nb::arg("cross_kv_cache_manager") = std::nullopt,
82+
nb::arg("peft_cache_manager") = std::nullopt)
83+
.def("name", [](PauseRequests const&) { return PauseRequests::name; });
84+
85+
nb::class_<AssignReqSeqSlots>(m, AssignReqSeqSlots::name)
86+
.def(nb::init<>())
87+
.def("__call__", &AssignReqSeqSlots::operator(), nb::arg("seq_slot_manager"), nb::arg("context_requests"),
88+
nb::arg("generation_requests"))
89+
.def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; });
90+
91+
nb::class_<AllocateKvCache>(m, AllocateKvCache::name)
92+
.def(nb::init<>())
93+
.def("__call__", &AllocateKvCache::operator(), nb::arg("kv_cache_manager"), nb::arg("context_requests"),
94+
nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt)
95+
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
96+
97+
nb::class_<HandleContextLogits>(m, HandleContextLogits::name)
98+
.def(nb::init<>())
99+
.def(
100+
"__call__",
101+
[](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
102+
at::Tensor const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
103+
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
104+
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
105+
{
106+
return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig,
107+
manager, medusaBuffers);
108+
},
109+
nb::arg("decoder_input_buffers"), nb::arg("context_requests"), nb::arg("logits"),
110+
nb::arg("num_context_logits"), nb::arg("model_config"), nb::arg("buffer_manager"),
111+
nb::arg("medusa_buffers") = std::nullopt)
112+
.def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; });
113+
114+
nb::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
115+
.def(nb::init<>())
116+
.def(
117+
"__call__",
118+
[](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers,
119+
RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex,
120+
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
121+
OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt,
122+
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
123+
{
124+
self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager,
125+
genRuntimeBuffers, medusaBuffers);
126+
},
127+
nb::arg("decoder_input_buffers"), nb::arg("generation_requests"), nb::arg("logits"),
128+
nb::arg("logits_index"), nb::arg("model_config"), nb::arg("buffer_manager"),
129+
nb::arg("gen_runtime_buffers") = std::nullopt, nb::arg("medusa_buffers") = std::nullopt)
130+
.def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; });
131+
132+
nb::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
133+
.def(nb::init<>())
134+
.def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("decoder_input_buffers"),
135+
nb::arg("decoder_state"), nb::arg("model_config"), nb::arg("max_num_sequences"),
136+
nb::arg("fused_runtime_buffers") = std::nullopt)
137+
.def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; });
138+
139+
nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
140+
.def(nb::init<>())
141+
.def("__call__", &LogitsPostProcessor::operator(), nb::arg("decoder_input_buffers"),
142+
nb::arg("replicate_logits_post_processor"), nb::arg("world_config"), nb::arg("stream"),
143+
nb::arg("logits_post_processor_batched") = std::nullopt)
144+
.def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; });
145+
146+
nb::class_<CreateNewDecoderRequests>(m, CreateNewDecoderRequests::name)
147+
.def(nb::init<bool, bool, bool>(), nb::arg("speculative_decoding_fast_logits"),
148+
nb::arg("is_leader_in_orch_mode"), nb::arg("is_normalize_log_probs"))
149+
.def(
150+
"__call__",
151+
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
152+
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
153+
tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
154+
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
155+
tensorrt_llm::runtime::CudaStream const& runtimeStream,
156+
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
157+
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
158+
{
159+
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
160+
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
161+
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
162+
163+
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
164+
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
165+
},
166+
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
167+
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
168+
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
169+
nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt)
170+
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
171+
172+
nb::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
173+
.def(nb::init<>())
174+
.def("__call__", &UpdateDecoderBuffers::operator(), nb::arg("model_config"), nb::arg("decoder_output_buffers"),
175+
nb::arg("copy_buffer_manager"), nb::arg("decoder_state"), nb::arg("return_log_probs"),
176+
nb::arg("decoder_finish_event"))
177+
.def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; });
178+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include <nanobind/nanobind.h>
21+
22+
namespace nb = nanobind;
23+
24+
namespace tensorrt_llm::nanobind::batch_manager::algorithms
25+
{
26+
27+
void initBindings(nb::module_& m);
28+
29+
}

0 commit comments

Comments
 (0)