Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class DecoderInputBuffers
public:
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;
using TensorConstPtr = runtime::ITensor::SharedConstPtr;

explicit DecoderInputBuffers(
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager);
Expand All @@ -60,13 +61,22 @@ class DecoderInputBuffers
//! Requests for considered in decoder forward
RequestVector decoderRequests;

//! Logits of decoder requests
std::vector<TensorPtr> decoderLogits;

//! Maximum number of decoding steps of decoder requests.
//! This is only more than 1 for external draft tokens speculative decoding.
SizeType32 maxDecoderSteps{1};

//! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize]
std::vector<TensorPtr> forwardBatchSlots;

//! Logits of decoder requests
std::vector<TensorPtr> logits;
//! Logits for requests in forwardBatchSlots (in the same order).
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
std::vector<std::vector<TensorConstPtr>> batchLogits;

//! Logits for speculative decoding (Medusa)
//! Logits for speculative decoding (Medusa).
//! The vector is sparse, only slots in forwardBatchSlots are used.
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
std::vector<std::vector<runtime::ITensor::SharedPtr>> predictedDraftLogits;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ class MakeDecodingBatchInputOutput : Algorithm
constexpr static auto name{"MakeDecodingBatchInputOutput"};

using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TensorPtr = runtime::decoder_batch::Input::TensorPtr;
using TensorPtr = runtime::ITensor::SharedPtr;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;

MakeDecodingBatchInputOutput() = default;

std::unique_ptr<runtime::decoder_batch::Input> operator()(DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
SizeType32 maxNumSequences, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
void operator()(DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
runtime::ModelConfig const& modelConfig, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;

[[nodiscard]] static std::unique_ptr<runtime::decoder_batch::Input> createDecoderBatchInputs(
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots);
static void createDecoderBatchInputs(DecoderInputBuffers& inputBuffers, std::vector<SizeType32> const& activeSlots,
runtime::decoder::DecoderState const& decoderState);
};

} // namespace tensorrt_llm::batch_manager
7 changes: 4 additions & 3 deletions cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class GptDecoderBatched : public IGptDecoderBatched

void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;

CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
CudaEvent forwardAsync(
decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override;
void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input) override;

//! @brief Gather final beam search results for request `batchSlot`.
//! Result will only be available after event returned.
Expand All @@ -77,7 +78,7 @@ class GptDecoderBatched : public IGptDecoderBatched

private:
//! @brief Calls decoders for tokens per engine step
void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);
void forwardDispatch(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input);

private:
CudaStreamPtr mRuntimeStream;
Expand Down
47 changes: 7 additions & 40 deletions cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@

namespace tensorrt_llm::batch_manager
{
class DecoderInputBuffers;
class LlmRequest;
}
} // namespace tensorrt_llm::batch_manager

namespace tensorrt_llm::runtime
{
Expand All @@ -39,43 +40,6 @@ namespace decoder
class DecoderState;
}

namespace decoder_batch
{

class Input
{
public:
using TensorConstPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;

explicit Input(std::vector<std::vector<TensorConstPtr>> const& logits, SizeType32 maxDecoderSteps)
: logits{logits}
, maxDecoderSteps{maxDecoderSteps}
{
TLLM_CHECK_WITH_INFO(
logits.size() == static_cast<size_t>(maxDecoderSteps), "logits vector size does not match maxDecoderSteps");
}

explicit Input(std::vector<TensorConstPtr> const& logits)
: Input{{logits}, 1}
{
}

//! Mandatory parameters
//! Logits
// FIXME: remove first dimension of tensors
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
std::vector<std::vector<TensorConstPtr>> logits;

//! Maximum number of decoding tokens of active slots
SizeType32 maxDecoderSteps;

//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
std::vector<TensorPtr> batchSlots;
};

} // namespace decoder_batch

//! GPT decoder class with support for in-flight batching
class IGptDecoderBatched
{
Expand All @@ -94,10 +58,13 @@ class IGptDecoderBatched
virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0;

//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
virtual CudaEvent forwardAsync(
decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input)
= 0;

//! @brief Run one step for all requests and wait for completion on the host.
virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
virtual void forward(decoder::DecoderState const& decoderState, batch_manager::DecoderInputBuffers const& input)
= 0;

//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void GuidedDecoder::execute(DecoderInputBuffers const& decoderInputBuffers, Buff
{
auto const seqSlot = llmReq->mSeqSlot.value();

auto const& logits = decoderInputBuffers.logits.at(requestIdx);
auto const& logits = decoderInputBuffers.decoderLogits.at(requestIdx);
auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot});

// Use void* to unify the code for different mLogitsDtype
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re
auto& decoderRequests = inputBuffers.decoderRequests;
decoderRequests.clear();
decoderRequests.reserve(contextRequests.size());
auto& allDecoderLogits = inputBuffers.logits;
auto& allDecoderLogits = inputBuffers.decoderLogits;
allDecoderLogits.clear();
allDecoderLogits.reserve(contextRequests.size());

Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque

auto& decoderRequests = inputBuffers.decoderRequests;
decoderRequests.reserve(decoderRequests.size() + generationRequests.size());
auto& allDecoderLogits = inputBuffers.logits;
auto& allDecoderLogits = inputBuffers.decoderLogits;
allDecoderLogits.reserve(allDecoderLogits.size() + generationRequests.size());

for (auto const& llmReq : generationRequests)
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool rep
for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx)
{
auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx);
auto& logits = inputBuffers.logits.at(batchIdx);
auto& logits = inputBuffers.decoderLogits.at(batchIdx);

// Invoke non-batched processor or collect arguments for batched processor
if (llmReq->mLogitsPostProcessor)
Expand Down
26 changes: 13 additions & 13 deletions cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ namespace tensorrt_llm::batch_manager
using SizeType32 = MakeDecodingBatchInputOutput::SizeType32;
using TensorPtr = MakeDecodingBatchInputOutput::TensorPtr;

std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::createDecoderBatchInputs(
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
std::vector<TensorPtr> const& decoderLogits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots)
void MakeDecodingBatchInputOutput::createDecoderBatchInputs(DecoderInputBuffers& inputBuffers,
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

Expand All @@ -42,9 +41,12 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::createDe
auto const& maxDecodingDecoderTokens = decoderState.getMaxDecodingDecoderTokens();
auto const maxDecoderSteps = common::ceilDiv(maxDecodingEngineTokens, maxDecodingDecoderTokens);

auto& batchSlots = inputBuffers.forwardBatchSlots;
auto& decoderLogits = inputBuffers.decoderLogits;

for (SizeType32 step = 0; step < maxDecoderSteps; ++step)
{
batchSlots.at(step)->resize(maxNumSequences);
batchSlots.at(step)->resize(activeSlots.size());
}

auto constexpr singleRequest = 1;
Expand All @@ -64,7 +66,7 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::createDe
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlots.at(step));
batchSlotsRange[batchSizes[step]] = slot;
batchSizes[step]++;
TensorPtr logitsSlice = tr::ITensor::slice(logits, step, singleRequest);
auto logitsSlice = tr::ITensor::slice(logits, step, singleRequest);
batchLogits[step].emplace_back(std::move(logitsSlice));
}
}
Expand All @@ -75,10 +77,10 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::createDe
}
batchLogits.resize(maxActiveDecoderSteps);

auto decodingInput = std::make_unique<tr::decoder_batch::Input>(batchLogits, maxActiveDecoderSteps);
decodingInput->batchSlots = batchSlots;
inputBuffers.maxDecoderSteps = maxActiveDecoderSteps;
inputBuffers.batchLogits = batchLogits;

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return decodingInput;
}

namespace
Expand Down Expand Up @@ -155,16 +157,15 @@ void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntim

} // namespace

std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator()(DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
void MakeDecodingBatchInputOutput::operator()(DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto [activeSlots, generationSteps] = getActiveSlots(inputBuffers.decoderRequests);

auto decodingInput = createDecoderBatchInputs(
activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots);
createDecoderBatchInputs(inputBuffers, activeSlots, decoderState);

auto const maxBeamWidth = decoderState.getMaxBeamWidth();
if (maxBeamWidth > 1)
Expand Down Expand Up @@ -192,7 +193,6 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return decodingInput;
}

} // namespace tensorrt_llm::batch_manager
Original file line number Diff line number Diff line change
Expand Up @@ -1557,8 +1557,6 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const&
mOperatingBeamWidth, getMaxSequenceLen(), mRuntime->getBufferManager()));
}

mDecodingInputs.resize(mNumMicroBatches);

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

Expand Down Expand Up @@ -2071,11 +2069,9 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques
auto const fusedBufferId = getFusedBufferId();
auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId);

auto& decodingInput = mDecodingInputs.at(mMicroBatchId);
decodingInput = (*mMakeDecodingBatchInputOutput)(mDecoderInputBuffers.at(fusedBufferId), *mDecoderState,
mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers);
(*mMakeDecodingBatchInputOutput)(decoderInputBuffers, *mDecoderState, mModelConfig, *fusedRuntimeBuffers);

auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *decodingInput);
auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, decoderInputBuffers);

auto const returnLogProbs = batchReturnLogProbs(scheduledRequests);
auto updateDecoderBuffersEvent = (*mUpdateDecoderBuffers)(mModelConfig, mDecoderOutputBuffers.at(fusedBufferId),
Expand Down
2 changes: 0 additions & 2 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,6 @@ class TrtGptModelInflightBatching : public TrtGptModel
std::vector<std::unique_ptr<SlotDecoderBuffers>> mSlotDecoderBuffers;
// PEFT table for each micro batch
std::vector<PeftTable> mPeftTables;
// Decoder input for each micro batch.
std::vector<std::unique_ptr<runtime::decoder_batch::Input>> mDecodingInputs;

/******************** Book keeping ********************/
// List of requests in each micro batch
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(TRTLLM_NB_MODULE
set(SRCS
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/buffers.cpp
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheManager.cpp
batch_manager/llmRequest.cpp
Expand Down
52 changes: 8 additions & 44 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,39 +390,6 @@ void initBindings(nb::module_& m)
.def(nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"));

nb::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
.def(nb::init<runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(), nb::arg("max_batch_size"),
nb::arg("max_tokens_per_engine_step"), nb::arg("manager"))
.def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
.def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
.def_rw("fill_values", &tb::DecoderInputBuffers::fillValues)
.def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
.def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
.def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
.def_rw("logits", &tb::DecoderInputBuffers::logits)
.def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests);

nb::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
.def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)
.def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost)
.def_prop_ro("new_output_tokens_host",
[](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); })
.def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost)
.def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost)
.def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost);

nb::class_<tb::SlotDecoderBuffers>(m, "SlotDecoderBuffers")
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&>(),
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"))
.def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds)
.def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost)
.def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost)
.def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs)
.def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost)
.def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs)
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);

m.def(
"add_new_tokens_to_requests",
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
Expand All @@ -441,10 +408,10 @@ void initBindings(nb::module_& m)

m.def(
"make_decoding_batch_input",
[](std::vector<std::shared_ptr<tb::LlmRequest>>& contextRequests,
std::vector<std::shared_ptr<tb::LlmRequest>>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth,
std::vector<int> const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers,
runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager)
[](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState,
std::vector<std::shared_ptr<tb::LlmRequest>> const& contextRequests,
std::vector<std::shared_ptr<tb::LlmRequest>> const& genRequests, tr::ITensor::SharedPtr const& logits,
int beamWidth, std::vector<int> const& numContextLogitsPrefixSum, tr::BufferManager const& manager)
{
std::vector<int> activeSlots;
std::vector<int> generationSteps;
Expand Down Expand Up @@ -502,21 +469,18 @@ void initBindings(nb::module_& m)
batchSlotsRange[i] = activeSlots[i];
}

auto decodingInput = std::make_unique<tr::decoder_batch::Input>(logitsVec, 1);
decodingInput->batchSlots = batchSlots;
decoderInputBuffers.batchLogits = logitsVec;

auto const maxBeamWidth = decoderState.getMaxBeamWidth();
if (maxBeamWidth > 1)
{
// For Variable-Beam-Width-Search
decoderState.getJointDecodingInput().generationSteps = generationSteps;
}

return decodingInput;
},
nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"),
nb::arg("num_context_logits_prefix_sum"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"),
nb::arg("buffer_manager"), "Make decoding batch input.");
nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), nb::arg("context_requests"),
nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"),
nb::arg("num_context_logits_prefix_sum"), nb::arg("buffer_manager"), "Make decoding batch input.");
}

} // namespace tensorrt_llm::nanobind::batch_manager
Loading