diff --git a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h index 926b1349d07..63ab6dbfebb 100644 --- a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h @@ -107,32 +107,31 @@ class DecoderBuffers using SizeType32 = runtime::SizeType32; using TensorPtr = runtime::ITensor::SharedPtr; - TensorPtr cacheIndirectionInput; - TensorPtr cacheIndirectionOutput; - DraftBuffers draftBuffers; - DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, - SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig); + DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, + runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig); }; class DecoderStepAsyncSend { public: using SizeType32 = runtime::SizeType32; - using BufferPtr = runtime::IBuffer::SharedPtr; + using TensorPtr = runtime::ITensor::SharedPtr; - DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers, - bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer); + DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers, + TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, + mpi::MpiComm const& commSession, int peer); ~DecoderStepAsyncSend(); - static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers, - bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer); + static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers, + TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, + mpi::MpiComm const& commSession, int peer); - static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers, - bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int root); + static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers, + TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, + mpi::MpiComm const& commSession, int root); private: std::unique_ptr mRequest1; diff --git a/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h b/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h index 1c3dfb468db..67c730cf7b1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h +++ b/cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h @@ -47,16 +47,15 @@ class MakeDecodingBatchInputOutput : Algorithm MakeDecodingBatchInputOutput() = default; - std::tuple, std::unique_ptr> - operator()(RequestVector const& contextRequests, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers, - runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, - SizeType32 maxNumSequences, OptionalRef fusedRuntimeBuffers) const; + std::unique_ptr operator()(RequestVector const& contextRequests, + RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, + DecoderInputBuffers const& inputBuffers, runtime::decoder::DecoderState& decoderState, + runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences, + OptionalRef fusedRuntimeBuffers) const; [[nodiscard]] static std::unique_ptr createDecoderBatchInputs( std::vector const& activeSlots, runtime::decoder::DecoderState const& decoderState, - std::vector const& logits, SizeType32 maxNumSequences, std::vector const& batchSlots, - TensorPtr const& cacheIndirectionInput); + std::vector const& logits, SizeType32 maxNumSequences, std::vector const& batchSlots); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h index 9f777c439e5..4b68f6a71ad 100644 --- a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h @@ -279,10 +279,9 @@ class RuntimeBuffers std::tuple prepareStep(RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, - DecoderBuffers& decoderBuffers, runtime::decoder::DecoderState const& decoderState, - kv_cache_manager::BaseKVCacheManager* kvCacheManager, kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, - rnn_state_manager::RnnStateManager* rnnStateManager, PeftTable const& peftTable, - runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, + runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManager, + kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager, + PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits, bool trtOverlap, OptionalRef newOutputTokens = std::nullopt); @@ -314,8 +313,8 @@ class RuntimeBuffers runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits); void setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth, - SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers, - runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr, + SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState, + kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr, kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr, rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, diff --git a/cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h b/cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h index 17f23c8c129..dd9af60868e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h @@ -122,10 +122,6 @@ class TransformerBuffers void copyPositionIds(runtime::TllmRuntime const& runtime, std::vector const& positionIdsHost, bool isChatGlm, TensorPtr const& decoderPositionIds); - void resetCacheIndirection(RequestVector const& contextRequests, SizeType32 maxBeamWidth, - SizeType32 maxAttentionWindow, TensorPtr const& decoderCacheIndirectionInput, - TensorPtr const& decoderCacheIndirectionOutput, runtime::BufferManager const& manager); - void copyKvBlockOffsets(RequestVector const& contextRequests, RequestVector const& genRequests, kv_cache_manager::BaseKVCacheManager const* kvCacheManager, kv_cache_manager::BaseKVCacheManager const* crossKvCacheManager, runtime::BufferManager const& manager); diff --git a/cpp/include/tensorrt_llm/runtime/decoderState.h b/cpp/include/tensorrt_llm/runtime/decoderState.h index c4a3aad5185..75045e72efc 100644 --- a/cpp/include/tensorrt_llm/runtime/decoderState.h +++ b/cpp/include/tensorrt_llm/runtime/decoderState.h @@ -54,10 +54,15 @@ class DecoderState void allocateSpeculativeDecodingBuffers( SpeculativeDecodingMode speculativeDecodingMode, nvinfer1::DataType dtype, BufferManager const& bufferManager); + //! @brief Setup buffers for the decoder excluding speculative decoding. void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); + //! @brief Setup buffers for the cache indirection. + //! @details This is used for beam search on pipeline parallel ranks without a decoder. + void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow); + //! @brief Setup buffers for speculative decoding. void setupSpeculativeDecoding(SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, @@ -174,6 +179,12 @@ class DecoderState //! @brief Workspace for beam search in streaming mode. [[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const; + //! @brief Cache indirection input for beam search. + [[nodiscard]] TensorPtr getCacheIndirectionInput() const; + + //! @brief Cache indirection output for beam search. + [[nodiscard]] TensorPtr getCacheIndirectionOutput() const; + //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. [[nodiscard]] DecodingInput& getJointDecodingInput() const; diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index 2ac3aa94e74..d5dfe9b7b19 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -52,10 +52,8 @@ class GptDecoderBatched : public IGptDecoderBatched void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; - CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Output& output, - decoder_batch::Input const& input) override; - void forward(decoder::DecoderState const& decoderState, decoder_batch::Output& output, - decoder_batch::Input const& input) override; + CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; + void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; //! @brief Gather final beam search results for request `batchSlot`. //! Result will only be available after event returned. @@ -79,8 +77,7 @@ class GptDecoderBatched : public IGptDecoderBatched private: //! @brief Calls decoders for tokens per engine step - void forwardDispatch( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input); + void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input); private: CudaStreamPtr mRuntimeStream; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 821e916e10b..ed37c1260e9 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -76,8 +76,6 @@ class Input TensorPtr batchSlotsRequestOrder; //! For Beam Search - //! Indices into KV cache of different rays within one beam, [maxBatchSize, maxBeamWidth, maxSeqLen], on gpu - TensorPtr cacheIndirection; //! The generation step of each request (for Variable-Beam-Width-Search), [batchSize] std::vector generationSteps; @@ -95,17 +93,6 @@ class Input std::optional eagleLastInputs; }; -class Output -{ -public: - using TensorPtr = std::shared_ptr; - - Output() = default; - - //! parameters for beam search, [batchSize, maxBeamWidth, maxSeqLen], on gpu - TensorPtr cacheIndirection; -}; - } // namespace decoder_batch //! GPT decoder class with support for in-flight batching @@ -126,14 +113,10 @@ class IGptDecoderBatched virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0; //! @brief Run one step for all requests without blocking the host process and return the token for synchronization. - virtual CudaEvent forwardAsync( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input) - = 0; + virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0; //! @brief Run one step for all requests and wait for completion on the host. - virtual void forward( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input) - = 0; + virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0; //! @brief Gather final beam search results for request `batchIdx`. //! Result will only be available after event returned diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 37e344b0c94..e2a09a19b41 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -25,7 +25,6 @@ #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" -#include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -307,6 +306,12 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder parentIds->reshape(outputIdsShape); manager.setZero(*parentIds); + auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionInput); + + auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionOutput); + auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); beamHypotheses.init(manager, endId); } diff --git a/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp b/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp index 718991b83fe..b72426e92a5 100644 --- a/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp @@ -89,15 +89,9 @@ void DecoderOutputBuffers::disableLookaheadDecoding(SizeType32 maxNumSequences) TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, - SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig, - WorldConfig const& worldConfig) +DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const& manager, + ModelConfig const& modelConfig, WorldConfig const& worldConfig) { - cacheIndirectionInput = manager.gpu( - ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32); - cacheIndirectionOutput = manager.gpu( - ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32); - if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind() || modelConfig.getSpeculativeDecodingMode().hasDraftLogits() || modelConfig.getSpeculativeDecodingMode().predictsDraftTokens()) @@ -147,8 +141,8 @@ void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerSte } DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, - DecoderBuffers const& decoderBuffers, bool const returnLogProbs, SizeType32 const maxBeamWidth, - bool const useMedusa, mpi::MpiComm const& commSession, int peer) + DraftBuffers const& draftBuffers, TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, + SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession, int peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start send outputs of DecoderBuffers to rank %d", peer); @@ -165,14 +159,14 @@ DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOu mRequest5 = returnLogProbs ? commSession.sendAsync(*decoderOutputBuffers.logProbsHost, peer, mpi::MpiTag::kDecoderStepLogProbsHost) : nullptr; - mRequest6 = maxBeamWidth > 1 ? commSession.sendAsync( - *decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput) - : nullptr; - mRequest7 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer, + mRequest6 = maxBeamWidth > 1 + ? commSession.sendAsync(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput) + : nullptr; + mRequest7 = useMedusa ? commSession.sendAsync(*draftBuffers.acceptedLengthsCumSumDevice, peer, mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice) : nullptr; - mRequest8 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer, - mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice) + mRequest8 = useMedusa ? commSession.sendAsync( + *draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice) : nullptr; mRequest9 = commSession.sendAsync( *decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost); @@ -180,9 +174,9 @@ DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOu TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers, - bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession, - int const peer) +void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers, + TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth, + bool const useMedusa, mpi::MpiComm const& commSession, int const peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start recv outputs of DecoderBuffers from rank %d", peer); @@ -197,14 +191,14 @@ void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers } if (maxBeamWidth > 1) { - commSession.recv(*decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput); + commSession.recv(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput); } if (useMedusa) { - commSession.recv(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer, - mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice); - commSession.recv(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer, - mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice); + commSession.recv( + *draftBuffers.acceptedLengthsCumSumDevice, peer, mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice); + commSession.recv( + *draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice); } commSession.recv(*decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost); @@ -235,9 +229,9 @@ DecoderStepAsyncSend::~DecoderStepAsyncSend() TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers, - bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession, - int const root) +void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers, + TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth, + bool const useMedusa, mpi::MpiComm const& commSession, int const root) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start bcast outputs of DecoderBuffers from rank %d", root); @@ -247,11 +241,9 @@ void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffer auto request3 = commSession.bcastAsync(*decoderOutputBuffers.sequenceLengthsHost, root); auto request4 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.cumLogProbsHost, root) : nullptr; auto request5 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.logProbsHost, root) : nullptr; - auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*decoderBuffers.cacheIndirectionOutput, root) : nullptr; - auto request7 - = useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr; - auto request8 - = useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, root) : nullptr; + auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*cacheIndirectionOutput, root) : nullptr; + auto request7 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr; + auto request8 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedPackedPathsDevice, root) : nullptr; auto request9 = commSession.bcastAsync(*decoderOutputBuffers.finishReasonsHost, root); request1->wait(); diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index ba22802e106..4c89cc57514 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -33,8 +33,7 @@ using TensorPtr = MakeDecodingBatchInputOutput::TensorPtr; std::unique_ptr MakeDecodingBatchInputOutput::createDecoderBatchInputs( std::vector const& activeSlots, runtime::decoder::DecoderState const& decoderState, - std::vector const& logits, SizeType32 maxNumSequences, std::vector const& batchSlots, - TensorPtr const& cacheIndirectionInput) + std::vector const& logits, SizeType32 maxNumSequences, std::vector const& batchSlots) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -83,7 +82,6 @@ std::unique_ptr MakeDecodingBatchInputOutput::createDe auto decodingInput = std::make_unique(logitsVec, maxActiveDecoderSteps); decodingInput->batchSlots = batchSlots; - decodingInput->cacheIndirection = cacheIndirectionInput; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return decodingInput; } @@ -122,9 +120,8 @@ std::pair, std::vector> getActiveSlots( } // namespace -std::tuple, std::unique_ptr> -MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers, +std::unique_ptr MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, + RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers, runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences, OptionalRef fusedRuntimeBuffers) const { @@ -132,8 +129,8 @@ MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, R auto [activeSlots, generationSteps] = getActiveSlots(contextRequests, generationRequests); - auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, inputBuffers.logits, maxNumSequences, - inputBuffers.forwardBatchSlots, decoderBuffers.cacheIndirectionInput); + auto decodingInput = createDecoderBatchInputs( + activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots); decodingInput->generationSteps = generationSteps; if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) @@ -158,11 +155,8 @@ MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, R decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers->engineInputs; } - auto decodingOutput = std::make_unique(); - decodingOutput->cacheIndirection = decoderBuffers.cacheIndirectionOutput; - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return {std::move(decodingInput), std::move(decodingOutput)}; + return decodingInput; } } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 65680ff4210..e527d5a7769 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -451,8 +451,8 @@ void RuntimeBuffers::prepareBuffersForCudaGraph(SizeType32 maxSequenceLength) } void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests, - SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers, - runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr, + SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState, + kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr, kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr, rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, @@ -619,12 +619,6 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request { rnnStateBuffers->fillSlotMappings(contextRequests, rnnStateManagerPtr); } - - if (transformerBuffers && maxBeamWidth > 1) - { - transformerBuffers->resetCacheIndirection(contextRequests, maxBeamWidth, maxAttentionWindow, - decoderBuffers.cacheIndirectionInput, decoderBuffers.cacheIndirectionOutput, manager); - } } // generation preparation loop @@ -729,7 +723,7 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request if (transformerBuffers && maxBeamWidth > 1) { - transformerBuffers->copyCacheIndirection(genRequests, decoderBuffers.cacheIndirectionOutput, stream); + transformerBuffers->copyCacheIndirection(genRequests, decoderState.getCacheIndirectionOutput(), stream); } numSequences = numContextRequests; @@ -925,7 +919,7 @@ void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, R std::tuple RuntimeBuffers::prepareStep( RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth, - SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers, runtime::decoder::DecoderState const& decoderState, + SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManager, kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager, PeftTable const& peftTable, TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool gatherGenerationLogits, bool trtOverlap, @@ -937,8 +931,8 @@ std::tuple(*fillValuesAlt), numContextRequests, 0); - std::transform(contextRequests.begin(), contextRequests.end(), bufferCast(*seqSlotsAlt), - [](auto const& llmReq) { return llmReq->mSeqSlot.value(); }); - - auto const seqSlotsHostView = ITensor::slice(seqSlotsAlt, 0, numContextRequests); - auto seqSlotsDeviceView = ITensor::slice(seqSlotsAltDevice, 0, numContextRequests); - manager.copy(*seqSlotsHostView, *seqSlotsDeviceView); - manager.copy(*fillValuesAlt, *fillValuesAltDevice); - runtime::kernels::invokeFillBatch(*decoderCacheIndirectionInput, *seqSlotsDeviceView, - static_cast(maxBeamWidth) * maxAttentionWindow, *fillValuesAltDevice, stream); - runtime::kernels::invokeFillBatch(*decoderCacheIndirectionOutput, *seqSlotsDeviceView, - static_cast(maxBeamWidth) * maxAttentionWindow, *fillValuesAltDevice, stream); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - void TransformerBuffers::copyKvBlockOffsets(RequestVector const& contextRequests, RequestVector const& genRequests, kv_cache_manager::BaseKVCacheManager const* kvCacheManager, kv_cache_manager::BaseKVCacheManager const* crossKvCacheManager, BufferManager const& manager) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index b8855af568d..3bef5974f8d 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -1439,6 +1439,13 @@ void TrtGptModelInflightBatching::createDecoder(std::optionalsetupSpeculativeDecoding(mDecoderState->getSpeculativeDecodingMode(), mModelConfig.getMaxDecodingTokens(), mModelConfig, mWorldConfig, mRuntime->getBufferManager()); } + else + { + auto constexpr decoderDummyType = TRTDataType::value; + mDecoderState + = std::make_unique(decoderDummyType, mRuntime->getBufferManager()); + mDecoderState->setupCacheIndirection(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow()); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -1467,9 +1474,8 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); } - mDecoderBuffers - = std::make_unique(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), - mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager(), mModelConfig, mWorldConfig); + mDecoderBuffers = std::make_unique(getMaxNumSequences(), mModelConfig.getMaxDecodingTokens(), + mRuntime->getBufferManager(), mModelConfig, mWorldConfig); mSlotDecoderBuffers.clear(); for (SizeType32 i = 0; i < getMaxNumSequences(); ++i) @@ -1595,7 +1601,7 @@ void TrtGptModelInflightBatching::prepareDistGenBufferAndDecoder(RequestVector c auto const bufferId = getFusedBufferId(); auto& runtimeBuffers = *mBuffers[bufferId]; runtimeBuffers.prepareStep(cacheTransCompleteRequests, {}, getMaxBeamWidth(), getMaxAttentionWindow(), - *mDecoderBuffers, *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(), + *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(), mPeftTables[mMicroBatchId], *mRuntime, mModelConfig, mWorldConfig, getGatherGenerationLogits(), isTrtOverlap()); auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); @@ -1660,9 +1666,9 @@ TrtGptModelInflightBatching::prepareBuffers( : std::nullopt; auto [optProfileId, inputMap, outputMap] = runtimeBuffers.prepareStep(contextRequests, generationRequests, - mOperatingBeamWidth, getMaxAttentionWindow(), *mDecoderBuffers, *mDecoderState, mKvCacheManager.get(), - mCrossKvCacheManager.get(), mRnnStateManager.get(), mPeftTables[bufferId], *mRuntime, mModelConfig, - mWorldConfig, getGatherGenerationLogits(), isTrtOverlap(), allNewTokens); + mOperatingBeamWidth, getMaxAttentionWindow(), *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), + mRnnStateManager.get(), mPeftTables[bufferId], *mRuntime, mModelConfig, mWorldConfig, + getGatherGenerationLogits(), isTrtOverlap(), allNewTokens); // For Variable-Beam-Width-Search mRuntime->setCurrentBeamWidths( @@ -2010,9 +2016,10 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques (*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits, genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, mDecoderBuffers->draftBuffers); - // Copy indirection output into input - // TODO: Could we avoid this by modifying batchDecoder to take a vector of tensors instead? - copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId); + if (mOperatingBeamWidth > 1) + { + copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId); + } mLogitsPostProcessorIsApplied = (*mLogitsPostProcessor)(scheduledRequests.contextRequests, scheduledRequests.generationRequests, @@ -2027,11 +2034,11 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); auto& decodingInput = mDecodingInputs.at(mMicroBatchId); - std::tie(decodingInput, mDecodingOutput) = (*mMakeDecodingBatchInputOutput)(scheduledRequests.contextRequests, + decodingInput = (*mMakeDecodingBatchInputOutput)(scheduledRequests.contextRequests, scheduledRequests.generationRequests, *mDecoderBuffers, mDecoderInputBuffers.at(fusedBufferId), *mDecoderState, mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers); - auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *mDecodingOutput, *decodingInput); + auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *decodingInput); auto const returnLogProbs = batchReturnLogProbs(scheduledRequests); auto updateDecoderBuffersEvent @@ -2054,7 +2061,10 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( auto* copySizesPtr = bufferCast(*genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySizes); // Only `cacheIndirShape.d[2]` is used - auto const& cacheIndirShape = mDecoderBuffers->cacheIndirectionOutput->getShape(); + auto const& cacheIndirShape = mDecoderState->getCacheIndirectionOutput()->getShape(); + auto const maxBeamWidth = cacheIndirShape.d[1]; + auto const maxAttentionWindow = cacheIndirShape.d[2]; + auto const slotOffset = maxBeamWidth * maxAttentionWindow; SizeType32 batchIdx{0}; SizeType64 maxCopySize{0}; @@ -2065,9 +2075,9 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( { auto const reqBeamWidth = llmReq->getBeamWidthByIter(); auto const seqSlot = llmReq->mSeqSlot.value(); - auto const copySize = static_cast(cacheIndirShape.d[2]) * reqBeamWidth; - srcOffsetsPtr[batchIdx] = seqSlot * copySize; - dstOffsetsPtr[batchIdx] = seqSlot * copySize; + auto const copySize = reqBeamWidth * maxAttentionWindow; + srcOffsetsPtr[batchIdx] = seqSlot * slotOffset; + dstOffsetsPtr[batchIdx] = seqSlot * slotOffset; copySizesPtr[batchIdx] = copySize; maxCopySize = std::max(maxCopySize, copySize); batchIdx++; @@ -2091,8 +2101,8 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( auto const copySizesDeviceSlice = ITensor::slice(genRuntimeBuffers.mCacheIndirDecoderIOBatchedCopyCopySizesDevice, 0, batchIdx); manager.copy(sizesSlice->data(), *copySizesDeviceSlice); // Explicitly move to device for faster access. - runtime::kernels::invokeCopyBatch(*mDecoderBuffers->cacheIndirectionOutput, - *mDecoderBuffers->cacheIndirectionInput, *srcOffsetsSliceDeviceSlice, *dstOffsetsSliceDeviceSlice, + runtime::kernels::invokeCopyBatch(*mDecoderState->getCacheIndirectionOutput(), + *mDecoderState->getCacheIndirectionInput(), *srcOffsetsSliceDeviceSlice, *dstOffsetsSliceDeviceSlice, *copySizesDeviceSlice, maxCopySize, manager.getStream()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -2111,30 +2121,34 @@ std::vector> TrtGptModelInflightBatching:: { if (broadcastPostDecoder()) { - DecoderStepAsyncSend::bcast(decoderOutputBuffers, *mDecoderBuffers, returnLogProbs, mOperatingBeamWidth, + DecoderStepAsyncSend::bcast(decoderOutputBuffers, mDecoderBuffers->draftBuffers, + mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommTensorPara, 0); } if (mWorldConfig.isPipelineParallel()) { auto const peerSend = 0; - asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, *mDecoderBuffers, - returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), - *mMpiCommPipelinePara, peerSend)); + asyncHandles.emplace_back( + std::make_unique(decoderOutputBuffers, mDecoderBuffers->draftBuffers, + mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth, + mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend)); } } else { auto const peerRecv = mWorldConfig.isFirstPipelineParallelRank() ? mWorldConfig.getPipelineParallelism() - 1 : mWorldConfig.getPipelineParallelRank() - 1; - DecoderStepAsyncSend::recv(decoderOutputBuffers, *mDecoderBuffers, returnLogProbs, mOperatingBeamWidth, + DecoderStepAsyncSend::recv(decoderOutputBuffers, mDecoderBuffers->draftBuffers, + mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerRecv); auto const peerSend = mWorldConfig.getPipelineParallelRank() + 1; if (peerSend != mWorldConfig.getPipelineParallelism() - 1) { - asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, *mDecoderBuffers, - returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), - *mMpiCommPipelinePara, peerSend)); + asyncHandles.emplace_back( + std::make_unique(decoderOutputBuffers, mDecoderBuffers->draftBuffers, + mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth, + mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend)); } } TLLM_CHECK_WITH_INFO(asyncHandles.size() <= 2, "Up to two decoder step async handles expected"); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 831ada4f510..f7df11b911a 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -537,7 +537,6 @@ class TrtGptModelInflightBatching : public TrtGptModel std::vector mPeftTables; // Decoder input for each micro batch. std::vector> mDecodingInputs; - std::unique_ptr mDecodingOutput; /******************** Book keeping ********************/ // List of requests in each micro batch diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index 3116002206c..3598d3422fa 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -16,7 +16,6 @@ */ #include "tensorrt_llm/executor/executorImpl.h" -#include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/batch_manager/trtEncoderModel.h" #include "tensorrt_llm/batch_manager/trtGptModelFactory.h" #include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 5191c8d1bb8..6a63b50a4bc 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -423,12 +423,10 @@ void initBindings(pybind11::module_& m) py::arg("runtime"), py::arg("model_config")); py::classh(m, "DecoderBuffers") - .def(py::init(), - py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window"), - py::arg("max_tokens_per_step"), py::arg("buffer_manager"), py::arg("model_config"), py::arg("world_config")) - .def_readwrite("cache_indirection_input", &tb::DecoderBuffers::cacheIndirectionInput) - .def_readwrite("cache_indirection_output", &tb::DecoderBuffers::cacheIndirectionOutput) + .def(py::init(), + py::arg("max_num_sequences"), py::arg("max_tokens_per_step"), py::arg("buffer_manager"), + py::arg("model_config"), py::arg("world_config")) .def_readwrite("draft_buffers", &tb::DecoderBuffers::draftBuffers); py::class_(m, "SlotDecoderBuffers") diff --git a/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp b/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp index 721b12f6872..283c8fb3f9e 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/buffers.cpp @@ -52,9 +52,6 @@ void Buffers::initBindings(pybind11::module_& m) py::arg("buffer_manager")) .def("get_buffers", &tb::TransformerBuffers::getBuffers, py::arg("input_buffers"), py::arg("output_buffers"), py::arg("model_config")) - .def("reset_cache_indirection", &tb::TransformerBuffers::resetCacheIndirection, py::arg("context_requests"), - py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("decoder_cache_indirection_input"), - py::arg("decoder_cache_indirection_output"), py::arg("runtime")) .def("copy_position_ids", &tb::TransformerBuffers::copyPositionIds, py::arg("runtime"), py::arg("position_ids_host"), py::arg("is_chat_glm"), py::arg("decoder_position_ids")) .def("copy_kv_block_offsets", &tb::TransformerBuffers::copyKvBlockOffsets, py::arg("context_requests"), diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 1839b4f1196..dc683d02be4 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -275,14 +275,9 @@ void initBindings(pybind11::module_& m) .def_readwrite("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) .def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots) .def_readwrite("batch_slots_request_order", &tr::decoder_batch::Input::batchSlotsRequestOrder) - .def_readwrite("cache_indirection", &tr::decoder_batch::Input::cacheIndirection) .def_readwrite("generation_steps", &tr::decoder_batch::Input::generationSteps) .def_readwrite("predicted_draft_logits", &tr::decoder_batch::Input::predictedDraftLogits); - py::class_(m, "DecoderBatchOutput") - .def(py::init()) - .def_readwrite("cache_indirection", &tr::decoder_batch::Output::cacheIndirection); - py::class_(m, "LookaheadDecodingBuffers") .def(py::init(), py::arg("max_num_sequences"), py::arg("max_tokens_per_step"), py::arg("buffer_manager")) @@ -391,8 +386,7 @@ void initBindings(pybind11::module_& m) .def(py::init(), py::arg("stream")) .def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_batch_size"), py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config")) - .def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("output"), - py::arg("input")) + .def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input")) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference) .def_property_readonly( "decoder_stream", diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index 57c90b43643..dc68126d28c 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -95,6 +95,9 @@ DecoderState::DecoderState(nvinfer1::DataType dtype, BufferManager const& buffer mBeamSearchBuffers = std::make_unique(bufferManager); + dInput->cacheIndirection = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); + dOutput->cacheIndirection = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -249,6 +252,8 @@ void DecoderState::setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeT dOutput.beamHypotheses.reshape(mMaxBatchSize, mMaxBeamWidth, mMaxSequenceLength); mBeamSearchBuffers->reshape(mMaxBeamWidth, mMaxSequenceLength); + setupCacheIndirection(mMaxBatchSize, mMaxBeamWidth, maxAttentionWindow); + dOutput.gatheredIds->reshape(maxTotalTokensShape); } else @@ -268,6 +273,15 @@ void DecoderState::setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeT TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +void DecoderState::setupCacheIndirection( + SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow) +{ + mJointDecodingInput->cacheIndirection->reshape( + ITensor::makeShape({maxBatchSize, maxBeamWidth, maxAttentionWindow})); + mJointDecodingOutput->cacheIndirection->reshape( + ITensor::makeShape({maxBatchSize, maxBeamWidth, maxAttentionWindow})); +} + void DecoderState::setupSpeculativeDecoding(SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) @@ -578,6 +592,16 @@ BeamSearchBuffers const& DecoderState::getBeamSearchBuffers() const return *mBeamSearchBuffers; } +TensorPtr DecoderState::getCacheIndirectionInput() const +{ + return mJointDecodingInput->cacheIndirection; +} + +TensorPtr DecoderState::getCacheIndirectionOutput() const +{ + return mJointDecodingOutput->cacheIndirection; +} + DecodingInput& DecoderState::getJointDecodingInput() const { return *mJointDecodingInput; diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 84656ca3ebe..0cdf72f3980 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -147,8 +147,8 @@ void setEagleInputs(DecodingInput& dInput, decoder_batch::Input const& input) //! @brief Prepare Input and Output for decoder step. // TODO: produce new input and output objects -void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Output& output, - decoder_batch::Input const& input, BufferManager const& bufferManager) +void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Input const& input, + BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -160,8 +160,6 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, if (maxBeamWidth > 1) { - dInput.cacheIndirection = input.cacheIndirection; - dOutput.cacheIndirection = output.cacheIndirection; dInput.generationSteps = input.generationSteps; // For Variable-Beam-Width-Search } @@ -217,15 +215,14 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, } // namespace -void GptDecoderBatched::forwardDispatch( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input) +void GptDecoderBatched::forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); for (SizeType32 step = 0; step < input.maxDecoderSteps; ++step) { BufferManager manager{mDecoderStream}; - prepareForward(decoderState, step, output, input, manager); + prepareForward(decoderState, step, input, manager); if (decoderState.getJointDecodingInput().batchSize > 0) { @@ -236,8 +233,7 @@ void GptDecoderBatched::forwardDispatch( TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -CudaEvent GptDecoderBatched::forwardAsync( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input) +CudaEvent GptDecoderBatched::forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -245,7 +241,7 @@ CudaEvent GptDecoderBatched::forwardAsync( mRuntimeStream->record(eventStart); mDecoderStream->wait(eventStart.get()); - forwardDispatch(decoderState, output, input); + forwardDispatch(decoderState, input); CudaEvent event{}; mDecoderStream->record(event); @@ -257,11 +253,10 @@ CudaEvent GptDecoderBatched::forwardAsync( return eventStop; } -void GptDecoderBatched::forward( - decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input) +void GptDecoderBatched::forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto decoderFinishEvent = forwardAsync(decoderState, output, input); + auto decoderFinishEvent = forwardAsync(decoderState, input); decoderFinishEvent.synchronize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } diff --git a/cpp/tests/runtime/gptDecoderBatchedTest.cpp b/cpp/tests/runtime/gptDecoderBatchedTest.cpp index ebc55d8630c..f7f28d343ff 100644 --- a/cpp/tests/runtime/gptDecoderBatchedTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchedTest.cpp @@ -51,7 +51,6 @@ namespace struct DecoderInputs { std::vector logits; - TensorPtr srcCacheIndirection; }; std::shared_ptr createLlmRequest(SizeType32 batchSlot, SizeType32 inputLengths, @@ -132,9 +131,9 @@ void newRequests(std::vector> const& requests, T runtimeStream.wait(event); } -DecoderInputs createDecoderInputs(SizeType32 batchSize, SizeType32 maxBeamWidth, SizeType32 maxSeqLength, - SizeType32 vocabSizePadded, nvinfer1::DataType dataType, std::vector& samplingConfigs, - std::vector const& generatedTokensPerSteps, bool computeLogProbs, BufferManager& manager) +DecoderInputs createDecoderInputs(SizeType32 batchSize, SizeType32 vocabSizePadded, nvinfer1::DataType dataType, + std::vector& samplingConfigs, std::vector const& generatedTokensPerSteps, + bool computeLogProbs, BufferManager& manager) { DecoderInputs inputs; @@ -149,33 +148,14 @@ DecoderInputs createDecoderInputs(SizeType32 batchSize, SizeType32 maxBeamWidth, manager.setZero(*inputs.logits.back()); } - if (maxBeamWidth > 1) - { - inputs.srcCacheIndirection - = manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType::value); - manager.setZero(*inputs.srcCacheIndirection); - } - return inputs; } -decoder_batch::Output createDecoderOutputs(SizeType32 batchSize, SizeType32 maxBeamWidth, SizeType32 maxSeqLength, +void copySequenceLengths( std::vector const& tiledInputLengths, ITensor& sequenceLengths, BufferManager const& manager) { - decoder_batch::Output outputs{}; - TLLM_CHECK(sequenceLengths.getSize() == tiledInputLengths.size()); manager.copy(tiledInputLengths.data(), sequenceLengths); - - if (maxBeamWidth > 1) - { - auto tgtCacheIndirection - = manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType::value); - manager.setZero(*tgtCacheIndirection); - outputs.cacheIndirection = std::move(tgtCacheIndirection); - } - - return outputs; } [[nodiscard]] std::vector getFinished( @@ -352,10 +332,10 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa auto batchSlotsRange = BufferRange(*inputBuffers.setupBatchSlots); std::iota(batchSlotsRange.begin(), batchSlotsRange.end(), 0); - auto decoderInputs = createDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType, - samplingConfigs, generatedTokensPerSteps, computeLogProbs, manager); - auto outputs = createDecoderOutputs( - batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, *decoderState.getSequenceLengths(), manager); + auto decoderInputs = createDecoderInputs( + batchSize, vocabSizePadded, dataType, samplingConfigs, generatedTokensPerSteps, computeLogProbs, manager); + manager.setZero(*decoderState.getCacheIndirectionInput()); + copySequenceLengths(tiledInputLengths, *decoderState.getSequenceLengths(), manager); auto requests = createLlmRequests(inputLengths, generatedTokensPerSteps, acceptedTokensPerStep, inputTokenId, expectedTokenId, inputBuffers.setupBatchSlots, samplingConfigs, maxNewTokens, endId); @@ -379,9 +359,9 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa auto activeSlots = std::vector(batchSize); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(activeSlots, decoderState, - decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots, decoderInputs.srcCacheIndirection); - decoder.forward(decoderState, outputs, *inputs); + auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( + activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); + decoder.forward(decoderState, *inputs); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(false)); @@ -392,14 +372,14 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa // run decoder for 1 step advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), batchSize, maxBeamWidth); - decoder.forward(decoderState, outputs, *inputs); + decoder.forward(decoderState, *inputs); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(true)); verifyResults(manager, decoderState, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth, maxSeqLength, inputTokenId, expectedTokenId, endId); - EXPECT_NO_THROW(decoder.forward(decoderState, outputs, *inputs)); + EXPECT_NO_THROW(decoder.forward(decoderState, *inputs)); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); TensorPtr batchSlotsView = ITensor::slice(inputBuffers.setupBatchSlots, 0, 1); @@ -490,10 +470,10 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector const expectedSteps(batchSize, 0); auto expectedLengths = tiledInputLengths; @@ -515,9 +495,9 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector(batchIdx + 1); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(activeSlots, decoderState, - decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots, decoderInputs.srcCacheIndirection); - decoder.forward(decoderState, outputs, *inputs); + auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( + activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); + decoder.forward(decoderState, *inputs); advanceSequenceLengths( expectedLengths, acceptedTokensPerStep, samplingConfigs, expectedFinished, batchIdx + 1, maxBeamWidth); @@ -538,9 +518,9 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector(*inputBuffers.setupBatchSlots); std::iota(batchSlotsRange.begin(), batchSlotsRange.end(), 0); @@ -675,9 +655,9 @@ void testDecoderDraft(nvinfer1::DataType const dtype, std::vector(batchSize); std::iota(activeSlots.begin(), activeSlots.end(), 0); - auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(activeSlots, decoderState, - decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots, decoderInputs.srcCacheIndirection); - decoder.forward(decoderState, outputs, *inputs); + auto inputs = tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs( + activeSlots, decoderState, decoderInputs.logits, batchSize, inputBuffers.forwardBatchSlots); + decoder.forward(decoderState, *inputs); checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager); EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(false)); diff --git a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py index 8a241693675..d2b94905db7 100644 --- a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py +++ b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py @@ -1,11 +1,10 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List import torch from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest -from tensorrt_llm.bindings.internal.runtime import (DecoderBatchInput, - DecoderBatchOutput) +from tensorrt_llm.bindings.internal.runtime import DecoderBatchInput @dataclass @@ -18,12 +17,12 @@ class MakeDecodingBatchInputOutput: @staticmethod def create_decoder_batch_inputs( + *, active_slots: List[int], decoder_state, logits: List[torch.Tensor], max_num_sequences: int, batch_slots: List[torch.Tensor], - cache_indirection_input: Optional[torch.Tensor] = None ) -> DecoderBatchInput: """Create decoder batch inputs from active slots and logits. @@ -33,7 +32,6 @@ def create_decoder_batch_inputs( logits: List of logit tensors for each slot max_num_sequences: Maximum number of sequences to process batch_slots: List of batch slot tensors for each decoding step - cache_indirection_input: Optional cache indirection input tensor Returns: DecoderBatchInput containing the prepared inputs @@ -83,16 +81,19 @@ def create_decoder_batch_inputs( # Create decoder batch input decoding_input = DecoderBatchInput(logits_vec, max_active_decoder_steps) decoding_input.batch_slots = batch_slots - decoding_input.cache_indirection = cache_indirection_input return decoding_input def __call__( - self, context_requests: List[LlmRequest], - generation_requests: List[LlmRequest], decoder_buffers, - decoder_input_buffers, decoder_state, model_config, - max_num_sequences: int - ) -> Tuple[DecoderBatchInput, DecoderBatchOutput]: + self, + context_requests: List[LlmRequest], + generation_requests: List[LlmRequest], + decoder_buffers, + decoder_input_buffers, + decoder_state, + model_config, + max_num_sequences: int, + ) -> DecoderBatchInput: """Create decoder batch inputs and outputs for the given requests. Args: @@ -106,7 +107,7 @@ def __call__( fused_runtime_buffers: Optional fused runtime buffers Returns: - Tuple of (DecoderBatchInput, DecoderBatchOutput) + DecoderBatchInput """ # Get active slots and generation steps active_slots = [] @@ -126,9 +127,12 @@ def __call__( # Create decoder batch inputs decoding_input = self.create_decoder_batch_inputs( - active_slots, decoder_state, decoder_input_buffers.logits, - max_num_sequences, decoder_input_buffers.forward_batch_slots, - decoder_buffers.cache_indirection_input) + active_slots=active_slots, + decoder_state=decoder_state, + logits=decoder_input_buffers.logits, + max_num_sequences=max_num_sequences, + batch_slots=decoder_input_buffers.forward_batch_slots, + ) decoding_input.generation_steps = generation_steps # Handle speculative decoding modes @@ -149,8 +153,4 @@ def __call__( # decoding_input.eagle_inputs = fused_runtime_buffers.eagle_buffers.engine_outputs # decoding_input.eagle_last_inputs = fused_runtime_buffers.eagle_buffers.engine_inputs - # Create decoder batch output - decoding_output = DecoderBatchOutput() - decoding_output.cache_indirection = decoder_buffers.cache_indirection_output - - return decoding_input, decoding_output + return decoding_input diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 4106c5976b4..2c116831964 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -524,9 +524,7 @@ def _initialize_store(self): "buffer_manager": buffer_manager, "decoder_buffers": - DecoderBuffers(self.max_num_sequences, - self.executor_config.max_beam_width, - self.max_attention_window, self.MAX_DECODING_TOKENS, + DecoderBuffers(self.max_num_sequences, self.MAX_DECODING_TOKENS, buffer_manager, self.model_config, self.world_config), "decoder_input_buffers": @@ -626,7 +624,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests, self.store["decoder_input_buffers"].logits = decoder_buffer_logits - decoding_input, self.decoding_output = self.algs.make_decoding_batch_input_output( + decoding_input = self.algs.make_decoding_batch_input_output( scheduled_requests.context_requests, scheduled_requests.generation_requests, self.store["decoder_buffers"], self.store["decoder_input_buffers"], @@ -634,7 +632,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests, self.max_num_sequences) self.algs.decoder.forward_async(self.store["decoder_state"], - self.decoding_output, decoding_input) + decoding_input) # NOTE: The following code prepares a new_tokens_device_tensor in accordance with the # current implementation of model_engine.