diff --git a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h index 5f3fc5739c4..926b1349d07 100644 --- a/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h @@ -36,18 +36,28 @@ class DecoderInputBuffers using SizeType32 = runtime::SizeType32; using TensorPtr = runtime::ITensor::SharedPtr; - explicit DecoderInputBuffers( - SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager); + explicit DecoderInputBuffers(SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, + runtime::BufferManager const& manager); - // buffers for setup + //! Buffers for decoder setup + + //! Input IDs of new requests, [maxBatchSize] TensorPtr inputsIds; + //! Batch slots for setup step, [maxBatchSize] TensorPtr setupBatchSlots; TensorPtr setupBatchSlotsDevice; + //! Helper buffer for copying sequence lengths, [maxBatchSize] TensorPtr fillValues; TensorPtr fillValuesDevice; - // buffers for forward + //! Buffers for decoder forward + + //! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize] std::vector forwardBatchSlots; + + //! Logits for all batch slots, [maxNumSequences] + //! The vector is sparse, only slots in forwardBatchSlots are used. + std::vector logits; }; class DecoderOutputBuffers @@ -70,35 +80,36 @@ class DecoderOutputBuffers TensorPtr finishReasonsHost; // [mMaxNumRequests, beamWidth], pinned host tensor }; -class DecoderBuffers +class DraftBuffers { public: using SizeType32 = runtime::SizeType32; using TensorPtr = runtime::ITensor::SharedPtr; - std::vector logits; + TensorPtr nextDraftTokensDevice; // [mMaxNumRequests, maxTokensPerStep-1] + TensorPtr nextDraftTokensHost; // [mMaxNumRequests, maxTokensPerStep-1] + TensorPtr prevDraftTokensLengthsDevice; // [mMaxNumRequests] + TensorPtr prevDraftTokensLengthsHost; // [mMaxNumRequests] + TensorPtr nextDraftTokensLengthsDevice; // [mMaxNumRequests] + TensorPtr nextDraftTokensLengthsHost; // [mMaxNumRequests] + TensorPtr acceptedLengthsCumSumDevice; // [mMaxNumRequests+1] + TensorPtr acceptedPackedPathsDevice; // [mMaxNumRequests * maxAcceptedTokens] + std::vector> + predictedDraftLogits; // [mMaxNumRequests][mMaxNumHeads][maxDraftTokens + 1, vocabSize] + + void create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, + runtime::ModelConfig const& modelConfig); +}; + +class DecoderBuffers +{ +public: + using SizeType32 = runtime::SizeType32; + using TensorPtr = runtime::ITensor::SharedPtr; TensorPtr cacheIndirectionInput; TensorPtr cacheIndirectionOutput; - class DraftBuffers - { - public: - TensorPtr nextDraftTokensDevice; // [mMaxNumRequests, maxTokensPerStep-1] - TensorPtr nextDraftTokensHost; // [mMaxNumRequests, maxTokensPerStep-1] - TensorPtr prevDraftTokensLengthsDevice; // [mMaxNumRequests] - TensorPtr prevDraftTokensLengthsHost; // [mMaxNumRequests] - TensorPtr nextDraftTokensLengthsDevice; // [mMaxNumRequests] - TensorPtr nextDraftTokensLengthsHost; // [mMaxNumRequests] - TensorPtr acceptedLengthsCumSumDevice; // [mMaxNumRequests+1] - TensorPtr acceptedPackedPathsDevice; // [mMaxNumRequests * maxAcceptedTokens] - std::vector> - predictedDraftLogits; // [mMaxNumRequests][mMaxNumHeads][maxDraftTokens + 1, vocabSize] - - void create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, - runtime::ModelConfig const& modelConfig); - }; - DraftBuffers draftBuffers; DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, diff --git a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h index cf5235a0570..5e35346da04 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h @@ -31,8 +31,8 @@ class CudaStream; namespace tensorrt_llm::batch_manager { -class RuntimeBuffers; -class DecoderBuffers; +class DecoderInputBuffers; +class DraftBuffers; class MedusaBuffers; namespace tr = tensorrt_llm::runtime; @@ -47,10 +47,10 @@ class HandleContextLogits : Algorithm HandleContextLogits() = default; - tr::SizeType32 operator()(RequestVector const& contextRequests, - std::vector const& numContextLogitsVec, tr::ITensor::SharedPtr const& logits, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, OptionalRef medusaBuffers) const; + tr::SizeType32 operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, + tr::ITensor::SharedPtr const& logits, std::vector const& numContextLogitsVec, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef draftBuffers, + OptionalRef medusaBuffers) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h index 64a4ed6892e..33f32006527 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h @@ -30,8 +30,9 @@ class BufferManager; namespace tensorrt_llm::batch_manager { +class DecoderInputBuffers; +class DraftBuffers; class RuntimeBuffers; -class DecoderBuffers; namespace tr = tensorrt_llm::runtime; @@ -45,9 +46,10 @@ class HandleGenerationLogits : Algorithm HandleGenerationLogits() = default; - void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, - tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, tr::ITensor::SharedPtr const& logits, - OptionalRef genRuntimeBuffers) const; + void operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests, + tr::ITensor::SharedPtr const& logits, tr::SizeType32 logitsIndex, tr::ModelConfig const& modelConfig, + tr::BufferManager const& manager, OptionalRef genRuntimeBuffers, + OptionalRef draftBuffers) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h index a4a8cc035cd..9610b96763b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h +++ b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h @@ -30,10 +30,6 @@ class TllmRuntime; namespace tensorrt_llm::batch_manager { -class DecoderBuffers; - -namespace tr = tensorrt_llm::runtime; - class LogitsPostProcessor : Algorithm { public: @@ -48,8 +44,8 @@ class LogitsPostProcessor : Algorithm LogitsPostProcessor() = default; bool operator()(RequestVector const& contextRequests, RequestVector const& generationRequests, - bool replicateLogitsPostProcessor, DecoderBuffers& decoderBuffers, tr::WorldConfig const& worldConfig, - tr::TllmRuntime& runtime, + bool replicateLogitsPostProcessor, std::vector& seqSlotLogits, + runtime::WorldConfig const& worldConfig, runtime::TllmRuntime& runtime, std::optional logitsPostProcessorBatched = std::nullopt) const; }; diff --git a/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp b/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp index 682ea3ca786..718991b83fe 100644 --- a/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp @@ -30,7 +30,7 @@ namespace tensorrt_llm::batch_manager { DecoderInputBuffers::DecoderInputBuffers( - SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager) + SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager) { auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize}); auto const nvSizeType = TRTDataType::value; @@ -48,6 +48,8 @@ DecoderInputBuffers::DecoderInputBuffers( { forwardBatchSlots.emplace_back(BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize}), nvSizeType)); } + + logits.resize(maxNumSequences); } DecoderOutputBuffers::DecoderOutputBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxSeqLen, @@ -91,11 +93,6 @@ DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWid SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig, WorldConfig const& worldConfig) { - if (worldConfig.isLastPipelineParallelRank()) - { - logits.resize(maxNumSequences); - } - cacheIndirectionInput = manager.gpu( ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32); cacheIndirectionOutput = manager.gpu( @@ -109,8 +106,8 @@ DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWid } } -void DecoderBuffers::DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, - BufferManager const& manager, ModelConfig const& modelConfig) +void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const& manager, + ModelConfig const& modelConfig) { auto const speculativeDecodingMode = modelConfig.getSpeculativeDecodingMode(); diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index dd145cd1b62..fc01214bb0d 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -67,9 +67,9 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons } // namespace -SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, - std::vector const& numContextLogitsVec, TensorPtr const& logits, DecoderBuffers& decoderBuffers, - tr::ModelConfig const& modelConfig, BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, +SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, + tr::ITensor::SharedPtr const& logits, std::vector const& numContextLogitsVec, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef draftBuffers, OptionalRef medusaBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -114,13 +114,14 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, // Get the logits from the last context token and draft tokens auto const numDecoderLogits = 1 + draftLength; auto const seqSlot = llmReq->mSeqSlot.value(); - auto& decoderLogits = decoderBuffers.logits.at(seqSlot); + auto& decoderLogits = inputBuffers.logits.at(seqSlot); TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + TLLM_CHECK(draftBuffers); + auto& medusaLogitsHeads = draftBuffers->predictedDraftLogits.at(seqSlot); TLLM_CHECK(medusaBuffers); - auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot); setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice, modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex - numDecoderLogits, numDecoderLogits); @@ -143,7 +144,7 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, auto const logitsShape = logitsView->getShape(); auto const logitsType = logitsView->getDataType(); decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType); - tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, stream); + tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, manager.getStream()); decoderLogits->unsqueeze(0); } else diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 5b0c28bc79f..9c2543b08e2 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -73,9 +73,10 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons } // namespace -void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager, - TensorPtr const& logits, OptionalRef genRuntimeBuffers) const +void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests, + tr::ITensor::SharedPtr const& logits, tr::SizeType32 logitsIndex, tr::ModelConfig const& modelConfig, + tr::BufferManager const& manager, OptionalRef genRuntimeBuffers, + OptionalRef draftBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleGenerationLogits); @@ -99,7 +100,7 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co TensorPtr logitsView = ITensor::slice(logits, logitsIndex, numLogits); TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid(*logitsView, manager, "logits") == false, "Found invalid number (NaN or Inf) in logits"); - auto& decoderLogits = decoderBuffers.logits.at(seqSlot); + auto& decoderLogits = inputBuffers.logits.at(seqSlot); auto const logitsViewShape = logitsView->getShape(); if (reqBeamWidth > 1) { @@ -136,8 +137,10 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co } if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + TLLM_CHECK(draftBuffers); + auto& medusaLogitsHeads = draftBuffers->predictedDraftLogits.at(seqSlot); TLLM_CHECK(genRuntimeBuffers); - auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot); + TLLM_CHECK(genRuntimeBuffers->mMedusaBuffers); setupMedusaLogits(medusaLogitsHeads, genRuntimeBuffers->mMedusaBuffers->medusaLogitsDevice, modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex, draftLength); } diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index 888de98e7bf..640f7ebf220 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -25,6 +25,8 @@ #include "tensorrt_llm/runtime/tllmRuntime.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" +namespace tr = tensorrt_llm::runtime; + namespace tensorrt_llm::batch_manager { @@ -34,7 +36,7 @@ using ITensor = runtime::ITensor; using SizeType32 = tensorrt_llm::runtime::SizeType32; bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests, - bool replicateLogitsPostProcessor, DecoderBuffers& decoderBuffers, tr::WorldConfig const& worldConfig, + bool replicateLogitsPostProcessor, std::vector& seqSlotLogits, tr::WorldConfig const& worldConfig, tr::TllmRuntime& runtime, std::optional logitsPostProcessorBatched) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -59,7 +61,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque logitsPostProcessorIsApplied = true; if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank()) { - auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value()); + auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value()); (*llmReq->mLogitsPostProcessor)( llmReq->mRequestId, logits, llmReq->getTokens(), runtime.getStreamPtr(), llmReq->mClientId); } @@ -68,7 +70,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque { reqIdsVec.push_back(llmReq->mRequestId); - auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value()); + auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value()); logitsVec.push_back(logits); beamTokensVec.emplace_back(llmReq->getTokens()); diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index db08c579c9e..ba22802e106 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -132,7 +132,7 @@ MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, R auto [activeSlots, generationSteps] = getActiveSlots(contextRequests, generationRequests); - auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, decoderBuffers.logits, maxNumSequences, + auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots, decoderBuffers.cacheIndirectionInput); decodingInput->generationSteps = generationSteps; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index b84ef6c48c2..b8855af568d 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -1462,7 +1462,7 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& for (SizeType32 i = 0; i < mNumMicroBatches; ++i) { mDecoderInputBuffers.emplace_back( - getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); + getMaxNumSequences(), getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); mDecoderOutputBuffers.emplace_back(getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); } @@ -1995,17 +1995,20 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(decoderStepAsync); + auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId()); + auto& seqSlotLogits = decoderInputBuffers.logits; + auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); auto& contextRuntimeBuffers = mBuffers.at(contextBufferId); - auto const logitsIndex = (*mHandleContextLogits)(scheduledRequests.contextRequests, - contextRuntimeBuffers->numContextLogits, contextRuntimeBuffers->logits, *mDecoderBuffers, mModelConfig, - mRuntime->getBufferManager(), mRuntime->getStream(), contextRuntimeBuffers->mMedusaBuffers); + auto const logitsIndex = (*mHandleContextLogits)(decoderInputBuffers, scheduledRequests.contextRequests, + contextRuntimeBuffers->logits, contextRuntimeBuffers->numContextLogits, mModelConfig, + mRuntime->getBufferManager(), mDecoderBuffers->draftBuffers, contextRuntimeBuffers->mMedusaBuffers); auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0; auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); auto& genRuntimeBuffers = mBuffers.at(genBufferId); - (*mHandleGenerationLogits)(genLogitsIndex, scheduledRequests.generationRequests, *mDecoderBuffers, mModelConfig, - mRuntime->getBufferManager(), genRuntimeBuffers->logits, *genRuntimeBuffers); + (*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits, + genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, mDecoderBuffers->draftBuffers); // Copy indirection output into input // TODO: Could we avoid this by modifying batchDecoder to take a vector of tensors instead? @@ -2013,11 +2016,11 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques mLogitsPostProcessorIsApplied = (*mLogitsPostProcessor)(scheduledRequests.contextRequests, scheduledRequests.generationRequests, - mReplicateLogitsPostProcessor, *mDecoderBuffers, mWorldConfig, *mRuntime, mLogitsPostProcessorBatched); + mReplicateLogitsPostProcessor, seqSlotLogits, mWorldConfig, *mRuntime, mLogitsPostProcessorBatched); if (mGuidedDecoder) { - mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), mDecoderBuffers->logits); + mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), seqSlotLogits); } auto const fusedBufferId = getFusedBufferId(); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index 9d4bc670e2a..81df3ef584d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -100,34 +100,36 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod .def(py::init()) .def( "__call__", - [](HandleContextLogits const& self, RequestVector const& contextRequests, - std::vector const& numContextLogitsVec, at::Tensor const& logits, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, - OptionalRef medusaBuffers = std::nullopt) + [](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, + at::Tensor const& logits, std::vector const& numContextLogitsVec, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef medusaBuffers = std::nullopt, + OptionalRef draftBuffers = std::nullopt) { - return self(contextRequests, numContextLogitsVec, tr::TorchView::of(logits), decoderBuffers, - modelConfig, manager, stream, medusaBuffers); + return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig, + manager, draftBuffers, medusaBuffers); }, - py::arg("context_requests"), py::arg("num_context_logits"), py::arg("logits"), py::arg("decoder_buffers"), - py::arg("model_config"), py::arg("buffer_manager"), py::arg("stream"), - py::arg("medusa_buffers") = std::nullopt) + py::arg("decoder_input_buffers"), py::arg("context_requests"), py::arg("logits"), + py::arg("num_context_logits"), py::arg("model_config"), py::arg("buffer_manager"), + py::arg("draft_buffers") = std::nullopt, py::arg("medusa_buffers") = std::nullopt) .def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; }); py::class_(m, HandleGenerationLogits::name) .def(py::init()) .def( "__call__", - [](HandleGenerationLogits const& self, tr::SizeType32 logitsIndex, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - at::Tensor const& logits, OptionalRef genRuntimeBuffers = std::nullopt) + [](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers, + RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef genRuntimeBuffers = std::nullopt, + OptionalRef draftBuffers = std::nullopt) { - self(logitsIndex, generationRequests, decoderBuffers, modelConfig, manager, tr::TorchView::of(logits), - genRuntimeBuffers); + self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager, + genRuntimeBuffers, draftBuffers); }, - py::arg("logits_index"), py::arg("generation_requests"), py::arg("decoder_buffers"), - py::arg("model_config"), py::arg("buffer_manager"), py::arg("logits"), - py::arg("gen_runtime_buffers") = std::nullopt) + py::arg("decoder_input_buffers"), py::arg("generation_requests"), py::arg("logits"), + py::arg("logits_index"), py::arg("model_config"), py::arg("buffer_manager"), + py::arg("gen_runtime_buffers") = std::nullopt, py::arg("draft_buffers") = std::nullopt) .def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; }); py::class_(m, MakeDecodingBatchInputOutput::name) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 35f32a3b128..b9eefc19a5b 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -377,14 +377,16 @@ void initBindings(pybind11::module_& m) py::arg("max_num_sequences"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager")); py::class_(m, "DecoderInputBuffers") - .def(py::init(), py::arg("max_batch_size"), - py::arg("max_tokens_per_engine_step"), py::arg("manager")) + .def(py::init(), + py::arg("max_num_sequences"), py::arg("max_batch_size"), py::arg("max_tokens_per_engine_step"), + py::arg("manager")) .def_readwrite("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) .def_readwrite("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) .def_readwrite("fill_values", &tb::DecoderInputBuffers::fillValues) .def_readwrite("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) .def_readwrite("inputs_ids", &tb::DecoderInputBuffers::inputsIds) - .def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots); + .def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) + .def_readwrite("logits", &tb::DecoderInputBuffers::logits); py::class_(m, "DecoderOutputBuffers") .def_readwrite("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) @@ -395,29 +397,25 @@ void initBindings(pybind11::module_& m) .def_readwrite("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) .def_readwrite("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); - py::class_(m, "DraftBuffers") + py::class_(m, "DraftBuffers") .def(py::init()) - .def_readwrite("next_draft_tokens_device", &tb::DecoderBuffers::DraftBuffers::nextDraftTokensDevice) - .def_readwrite("next_draft_tokens_host", &tb::DecoderBuffers::DraftBuffers::nextDraftTokensHost) - .def_readwrite( - "prev_draft_tokens_lengths_device", &tb::DecoderBuffers::DraftBuffers::prevDraftTokensLengthsDevice) - .def_readwrite("prev_draft_tokens_lengths_host", &tb::DecoderBuffers::DraftBuffers::prevDraftTokensLengthsHost) - .def_readwrite( - "next_draft_tokens_lengths_device", &tb::DecoderBuffers::DraftBuffers::nextDraftTokensLengthsDevice) - .def_readwrite("next_draft_tokens_lengths_host", &tb::DecoderBuffers::DraftBuffers::nextDraftTokensLengthsHost) - .def_readwrite( - "accepted_lengths_cum_sum_device", &tb::DecoderBuffers::DraftBuffers::acceptedLengthsCumSumDevice) - .def_readwrite("accepted_packed_paths_device", &tb::DecoderBuffers::DraftBuffers::acceptedPackedPathsDevice) - .def_readwrite("predicted_draft_logits", &tb::DecoderBuffers::DraftBuffers::predictedDraftLogits) - .def("create", &tb::DecoderBuffers::DraftBuffers::create, py::arg("max_num_sequences"), - py::arg("max_tokens_per_step"), py::arg("runtime"), py::arg("model_config")); + .def_readwrite("next_draft_tokens_device", &tb::DraftBuffers::nextDraftTokensDevice) + .def_readwrite("next_draft_tokens_host", &tb::DraftBuffers::nextDraftTokensHost) + .def_readwrite("prev_draft_tokens_lengths_device", &tb::DraftBuffers::prevDraftTokensLengthsDevice) + .def_readwrite("prev_draft_tokens_lengths_host", &tb::DraftBuffers::prevDraftTokensLengthsHost) + .def_readwrite("next_draft_tokens_lengths_device", &tb::DraftBuffers::nextDraftTokensLengthsDevice) + .def_readwrite("next_draft_tokens_lengths_host", &tb::DraftBuffers::nextDraftTokensLengthsHost) + .def_readwrite("accepted_lengths_cum_sum_device", &tb::DraftBuffers::acceptedLengthsCumSumDevice) + .def_readwrite("accepted_packed_paths_device", &tb::DraftBuffers::acceptedPackedPathsDevice) + .def_readwrite("predicted_draft_logits", &tb::DraftBuffers::predictedDraftLogits) + .def("create", &tb::DraftBuffers::create, py::arg("max_num_sequences"), py::arg("max_tokens_per_step"), + py::arg("runtime"), py::arg("model_config")); py::classh(m, "DecoderBuffers") .def(py::init(), py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("max_tokens_per_step"), py::arg("buffer_manager"), py::arg("model_config"), py::arg("world_config")) - .def_readwrite("logits", &tb::DecoderBuffers::logits) .def_readwrite("cache_indirection_input", &tb::DecoderBuffers::cacheIndirectionInput) .def_readwrite("cache_indirection_output", &tb::DecoderBuffers::cacheIndirectionOutput) .def_readwrite("draft_buffers", &tb::DecoderBuffers::draftBuffers); diff --git a/cpp/tests/runtime/gptDecoderBatchedTest.cpp b/cpp/tests/runtime/gptDecoderBatchedTest.cpp index 6fa48e62f45..ebc55d8630c 100644 --- a/cpp/tests/runtime/gptDecoderBatchedTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchedTest.cpp @@ -348,7 +348,7 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa decoderState.getSpeculativeDecodingMode(), maxGeneratedTokensPerStep, modelConfig, worldConfig, manager); // set up inputs and outputs - tb::DecoderInputBuffers inputBuffers(batchSize, maxGeneratedTokensPerStep, manager); + tb::DecoderInputBuffers inputBuffers(batchSize, batchSize, maxGeneratedTokensPerStep, manager); auto batchSlotsRange = BufferRange(*inputBuffers.setupBatchSlots); std::iota(batchSlotsRange.begin(), batchSlotsRange.end(), 0); @@ -488,7 +488,7 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector int: + def __call__( + self, + context_requests: List[LlmRequest], + logits: torch.Tensor, + num_context_logits_vec: List[int], + max_num_sequences: int, + ) -> Tuple[List[torch.Tensor], int]: """Handle context logits for a batch of requests. Args: context_requests: List of context requests to process - num_context_logits_vec: Number of context logits for each request logits: Input logits tensor - decoder_buffers: Decoder buffers for storing intermediate results + num_context_logits_vec: Number of context logits for each request + max_num_sequences: Maximum number of sequences to process Returns: + List[torch.Tensor]: List of logits tensors for each request int: Index into logits tensor after processing all requests """ logits_index = 0 # Copy logits into decoderBuffers.logits - decoder_buffer_logits = [torch.empty(0)] * len(decoder_buffers.logits) + decoder_buffer_logits = [torch.empty(0)] * max_num_sequences for batch_index, llm_req in enumerate(context_requests): num_context_logits = num_context_logits_vec[batch_index] draft_length = llm_req.num_draft_tokens if llm_req.is_last_context_chunk else 0 @@ -71,7 +75,4 @@ def __call__(self, context_requests: List[LlmRequest], # else: # decoder_buffer_logits[seq_slot] = logits_view[:logits_view.shape[0], :1, :logits_view.shape[1]] - # Needs to be done in bulk for the copy to work - decoder_buffers.logits = decoder_buffer_logits - - return logits_index + return decoder_buffer_logits, logits_index diff --git a/tensorrt_llm/_torch/pyexecutor/handle_generation_logits.py b/tensorrt_llm/_torch/pyexecutor/handle_generation_logits.py index ad4a15fa100..5cb3da8deae 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_generation_logits.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_generation_logits.py @@ -1,19 +1,19 @@ +from typing import List + import torch from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest -from tensorrt_llm.bindings.internal.batch_manager import DecoderBuffers class HandleGenerationLogits: def __call__( self, - logits_index: int, + decoder_buffer_logits: List[torch.Tensor], generation_requests: list[LlmRequest], - decoder_buffers: DecoderBuffers, logits: torch.Tensor, + logits_index: int, ): - decoder_buffer_logits = decoder_buffers.logits for llm_req in generation_requests: beam_width = llm_req.get_beam_width_by_iter() seq_slot = llm_req.seq_slot @@ -31,5 +31,4 @@ def __call__( logits_index += beam_width - # Needs to be done in bulk for the copy to work - decoder_buffers.logits = decoder_buffer_logits + return decoder_buffer_logits diff --git a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py index 805b5d47cc7..8a241693675 100644 --- a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py +++ b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py @@ -126,7 +126,7 @@ def __call__( # Create decoder batch inputs decoding_input = self.create_decoder_batch_inputs( - active_slots, decoder_state, decoder_buffers.logits, + active_slots, decoder_state, decoder_input_buffers.logits, max_num_sequences, decoder_input_buffers.forward_batch_slots, decoder_buffers.cache_indirection_input) decoding_input.generation_steps = generation_steps diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index f5b2cb34cab..4106c5976b4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -530,7 +530,8 @@ def _initialize_store(self): buffer_manager, self.model_config, self.world_config), "decoder_input_buffers": - DecoderInputBuffers(self.executor_config.max_batch_size, + DecoderInputBuffers(self.max_num_sequences, + self.executor_config.max_batch_size, self.MAX_DECODING_TOKENS, buffer_manager), "new_tokens_device_tensor": torch.empty(( @@ -544,9 +545,22 @@ def _initialize_store(self): self.executor_config.max_batch_size, self.executor_config.max_beam_width, ), - dtype=torch.int) + dtype=torch.int), + "decoder_state": + DecoderState(dtype=self.logits_datatype, + buffer_manager=buffer_manager) } + self.store["decoder_state"].setup( + max_batch_size=self.executor_config.max_batch_size, + max_beam_width=self.executor_config.max_beam_width, + max_attention_window=self.max_attention_window, + sink_token_length=0, + max_sequence_length=self.executor_config.max_seq_len, + model_config=self.model_config, + world_config=self.world_config, + buffer_manager=buffer_manager) + def _instantiate_algorithms(self): self.algs = Algorithms() self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) @@ -557,18 +571,6 @@ def _instantiate_algorithms(self): dtype=self.logits_datatype, model_config=self.model_config, world_config=self.world_config) - self.algs.decoder_state = DecoderState( - dtype=self.logits_datatype, - buffer_manager=self.store["buffer_manager"]) - self.algs.decoder_state.setup( - max_batch_size=self.executor_config.max_batch_size, - max_beam_width=self.executor_config.max_beam_width, - max_attention_window=self.max_attention_window, - sink_token_length=0, - max_sequence_length=self.executor_config.max_seq_len, - model_config=self.model_config, - world_config=self.world_config, - buffer_manager=self.store["buffer_manager"]) self.algs.create_new_decoder_requests = CreateNewDecoderRequests( speculative_decoding_fast_logits=False, is_leader_in_orch_mode=False, @@ -582,7 +584,7 @@ def setup_sampler_step(self, requests): batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = self.algs.create_new_decoder_requests( self.model_config, self.world_config, self.decoding_config, requests, self.store["buffer_manager"], self.logits_datatype, - self.store["decoder_input_buffers"], self.algs.decoder_state, + self.store["decoder_input_buffers"], self.store["decoder_state"], self.store["cuda_stream"], self.algs.decoder.decoder_stream, self.executor_config.max_seq_len, self.beam_width(requests)) @@ -591,7 +593,7 @@ def setup_sampler_step(self, requests): sampling_config = make_sampling_config(sampling_configs) self.algs.decoder.underlying_decoder().setup( sampling_config, local_batch_size, batch_slots, - self.algs.decoder_state.joint_decoding_output, + self.store["decoder_state"].joint_decoding_output, self.model_config.data_type, lookahead_prompt, lookahead_algo_configs) @@ -614,21 +616,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests, num_context_logits[ batch_index] = request.context_chunk_size if request.py_return_context_logits else 1 - logits_index = self.algs.handle_context_logits( - scheduled_requests.context_requests, num_context_logits, - model_outputs["logits"], self.store["decoder_buffers"]) + decoder_buffer_logits, logits_index = self.algs.handle_context_logits( + scheduled_requests.context_requests, model_outputs["logits"], + num_context_logits, self.max_num_sequences) - self.algs.handle_generation_logits( - logits_index, scheduled_requests.generation_requests, - self.store["decoder_buffers"], model_outputs["logits"]) + decoder_buffer_logits = self.algs.handle_generation_logits( + decoder_buffer_logits, scheduled_requests.generation_requests, + model_outputs["logits"], logits_index) + + self.store["decoder_input_buffers"].logits = decoder_buffer_logits decoding_input, self.decoding_output = self.algs.make_decoding_batch_input_output( scheduled_requests.context_requests, scheduled_requests.generation_requests, self.store["decoder_buffers"], self.store["decoder_input_buffers"], - self.algs.decoder_state, self.model_config, self.max_num_sequences) + self.store["decoder_state"], self.model_config, + self.max_num_sequences) - self.algs.decoder.forward_async(self.algs.decoder_state, + self.algs.decoder.forward_async(self.store["decoder_state"], self.decoding_output, decoding_input) # NOTE: The following code prepares a new_tokens_device_tensor in accordance with the @@ -641,26 +646,26 @@ def sample_async(self, scheduled_requests: ScheduledRequests, request.seq_slot for request in scheduled_requests.all_requests ] new_tokens_device_tensor.copy_( - self.algs.decoder_state.all_new_tokens[0][seq_slots], + self.store["decoder_state"].all_new_tokens[0][seq_slots], non_blocking=True) new_tokens_device_tensor = new_tokens_device_tensor.view(-1) - new_output_tokens = self.algs.decoder_state.all_new_tokens.to( + new_output_tokens = self.store["decoder_state"].all_new_tokens.to( 'cpu', non_blocking=True) - finished_sum = self.algs.decoder_state.finished_sum.to( + finished_sum = self.store["decoder_state"].finished_sum.to( 'cpu', non_blocking=True) - finish_reasons = self.algs.decoder_state.finish_reasons.to( + finish_reasons = self.store["decoder_state"].finish_reasons.to( 'cpu', non_blocking=True) - sequence_lengths = self.algs.decoder_state.sequence_lengths.to( + sequence_lengths = self.store["decoder_state"].sequence_lengths.to( 'cpu', non_blocking=True) log_probs = torch.empty([0], dtype=torch.float, device='cpu') cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu') if any(request.py_return_log_probs for request in scheduled_requests.all_requests): - log_probs = self.algs.decoder_state.log_probs.to('cpu', - non_blocking=True) - cum_log_probs = self.algs.decoder_state.cum_log_probs.to( + log_probs = self.store["decoder_state"].log_probs.to( + 'cpu', non_blocking=True) + cum_log_probs = self.store["decoder_state"].cum_log_probs.to( 'cpu', non_blocking=True) device = SampleStateTensors(new_tokens=new_tokens_device_tensor)