Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,31 @@ class DecoderBuffers
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;

TensorPtr cacheIndirectionInput;
TensorPtr cacheIndirectionOutput;

DraftBuffers draftBuffers;

DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig);
DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
};

class DecoderStepAsyncSend
{
public:
using SizeType32 = runtime::SizeType32;
using BufferPtr = runtime::IBuffer::SharedPtr;
using TensorPtr = runtime::ITensor::SharedPtr;

DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer);
DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
mpi::MpiComm const& commSession, int peer);

~DecoderStepAsyncSend();

static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer);
static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
mpi::MpiComm const& commSession, int peer);

static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int root);
static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
mpi::MpiComm const& commSession, int root);

private:
std::unique_ptr<mpi::MpiRequest> mRequest1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ class MakeDecodingBatchInputOutput : Algorithm

MakeDecodingBatchInputOutput() = default;

std::tuple<std::unique_ptr<runtime::decoder_batch::Input>, std::unique_ptr<runtime::decoder_batch::Output>>
operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers,
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
SizeType32 maxNumSequences, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
std::unique_ptr<runtime::decoder_batch::Input> operator()(RequestVector const& contextRequests,
RequestVector const& generationRequests, DecoderBuffers& decoderBuffers,
DecoderInputBuffers const& inputBuffers, runtime::decoder::DecoderState& decoderState,
runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;

[[nodiscard]] static std::unique_ptr<runtime::decoder_batch::Input> createDecoderBatchInputs(
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots,
TensorPtr const& cacheIndirectionInput);
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots);
};

} // namespace tensorrt_llm::batch_manager
11 changes: 5 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,9 @@ class RuntimeBuffers

std::tuple<SizeType32, TensorMap const&, TensorMap&> prepareStep(RequestVector const& contextRequests,
RequestVector const& genRequests, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
DecoderBuffers& decoderBuffers, runtime::decoder::DecoderState const& decoderState,
kv_cache_manager::BaseKVCacheManager* kvCacheManager, kv_cache_manager::BaseKVCacheManager* crossKvCacheManager,
rnn_state_manager::RnnStateManager* rnnStateManager, PeftTable const& peftTable,
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManager,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager,
PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits, bool trtOverlap,
OptionalRef<runtime::ITensor const> newOutputTokens = std::nullopt);

Expand Down Expand Up @@ -314,8 +313,8 @@ class RuntimeBuffers
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits);

void setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers,
runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState,
kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr,
rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable,
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
Expand Down
4 changes: 0 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ class TransformerBuffers
void copyPositionIds(runtime::TllmRuntime const& runtime, std::vector<SizeType32> const& positionIdsHost,
bool isChatGlm, TensorPtr const& decoderPositionIds);

void resetCacheIndirection(RequestVector const& contextRequests, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, TensorPtr const& decoderCacheIndirectionInput,
TensorPtr const& decoderCacheIndirectionOutput, runtime::BufferManager const& manager);

void copyKvBlockOffsets(RequestVector const& contextRequests, RequestVector const& genRequests,
kv_cache_manager::BaseKVCacheManager const* kvCacheManager,
kv_cache_manager::BaseKVCacheManager const* crossKvCacheManager, runtime::BufferManager const& manager);
Expand Down
11 changes: 11 additions & 0 deletions cpp/include/tensorrt_llm/runtime/decoderState.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ class DecoderState
void allocateSpeculativeDecodingBuffers(
SpeculativeDecodingMode speculativeDecodingMode, nvinfer1::DataType dtype, BufferManager const& bufferManager);

//! @brief Setup buffers for the decoder excluding speculative decoding.
void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, ModelConfig const& modelConfig,
WorldConfig const& worldConfig, BufferManager const& bufferManager);

//! @brief Setup buffers for the cache indirection.
//! @details This is used for beam search on pipeline parallel ranks without a decoder.
void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow);

//! @brief Setup buffers for speculative decoding.
void setupSpeculativeDecoding(SpeculativeDecodingMode const& speculativeDecodingMode,
SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
Expand Down Expand Up @@ -174,6 +179,12 @@ class DecoderState
//! @brief Workspace for beam search in streaming mode.
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;

//! @brief Cache indirection input for beam search.
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;

//! @brief Cache indirection output for beam search.
[[nodiscard]] TensorPtr getCacheIndirectionOutput() const;

//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
[[nodiscard]] DecodingInput& getJointDecodingInput() const;

Expand Down
9 changes: 3 additions & 6 deletions cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ class GptDecoderBatched : public IGptDecoderBatched

void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;

CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Output& output,
decoder_batch::Input const& input) override;
void forward(decoder::DecoderState const& decoderState, decoder_batch::Output& output,
decoder_batch::Input const& input) override;
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;

//! @brief Gather final beam search results for request `batchSlot`.
//! Result will only be available after event returned.
Expand All @@ -79,8 +77,7 @@ class GptDecoderBatched : public IGptDecoderBatched

private:
//! @brief Calls decoders for tokens per engine step
void forwardDispatch(
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input);
void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);

private:
CudaStreamPtr mRuntimeStream;
Expand Down
21 changes: 2 additions & 19 deletions cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class Input
TensorPtr batchSlotsRequestOrder;

//! For Beam Search
//! Indices into KV cache of different rays within one beam, [maxBatchSize, maxBeamWidth, maxSeqLen], on gpu
TensorPtr cacheIndirection;
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
std::vector<SizeType32> generationSteps;

Expand All @@ -95,17 +93,6 @@ class Input
std::optional<EagleBuffers::Inputs> eagleLastInputs;
};

class Output
{
public:
using TensorPtr = std::shared_ptr<ITensor>;

Output() = default;

//! parameters for beam search, [batchSize, maxBeamWidth, maxSeqLen], on gpu
TensorPtr cacheIndirection;
};

} // namespace decoder_batch

//! GPT decoder class with support for in-flight batching
Expand All @@ -126,14 +113,10 @@ class IGptDecoderBatched
virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0;

//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual CudaEvent forwardAsync(
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input)
= 0;
virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;

//! @brief Run one step for all requests and wait for completion on the host.
virtual void forward(
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input)
= 0;
virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;

//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned
Expand Down
7 changes: 6 additions & 1 deletion cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
Expand Down Expand Up @@ -307,6 +306,12 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder
parentIds->reshape(outputIdsShape);
manager.setZero(*parentIds);

auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1);
manager.setZero(*cacheIndirectionInput);

auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1);
manager.setZero(*cacheIndirectionOutput);

auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1);
beamHypotheses.init(manager, endId);
}
Expand Down
56 changes: 24 additions & 32 deletions cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,9 @@ void DecoderOutputBuffers::disableLookaheadDecoding(SizeType32 maxNumSequences)
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig,
WorldConfig const& worldConfig)
DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const& manager,
ModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
cacheIndirectionInput = manager.gpu(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
cacheIndirectionOutput = manager.gpu(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);

if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()
|| modelConfig.getSpeculativeDecodingMode().hasDraftLogits()
|| modelConfig.getSpeculativeDecodingMode().predictsDraftTokens())
Expand Down Expand Up @@ -147,8 +141,8 @@ void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerSte
}

DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers,
DecoderBuffers const& decoderBuffers, bool const returnLogProbs, SizeType32 const maxBeamWidth,
bool const useMedusa, mpi::MpiComm const& commSession, int peer)
DraftBuffers const& draftBuffers, TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs,
SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession, int peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start send outputs of DecoderBuffers to rank %d", peer);
Expand All @@ -165,24 +159,24 @@ DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOu
mRequest5 = returnLogProbs
? commSession.sendAsync(*decoderOutputBuffers.logProbsHost, peer, mpi::MpiTag::kDecoderStepLogProbsHost)
: nullptr;
mRequest6 = maxBeamWidth > 1 ? commSession.sendAsync(
*decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput)
: nullptr;
mRequest7 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer,
mRequest6 = maxBeamWidth > 1
? commSession.sendAsync(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput)
: nullptr;
mRequest7 = useMedusa ? commSession.sendAsync(*draftBuffers.acceptedLengthsCumSumDevice, peer,
mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice)
: nullptr;
mRequest8 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer,
mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice)
mRequest8 = useMedusa ? commSession.sendAsync(
*draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice)
: nullptr;
mRequest9 = commSession.sendAsync(
*decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost);

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession,
int const peer)
void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth,
bool const useMedusa, mpi::MpiComm const& commSession, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start recv outputs of DecoderBuffers from rank %d", peer);
Expand All @@ -197,14 +191,14 @@ void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers
}
if (maxBeamWidth > 1)
{
commSession.recv(*decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput);
commSession.recv(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput);
}
if (useMedusa)
{
commSession.recv(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer,
mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice);
commSession.recv(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer,
mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice);
commSession.recv(
*draftBuffers.acceptedLengthsCumSumDevice, peer, mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice);
commSession.recv(
*draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice);
}
commSession.recv(*decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost);

Expand Down Expand Up @@ -235,9 +229,9 @@ DecoderStepAsyncSend::~DecoderStepAsyncSend()
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession,
int const root)
void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth,
bool const useMedusa, mpi::MpiComm const& commSession, int const root)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start bcast outputs of DecoderBuffers from rank %d", root);
Expand All @@ -247,11 +241,9 @@ void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffer
auto request3 = commSession.bcastAsync(*decoderOutputBuffers.sequenceLengthsHost, root);
auto request4 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.cumLogProbsHost, root) : nullptr;
auto request5 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.logProbsHost, root) : nullptr;
auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*decoderBuffers.cacheIndirectionOutput, root) : nullptr;
auto request7
= useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr;
auto request8
= useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, root) : nullptr;
auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*cacheIndirectionOutput, root) : nullptr;
auto request7 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr;
auto request8 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedPackedPathsDevice, root) : nullptr;
auto request9 = commSession.bcastAsync(*decoderOutputBuffers.finishReasonsHost, root);

request1->wait();
Expand Down
Loading