Skip to content

Commit dc05299

Browse files
Funatiqdominicshanshan
authored andcommitted
refactor: manage cache indirection in decoder state (NVIDIA#5315)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 2b9b8fd commit dc05299

22 files changed

+211
-268
lines changed

cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,32 +107,31 @@ class DecoderBuffers
107107
using SizeType32 = runtime::SizeType32;
108108
using TensorPtr = runtime::ITensor::SharedPtr;
109109

110-
TensorPtr cacheIndirectionInput;
111-
TensorPtr cacheIndirectionOutput;
112-
113110
DraftBuffers draftBuffers;
114111

115-
DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
116-
SizeType32 maxTokensPerStep, runtime::BufferManager const& manager, runtime::ModelConfig const& modelConfig,
117-
runtime::WorldConfig const& worldConfig);
112+
DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager,
113+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
118114
};
119115

120116
class DecoderStepAsyncSend
121117
{
122118
public:
123119
using SizeType32 = runtime::SizeType32;
124-
using BufferPtr = runtime::IBuffer::SharedPtr;
120+
using TensorPtr = runtime::ITensor::SharedPtr;
125121

126-
DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
127-
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer);
122+
DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
123+
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
124+
mpi::MpiComm const& commSession, int peer);
128125

129126
~DecoderStepAsyncSend();
130127

131-
static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
132-
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int peer);
128+
static void recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
129+
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
130+
mpi::MpiComm const& commSession, int peer);
133131

134-
static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
135-
bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa, mpi::MpiComm const& commSession, int root);
132+
static void bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
133+
TensorPtr const& cacheIndirectionOutput, bool returnLogProbs, SizeType32 maxBeamWidth, bool useMedusa,
134+
mpi::MpiComm const& commSession, int root);
136135

137136
private:
138137
std::unique_ptr<mpi::MpiRequest> mRequest1;

cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,15 @@ class MakeDecodingBatchInputOutput : Algorithm
4747

4848
MakeDecodingBatchInputOutput() = default;
4949

50-
std::tuple<std::unique_ptr<runtime::decoder_batch::Input>, std::unique_ptr<runtime::decoder_batch::Output>>
51-
operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
52-
DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers,
53-
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
54-
SizeType32 maxNumSequences, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
50+
std::unique_ptr<runtime::decoder_batch::Input> operator()(RequestVector const& contextRequests,
51+
RequestVector const& generationRequests, DecoderBuffers& decoderBuffers,
52+
DecoderInputBuffers const& inputBuffers, runtime::decoder::DecoderState& decoderState,
53+
runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
54+
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
5555

5656
[[nodiscard]] static std::unique_ptr<runtime::decoder_batch::Input> createDecoderBatchInputs(
5757
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
58-
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots,
59-
TensorPtr const& cacheIndirectionInput);
58+
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots);
6059
};
6160

6261
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,9 @@ class RuntimeBuffers
279279

280280
std::tuple<SizeType32, TensorMap const&, TensorMap&> prepareStep(RequestVector const& contextRequests,
281281
RequestVector const& genRequests, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
282-
DecoderBuffers& decoderBuffers, runtime::decoder::DecoderState const& decoderState,
283-
kv_cache_manager::BaseKVCacheManager* kvCacheManager, kv_cache_manager::BaseKVCacheManager* crossKvCacheManager,
284-
rnn_state_manager::RnnStateManager* rnnStateManager, PeftTable const& peftTable,
285-
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
282+
runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManager,
283+
kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager,
284+
PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
286285
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits, bool trtOverlap,
287286
OptionalRef<runtime::ITensor const> newOutputTokens = std::nullopt);
288287

@@ -314,8 +313,8 @@ class RuntimeBuffers
314313
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits);
315314

316315
void setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth,
317-
SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers,
318-
runtime::decoder::DecoderState const& decoderState, kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
316+
SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState,
317+
kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
319318
kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr,
320319
rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable,
321320
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,

cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,6 @@ class TransformerBuffers
122122
void copyPositionIds(runtime::TllmRuntime const& runtime, std::vector<SizeType32> const& positionIdsHost,
123123
bool isChatGlm, TensorPtr const& decoderPositionIds);
124124

125-
void resetCacheIndirection(RequestVector const& contextRequests, SizeType32 maxBeamWidth,
126-
SizeType32 maxAttentionWindow, TensorPtr const& decoderCacheIndirectionInput,
127-
TensorPtr const& decoderCacheIndirectionOutput, runtime::BufferManager const& manager);
128-
129125
void copyKvBlockOffsets(RequestVector const& contextRequests, RequestVector const& genRequests,
130126
kv_cache_manager::BaseKVCacheManager const* kvCacheManager,
131127
kv_cache_manager::BaseKVCacheManager const* crossKvCacheManager, runtime::BufferManager const& manager);

cpp/include/tensorrt_llm/runtime/decoderState.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,15 @@ class DecoderState
5454
void allocateSpeculativeDecodingBuffers(
5555
SpeculativeDecodingMode speculativeDecodingMode, nvinfer1::DataType dtype, BufferManager const& bufferManager);
5656

57+
//! @brief Setup buffers for the decoder excluding speculative decoding.
5758
void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
5859
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, ModelConfig const& modelConfig,
5960
WorldConfig const& worldConfig, BufferManager const& bufferManager);
6061

62+
//! @brief Setup buffers for the cache indirection.
63+
//! @details This is used for beam search on pipeline parallel ranks without a decoder.
64+
void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow);
65+
6166
//! @brief Setup buffers for speculative decoding.
6267
void setupSpeculativeDecoding(SpeculativeDecodingMode const& speculativeDecodingMode,
6368
SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
@@ -174,6 +179,12 @@ class DecoderState
174179
//! @brief Workspace for beam search in streaming mode.
175180
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;
176181

182+
//! @brief Cache indirection input for beam search.
183+
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;
184+
185+
//! @brief Cache indirection output for beam search.
186+
[[nodiscard]] TensorPtr getCacheIndirectionOutput() const;
187+
177188
//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
178189
[[nodiscard]] DecodingInput& getJointDecodingInput() const;
179190

cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ class GptDecoderBatched : public IGptDecoderBatched
5252

5353
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
5454

55-
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Output& output,
56-
decoder_batch::Input const& input) override;
57-
void forward(decoder::DecoderState const& decoderState, decoder_batch::Output& output,
58-
decoder_batch::Input const& input) override;
55+
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
56+
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
5957

6058
//! @brief Gather final beam search results for request `batchSlot`.
6159
//! Result will only be available after event returned.
@@ -79,8 +77,7 @@ class GptDecoderBatched : public IGptDecoderBatched
7977

8078
private:
8179
//! @brief Calls decoders for tokens per engine step
82-
void forwardDispatch(
83-
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input);
80+
void forwardDispatch(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);
8481

8582
private:
8683
CudaStreamPtr mRuntimeStream;

cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ class Input
7676
TensorPtr batchSlotsRequestOrder;
7777

7878
//! For Beam Search
79-
//! Indices into KV cache of different rays within one beam, [maxBatchSize, maxBeamWidth, maxSeqLen], on gpu
80-
TensorPtr cacheIndirection;
8179
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
8280
std::vector<SizeType32> generationSteps;
8381

@@ -95,17 +93,6 @@ class Input
9593
std::optional<EagleBuffers::Inputs> eagleLastInputs;
9694
};
9795

98-
class Output
99-
{
100-
public:
101-
using TensorPtr = std::shared_ptr<ITensor>;
102-
103-
Output() = default;
104-
105-
//! parameters for beam search, [batchSize, maxBeamWidth, maxSeqLen], on gpu
106-
TensorPtr cacheIndirection;
107-
};
108-
10996
} // namespace decoder_batch
11097

11198
//! GPT decoder class with support for in-flight batching
@@ -126,14 +113,10 @@ class IGptDecoderBatched
126113
virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0;
127114

128115
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
129-
virtual CudaEvent forwardAsync(
130-
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input)
131-
= 0;
116+
virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
132117

133118
//! @brief Run one step for all requests and wait for completion on the host.
134-
virtual void forward(
135-
decoder::DecoderState const& decoderState, decoder_batch::Output& output, decoder_batch::Input const& input)
136-
= 0;
119+
virtual void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;
137120

138121
//! @brief Gather final beam search results for request `batchIdx`.
139122
//! Result will only be available after event returned

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder
306306
parentIds->reshape(outputIdsShape);
307307
manager.setZero(*parentIds);
308308

309+
auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1);
310+
manager.setZero(*cacheIndirectionInput);
311+
312+
auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1);
313+
manager.setZero(*cacheIndirectionOutput);
314+
309315
auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1);
310316
beamHypotheses.init(manager, endId);
311317
}

cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,9 @@ void DecoderOutputBuffers::disableLookaheadDecoding(SizeType32 maxNumSequences)
8989
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
9090
}
9191

92-
DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
93-
SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig,
94-
WorldConfig const& worldConfig)
92+
DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const& manager,
93+
ModelConfig const& modelConfig, WorldConfig const& worldConfig)
9594
{
96-
cacheIndirectionInput = manager.gpu(
97-
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
98-
cacheIndirectionOutput = manager.gpu(
99-
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
100-
10195
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()
10296
|| modelConfig.getSpeculativeDecodingMode().hasDraftLogits()
10397
|| modelConfig.getSpeculativeDecodingMode().predictsDraftTokens())
@@ -147,8 +141,8 @@ void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerSte
147141
}
148142

149143
DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOutputBuffers,
150-
DecoderBuffers const& decoderBuffers, bool const returnLogProbs, SizeType32 const maxBeamWidth,
151-
bool const useMedusa, mpi::MpiComm const& commSession, int peer)
144+
DraftBuffers const& draftBuffers, TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs,
145+
SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession, int peer)
152146
{
153147
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
154148
TLLM_LOG_DEBUG("start send outputs of DecoderBuffers to rank %d", peer);
@@ -165,24 +159,24 @@ DecoderStepAsyncSend::DecoderStepAsyncSend(DecoderOutputBuffers const& decoderOu
165159
mRequest5 = returnLogProbs
166160
? commSession.sendAsync(*decoderOutputBuffers.logProbsHost, peer, mpi::MpiTag::kDecoderStepLogProbsHost)
167161
: nullptr;
168-
mRequest6 = maxBeamWidth > 1 ? commSession.sendAsync(
169-
*decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput)
170-
: nullptr;
171-
mRequest7 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer,
162+
mRequest6 = maxBeamWidth > 1
163+
? commSession.sendAsync(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput)
164+
: nullptr;
165+
mRequest7 = useMedusa ? commSession.sendAsync(*draftBuffers.acceptedLengthsCumSumDevice, peer,
172166
mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice)
173167
: nullptr;
174-
mRequest8 = useMedusa ? commSession.sendAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer,
175-
mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice)
168+
mRequest8 = useMedusa ? commSession.sendAsync(
169+
*draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice)
176170
: nullptr;
177171
mRequest9 = commSession.sendAsync(
178172
*decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost);
179173

180174
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
181175
}
182176

183-
void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
184-
bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession,
185-
int const peer)
177+
void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
178+
TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth,
179+
bool const useMedusa, mpi::MpiComm const& commSession, int const peer)
186180
{
187181
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
188182
TLLM_LOG_DEBUG("start recv outputs of DecoderBuffers from rank %d", peer);
@@ -197,14 +191,14 @@ void DecoderStepAsyncSend::recv(DecoderOutputBuffers const& decoderOutputBuffers
197191
}
198192
if (maxBeamWidth > 1)
199193
{
200-
commSession.recv(*decoderBuffers.cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput);
194+
commSession.recv(*cacheIndirectionOutput, peer, mpi::MpiTag::kDecoderStepCacheIndirectionOutput);
201195
}
202196
if (useMedusa)
203197
{
204-
commSession.recv(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, peer,
205-
mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice);
206-
commSession.recv(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, peer,
207-
mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice);
198+
commSession.recv(
199+
*draftBuffers.acceptedLengthsCumSumDevice, peer, mpi::MpiTag::kDecoderStepAcceptedLengthsCumSumDevice);
200+
commSession.recv(
201+
*draftBuffers.acceptedPackedPathsDevice, peer, mpi::MpiTag::kDecoderStepAcceptedPackedPathsDevice);
208202
}
209203
commSession.recv(*decoderOutputBuffers.finishReasonsHost, peer, mpi::MpiTag::kDecoderStepFinishReasonsHost);
210204

@@ -235,9 +229,9 @@ DecoderStepAsyncSend::~DecoderStepAsyncSend()
235229
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
236230
}
237231

238-
void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DecoderBuffers const& decoderBuffers,
239-
bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, mpi::MpiComm const& commSession,
240-
int const root)
232+
void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffers, DraftBuffers const& draftBuffers,
233+
TensorPtr const& cacheIndirectionOutput, bool const returnLogProbs, SizeType32 const maxBeamWidth,
234+
bool const useMedusa, mpi::MpiComm const& commSession, int const root)
241235
{
242236
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
243237
TLLM_LOG_DEBUG("start bcast outputs of DecoderBuffers from rank %d", root);
@@ -247,11 +241,9 @@ void DecoderStepAsyncSend::bcast(DecoderOutputBuffers const& decoderOutputBuffer
247241
auto request3 = commSession.bcastAsync(*decoderOutputBuffers.sequenceLengthsHost, root);
248242
auto request4 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.cumLogProbsHost, root) : nullptr;
249243
auto request5 = returnLogProbs ? commSession.bcastAsync(*decoderOutputBuffers.logProbsHost, root) : nullptr;
250-
auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*decoderBuffers.cacheIndirectionOutput, root) : nullptr;
251-
auto request7
252-
= useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr;
253-
auto request8
254-
= useMedusa ? commSession.bcastAsync(*decoderBuffers.draftBuffers.acceptedPackedPathsDevice, root) : nullptr;
244+
auto request6 = maxBeamWidth > 1 ? commSession.bcastAsync(*cacheIndirectionOutput, root) : nullptr;
245+
auto request7 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr;
246+
auto request8 = useMedusa ? commSession.bcastAsync(*draftBuffers.acceptedPackedPathsDevice, root) : nullptr;
255247
auto request9 = commSession.bcastAsync(*decoderOutputBuffers.finishReasonsHost, root);
256248

257249
request1->wait();

0 commit comments

Comments
 (0)