-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-8160][feat] Add draft token tree runtime on CDL #8586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yweng0828
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
yweng0828:yweng/add_draft_token_tree_runtime_on_cdl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,918
−633
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
102 changes: 102 additions & 0 deletions
102
cpp/tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| /* | ||
| * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. | ||
| * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <cuda_runtime_api.h> | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
||
| #ifdef ENABLE_FP8 | ||
| #include <cuda_fp8.h> | ||
| #endif | ||
|
|
||
| #include "draftTokenTreeKernels.h" | ||
| #include "tensorrt_llm/common/assert.h" | ||
| #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" | ||
| #include "tensorrt_llm/common/cudaTypeUtils.cuh" | ||
| #include "tensorrt_llm/common/cudaUtils.h" | ||
|
|
||
| using namespace tensorrt_llm::common; | ||
|
|
||
| namespace tensorrt_llm | ||
| { | ||
| namespace kernels | ||
| { | ||
|
|
||
| __global__ void extractRealDraftTokensKernel(int const curDraftIdx, int const batchSize, int const maxDraftLen, | ||
| int const maxTotalDraftTokens, int const maxTopK, int const numTokensExpandThisLayer, | ||
| int* tokensGatherIdxForDrafterModel, int* topKList, int* draftTokensIndicesCumsum, int64_t* newDraftTokens, | ||
| int64_t* draftTokensBuffer) | ||
| { | ||
| // curDraftIdx: int | ||
| // batchSize: int | ||
| // maxTotalDraftTokens: int | ||
| // maxTopK: int | ||
| // tokensGatherIdxForDrafterModel: int32_t*, indices of the draft tokens that need to be expand this layer | ||
| // shape: [numTokensExpandThisLayer] | ||
| // topKList: int32_t*, top k value for each expandable token | ||
| // shape: [numTokensExpandThisLayer] | ||
| // draftTokensIndicesCumsum: int32_t*, the cumulative sum of the write back indices for each draft layer | ||
| // shape: [maxDraftLen + 1] | ||
| // newDraftTokens: int64_t*, the new draft tokens. We only need to extract this layer's tokens and write back to | ||
| // the draftTokensBuffer shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK] | ||
| // draftTokensBuffer: int64_t*, the buffer to store the real draft tokens | ||
| // shape: [maxBatchSize, maxTotalDraftTokens + 1] | ||
|
|
||
| // Each thread handles one request | ||
| auto const tix = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x); | ||
| auto const isValid{tix < batchSize}; | ||
|
|
||
| if (isValid) | ||
| { | ||
| int newDraftTokensOffset = curDraftIdx == 0 ? 1 : maxTotalDraftTokens + 1; | ||
| auto newDraftTokensStartPtr = newDraftTokens + tix * newDraftTokensOffset * maxTopK; | ||
| auto draftTokensBufferPtr | ||
| = draftTokensBuffer + tix * (maxTotalDraftTokens + 1) + draftTokensIndicesCumsum[curDraftIdx]; | ||
|
|
||
| int cnt = 0; | ||
| for (int i = 0; i < numTokensExpandThisLayer; i++) | ||
| { | ||
| int tokenGatherIdx = tokensGatherIdxForDrafterModel[i]; | ||
| auto newDraftTokenPtr = newDraftTokensStartPtr + tokenGatherIdx * maxTopK; | ||
|
|
||
| int topKValue = topKList[i]; | ||
| for (int j = 0; j < topKValue; j++) | ||
| { | ||
| int64_t newGenDraftToken = newDraftTokenPtr[j]; | ||
| draftTokensBufferPtr[cnt] = newGenDraftToken; | ||
| cnt++; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream) | ||
| { | ||
| int constexpr BLOCK_SIZE = 64; | ||
| int NUM_BLOCKS = divUp(params.batchSize, BLOCK_SIZE); | ||
|
|
||
| extractRealDraftTokensKernel<<<NUM_BLOCKS, BLOCK_SIZE, 0, stream>>>(params.curDraftIdx, params.batchSize, | ||
| params.maxDraftLen, params.maxTotalDraftTokens, params.maxTopK, params.numTokensExpandThisLayer, | ||
| params.tokensGatherIdxForDrafterModel, params.topKList, params.draftTokensIndicesCumsum, params.newDraftTokens, | ||
| params.draftTokensBuffer); | ||
|
|
||
| sync_check_cuda_error(stream); | ||
| } | ||
|
|
||
| } // namespace kernels | ||
| } // namespace tensorrt_llm |
54 changes: 54 additions & 0 deletions
54
cpp/tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /* | ||
| * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. | ||
| * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
||
| #include "tensorrt_llm/common/assert.h" | ||
| #include "tensorrt_llm/common/cudaUtils.h" | ||
| #include "tensorrt_llm/runtime/common.h" | ||
|
|
||
| namespace tensorrt_llm | ||
| { | ||
| // namespace tensorrt_llm::kernels | ||
| namespace kernels | ||
| { | ||
|
|
||
| //////////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
| // Relaxed acceptance | ||
| struct ExtractRealDraftTokensParam | ||
| { | ||
| int curDraftIdx; | ||
| int batchSize; | ||
| int maxDraftLen; | ||
| int maxTotalDraftTokens; | ||
| int maxTopK; | ||
| int numTokensExpandThisLayer; | ||
| int* tokensGatherIdxForDrafterModel; | ||
| int* topKList; | ||
| int* draftTokensIndicesCumsum; | ||
| int64_t* newDraftTokens; | ||
| int64_t* draftTokensBuffer; | ||
| }; | ||
|
|
||
| void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream); | ||
|
|
||
| } // namespace kernels | ||
|
|
||
| } // namespace tensorrt_llm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.