diff --git a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h index 2af03c0af71..df507cf1001 100644 --- a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h @@ -38,6 +38,7 @@ class DecoderInputBuffers public: using SizeType32 = runtime::SizeType32; using TensorPtr = runtime::ITensor::SharedPtr; + using TensorConstPtr = runtime::ITensor::SharedConstPtr; explicit DecoderInputBuffers( SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager); @@ -60,13 +61,22 @@ class DecoderInputBuffers //! Requests for considered in decoder forward RequestVector decoderRequests; + //! Logits of decoder requests + std::vector decoderLogits; + + //! Maximum number of decoding steps of decoder requests. + //! This is only more than 1 for external draft tokens speculative decoding. + SizeType32 maxDecoderSteps{1}; + //! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize] std::vector forwardBatchSlots; - //! Logits of decoder requests - std::vector logits; + //! Logits for requests in forwardBatchSlots (in the same order). + //! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu + std::vector> batchLogits; - //! Logits for speculative decoding (Medusa) + //! Logits for speculative decoding (Medusa). + //! The vector is sparse, only slots in forwardBatchSlots are used. //! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded] std::vector> predictedDraftLogits; }; diff --git a/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h b/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h index cea23a4e7ec..245f4b4b528 100644 --- a/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h +++ b/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h @@ -40,19 +40,17 @@ class MakeDecodingBatchInputOutput : Algorithm constexpr static auto name{"MakeDecodingBatchInputOutput"}; using SizeType32 = tensorrt_llm::runtime::SizeType32; - using TensorPtr = runtime::decoder_batch::Input::TensorPtr; + using TensorPtr = runtime::ITensor::SharedPtr; template using OptionalRef = tensorrt_llm::common::OptionalRef; MakeDecodingBatchInputOutput() = default; - std::unique_ptr operator()(DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, - SizeType32 maxNumSequences, OptionalRef fusedRuntimeBuffers) const; + void operator()(DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, + runtime::ModelConfig const& modelConfig, OptionalRef fusedRuntimeBuffers) const; - [[nodiscard]] static std::unique_ptr createDecoderBatchInputs( - std::vector const& activeSlots, runtime::decoder::DecoderState const& decoderState, - std::vector const& logits, SizeType32 maxNumSequences, std::vector const& batchSlots); + static void createDecoderBatchInputs(DecoderInputBuffers& inputBuffers, std::vector const& activeSlots, + runtime::decoder::DecoderState const& decoderState); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index d0a9e726d13..9fcd3262c8c 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -52,8 +52,9 @@ class GptDecoderBatched : public IGptDecoderBatched void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; - CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; - void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; + CudaEvent forwardAsync( + decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override; + void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override; //! @brief Gather final beam search results for request `batchSlot`. //! Result will only be available after event returned. @@ -77,7 +78,7 @@ class GptDecoderBatched : public IGptDecoderBatched private: //! @brief Calls decoders for tokens per engine step - void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input); + void forwardDispatch(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input); private: CudaStreamPtr mRuntimeStream; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 606ba3c98a4..ab55b754f9b 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -27,8 +27,9 @@ namespace tensorrt_llm::batch_manager { +class DecoderInputBuffers; class LlmRequest; -} +} // namespace tensorrt_llm::batch_manager namespace tensorrt_llm::runtime { @@ -39,43 +40,6 @@ namespace decoder class DecoderState; } -namespace decoder_batch -{ - -class Input -{ -public: - using TensorConstPtr = ITensor::SharedConstPtr; - using TensorPtr = ITensor::SharedPtr; - - explicit Input(std::vector> const& logits, SizeType32 maxDecoderSteps) - : logits{logits} - , maxDecoderSteps{maxDecoderSteps} - { - TLLM_CHECK_WITH_INFO( - logits.size() == static_cast(maxDecoderSteps), "logits vector size does not match maxDecoderSteps"); - } - - explicit Input(std::vector const& logits) - : Input{{logits}, 1} - { - } - - //! Mandatory parameters - //! Logits - // FIXME: remove first dimension of tensors - //! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu - std::vector> logits; - - //! Maximum number of decoding tokens of active slots - SizeType32 maxDecoderSteps; - - //! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize] - std::vector batchSlots; -}; - -} // namespace decoder_batch - //! GPT decoder class with support for in-flight batching class IGptDecoderBatched { @@ -94,10 +58,13 @@ class IGptDecoderBatched virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0; //! @brief Run one step for all requests without blocking the host process and return the token for synchronization. - virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0; + virtual CudaEvent forwardAsync( + decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) + = 0; //! @brief Run one step for all requests and wait for completion on the host. - virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0; + virtual void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) + = 0; //! @brief Gather final beam search results for request `batchIdx`. //! Result will only be available after event returned diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 3a68d03eb69..cb2264ec800 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -182,7 +182,7 @@ void GuidedDecoder::execute(DecoderInputBuffers const& decoderInputBuffers, Buff { auto const seqSlot = llmReq->mSeqSlot.value(); - auto const& logits = decoderInputBuffers.logits.at(requestIdx); + auto const& logits = decoderInputBuffers.decoderLogits.at(requestIdx); auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot}); // Use void* to unify the code for different mLogitsDtype diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index df3840c14b4..6f4a541ffcb 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -79,7 +79,7 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re auto& decoderRequests = inputBuffers.decoderRequests; decoderRequests.clear(); decoderRequests.reserve(contextRequests.size()); - auto& allDecoderLogits = inputBuffers.logits; + auto& allDecoderLogits = inputBuffers.decoderLogits; allDecoderLogits.clear(); allDecoderLogits.reserve(contextRequests.size()); diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 5018ae36290..e2a7486b050 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -85,7 +85,7 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque auto& decoderRequests = inputBuffers.decoderRequests; decoderRequests.reserve(decoderRequests.size() + generationRequests.size()); - auto& allDecoderLogits = inputBuffers.logits; + auto& allDecoderLogits = inputBuffers.decoderLogits; allDecoderLogits.reserve(allDecoderLogits.size() + generationRequests.size()); for (auto const& llmReq : generationRequests) diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index dbb90da326a..95b324f0f2e 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -49,7 +49,7 @@ bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool rep for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx) { auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx); - auto& logits = inputBuffers.logits.at(batchIdx); + auto& logits = inputBuffers.decoderLogits.at(batchIdx); // Invoke non-batched processor or collect arguments for batched processor if (llmReq->mLogitsPostProcessor) diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index c9b2bb0b937..3e494a6383e 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -31,9 +31,8 @@ namespace tensorrt_llm::batch_manager using SizeType32 = MakeDecodingBatchInputOutput::SizeType32; using TensorPtr = MakeDecodingBatchInputOutput::TensorPtr; -std::unique_ptr MakeDecodingBatchInputOutput::createDecoderBatchInputs( - std::vector const& activeSlots, runtime::decoder::DecoderState const& decoderState, - std::vector const& decoderLogits, SizeType32 maxNumSequences, std::vector const& batchSlots) +void MakeDecodingBatchInputOutput::createDecoderBatchInputs(DecoderInputBuffers& inputBuffers, + std::vector const& activeSlots, runtime::decoder::DecoderState const& decoderState) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -42,9 +41,12 @@ std::unique_ptr MakeDecodingBatchInputOutput::createDe auto const& maxDecodingDecoderTokens = decoderState.getMaxDecodingDecoderTokens(); auto const maxDecoderSteps = common::ceilDiv(maxDecodingEngineTokens, maxDecodingDecoderTokens); + auto& batchSlots = inputBuffers.forwardBatchSlots; + auto& decoderLogits = inputBuffers.decoderLogits; + for (SizeType32 step = 0; step < maxDecoderSteps; ++step) { - batchSlots.at(step)->resize(maxNumSequences); + batchSlots.at(step)->resize(activeSlots.size()); } auto constexpr singleRequest = 1; @@ -64,7 +66,7 @@ std::unique_ptr MakeDecodingBatchInputOutput::createDe auto batchSlotsRange = tr::BufferRange(*batchSlots.at(step)); batchSlotsRange[batchSizes[step]] = slot; batchSizes[step]++; - TensorPtr logitsSlice = tr::ITensor::slice(logits, step, singleRequest); + auto logitsSlice = tr::ITensor::slice(logits, step, singleRequest); batchLogits[step].emplace_back(std::move(logitsSlice)); } } @@ -75,10 +77,10 @@ std::unique_ptr MakeDecodingBatchInputOutput::createDe } batchLogits.resize(maxActiveDecoderSteps); - auto decodingInput = std::make_unique(batchLogits, maxActiveDecoderSteps); - decodingInput->batchSlots = batchSlots; + inputBuffers.maxDecoderSteps = maxActiveDecoderSteps; + inputBuffers.batchLogits = batchLogits; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return decodingInput; } namespace @@ -155,16 +157,15 @@ void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntim } // namespace -std::unique_ptr MakeDecodingBatchInputOutput::operator()(DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences, +void MakeDecodingBatchInputOutput::operator()(DecoderInputBuffers& inputBuffers, + runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, OptionalRef fusedRuntimeBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto [activeSlots, generationSteps] = getActiveSlots(inputBuffers.decoderRequests); - auto decodingInput = createDecoderBatchInputs( - activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots); + createDecoderBatchInputs(inputBuffers, activeSlots, decoderState); auto const maxBeamWidth = decoderState.getMaxBeamWidth(); if (maxBeamWidth > 1) @@ -192,7 +193,6 @@ std::unique_ptr MakeDecodingBatchInputOutput::operator } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return decodingInput; } } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index a18b5196aad..f128dbb2ee4 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -1557,8 +1557,6 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& mOperatingBeamWidth, getMaxSequenceLen(), mRuntime->getBufferManager())); } - mDecodingInputs.resize(mNumMicroBatches); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -2071,11 +2069,9 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques auto const fusedBufferId = getFusedBufferId(); auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); - auto& decodingInput = mDecodingInputs.at(mMicroBatchId); - decodingInput = (*mMakeDecodingBatchInputOutput)(mDecoderInputBuffers.at(fusedBufferId), *mDecoderState, - mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers); + (*mMakeDecodingBatchInputOutput)(decoderInputBuffers, *mDecoderState, mModelConfig, *fusedRuntimeBuffers); - auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *decodingInput); + auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, decoderInputBuffers); auto const returnLogProbs = batchReturnLogProbs(scheduledRequests); auto updateDecoderBuffersEvent = (*mUpdateDecoderBuffers)(mModelConfig, mDecoderOutputBuffers.at(fusedBufferId), diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 1478172ddf9..71a5bc9d5f5 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -568,8 +568,6 @@ class TrtGptModelInflightBatching : public TrtGptModel std::vector> mSlotDecoderBuffers; // PEFT table for each micro batch std::vector mPeftTables; - // Decoder input for each micro batch. - std::vector> mDecodingInputs; /******************** Book keeping ********************/ // List of requests in each micro batch diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index dc42643f4ed..30decdb4959 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -6,6 +6,7 @@ set(TRTLLM_NB_MODULE set(SRCS batch_manager/algorithms.cpp batch_manager/bindings.cpp + batch_manager/buffers.cpp batch_manager/cacheTransceiver.cpp batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index d6149755e3e..526d0e6ffd9 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -397,39 +397,6 @@ void initBindings(nb::module_& m) .def(nb::init(), nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); - nb::class_(m, "DecoderInputBuffers") - .def(nb::init(), nb::arg("max_batch_size"), - nb::arg("max_tokens_per_engine_step"), nb::arg("manager")) - .def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) - .def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) - .def_rw("fill_values", &tb::DecoderInputBuffers::fillValues) - .def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) - .def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds) - .def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) - .def_rw("logits", &tb::DecoderInputBuffers::logits) - .def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests); - - nb::class_(m, "DecoderOutputBuffers") - .def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) - .def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) - .def_prop_ro("new_output_tokens_host", - [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) - .def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) - .def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) - .def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); - - nb::class_(m, "SlotDecoderBuffers") - .def(nb::init(), - nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager")) - .def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds) - .def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) - .def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) - .def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) - .def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) - .def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs) - .def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) - .def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); - m.def( "add_new_tokens_to_requests", [](std::vector>& requests, @@ -448,10 +415,10 @@ void initBindings(nb::module_& m) m.def( "make_decoding_batch_input", - [](std::vector>& contextRequests, - std::vector>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth, - std::vector const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers, - runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager) + [](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState, + std::vector> const& contextRequests, + std::vector> const& genRequests, tr::ITensor::SharedPtr const& logits, + int beamWidth, std::vector const& numContextLogitsPrefixSum, tr::BufferManager const& manager) { std::vector activeSlots; std::vector generationSteps; @@ -509,8 +476,7 @@ void initBindings(nb::module_& m) batchSlotsRange[i] = activeSlots[i]; } - auto decodingInput = std::make_unique(logitsVec, 1); - decodingInput->batchSlots = batchSlots; + decoderInputBuffers.batchLogits = logitsVec; auto const maxBeamWidth = decoderState.getMaxBeamWidth(); if (maxBeamWidth > 1) @@ -518,12 +484,10 @@ void initBindings(nb::module_& m) // For Variable-Beam-Width-Search decoderState.getJointDecodingInput().generationSteps = generationSteps; } - - return decodingInput; }, - nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"), - nb::arg("num_context_logits_prefix_sum"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), - nb::arg("buffer_manager"), "Make decoding batch input."); + nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"), + nb::arg("num_context_logits_prefix_sum"), nb::arg("buffer_manager"), "Make decoding batch input."); } } // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp new file mode 100644 index 00000000000..9b8e441745c --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "buffers.h" + +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/torch.h" + +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; + +using tr::SizeType32; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void Buffers::initBindings(nb::module_& m) +{ + nb::class_(m, "DecoderInputBuffers") + .def(nb::init(), nb::arg("max_batch_size"), + nb::arg("max_tokens_per_engine_step"), nb::arg("manager")) + .def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) + .def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) + .def_rw("fill_values", &tb::DecoderInputBuffers::fillValues) + .def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) + .def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds) + .def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) + .def_rw("decoder_logits", &tb::DecoderInputBuffers::decoderLogits) + .def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests); + + nb::class_(m, "DecoderOutputBuffers") + .def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) + .def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) + .def_prop_ro("new_output_tokens_host", + [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) + .def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) + .def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); + + nb::class_(m, "SlotDecoderBuffers") + .def(nb::init(), + nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager")) + .def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds) + .def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) + .def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) + .def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) + .def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) + .def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs) + .def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); +} +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h new file mode 100644 index 00000000000..d33570f7d36 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ +class Buffers +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 8d0585bfc71..d9c2687c7f2 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -34,6 +34,7 @@ #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/nanobind/batch_manager/algorithms.h" #include "tensorrt_llm/nanobind/batch_manager/bindings.h" +#include "tensorrt_llm/nanobind/batch_manager/buffers.h" #include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h" @@ -490,6 +491,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_prop_ro("uvm", &tr::MemoryCounters::getUVM); tensorrt_llm::nanobind::process_group::initBindings(mInternalProcessGroup); + tpb::Buffers::initBindings(mInternalBatchManager); tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime); tensorrt_llm::nanobind::testing::initBindings(mInternalTesting); tpb::initBindings(mInternalBatchManager); diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 388819b957a..7a698b4eb68 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -124,11 +124,6 @@ void initBindings(nb::module_& m) .def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, nb::arg("tag"), nb::call_guard()); - nb::class_(m, "BufferManager") - .def(nb::init(), nb::arg("stream"), nb::arg("trim_pool") = false, - nb::call_guard()) - .def_prop_ro("stream", &tr::BufferManager::getStream); - nb::class_(m, "TllmRuntime") .def( "__init__", @@ -170,15 +165,6 @@ void initBindings(nb::module_& m) .def_prop_ro("logits_dtype_from_engine", [](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); }); - nb::class_(m, "DecoderBatchInput") - .def(nb::init>, tr::SizeType32>(), nb::arg("logits"), - nb::arg("max_decoding_engine_tokens"), nb::call_guard()) - .def(nb::init>(), nb::arg("logits"), - nb::call_guard()) - .def_rw("logits", &tr::decoder_batch::Input::logits) - .def_rw("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) - .def_rw("batch_slots", &tr::decoder_batch::Input::batchSlots); - nb::class_(m, "LookaheadDecodingBuffers") .def(nb::init(), nb::arg("max_num_sequences"), nb::arg("max_tokens_per_step"), nb::arg("buffer_manager"), nb::call_guard()) @@ -397,6 +383,11 @@ void initBindings(nb::module_& m) void initBindingsEarly(nb::module_& m) { + nb::class_(m, "BufferManager") + .def(nb::init(), nb::arg("stream"), nb::arg("trim_pool") = false, + nb::call_guard()) + .def_prop_ro("stream", &tr::BufferManager::getStream); + nb::class_(m, "SpeculativeDecodingMode") .def(nb::init(), nb::arg("state")) .def_static("NoneType", &tr::SpeculativeDecodingMode::None) diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index a8bb99587cf..a85f6793a6a 100755 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -6,6 +6,7 @@ set(TRTLLM_PYBIND_MODULE set(SRCS batch_manager/algorithms.cpp batch_manager/bindings.cpp + batch_manager/buffers.cpp batch_manager/cacheTransceiver.cpp batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index 9361c1bd565..2f573fb7549 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/batch_manager/assignReqSeqSlots.h" #include "tensorrt_llm/batch_manager/capacityScheduler.h" #include "tensorrt_llm/batch_manager/createNewDecoderRequests.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/logitsPostProcessor.h" diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index ecaffdda6aa..20949cabe93 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -403,39 +403,6 @@ void initBindings(pybind11::module_& m) .def(py::init(), py::arg("max_num_sequences"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager")); - py::class_(m, "DecoderInputBuffers") - .def(py::init(), py::arg("max_batch_size"), - py::arg("max_tokens_per_engine_step"), py::arg("manager")) - .def_readwrite("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) - .def_readwrite("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) - .def_readwrite("fill_values", &tb::DecoderInputBuffers::fillValues) - .def_readwrite("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) - .def_readwrite("inputs_ids", &tb::DecoderInputBuffers::inputsIds) - .def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) - .def_readwrite("logits", &tb::DecoderInputBuffers::logits) - .def_readwrite("decoder_requests", &tb::DecoderInputBuffers::decoderRequests); - - py::class_(m, "DecoderOutputBuffers") - .def_readwrite("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) - .def_readwrite("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) - .def_property_readonly("new_output_tokens_host", - [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) - .def_readwrite("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) - .def_readwrite("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) - .def_readwrite("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); - - py::class_(m, "SlotDecoderBuffers") - .def(py::init(), - py::arg("max_beam_width"), py::arg("max_seq_len"), py::arg("buffer_manager")) - .def_readwrite("output_ids", &tb::SlotDecoderBuffers::outputIds) - .def_readwrite("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) - .def_readwrite("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) - .def_readwrite("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) - .def_readwrite("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) - .def_readwrite("log_probs", &tb::SlotDecoderBuffers::logProbs) - .def_readwrite("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) - .def_readwrite("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); - m.def( "add_new_tokens_to_requests", [](std::vector>& requests, @@ -454,10 +421,10 @@ void initBindings(pybind11::module_& m) m.def( "make_decoding_batch_input", - [](std::vector>& contextRequests, - std::vector>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth, - std::vector const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers, - runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager) + [](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState, + std::vector> const& contextRequests, + std::vector> const& genRequests, tr::ITensor::SharedPtr const& logits, + int beamWidth, std::vector const& numContextLogitsPrefixSum, tr::BufferManager const& manager) { std::vector activeSlots; std::vector generationSteps; @@ -515,8 +482,7 @@ void initBindings(pybind11::module_& m) batchSlotsRange[i] = activeSlots[i]; } - auto decodingInput = std::make_unique(logitsVec, 1); - decodingInput->batchSlots = batchSlots; + decoderInputBuffers.batchLogits = logitsVec; auto const maxBeamWidth = decoderState.getMaxBeamWidth(); if (maxBeamWidth > 1) @@ -524,12 +490,10 @@ void initBindings(pybind11::module_& m) // For Variable-Beam-Width-Search decoderState.getJointDecodingInput().generationSteps = generationSteps; } - - return decodingInput; }, - py::arg("context_requests"), py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"), - py::arg("num_context_logits_prefix_sum"), py::arg("decoder_input_buffers"), py::arg("decoder_state"), - py::arg("buffer_manager"), "Make decoding batch input."); + py::arg("decoder_input_buffers"), py::arg("decoder_state"), py::arg("context_requests"), + py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"), + py::arg("num_context_logits_prefix_sum"), py::arg("buffer_manager"), "Make decoding batch input."); } } // namespace tensorrt_llm::pybind::batch_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp b/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp new file mode 100644 index 00000000000..768d21169b9 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "buffers.h" + +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/runtime/torch.h" + +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; + +using tr::SizeType32; + +namespace tensorrt_llm::pybind::batch_manager +{ + +void Buffers::initBindings(pybind11::module_& m) +{ + py::class_(m, "DecoderInputBuffers") + .def(py::init(), py::arg("max_batch_size"), + py::arg("max_tokens_per_engine_step"), py::arg("manager")) + .def_readwrite("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) + .def_readwrite("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) + .def_readwrite("fill_values", &tb::DecoderInputBuffers::fillValues) + .def_readwrite("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) + .def_readwrite("inputs_ids", &tb::DecoderInputBuffers::inputsIds) + .def_readwrite("batch_logits", &tb::DecoderInputBuffers::batchLogits) + .def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) + .def_readwrite("decoder_logits", &tb::DecoderInputBuffers::decoderLogits) + .def_readwrite("max_decoder_steps", &tb::DecoderInputBuffers::maxDecoderSteps); + + py::class_(m, "DecoderOutputBuffers") + .def_readwrite("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) + .def_readwrite("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) + .def_property_readonly("new_output_tokens_host", + [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) + .def_readwrite("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) + .def_readwrite("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) + .def_readwrite("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); + + py::class_(m, "SlotDecoderBuffers") + .def(py::init(), + py::arg("max_beam_width"), py::arg("max_seq_len"), py::arg("buffer_manager")) + .def_readwrite("output_ids", &tb::SlotDecoderBuffers::outputIds) + .def_readwrite("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) + .def_readwrite("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) + .def_readwrite("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) + .def_readwrite("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) + .def_readwrite("log_probs", &tb::SlotDecoderBuffers::logProbs) + .def_readwrite("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) + .def_readwrite("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); +} +} // namespace tensorrt_llm::pybind::batch_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/buffers.h b/cpp/tensorrt_llm/pybind/batch_manager/buffers.h new file mode 100644 index 00000000000..bfe06c0e8e8 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/buffers.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/pybind/common/customCasters.h" +#include + +namespace tensorrt_llm::pybind::batch_manager +{ +class Buffers +{ +public: + static void initBindings(pybind11::module_& m); +}; +} // namespace tensorrt_llm::pybind::batch_manager diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 4b5415afd95..0a04d5ad19d 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -28,6 +28,8 @@ #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/pybind/batch_manager/algorithms.h" #include "tensorrt_llm/pybind/batch_manager/bindings.h" +#include "tensorrt_llm/pybind/batch_manager/buffers.h" + #include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h" @@ -478,6 +480,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("uvm", &tr::MemoryCounters::getUVM); tensorrt_llm::pybind::process_group::initBindings(mInternalProcessGroup); + tpb::Buffers::initBindings(mInternalBatchManager); tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime); tensorrt_llm::pybind::testing::initBindings(mInternalTesting); tpb::initBindings(mInternalBatchManager); diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 469aafe6476..ee4303d31b8 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -18,6 +18,7 @@ #include "bindings.h" #include "hostfunc.h" #include "moeBindings.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h" #include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" #include "tensorrt_llm/kernels/customAllReduceKernels.h" @@ -49,6 +50,7 @@ namespace tr = tensorrt_llm::runtime; namespace te = tensorrt_llm::executor; +namespace tb = tensorrt_llm::batch_manager; class PyITensor : public tensorrt_llm::runtime::ITensor { @@ -221,11 +223,6 @@ void initBindings(pybind11::module_& m) .def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, py::arg("tag"), py::call_guard()); - py::classh(m, "BufferManager") - .def(py::init(), py::arg("stream"), py::arg("trim_pool") = false, - py::call_guard()) - .def_property_readonly("stream", &tr::BufferManager::getStream); - py::classh(m, "TllmRuntime") .def(py::init( [](std::filesystem::path engine_path, float gpu_weights_percent = 1.0f, bool use_shape_inference = true) @@ -262,15 +259,6 @@ void initBindings(pybind11::module_& m) .def_property_readonly("logits_dtype_from_engine", [](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); }); - py::class_(m, "DecoderBatchInput") - .def(py::init>, tr::SizeType32>(), py::arg("logits"), - py::arg("max_decoding_engine_tokens"), py::call_guard()) - .def(py::init>(), py::arg("logits"), - py::call_guard()) - .def_readwrite("logits", &tr::decoder_batch::Input::logits) - .def_readwrite("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) - .def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots); - py::class_(m, "LookaheadDecodingBuffers") .def(py::init(), py::arg("max_num_sequences"), py::arg("max_tokens_per_step"), py::arg("buffer_manager"), py::call_guard()) @@ -492,6 +480,11 @@ void initBindings(pybind11::module_& m) void initBindingsEarly(py::module_& m) { + py::classh(m, "BufferManager") + .def(py::init(), py::arg("stream"), py::arg("trim_pool") = false, + py::call_guard()) + .def_property_readonly("stream", &tr::BufferManager::getStream); + py::class_(m, "SpeculativeDecodingMode") .def(py::init(), py::arg("state")) .def_static("NoneType", &tr::SpeculativeDecodingMode::None) diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 6df7b1634b8..916062d3cd7 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -19,7 +19,7 @@ #include "common.h" #include "decoderState.h" #include "iBuffer.h" -#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/executor/types.h" @@ -33,6 +33,7 @@ #include using namespace tensorrt_llm::runtime; +namespace tb = tensorrt_llm::batch_manager; using TensorPtr = ITensor::SharedPtr; GptDecoderBatched::GptDecoderBatched(GptDecoderBatched::CudaStreamPtr stream) @@ -102,7 +103,7 @@ namespace { //! @brief Prepare Input and Output for decoder step. // TODO: produce new input and output objects -void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Input const& input, +void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, tb::DecoderInputBuffers const& input, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -112,9 +113,9 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, auto& dInput = decoderState.getJointDecodingInput(); auto& dOutput = decoderState.getJointDecodingOutput(); - dInput.batchSlots = input.batchSlots.at(step); + dInput.batchSlots = input.forwardBatchSlots.at(step); dInput.batchSize = static_cast(dInput.batchSlots->getSize()); - dInput.logitsVec = input.logits.at(step); + dInput.logitsVec = input.batchLogits.at(step); if (speculativeDecodingMode.isDraftTokensExternal()) { @@ -139,7 +140,7 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, } // namespace -void GptDecoderBatched::forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) +void GptDecoderBatched::forwardDispatch(decoder::DecoderState const& decoderState, tb::DecoderInputBuffers const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -157,7 +158,8 @@ void GptDecoderBatched::forwardDispatch(decoder::DecoderState const& decoderStat TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -CudaEvent GptDecoderBatched::forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) +CudaEvent GptDecoderBatched::forwardAsync( + decoder::DecoderState const& decoderState, tb::DecoderInputBuffers const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -177,7 +179,7 @@ CudaEvent GptDecoderBatched::forwardAsync(decoder::DecoderState const& decoderSt return eventStop; } -void GptDecoderBatched::forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) +void GptDecoderBatched::forward(decoder::DecoderState const& decoderState, tb::DecoderInputBuffers const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto decoderFinishEvent = forwardAsync(decoderState, input); diff --git a/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp b/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp index 8358e987334..a9dc6ff785f 100644 --- a/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp +++ b/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp @@ -137,7 +137,7 @@ class GuidedDecoderTest : public ::testing::Test decoderInputBuffers.decoderRequests.push_back(llmReq); } } - decoderInputBuffers.logits = mLogits; + decoderInputBuffers.decoderLogits = mLogits; // Context phase resetLogits(); diff --git a/cpp/tests/e2e_tests/executor/executorTest.cpp b/cpp/tests/e2e_tests/executor/executorTest.cpp index 6b8c8d7eb9e..0da28c59552 100644 --- a/cpp/tests/e2e_tests/executor/executorTest.cpp +++ b/cpp/tests/e2e_tests/executor/executorTest.cpp @@ -2459,6 +2459,7 @@ void doTokenComparisonChangeBeamWidth(bool enableReuse, SizeType32 maxWaitMs) for (SizeType32 beamWidth : {1, 2}) { + TLLM_LOG_INFO("Running beam width: %d", beamWidth); BeamResult beamResult{beamWidth}; auto const resultsPath = GPT_DATA_PATH / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth)); @@ -2466,8 +2467,10 @@ void doTokenComparisonChangeBeamWidth(bool enableReuse, SizeType32 maxWaitMs) beamResult.contextLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_CONTEXT_LOGITS_FILE(); beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_FILE(); + auto const numReturnSequences = beamWidth; + runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig, - isSpeculativeDecoding, maxWaitMs, false, 1, false, 1); + isSpeculativeDecoding, maxWaitMs, false, numReturnSequences, false, 1); } } diff --git a/cpp/tests/unit_tests/runtime/gptDecoderBatchedTest.cpp b/cpp/tests/unit_tests/runtime/gptDecoderBatchedTest.cpp index 338b974aa0d..15476899fce 100644 --- a/cpp/tests/unit_tests/runtime/gptDecoderBatchedTest.cpp +++ b/cpp/tests/unit_tests/runtime/gptDecoderBatchedTest.cpp @@ -27,7 +27,6 @@ #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" -#include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/worldConfig.h" #include @@ -43,16 +42,11 @@ namespace tle = tensorrt_llm::executor; namespace tc = tensorrt_llm::common; namespace tb = tensorrt_llm::batch_manager; -using TensorPtr = decoder_batch::Input::TensorPtr; +using TensorPtr = ITensor::SharedPtr; namespace { -struct DecoderInputs -{ - std::vector logits; -}; - std::shared_ptr createLlmRequest(SizeType32 batchSlot, SizeType32 inputLengths, SizeType32 generatedTokensPerSteps, SizeType32 acceptedTokensPerStep, TokenIdType inputTokenId, TokenIdType expectedTokenId, SizeType32 maxNewTokens, SamplingConfig const& samplingConfig, TokenIdType endId) @@ -130,24 +124,23 @@ void newRequests(std::vector> const& requests, T runtimeStream.wait(event); } -DecoderInputs createDecoderInputs(SizeType32 batchSize, SizeType32 vocabSizePadded, nvinfer1::DataType dataType, - std::vector& samplingConfigs, std::vector const& generatedTokensPerSteps, - bool computeLogProbs, BufferManager& manager) +void createDecoderInputs(tb::DecoderInputBuffers& inputBuffers, SizeType32 batchSize, SizeType32 vocabSizePadded, + nvinfer1::DataType dataType, std::vector& samplingConfigs, + std::vector const& generatedTokensPerSteps, bool computeLogProbs, BufferManager& manager) { - DecoderInputs inputs; - - inputs.logits.reserve(batchSize); + auto& logits = inputBuffers.decoderLogits; + logits.reserve(batchSize); for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx) { - auto const beamWidth = samplingConfigs[batchIdx].beamWidth; - samplingConfigs[batchIdx].outputLogProbs = {{computeLogProbs}}; - samplingConfigs[batchIdx].cumLogProbs = {{computeLogProbs}}; - inputs.logits.emplace_back( + auto& samplingConfig = samplingConfigs[batchIdx]; + auto const beamWidth = samplingConfig.beamWidth; + samplingConfig.outputLogProbs = {{computeLogProbs}}; + samplingConfig.cumLogProbs = {{computeLogProbs}}; + + logits.emplace_back( manager.gpu(ITensor::makeShape({generatedTokensPerSteps[batchIdx], beamWidth, vocabSizePadded}), dataType)); - manager.setZero(*inputs.logits.back()); + manager.setZero(*logits.back()); } - - return inputs; } void copySequenceLengths( @@ -325,8 +318,8 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa auto batchSlotsRange = BufferRange(*inputBuffers.setupBatchSlots); std::iota(batchSlotsRange.begin(), batchSlotsRange.end(), 0); - auto decoderInputs = createDecoderInputs( - batchSize, vocabSizePadded, dataType, samplingConfigs, generatedTokensPerSteps, computeLogProbs, manager); + createDecoderInputs(inputBuffers, batchSize, vocabSizePadded, dataType, samplingConfigs, generatedTokensPerSteps, + computeLogProbs, manager); manager.setZero(*decoderState.getCacheIndirectionInput()); copySequenceLengths(tiledInputLengths, *decoderState.getSequenceLengths(), manager); @@ -352,9 +345,8 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa auto activeSlots = std::vector(batchSize); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( - activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); - decoder.forward(decoderState, *inputs); + tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(inputBuffers, activeSlots, decoderState); + decoder.forward(decoderState, inputBuffers); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(false)); @@ -365,14 +357,14 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa // run decoder for 1 step advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), batchSize, maxBeamWidth); - decoder.forward(decoderState, *inputs); + decoder.forward(decoderState, inputBuffers); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(true)); verifyResults(manager, decoderState, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth, maxSeqLength, inputTokenId, expectedTokenId, endId); - EXPECT_NO_THROW(decoder.forward(decoderState, *inputs)); + EXPECT_NO_THROW(decoder.forward(decoderState, inputBuffers)); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); TensorPtr batchSlotsView = ITensor::slice(inputBuffers.setupBatchSlots, 0, 1); @@ -457,8 +449,8 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector(batchIdx + 1); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( - activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); - decoder.forward(decoderState, *inputs); + tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(inputBuffers, activeSlots, decoderState); + decoder.forward(decoderState, inputBuffers); advanceSequenceLengths( expectedLengths, acceptedTokensPerStep, samplingConfigs, expectedFinished, batchIdx + 1, maxBeamWidth); @@ -505,9 +496,8 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector(batchSize); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( - activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); - decoder.forward(decoderState, *inputs); + tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(inputBuffers, activeSlots, decoderState); + decoder.forward(decoderState, inputBuffers); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(false)); diff --git a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py index 1f92b87ff16..2d72acfb026 100644 --- a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py +++ b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py @@ -4,7 +4,8 @@ import torch from tensorrt_llm._utils import nvtx_range -from tensorrt_llm.bindings.internal.runtime import DecoderBatchInput +from tensorrt_llm.bindings.internal.batch_manager import DecoderInputBuffers +from tensorrt_llm.bindings.internal.runtime import DecoderState @dataclass @@ -19,21 +20,22 @@ class MakeDecodingBatchInputOutput: @nvtx_range("make_decoding_batch_input_output") def __call__( self, + decoder_input_buffers: DecoderInputBuffers, + decoder_state: DecoderState, scheduled_requests, logits: torch.Tensor, beam_width: int, num_context_logits_prefix_sum: List[int], - ) -> DecoderBatchInput: + ): """Create decoder batch inputs and outputs for the given requests. Args: + decoder_input_buffers: Decoder input buffers + decoder_state: Current decoder state scheduled_requests: Scheduled requests logits: Logits tensor beam_width: Beam width num_context_logits_prefix_sum: Number of context logits prefix sum - - Returns: - DecoderBatchInput """ # In order to make a decoding_input assuming no drafting, we need: # 1. logits_vec = [[logits_slice of each active slot]] @@ -61,10 +63,9 @@ def __call__( start=logits_index + i * beam_width, length=beam_width).unsqueeze(0)) - decoding_input = DecoderBatchInput(logits_vec, 1) - decoding_input.generation_steps = generation_steps - decoding_input.batch_slots = [ + decoder_state.generation_steps = generation_steps + decoder_input_buffers.forward_batch_slots = [ torch.tensor(active_slots[0], dtype=torch.int32) ] - - return decoding_input + decoder_input_buffers.logits = logits_vec + decoder_input_buffers.max_decoder_steps = 1 diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 0ca3a27bd7c..1ed014bea86 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1928,7 +1928,6 @@ def _initialize_store(self): dtype=torch.int, ), "decoder_state": DecoderState(), - "decoding_input": [None] * self.num_micro_batches, } self.store["decoder_state"].setup( @@ -2037,19 +2036,20 @@ def sample_async( if beam_width > 1: self._update_cache_indirection_buffer(scheduled_requests) - self.store["decoding_input"][self.micro_batch_idx] = make_decoding_batch_input( + make_decoding_batch_input( + self.store["decoder_input_buffers"][self.micro_batch_idx], + self.store["decoder_state"], scheduled_requests.context_requests, scheduled_requests.generation_requests, model_outputs["logits"], beam_width, num_context_logits_prefix_sum, - self.store["decoder_input_buffers"][self.micro_batch_idx], - self.store["decoder_state"], self.store["buffer_manager"], ) self.algs.decoder.forward_async( - self.store["decoder_state"], self.store["decoding_input"][self.micro_batch_idx] + self.store["decoder_state"], + self.store["decoder_input_buffers"][self.micro_batch_idx], ) finalize_events = {}