@@ -118,6 +118,62 @@ std::pair<std::vector<SizeType32>, std::vector<SizeType32>> getActiveSlots(
118118 return {activeSlots, generationSteps};
119119}
120120
121+ // ! @brief Sets inputs for explicit draft tokens.
122+ void setExplicitDraftTokensInputs (tr::DecodingInput& dInput, RuntimeBuffers const & fusedRuntimeBuffers)
123+ {
124+ TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
125+
126+ TLLM_CHECK (fusedRuntimeBuffers.mExplicitDraftTokensBuffers );
127+ auto const & explicitDraftTokensInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers ->engineOutputs ;
128+ auto const & explicitDraftTokensLastInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers ->engineInputs ;
129+
130+ dInput.explicitDraftTokensInputs = tr::DecodingInput::ExplicitDraftTokensInputs ();
131+ dInput.explicitDraftTokensInputs ->nextDraftTokens = explicitDraftTokensInputs.nextDraftTokens ;
132+ dInput.explicitDraftTokensInputs ->nextFlatTokens = explicitDraftTokensInputs.nextFlatTokens ;
133+ dInput.explicitDraftTokensInputs ->nextDraftIndices = explicitDraftTokensInputs.nextDraftIndices ;
134+ dInput.explicitDraftTokensInputs ->nextDraftProbs = explicitDraftTokensInputs.nextDraftProbs ;
135+ dInput.explicitDraftTokensInputs ->lastDraftTokens = explicitDraftTokensLastInputs.draftTokens ;
136+ dInput.explicitDraftTokensInputs ->lastDraftIndices = explicitDraftTokensLastInputs.draftIndices ;
137+ dInput.explicitDraftTokensInputs ->lastPositionIdsBase = explicitDraftTokensLastInputs.positionIdsBase ;
138+ dInput.explicitDraftTokensInputs ->masks = explicitDraftTokensInputs.masks ;
139+ dInput.explicitDraftTokensInputs ->packedPositionIds = explicitDraftTokensInputs.packedPositionIds ;
140+ dInput.explicitDraftTokensInputs ->bestPathLengths = explicitDraftTokensInputs.bestPathLengths ;
141+ dInput.explicitDraftTokensInputs ->bestPathIndices = explicitDraftTokensInputs.bestPathIndices ;
142+ dInput.explicitDraftTokensInputs ->nextGenerationLengths = explicitDraftTokensInputs.nextGenerationLengths ;
143+ dInput.explicitDraftTokensInputs ->lastGenerationLengths = explicitDraftTokensLastInputs.generationLengths ;
144+ dInput.explicitDraftTokensInputs ->maxGenLengthDevice = explicitDraftTokensInputs.maxGenToken ;
145+ // Slots in request order
146+ dInput.explicitDraftTokensInputs ->seqSlots = fusedRuntimeBuffers.seqSlots ;
147+
148+ TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
149+ }
150+
151+ // ! @brief Sets inputs for eagle decoding.
152+ void setEagleInputs (tr::DecodingInput& dInput, RuntimeBuffers const & fusedRuntimeBuffers)
153+ {
154+ TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
155+
156+ TLLM_CHECK (fusedRuntimeBuffers.mEagleBuffers );
157+ auto const & eagleInputs = fusedRuntimeBuffers.mEagleBuffers ->engineOutputs ;
158+ auto const & eagleLastInputs = fusedRuntimeBuffers.mEagleBuffers ->engineInputs ;
159+
160+ dInput.eagleInputs = tr::DecodingInput::EagleInputs ();
161+ dInput.eagleInputs ->nextDraftTokens = eagleInputs.nextDraftTokens ;
162+ dInput.eagleInputs ->nextDraftLens = eagleInputs.nextDraftLens ;
163+ dInput.eagleInputs ->nextDraftPaths = eagleInputs.nextDraftPaths ;
164+ dInput.eagleInputs ->lastDraftTokens = eagleLastInputs.draftTokens ;
165+ dInput.eagleInputs ->lastDraftLens = eagleLastInputs.draftLens ;
166+ dInput.eagleInputs ->lastDraftPaths = eagleLastInputs.draftPaths ;
167+ dInput.eagleInputs ->acceptedTokens = eagleInputs.acceptedTokens ;
168+ dInput.eagleInputs ->acceptedLens = eagleInputs.acceptedLens ;
169+ dInput.eagleInputs ->acceptedPathIds = eagleInputs.acceptedPaths ;
170+ dInput.eagleInputs ->chunkedContextNextTokens = eagleInputs.chunkedContextNextTokens ;
171+ // Slots in request order
172+ dInput.eagleInputs ->seqSlots = fusedRuntimeBuffers.seqSlots ;
173+
174+ TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
175+ }
176+
121177} // namespace
122178
123179std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator ()(RequestVector const & contextRequests,
@@ -131,28 +187,30 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator
131187
132188 auto decodingInput = createDecoderBatchInputs (
133189 activeSlots, decoderState, inputBuffers.logits , maxNumSequences, inputBuffers.forwardBatchSlots );
134- decodingInput->generationSteps = generationSteps;
190+
191+ auto const maxBeamWidth = decoderState.getMaxBeamWidth ();
192+ if (maxBeamWidth > 1 )
193+ {
194+ // For Variable-Beam-Width-Search
195+ decoderState.getJointDecodingInput ().generationSteps = generationSteps;
196+ }
135197
136198 if (modelConfig.getSpeculativeDecodingMode ().hasDraftLogits ())
137199 {
138- decodingInput-> predictedDraftLogits = inputBuffers.predictedDraftLogits ;
200+ decoderState. getJointDecodingInput (). medusaInputs -> medusaLogits = inputBuffers.predictedDraftLogits ;
139201 }
140202
141203 if (modelConfig.getSpeculativeDecodingMode ().isExplicitDraftTokens ())
142204 {
143205 TLLM_CHECK (fusedRuntimeBuffers);
144206 // requires mCtxGenFusion == true
145- decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots ;
146- decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers ->engineOutputs ;
147- decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers ->engineInputs ;
207+ setExplicitDraftTokensInputs (decoderState.getJointDecodingInput (), *fusedRuntimeBuffers);
148208 }
149209 else if (modelConfig.getSpeculativeDecodingMode ().isEagle ())
150210 {
151211 TLLM_CHECK (fusedRuntimeBuffers);
152212 // requires mCtxGenFusion == true
153- decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots ;
154- decodingInput->eagleInputs = fusedRuntimeBuffers->mEagleBuffers ->engineOutputs ;
155- decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers ->engineInputs ;
213+ setEagleInputs (decoderState.getJointDecodingInput (), *fusedRuntimeBuffers);
156214 }
157215
158216 TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
0 commit comments