2020#include " tensorrt_llm/batch_manager/assignReqSeqSlots.h"
2121#include " tensorrt_llm/batch_manager/capacityScheduler.h"
2222#include " tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23- #include " tensorrt_llm/batch_manager/handleContextLogits.h"
24- #include " tensorrt_llm/batch_manager/handleGenerationLogits.h"
2523#include " tensorrt_llm/batch_manager/kvCacheManager.h"
2624#include " tensorrt_llm/batch_manager/llmRequest.h"
2725#include " tensorrt_llm/batch_manager/logitsPostProcessor.h"
28- #include " tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
2926#include " tensorrt_llm/batch_manager/medusaBuffers.h"
3027#include " tensorrt_llm/batch_manager/microBatchScheduler.h"
3128#include " tensorrt_llm/batch_manager/pauseRequests.h"
3229#include " tensorrt_llm/batch_manager/peftCacheManager.h"
33- #include " tensorrt_llm/batch_manager/runtimeBuffers.h"
34- #include " tensorrt_llm/batch_manager/updateDecoderBuffers.h"
3530#include " tensorrt_llm/nanobind/common/customCasters.h"
3631#include " tensorrt_llm/runtime/decoderState.h"
3732#include " tensorrt_llm/runtime/torch.h"
@@ -94,48 +89,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
9489 nb::arg (" generation_requests" ), nb::arg (" model_config" ), nb::arg (" cross_kv_cache_manager" ) = std::nullopt )
9590 .def (" name" , [](AllocateKvCache const &) { return AllocateKvCache::name; });
9691
97- nb::class_<HandleContextLogits>(m, HandleContextLogits::name)
98- .def (nb::init<>())
99- .def (
100- " __call__" ,
101- [](HandleContextLogits const & self, DecoderInputBuffers& inputBuffers, RequestVector const & contextRequests,
102- at::Tensor const & logits, std::vector<tr::SizeType32> const & numContextLogitsVec,
103- tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
104- OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
105- {
106- return self (inputBuffers, contextRequests, tr::TorchView::of (logits), numContextLogitsVec, modelConfig,
107- manager, medusaBuffers);
108- },
109- nb::arg (" decoder_input_buffers" ), nb::arg (" context_requests" ), nb::arg (" logits" ),
110- nb::arg (" num_context_logits" ), nb::arg (" model_config" ), nb::arg (" buffer_manager" ),
111- nb::arg (" medusa_buffers" ) = std::nullopt )
112- .def (" name" , [](HandleContextLogits const &) { return HandleContextLogits::name; });
113-
114- nb::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
115- .def (nb::init<>())
116- .def (
117- " __call__" ,
118- [](HandleGenerationLogits const & self, DecoderInputBuffers& inputBuffers,
119- RequestVector const & generationRequests, at::Tensor const & logits, tr::SizeType32 logitsIndex,
120- tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
121- OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt ,
122- OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
123- {
124- self (inputBuffers, generationRequests, tr::TorchView::of (logits), logitsIndex, modelConfig, manager,
125- genRuntimeBuffers, medusaBuffers);
126- },
127- nb::arg (" decoder_input_buffers" ), nb::arg (" generation_requests" ), nb::arg (" logits" ),
128- nb::arg (" logits_index" ), nb::arg (" model_config" ), nb::arg (" buffer_manager" ),
129- nb::arg (" gen_runtime_buffers" ) = std::nullopt , nb::arg (" medusa_buffers" ) = std::nullopt )
130- .def (" name" , [](HandleGenerationLogits const &) { return HandleGenerationLogits::name; });
131-
132- nb::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
133- .def (nb::init<>())
134- .def (" __call__" , &MakeDecodingBatchInputOutput::operator (), nb::arg (" decoder_input_buffers" ),
135- nb::arg (" decoder_state" ), nb::arg (" model_config" ), nb::arg (" max_num_sequences" ),
136- nb::arg (" fused_runtime_buffers" ) = std::nullopt )
137- .def (" name" , [](MakeDecodingBatchInputOutput const &) { return MakeDecodingBatchInputOutput::name; });
138-
13992 nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14093 .def (nb::init<>())
14194 .def (" __call__" , &LogitsPostProcessor::operator (), nb::arg (" decoder_input_buffers" ),
@@ -154,8 +107,9 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
154107 DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
155108 tensorrt_llm::runtime::CudaStream const & runtimeStream,
156109 tensorrt_llm::runtime::CudaStream const & decoderStream, SizeType32 maxSequenceLength,
157- SizeType32 beamWidth, OptionalRef<MedusaBuffers const > medusaBuffers = std:: nullopt )
110+ SizeType32 beamWidth)
158111 {
112+ OptionalRef<MedusaBuffers const > medusaBuffers = std::nullopt ;
159113 auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self (modelConfig,
160114 worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
161115 runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -166,13 +120,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
166120 nb::arg (" model_config" ), nb::arg (" world_config" ), nb::arg (" decoding_config" ), nb::arg (" context_requests" ),
167121 nb::arg (" buffer_manager" ), nb::arg (" logits_type" ), nb::arg (" decoder_input_buffers" ),
168122 nb::arg (" decoder_state" ), nb::arg (" runtime_stream" ), nb::arg (" decoder_stream" ),
169- nb::arg (" max_sequence_length" ), nb::arg (" beam_width" ), nb::arg ( " medusa_buffers " ) = std:: nullopt )
123+ nb::arg (" max_sequence_length" ), nb::arg (" beam_width" ))
170124 .def (" name" , [](CreateNewDecoderRequests const &) { return CreateNewDecoderRequests::name; });
171-
172- nb::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
173- .def (nb::init<>())
174- .def (" __call__" , &UpdateDecoderBuffers::operator (), nb::arg (" model_config" ), nb::arg (" decoder_output_buffers" ),
175- nb::arg (" copy_buffer_manager" ), nb::arg (" decoder_state" ), nb::arg (" return_log_probs" ),
176- nb::arg (" decoder_finish_event" ))
177- .def (" name" , [](UpdateDecoderBuffers const &) { return UpdateDecoderBuffers::name; });
178125}
0 commit comments