Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_runtime_api.h>

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#ifdef ENABLE_FP8
#include <cuda_fp8.h>
#endif

#include "draftTokenTreeKernels.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"

using namespace tensorrt_llm::common;

namespace tensorrt_llm
{
namespace kernels
{

__global__ void extractRealDraftTokensKernel(int const curDraftIdx, int const batchSize, int const maxDraftLen,
int const maxTotalDraftTokens, int const maxTopK, int const numTokensExpandThisLayer,
int* tokensGatherIdxForDrafterModel, int* topKList, int* draftTokensIndicesCumsum, int64_t* newDraftTokens,
int64_t* draftTokensBuffer)
{
// curDraftIdx: int
// batchSize: int
// maxTotalDraftTokens: int
// maxTopK: int
// tokensGatherIdxForDrafterModel: int32_t*, indices of the draft tokens that need to be expand this layer
// shape: [numTokensExpandThisLayer]
// topKList: int32_t*, top k value for each expandable token
// shape: [numTokensExpandThisLayer]
// draftTokensIndicesCumsum: int32_t*, the cumulative sum of the write back indices for each draft layer
// shape: [maxDraftLen + 1]
// newDraftTokens: int64_t*, the new draft tokens. We only need to extract this layer's tokens and write back to
// the draftTokensBuffer shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK]
// draftTokensBuffer: int64_t*, the buffer to store the real draft tokens
// shape: [maxBatchSize, maxTotalDraftTokens + 1]

// Each thread handles one request
auto const tix = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x);
auto const isValid{tix < batchSize};

if (isValid)
{
int newDraftTokensOffset = curDraftIdx == 0 ? 1 : maxTotalDraftTokens + 1;
auto newDraftTokensStartPtr = newDraftTokens + tix * newDraftTokensOffset * maxTopK;
auto draftTokensBufferPtr
= draftTokensBuffer + tix * (maxTotalDraftTokens + 1) + draftTokensIndicesCumsum[curDraftIdx];

int cnt = 0;
for (int i = 0; i < numTokensExpandThisLayer; i++)
{
int tokenGatherIdx = tokensGatherIdxForDrafterModel[i];
auto newDraftTokenPtr = newDraftTokensStartPtr + tokenGatherIdx * maxTopK;

int topKValue = topKList[i];
for (int j = 0; j < topKValue; j++)
{
int64_t newGenDraftToken = newDraftTokenPtr[j];
draftTokensBufferPtr[cnt] = newGenDraftToken;
cnt++;
}
}
}
}

void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream)
{
int constexpr BLOCK_SIZE = 64;
int NUM_BLOCKS = divUp(params.batchSize, BLOCK_SIZE);

extractRealDraftTokensKernel<<<NUM_BLOCKS, BLOCK_SIZE, 0, stream>>>(params.curDraftIdx, params.batchSize,
params.maxDraftLen, params.maxTotalDraftTokens, params.maxTopK, params.numTokensExpandThisLayer,
params.tokensGatherIdxForDrafterModel, params.topKList, params.draftTokensIndicesCumsum, params.newDraftTokens,
params.draftTokensBuffer);

sync_check_cuda_error(stream);
}

} // namespace kernels
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h"

namespace tensorrt_llm
{
// namespace tensorrt_llm::kernels
namespace kernels
{

////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Relaxed acceptance
struct ExtractRealDraftTokensParam
{
int curDraftIdx;
int batchSize;
int maxDraftLen;
int maxTotalDraftTokens;
int maxTopK;
int numTokensExpandThisLayer;
int* tokensGatherIdxForDrafterModel;
int* topKList;
int* draftTokensIndicesCumsum;
int64_t* newDraftTokens;
int64_t* draftTokensBuffer;
};

void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream);

} // namespace kernels

} // namespace tensorrt_llm
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ add_library(
virtualMemoryAllocator.cpp
weightOnlyQuantGemm.cpp
weightOnlyQuantOp.cpp
mtpOp.cpp
specDecOp.cpp
loraOp.cpp
finegrained_mixed_dtype_gemm_thop.cpp
tinygemm2.cpp
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,8 @@ class Runner : public RunnerBase
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
enqueue_params.spec_decoding_is_generation_length_variable = true;
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
TLLM_CHECK(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1]
enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1];
}

// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
* limitations under the License.
*/

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"

#include "tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"

namespace th = torch;
namespace tl = tensorrt_llm;
Expand Down Expand Up @@ -261,6 +262,78 @@ std::tuple<th::Tensor, th::Tensor> mtp_relaxed_acceptance_op(th::Tensor& reqSlot
return std::make_tuple(acceptedTokens, numAcceptedTokens);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
void extract_real_draft_tokens_op(th::Tensor newDraftTokens, th::Tensor draftTokensBuffer,
th::Tensor tokensGatherIdxForDrafterModel, th::Tensor topKList, th::Tensor draftTokensIndicesCumsum,
int64_t curDraftIdx, int64_t batchSize, int64_t maxDraftLen, int64_t maxTotalDraftTokens, int64_t maxTopK)
{
// Args:
// curDraftIdx: int
// batchSize: int
// maxTotalDraftTokens: int
// maxTopK: int
// tokensGatherIdxForDrafterModel: Tensor, int32, indices of the draft tokens that need to be expand this layer
// shape: [numTokensExpandThisLayer]
// topKList: Tensor, int32, top k value for each expandable token
// shape: [numTokensExpandThisLayer]
// draftTokensIndicesCumsum: Tensor, int32, the cumulative sum of the write back indices for each draft layer
// shape: [maxDraftLen + 1]
// newDraftTokens: Tensor, int64, the new draft tokens. We only need to extract this layer's tokens and write back
// to the draftTokensBuffer
// shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK]
// draftTokensBuffer: Tensor, int64, the buffer to store the real draft tokens
// shape: [maxBatchSize, maxTotalDraftTokens + 1]

// Check the data types
TLLM_CHECK(tokensGatherIdxForDrafterModel.scalar_type() == torch::kInt32);
TLLM_CHECK(topKList.scalar_type() == torch::kInt32);
TLLM_CHECK(draftTokensIndicesCumsum.scalar_type() == torch::kInt32);
TLLM_CHECK(newDraftTokens.scalar_type() == torch::kInt64);
TLLM_CHECK(draftTokensBuffer.scalar_type() == torch::kInt64);

// Check the shape of 'tokensGatherIdxForDrafterModel' and 'topKList'
auto numTokensExpandThisLayer = tokensGatherIdxForDrafterModel.size(0);
TLLM_CHECK(numTokensExpandThisLayer > 0);
TLLM_CHECK(topKList.size(0) == numTokensExpandThisLayer);

// Check the shape of 'draftTokensIndicesCumsum'
TLLM_CHECK(draftTokensIndicesCumsum.size(0) == maxDraftLen + 1);

// Check the shape of 'newDraftTokens'
TLLM_CHECK(newDraftTokens.size(0) == batchSize);
if (curDraftIdx == 0)
{
TLLM_CHECK(newDraftTokens.size(1) == 1);
TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
}
else
{
TLLM_CHECK(newDraftTokens.size(1) == maxTotalDraftTokens + 1);
TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
}

// Check the shape of 'draftTokensBuffer'
TLLM_CHECK(draftTokensBuffer.size(1) == maxTotalDraftTokens + 1);

auto stream = at::cuda::getCurrentCUDAStream(newDraftTokens.get_device());

// Fill params
tk::ExtractRealDraftTokensParam params;
params.curDraftIdx = curDraftIdx;
params.batchSize = batchSize;
params.maxDraftLen = maxDraftLen;
params.maxTotalDraftTokens = maxTotalDraftTokens;
params.maxTopK = maxTopK;
params.numTokensExpandThisLayer = numTokensExpandThisLayer;
params.tokensGatherIdxForDrafterModel = reinterpret_cast<int32_t*>(tokensGatherIdxForDrafterModel.data_ptr());
params.topKList = reinterpret_cast<int32_t*>(topKList.data_ptr());
params.draftTokensIndicesCumsum = reinterpret_cast<int32_t*>(draftTokensIndicesCumsum.data_ptr());
params.newDraftTokens = reinterpret_cast<int64_t*>(newDraftTokens.data_ptr());
params.draftTokensBuffer = reinterpret_cast<int64_t*>(draftTokensBuffer.data_ptr());

tk::invokeExtractRealDraftTokens(params, stream);
}

} // end namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
Expand Down Expand Up @@ -323,3 +396,18 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("mtp_relaxed_acceptance_op", &torch_ext::mtp_relaxed_acceptance_op);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"extract_real_draft_tokens_op(Tensor newDraftTokens, Tensor draftTokensBuffer, "
"Tensor tokensGatherIdxForDrafterModel, Tensor topKList, Tensor draftTokensIndicesCumsum, "
"int curDraftIdx, int batchSize, int maxDraftLen, int maxTotalDraftTokens, int maxTopK) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("extract_real_draft_tokens_op", &torch_ext::extract_real_draft_tokens_op);
}
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager

from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
RotaryScalingType)
Expand Down Expand Up @@ -338,10 +340,15 @@ def restore_from_spec_dec(self) -> None:

def update_spec_dec_param(
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
max_draft_len,
max_total_draft_tokens,
model_is_wrapped: bool = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
"""
Hook to be called when using TRTLLM attention backend in spec-dec mode.
Expand Down
17 changes: 12 additions & 5 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,17 +477,24 @@ def create_expanded_buffers(self, capture_graph=False):
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
def update_spec_dec_param(
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
max_draft_len,
max_total_draft_tokens,
model_is_wrapped: bool = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):
super().update_spec_dec_param(is_spec_decoding_enabled,
super().update_spec_dec_param(batch_size, is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens, spec_decoding_tensor)
self.max_draft_tokens = max_draft_tokens
is_spec_dec_dynamic_tree, max_draft_len,
max_total_draft_tokens, model_is_wrapped,
spec_metadata, spec_tree_manager,
spec_decoding_tensor)
self.max_draft_tokens = max_draft_len
init_shape = self.kv_lens_expanded_host.shape[0]
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
capture_graph = torch.cuda.is_current_stream_capturing()
Expand Down
Loading