@@ -89,15 +89,9 @@ void DecoderOutputBuffers::disableLookaheadDecoding(SizeType32 maxNumSequences)
89
89
TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
90
90
}
91
91
92
- DecoderBuffers::DecoderBuffers (SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
93
- SizeType32 maxTokensPerStep, BufferManager const & manager, ModelConfig const & modelConfig,
94
- WorldConfig const & worldConfig)
92
+ DecoderBuffers::DecoderBuffers (SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const & manager,
93
+ ModelConfig const & modelConfig, WorldConfig const & worldConfig)
95
94
{
96
- cacheIndirectionInput = manager.gpu (
97
- ITensor::makeShape ({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32 );
98
- cacheIndirectionOutput = manager.gpu (
99
- ITensor::makeShape ({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32 );
100
-
101
95
if (modelConfig.getSpeculativeDecodingMode ().needsKVCacheRewind ()
102
96
|| modelConfig.getSpeculativeDecodingMode ().hasDraftLogits ()
103
97
|| modelConfig.getSpeculativeDecodingMode ().predictsDraftTokens ())
@@ -147,8 +141,8 @@ void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerSte
147
141
}
148
142
149
143
DecoderStepAsyncSend::DecoderStepAsyncSend (DecoderOutputBuffers const & decoderOutputBuffers,
150
- DecoderBuffers const & decoderBuffers, bool const returnLogProbs, SizeType32 const maxBeamWidth ,
151
- bool const useMedusa, mpi::MpiComm const & commSession, int peer)
144
+ DraftBuffers const & draftBuffers, TensorPtr const & cacheIndirectionOutput, bool const returnLogProbs ,
145
+ SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const & commSession, int peer)
152
146
{
153
147
TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
154
148
TLLM_LOG_DEBUG (" start send outputs of DecoderBuffers to rank %d" , peer);
@@ -165,24 +159,24 @@ DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOu
165
159
mRequest5 = returnLogProbs
166
160
? commSession.sendAsync (*decoderOutputBuffers.logProbsHost , peer, mpi::MpiTag::kDecoderStepLogProbsHost )
167
161
: nullptr ;
168
- mRequest6 = maxBeamWidth > 1 ? commSession. sendAsync (
169
- *decoderBuffers. cacheIndirectionOutput , peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput )
170
- : nullptr ;
171
- mRequest7 = useMedusa ? commSession.sendAsync (*decoderBuffers. draftBuffers .acceptedLengthsCumSumDevice , peer,
162
+ mRequest6 = maxBeamWidth > 1
163
+ ? commSession. sendAsync (* cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput )
164
+ : nullptr ;
165
+ mRequest7 = useMedusa ? commSession.sendAsync (*draftBuffers.acceptedLengthsCumSumDevice , peer,
172
166
mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice )
173
167
: nullptr ;
174
- mRequest8 = useMedusa ? commSession.sendAsync (*decoderBuffers. draftBuffers . acceptedPackedPathsDevice , peer,
175
- mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice )
168
+ mRequest8 = useMedusa ? commSession.sendAsync (
169
+ *draftBuffers. acceptedPackedPathsDevice , peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice )
176
170
: nullptr ;
177
171
mRequest9 = commSession.sendAsync (
178
172
*decoderOutputBuffers.finishReasonsHost , peer, mpi::MpiTag::kDecoderStepFinishReasonsHost );
179
173
180
174
TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
181
175
}
182
176
183
- void DecoderStepAsyncSend::recv (DecoderOutputBuffers const & decoderOutputBuffers, DecoderBuffers const & decoderBuffers ,
184
- bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const & commSession ,
185
- int const peer)
177
+ void DecoderStepAsyncSend::recv (DecoderOutputBuffers const & decoderOutputBuffers, DraftBuffers const & draftBuffers ,
178
+ TensorPtr const & cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth ,
179
+ bool const useMedusa, mpi::MpiComm const & commSession, int const peer)
186
180
{
187
181
TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
188
182
TLLM_LOG_DEBUG (" start recv outputs of DecoderBuffers from rank %d" , peer);
@@ -197,14 +191,14 @@ void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers
197
191
}
198
192
if (maxBeamWidth > 1 )
199
193
{
200
- commSession.recv (*decoderBuffers. cacheIndirectionOutput , peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput );
194
+ commSession.recv (*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput );
201
195
}
202
196
if (useMedusa)
203
197
{
204
- commSession.recv (*decoderBuffers. draftBuffers . acceptedLengthsCumSumDevice , peer,
205
- mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice );
206
- commSession.recv (*decoderBuffers. draftBuffers . acceptedPackedPathsDevice , peer,
207
- mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice );
198
+ commSession.recv (
199
+ *draftBuffers. acceptedLengthsCumSumDevice , peer, mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice );
200
+ commSession.recv (
201
+ *draftBuffers. acceptedPackedPathsDevice , peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice );
208
202
}
209
203
commSession.recv (*decoderOutputBuffers.finishReasonsHost , peer, mpi::MpiTag::kDecoderStepFinishReasonsHost );
210
204
@@ -235,9 +229,9 @@ DecoderStepAsyncSend::~DecoderStepAsyncSend()
235
229
TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
236
230
}
237
231
238
- void DecoderStepAsyncSend::bcast (DecoderOutputBuffers const & decoderOutputBuffers, DecoderBuffers const & decoderBuffers ,
239
- bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const & commSession ,
240
- int const root)
232
+ void DecoderStepAsyncSend::bcast (DecoderOutputBuffers const & decoderOutputBuffers, DraftBuffers const & draftBuffers ,
233
+ TensorPtr const & cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth ,
234
+ bool const useMedusa, mpi::MpiComm const & commSession, int const root)
241
235
{
242
236
TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
243
237
TLLM_LOG_DEBUG (" start bcast outputs of DecoderBuffers from rank %d" , root);
@@ -247,11 +241,9 @@ void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffer
247
241
auto request3 = commSession.bcastAsync (*decoderOutputBuffers.sequenceLengthsHost , root);
248
242
auto request4 = returnLogProbs ? commSession.bcastAsync (*decoderOutputBuffers.cumLogProbsHost , root) : nullptr ;
249
243
auto request5 = returnLogProbs ? commSession.bcastAsync (*decoderOutputBuffers.logProbsHost , root) : nullptr ;
250
- auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync (*decoderBuffers.cacheIndirectionOutput , root) : nullptr ;
251
- auto request7
252
- = useMedusa ? commSession.bcastAsync (*decoderBuffers.draftBuffers .acceptedLengthsCumSumDevice , root) : nullptr ;
253
- auto request8
254
- = useMedusa ? commSession.bcastAsync (*decoderBuffers.draftBuffers .acceptedPackedPathsDevice , root) : nullptr ;
244
+ auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync (*cacheIndirectionOutput, root) : nullptr ;
245
+ auto request7 = useMedusa ? commSession.bcastAsync (*draftBuffers.acceptedLengthsCumSumDevice , root) : nullptr ;
246
+ auto request8 = useMedusa ? commSession.bcastAsync (*draftBuffers.acceptedPackedPathsDevice , root) : nullptr ;
255
247
auto request9 = commSession.bcastAsync (*decoderOutputBuffers.finishReasonsHost , root);
256
248
257
249
request1->wait ();
0 commit comments