Skip to content

Commit ae27261

Browse files
authored
refactor: decoding inputs (#5679)
Signed-off-by: Robin Kobus <[email protected]>
1 parent d95ae13 commit ae27261

File tree

8 files changed

+92
-118
lines changed

8 files changed

+92
-118
lines changed

cpp/include/tensorrt_llm/runtime/decoderState.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ class DecoderState
182182
//! @brief Cache indirection output for beam search.
183183
[[nodiscard]] TensorPtr getCacheIndirectionOutput() const;
184184

185+
//! @brief Get the generation steps for all requests in the batch.
186+
//! @returns The generation steps for all requests in the batch.
187+
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getGenerationSteps() const;
188+
189+
//! @brief Set the generation steps for all requests in the batch.
190+
//! @param generationSteps The generation steps for all requests in the batch.
191+
void setGenerationSteps(std::vector<SizeType32> const& generationSteps);
192+
185193
//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
186194
[[nodiscard]] DecodingInput& getJointDecodingInput() const;
187195

cpp/include/tensorrt_llm/runtime/decodingInput.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,6 @@ class DecodingInput
142142

143143
struct EagleInputs
144144
{
145-
EagleInputs(TensorConstPtr nextDraftTokens, TensorConstPtr nextDraftLens, TensorConstPtr nextDraftPaths,
146-
TensorConstPtr lastDraftTokens, TensorConstPtr lastDraftLens, TensorConstPtr lastDraftPaths,
147-
TensorConstPtr acceptedTokens, TensorConstPtr acceptedLens, TensorConstPtr acceptedPathIds,
148-
TensorConstPtr chunkedContextNextTokens, TensorConstPtr seqSlots)
149-
: nextDraftTokens(std::move(nextDraftTokens))
150-
, nextDraftLens(std::move(nextDraftLens))
151-
, nextDraftPaths(std::move(nextDraftPaths))
152-
, lastDraftTokens(std::move(lastDraftTokens))
153-
, lastDraftLens(std::move(lastDraftLens))
154-
, lastDraftPaths(std::move(lastDraftPaths))
155-
, acceptedTokens(std::move(acceptedTokens))
156-
, acceptedLens(std::move(acceptedLens))
157-
, acceptedPathIds(std::move(acceptedPathIds))
158-
, chunkedContextNextTokens(std::move(chunkedContextNextTokens))
159-
, seqSlots(std::move(seqSlots))
160-
{
161-
}
162-
163145
TensorConstPtr nextDraftTokens; // [batchSize, maxDecodingDraftTokens]
164146
TensorConstPtr nextDraftLens; // [batchSize]
165147
TensorConstPtr nextDraftPaths; // [batchSize, maxDecodingTokens, maxPathLen]

cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
#include "tensorrt_llm/runtime/cudaEvent.h"
2020
#include "tensorrt_llm/runtime/cudaStream.h"
21-
#include "tensorrt_llm/runtime/eagleBuffers.h"
22-
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
2321
#include "tensorrt_llm/runtime/iTensor.h"
22+
#include "tensorrt_llm/runtime/modelConfig.h"
23+
#include "tensorrt_llm/runtime/worldConfig.h"
2424

2525
#include <memory>
2626
#include <vector>
@@ -72,25 +72,6 @@ class Input
7272

7373
//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
7474
std::vector<TensorPtr> batchSlots;
75-
//! Filled with slots in request order, [batchSize]
76-
TensorPtr batchSlotsRequestOrder;
77-
78-
//! For Beam Search
79-
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
80-
std::vector<SizeType32> generationSteps;
81-
82-
//! For speculative decoding
83-
//! Logits of draft
84-
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
85-
std::vector<std::vector<TensorPtr>> predictedDraftLogits;
86-
87-
//! Explicit draft tokens data
88-
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
89-
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;
90-
91-
//! Eagle data
92-
std::optional<EagleBuffers::EngineOutputs> eagleInputs;
93-
std::optional<EagleBuffers::Inputs> eagleLastInputs;
9475
};
9576

9677
} // namespace decoder_batch

cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,62 @@ std::pair<std::vector<SizeType32>, std::vector<SizeType32>> getActiveSlots(
118118
return {activeSlots, generationSteps};
119119
}
120120

121+
//! @brief Sets inputs for explicit draft tokens.
122+
void setExplicitDraftTokensInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
123+
{
124+
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
125+
126+
TLLM_CHECK(fusedRuntimeBuffers.mExplicitDraftTokensBuffers);
127+
auto const& explicitDraftTokensInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineOutputs;
128+
auto const& explicitDraftTokensLastInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineInputs;
129+
130+
dInput.explicitDraftTokensInputs = tr::DecodingInput::ExplicitDraftTokensInputs();
131+
dInput.explicitDraftTokensInputs->nextDraftTokens = explicitDraftTokensInputs.nextDraftTokens;
132+
dInput.explicitDraftTokensInputs->nextFlatTokens = explicitDraftTokensInputs.nextFlatTokens;
133+
dInput.explicitDraftTokensInputs->nextDraftIndices = explicitDraftTokensInputs.nextDraftIndices;
134+
dInput.explicitDraftTokensInputs->nextDraftProbs = explicitDraftTokensInputs.nextDraftProbs;
135+
dInput.explicitDraftTokensInputs->lastDraftTokens = explicitDraftTokensLastInputs.draftTokens;
136+
dInput.explicitDraftTokensInputs->lastDraftIndices = explicitDraftTokensLastInputs.draftIndices;
137+
dInput.explicitDraftTokensInputs->lastPositionIdsBase = explicitDraftTokensLastInputs.positionIdsBase;
138+
dInput.explicitDraftTokensInputs->masks = explicitDraftTokensInputs.masks;
139+
dInput.explicitDraftTokensInputs->packedPositionIds = explicitDraftTokensInputs.packedPositionIds;
140+
dInput.explicitDraftTokensInputs->bestPathLengths = explicitDraftTokensInputs.bestPathLengths;
141+
dInput.explicitDraftTokensInputs->bestPathIndices = explicitDraftTokensInputs.bestPathIndices;
142+
dInput.explicitDraftTokensInputs->nextGenerationLengths = explicitDraftTokensInputs.nextGenerationLengths;
143+
dInput.explicitDraftTokensInputs->lastGenerationLengths = explicitDraftTokensLastInputs.generationLengths;
144+
dInput.explicitDraftTokensInputs->maxGenLengthDevice = explicitDraftTokensInputs.maxGenToken;
145+
// Slots in request order
146+
dInput.explicitDraftTokensInputs->seqSlots = fusedRuntimeBuffers.seqSlots;
147+
148+
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
149+
}
150+
151+
//! @brief Sets inputs for eagle decoding.
152+
void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
153+
{
154+
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
155+
156+
TLLM_CHECK(fusedRuntimeBuffers.mEagleBuffers);
157+
auto const& eagleInputs = fusedRuntimeBuffers.mEagleBuffers->engineOutputs;
158+
auto const& eagleLastInputs = fusedRuntimeBuffers.mEagleBuffers->engineInputs;
159+
160+
dInput.eagleInputs = tr::DecodingInput::EagleInputs();
161+
dInput.eagleInputs->nextDraftTokens = eagleInputs.nextDraftTokens;
162+
dInput.eagleInputs->nextDraftLens = eagleInputs.nextDraftLens;
163+
dInput.eagleInputs->nextDraftPaths = eagleInputs.nextDraftPaths;
164+
dInput.eagleInputs->lastDraftTokens = eagleLastInputs.draftTokens;
165+
dInput.eagleInputs->lastDraftLens = eagleLastInputs.draftLens;
166+
dInput.eagleInputs->lastDraftPaths = eagleLastInputs.draftPaths;
167+
dInput.eagleInputs->acceptedTokens = eagleInputs.acceptedTokens;
168+
dInput.eagleInputs->acceptedLens = eagleInputs.acceptedLens;
169+
dInput.eagleInputs->acceptedPathIds = eagleInputs.acceptedPaths;
170+
dInput.eagleInputs->chunkedContextNextTokens = eagleInputs.chunkedContextNextTokens;
171+
// Slots in request order
172+
dInput.eagleInputs->seqSlots = fusedRuntimeBuffers.seqSlots;
173+
174+
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
175+
}
176+
121177
} // namespace
122178

123179
std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests,
@@ -131,28 +187,30 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator
131187

132188
auto decodingInput = createDecoderBatchInputs(
133189
activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots);
134-
decodingInput->generationSteps = generationSteps;
190+
191+
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
192+
if (maxBeamWidth > 1)
193+
{
194+
// For Variable-Beam-Width-Search
195+
decoderState.getJointDecodingInput().generationSteps = generationSteps;
196+
}
135197

136198
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
137199
{
138-
decodingInput->predictedDraftLogits = inputBuffers.predictedDraftLogits;
200+
decoderState.getJointDecodingInput().medusaInputs->medusaLogits = inputBuffers.predictedDraftLogits;
139201
}
140202

141203
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
142204
{
143205
TLLM_CHECK(fusedRuntimeBuffers);
144206
// requires mCtxGenFusion == true
145-
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
146-
decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineOutputs;
147-
decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineInputs;
207+
setExplicitDraftTokensInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
148208
}
149209
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
150210
{
151211
TLLM_CHECK(fusedRuntimeBuffers);
152212
// requires mCtxGenFusion == true
153-
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
154-
decodingInput->eagleInputs = fusedRuntimeBuffers->mEagleBuffers->engineOutputs;
155-
decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers->engineInputs;
213+
setEagleInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
156214
}
157215

158216
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "tensorrt_llm/kernels/delayStream.h"
2424
#include "tensorrt_llm/runtime/cudaEvent.h"
2525
#include "tensorrt_llm/runtime/cudaStream.h"
26+
#include "tensorrt_llm/runtime/decoderState.h"
2627
#include "tensorrt_llm/runtime/decodingInput.h"
2728
#include "tensorrt_llm/runtime/decodingOutput.h"
2829
#include "tensorrt_llm/runtime/gptDecoder.h"
@@ -273,10 +274,7 @@ void initBindings(pybind11::module_& m)
273274
.def(py::init<std::vector<tr::ITensor::SharedConstPtr>>(), py::arg("logits"))
274275
.def_readwrite("logits", &tr::decoder_batch::Input::logits)
275276
.def_readwrite("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps)
276-
.def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots)
277-
.def_readwrite("batch_slots_request_order", &tr::decoder_batch::Input::batchSlotsRequestOrder)
278-
.def_readwrite("generation_steps", &tr::decoder_batch::Input::generationSteps)
279-
.def_readwrite("predicted_draft_logits", &tr::decoder_batch::Input::predictedDraftLogits);
277+
.def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots);
280278

281279
py::class_<tr::LookaheadDecodingBuffers>(m, "LookaheadDecodingBuffers")
282280
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager const&>(), py::arg("max_num_sequences"),
@@ -382,7 +380,9 @@ void initBindings(pybind11::module_& m)
382380
py::arg("batch_idx"))
383381
.def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens,
384382
py::arg("batch_idx"), py::arg("num_tokens"))
385-
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode);
383+
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode)
384+
.def_property("generation_steps", &tr::decoder::DecoderState::getGenerationSteps,
385+
&tr::decoder::DecoderState::setGenerationSteps);
386386

387387
py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
388388
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"))

cpp/tensorrt_llm/runtime/decoderState.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,16 @@ TensorPtr DecoderState::getCacheIndirectionOutput() const
644644
return mJointDecodingOutput->cacheIndirection;
645645
}
646646

647+
std::optional<std::vector<SizeType32>> const& DecoderState::getGenerationSteps() const
648+
{
649+
return mJointDecodingInput->generationSteps;
650+
}
651+
652+
void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationSteps)
653+
{
654+
mJointDecodingInput->generationSteps = generationSteps;
655+
}
656+
647657
DecodingInput& DecoderState::getJointDecodingInput() const
648658
{
649659
return *mJointDecodingInput;

cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -100,78 +100,18 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
100100

101101
namespace
102102
{
103-
//! @brief Sets inputs for explicit draft tokens.
104-
void setExplicitDraftTokensInputs(DecodingInput& dInput, decoder_batch::Input const& input)
105-
{
106-
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
107-
108-
dInput.explicitDraftTokensInputs = DecodingInput::ExplicitDraftTokensInputs();
109-
TLLM_CHECK(input.explicitDraftTokensInputs.has_value());
110-
TLLM_CHECK(input.explicitDraftTokensLastInputs.has_value());
111-
112-
dInput.explicitDraftTokensInputs->nextDraftTokens = input.explicitDraftTokensInputs->nextDraftTokens;
113-
dInput.explicitDraftTokensInputs->nextFlatTokens = input.explicitDraftTokensInputs->nextFlatTokens;
114-
dInput.explicitDraftTokensInputs->nextDraftIndices = input.explicitDraftTokensInputs->nextDraftIndices;
115-
dInput.explicitDraftTokensInputs->nextDraftProbs = input.explicitDraftTokensInputs->nextDraftProbs;
116-
dInput.explicitDraftTokensInputs->lastDraftTokens = input.explicitDraftTokensLastInputs->draftTokens;
117-
dInput.explicitDraftTokensInputs->lastDraftIndices = input.explicitDraftTokensLastInputs->draftIndices;
118-
dInput.explicitDraftTokensInputs->lastPositionIdsBase = input.explicitDraftTokensLastInputs->positionIdsBase;
119-
dInput.explicitDraftTokensInputs->masks = input.explicitDraftTokensInputs->masks;
120-
dInput.explicitDraftTokensInputs->packedPositionIds = input.explicitDraftTokensInputs->packedPositionIds;
121-
dInput.explicitDraftTokensInputs->bestPathLengths = input.explicitDraftTokensInputs->bestPathLengths;
122-
dInput.explicitDraftTokensInputs->bestPathIndices = input.explicitDraftTokensInputs->bestPathIndices;
123-
dInput.explicitDraftTokensInputs->nextGenerationLengths = input.explicitDraftTokensInputs->nextGenerationLengths;
124-
dInput.explicitDraftTokensInputs->lastGenerationLengths = input.explicitDraftTokensLastInputs->generationLengths;
125-
dInput.explicitDraftTokensInputs->maxGenLengthDevice = input.explicitDraftTokensInputs->maxGenToken;
126-
dInput.explicitDraftTokensInputs->seqSlots = input.batchSlotsRequestOrder;
127-
128-
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
129-
}
130-
131-
//! @brief Sets inputs for eagle decoding.
132-
void setEagleInputs(DecodingInput& dInput, decoder_batch::Input const& input)
133-
{
134-
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
135-
136-
TLLM_CHECK(input.eagleInputs.has_value());
137-
TLLM_CHECK(input.eagleLastInputs.has_value());
138-
139-
dInput.eagleInputs = DecodingInput::EagleInputs(input.eagleInputs->nextDraftTokens,
140-
input.eagleInputs->nextDraftLens, input.eagleInputs->nextDraftPaths, input.eagleLastInputs->draftTokens,
141-
input.eagleLastInputs->draftLens, input.eagleLastInputs->draftPaths, input.eagleInputs->acceptedTokens,
142-
input.eagleInputs->acceptedLens, input.eagleInputs->acceptedPaths, input.eagleInputs->chunkedContextNextTokens,
143-
input.batchSlotsRequestOrder);
144-
145-
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
146-
}
147-
148103
//! @brief Prepare Input and Output for decoder step.
149104
// TODO: produce new input and output objects
150105
void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Input const& input,
151106
BufferManager const& bufferManager)
152107
{
153108
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
154109

155-
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
156110
auto const speculativeDecodingMode = decoderState.getSpeculativeDecodingMode();
157111

158112
auto& dInput = decoderState.getJointDecodingInput();
159113
auto& dOutput = decoderState.getJointDecodingOutput();
160114

161-
if (maxBeamWidth > 1)
162-
{
163-
dInput.generationSteps = input.generationSteps; // For Variable-Beam-Width-Search
164-
}
165-
166-
if (speculativeDecodingMode.isExplicitDraftTokens())
167-
{
168-
setExplicitDraftTokensInputs(dInput, input);
169-
}
170-
else if (speculativeDecodingMode.isEagle())
171-
{
172-
setEagleInputs(dInput, input);
173-
}
174-
175115
dInput.batchSlots = input.batchSlots.at(step);
176116
dInput.batchSize = static_cast<SizeType32>(dInput.batchSlots->getSize());
177117
dInput.logitsVec = input.logits.at(step);
@@ -186,11 +126,6 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step,
186126

187127
dInput.finishReasons = finishedStepsInput;
188128

189-
if (speculativeDecodingMode.isMedusa())
190-
{
191-
dInput.medusaInputs->medusaLogits = input.predictedDraftLogits;
192-
}
193-
194129
if (speculativeDecodingMode.isDraftTokensExternal())
195130
{
196131
dInput.externalDraftTokensInputs->step = step;

tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __call__(
131131
max_num_sequences=max_num_sequences,
132132
batch_slots=decoder_input_buffers.forward_batch_slots,
133133
)
134-
decoding_input.generation_steps = generation_steps
134+
decoder_state.generation_steps = generation_steps
135135

136136
# TODO: Handle speculative decoding modes.
137137
# fused_runtime_buffers is not created in the pytorch framework.

0 commit comments

Comments
 (0)