diff --git a/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu new file mode 100644 index 00000000000..eb3b958eb2d --- /dev/null +++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moeTopKFuncs.cuh" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/archCondition.h" +#include "tensorrt_llm/kernels/customMoeRoutingKernels.h" +#include // For INT_MAX +#include +#include +#include +#include // For numeric_limits +#include + +namespace cg = cooperative_groups; +using namespace tensorrt_llm::common; + +namespace tensorrt_llm::kernels +{ + +static constexpr int BLOCK_SIZE = 1024; +static constexpr int WARP_SIZE = 32; +static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ T calcSoftmax(cg::thread_block_tile const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) +{ + T maxScore = T{-INFINITY}; + if (laneIdx < NumTopExperts) + { + maxScore = score >= maxScore ? score : maxScore; + } + maxScore = cg::reduce(warp, maxScore, cg::greater()); + + float sumScore{0.f}; + float newScore; + // Get the summation of scores for each token + if (laneIdx < NumTopExperts) + { + newScore = static_cast(score) - static_cast(maxScore); + newScore = static_cast(exp(newScore)); + sumScore += newScore; + } + sumScore = cg::reduce(warp, sumScore, cg::plus()); + + if (laneIdx < NumTopExperts) + { + score = static_cast(newScore / sumScore); + } + + return score; +} + +template +__device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize]) +{ + DataType maxScore = DataType{-INFINITY}; + DataType sumScore = DataType{0.f}; + + // Get the max score for each token +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + maxScore = scores[i] >= maxScore ? scores[i] : maxScore; + } + maxScore = cg::reduce(warp, maxScore, cg::greater()); + + // Get the summation of scores for each token +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + scores[i] = static_cast(exp(scores[i] - maxScore)); + sumScore += scores[i]; + } + sumScore = cg::reduce(warp, sumScore, cg::plus()); + + // Normalize the scores +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + scores[i] = static_cast(scores[i] / sumScore); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, + int32_t const numTokens, int32_t const numExperts, int32_t const topK) +{ + using BaseType = std::conditional_t; + uint32_t const blockRank = blockIdx.x; + uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; + uint32_t const warpIdx = tIdx / WARP_SIZE; + uint32_t const laneIdx = tIdx % WARP_SIZE; + uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + BaseType minScore = BaseType{-INFINITY}; + for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) + { + auto scoreOffset = tokenId * numExperts; + auto outputOffset = tokenId * topK; + + BaseType inputScore[MaxNumExperts / WARP_SIZE]; + IdxT inputIndex[MaxNumExperts / WARP_SIZE]; + + BaseType warpTopKScore[MaxNumTopExperts]; + IdxT warpTopKExpertIdx[MaxNumTopExperts]; + + // Load scores and indices for this warp + for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) + { + auto expertIdx = i * WARP_SIZE + laneIdx; + inputScore[i] + = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore; + inputIndex[i] = expertIdx; + } + + if constexpr (DoSoftmaxBeforeTopK) + { + calcSoftmax(warp, inputScore); + } + // Reduce topK scores and indices for this warp + reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); + + // Normalize the scores + if constexpr (DoSoftmaxBeforeTopK) + { + if (laneIdx < topK) + { + topkValues[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]); + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } + else + { + auto softmaxScore = calcSoftmax(warp, + laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx, + topK); + if (laneIdx < topK) + { + topkValues[outputOffset + laneIdx] = static_cast(softmaxScore); + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } + } // end for tokenId +} + +int nextPowerOfTwo(int num) +{ + if (num <= 0) + { + return 1; // Handle invalid input + } + int power = 1; + while (power < num) + { + // Check for overflow before shifting + if (power > INT_MAX / 2) + { + return power; + } + power <<= 1; + } + return power; +} + +#define CASE(MAX_NUM_EXPERTS) \ + case MAX_NUM_EXPERTS: \ + switch (maxNumTopExperts) \ + { \ + case 1: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 2: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 4: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 8: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + default: kernelInstance = nullptr; break; \ + } \ + break; + +template +void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, + int64_t const numExperts, int64_t const topK, cudaStream_t const stream) +{ + + const uint32_t maxNumBlocks = 1024; + const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); + + uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); + uint32_t maxNumTopExperts = nextPowerOfTwo(topK); + + auto* kernelInstance = &customMoeRoutingKernel; + + switch (maxNumExperts) + { + CASE(32) + CASE(64) + CASE(96) + CASE(128) + default: kernelInstance = nullptr; break; + } + + if (kernelInstance == nullptr) + { + TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); + } + + dim3 renormMoeRoutingGridDim(numBlocks); + dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); + cudaLaunchConfig_t config; + config.gridDim = renormMoeRoutingGridDim; + config.blockDim = renormMoeRoutingBlockDim; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens), + static_cast(numExperts), static_cast(topK)); + sync_check_cuda_error(stream); +} + +#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \ + template void invokeRenormMoeRouting(InputT * routerLogits, \ + OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \ + int64_t const topK, cudaStream_t const stream); + +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false); +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false); +#ifdef ENABLE_BF16 +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false); +#endif + +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true); +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true); +#ifdef ENABLE_BF16 +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true); +#endif + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h similarity index 86% rename from cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h rename to cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h index 1e9b001f658..cfe0ae8f15e 100644 --- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h +++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace tensorrt_llm::kernels { -template +template void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, cudaStream_t const stream); } // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh new file mode 100644 index 00000000000..933b599dbdd --- /dev/null +++ b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh @@ -0,0 +1,205 @@ + +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifndef TRTLLM_MOETOPKFUNCS_CUH_H +#define TRTLLM_MOETOPKFUNCS_CUH_H + +#include +#include +#include + +#include "tensorrt_llm/kernels/archCondition.h" + +namespace tensorrt_llm::kernels +{ + +namespace reduce_topk +{ +namespace cg = cooperative_groups; +static constexpr int kWARP_SIZE = 32; +static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>; + +template +struct TopKRedType +{ + using T = T_; + static_assert(std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v, + "Top K reduction only implemented for int, float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + + static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; + static constexpr int kMaxIdx = 65535; + TypeCmp compValIdx; + + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) + { + auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = reinterpret_cast(valueBits); + compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller indices. + return compactTmp; + } + + static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) + { + // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits + index = kMaxIdx - static_cast((cmp & 0xFFFF)); + + auto compactTmp = cmp >> kMoveBits; + auto valueBits + = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(T val, int32_t idx) + : compValIdx(makeCmpVal(val, idx)) + { + } + + __host__ __device__ operator TypeCmp() const noexcept + { + return compValIdx; + } + + __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) + { + if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) + { + return cg::reduce(warp, compValIdx, cg::greater{}); + } + else + { + TypeCmp result; + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); + return result; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKIdx +{ + // by default, empty +}; + +template +struct TopKIdx +{ + static constexpr int K = K_; + int32_t val[K]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ + topK[I].compValIdx = pairMax; \ + topK[J].compValIdx = pairMin; \ + } + +template +struct Sort; + +template +struct Sort<1, RedType> +{ + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<3, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +template +__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], + Type (&value)[N], int32_t (&idx)[N], Type minValue) +{ + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + if constexpr (!IsSorted) + { + Sort::run(topK); + } + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < K; ++kk) + { + bool update = kk > 0 && packedMax == topK[0].compValIdx; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +#undef TOPK_SWAP + +} // namespace reduce_topk +} // namespace tensorrt_llm::kernels +#endif // TRTLLM_MOETOPKFUNCS_CUH_H diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu deleted file mode 100644 index 1b4239e48c8..00000000000 --- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/kernels/archCondition.h" -#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h" -#include // For INT_MAX -#include -#include -#include -#include // For numeric_limits -#include - -namespace cg = cooperative_groups; -using namespace tensorrt_llm::common; - -namespace tensorrt_llm::kernels -{ - -static constexpr int BLOCK_SIZE = 1024; -static constexpr int WARP_SIZE = 32; -static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; - -namespace reduce_topk -{ - -static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>; - -template -struct TopKRedType -{ - using T = T_; - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "Top K reduction only implemented for float, float16 and bfloat16"); - - using TypeCmp = std::conditional_t; - using IdxT = std::conditional_t; - static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16; - static constexpr int maxIdx = 65535; - TypeCmp compValIdx; - - static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) - { - auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); - TypeCmp compactTmp = reinterpret_cast(valueBits); - compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); - // Use 65535 minus idx to give higher priority to elements with smaller indices. - return compactTmp; - } - - static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) - { - // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits - index = maxIdx - static_cast((cmp & 0xFFFF)); - - auto compactTmp = cmp >> moveBits; - auto valueBits - = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp)); - value = reinterpret_cast(valueBits); - } - - __host__ __device__ TopKRedType() = default; - - __host__ __device__ TopKRedType(T val, int32_t idx) - : compValIdx(makeCmpVal(val, idx)) - { - } - - __host__ __device__ operator TypeCmp() const noexcept - { - return compValIdx; - } - - __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) - { - if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) - { - return cg::reduce(warp, compValIdx, cg::greater{}); - } - else - { - TypeCmp result; - asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); - return result; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKIdx -{ - // by default, empty -}; - -template -struct TopKIdx -{ - static constexpr int K = K_; - int32_t val[K]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ - auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ - topK[I].compValIdx = pairMax; \ - topK[J].compValIdx = pairMin; \ - } - -template -struct Sort; - -template -struct Sort<1, RedType> -{ - static __device__ void run(RedType* topK) {} -}; - -template -struct Sort<2, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 1); - } -}; - -template -struct Sort<3, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 1); - TOPK_SWAP(1, 2); - TOPK_SWAP(0, 1); - } -}; - -template -struct Sort<4, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 2); - TOPK_SWAP(1, 3); - TOPK_SWAP(0, 1); - TOPK_SWAP(2, 3); - TOPK_SWAP(1, 2); - } -}; - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); - using RedType = TopKRedType; - RedType topK[N]; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = RedType{value[nn], idx[nn]}; - } - - if constexpr (!IsSorted) - { - Sort::run(topK); - } - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - bool update = kk > 0 && packedMax == topK[0].compValIdx; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; - } - // get the next largest value - packedMax = topK[0].reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -#undef TOPK_SWAP - -} // end of namespace reduce_topk - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ T calcSoftmax(cg::thread_block_tile const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) -{ - T maxScore = T{-INFINITY}; - if (laneIdx < NumTopExperts) - { - maxScore = score >= maxScore ? score : maxScore; - } - maxScore = cg::reduce(warp, maxScore, cg::greater()); - - float sumScore = float{0.f}; - float newScore; - // Get the summation of scores for each token - if (laneIdx < NumTopExperts) - { - newScore = static_cast(score) - static_cast(maxScore); - newScore = static_cast(exp(newScore)); - sumScore += newScore; - } - sumScore = cg::reduce(warp, sumScore, cg::plus()); - - if (laneIdx < NumTopExperts) - { - score = static_cast(newScore / sumScore); - } - - return score; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, - int32_t const numTokens, int32_t const numExperts, int32_t const topK) -{ - - uint32_t const blockRank = blockIdx.x; - uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; - uint32_t const warpIdx = tIdx / WARP_SIZE; - uint32_t const laneIdx = tIdx % WARP_SIZE; - uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - InputT minScore = InputT{-INFINITY}; - for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) - { - auto scoreOffset = tokenId * numExperts; - auto outputOffset = tokenId * topK; - InputT inputScore[MaxNumExperts / WARP_SIZE]; - IdxT inputIndex[MaxNumExperts / WARP_SIZE]; - - InputT warpTopKScore[MaxNumTopExperts]; - IdxT warpTopKExpertIdx[MaxNumTopExperts]; - - // Load scores and indices for this warp - for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) - { - auto expertIdx = i * WARP_SIZE + laneIdx; - inputScore[i] - = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore; - inputIndex[i] = expertIdx; - } - - // Reduce topK scores and indices for this warp - reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); - - // Perform softmax on topK scores - auto score = calcSoftmax(warp, - laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx, topK); - if (laneIdx < topK) - { - topkValues[outputOffset + laneIdx] = static_cast(score); - topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; - } - } // end for tokenId -} - -int nextPowerOfTwo(int num) -{ - if (num <= 0) - { - return 1; // Handle invalid input - } - int power = 1; - while (power < num) - { - // Check for overflow before shifting - if (power > INT_MAX / 2) - { - return power; - } - power <<= 1; - } - return power; -} - -#define CASE(MAX_NUM_EXPERTS) \ - case MAX_NUM_EXPERTS: \ - switch (maxNumTopExperts) \ - { \ - case 1: kernelInstance = &renormMoeRoutingKernel; break; \ - case 2: kernelInstance = &renormMoeRoutingKernel; break; \ - case 4: kernelInstance = &renormMoeRoutingKernel; break; \ - case 8: kernelInstance = &renormMoeRoutingKernel; break; \ - default: kernelInstance = nullptr; break; \ - } \ - break; - -template -void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, - int64_t const numExperts, int64_t const topK, cudaStream_t const stream) -{ - - const uint32_t maxNumBlocks = 1024; - const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); - - uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); - uint32_t maxNumTopExperts = nextPowerOfTwo(topK); - - auto* kernelInstance = &renormMoeRoutingKernel; - - switch (maxNumExperts) - { - CASE(32) - CASE(64) - CASE(96) - CASE(128) - default: kernelInstance = nullptr; break; - } - - if (kernelInstance == nullptr) - { - TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); - } - - dim3 renormMoeRoutingGridDim(numBlocks); - dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); - cudaLaunchConfig_t config; - config.gridDim = renormMoeRoutingGridDim; - config.blockDim = renormMoeRoutingBlockDim; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens), - static_cast(numExperts), static_cast(topK)); - sync_check_cuda_error(stream); -} - -#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \ - template void invokeRenormMoeRouting(InputT * routerLogits, OutputT * topkValues, \ - IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \ - cudaStream_t const stream); - -INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t); -INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t); -#ifdef ENABLE_BF16 -INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t); -#endif - -} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/topkLastDim.cu b/cpp/tensorrt_llm/kernels/topkLastDim.cu index 74e2838822c..e6e4e82c92b 100644 --- a/cpp/tensorrt_llm/kernels/topkLastDim.cu +++ b/cpp/tensorrt_llm/kernels/topkLastDim.cu @@ -22,9 +22,17 @@ */ #include +#include "moeTopKFuncs.cuh" #include "topkLastDim.h" +#include +#include #include #include +#include +#include +#include +#include +#include namespace tensorrt_llm { @@ -201,12 +209,12 @@ __host__ __device__ IdxT calc_buf_len(IdxT len) * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ -template -__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f) +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, IdxT len, Func f) { if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = thread_rank; i < len; i += num_threads) + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } @@ -231,12 +239,12 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - for (idxT i = thread_rank; i < len_cast; i += num_threads) + for (IdxT i = thread_rank; i < len_cast; i += num_threads) { wide.scalar = in_cast[i]; - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { @@ -256,7 +264,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con // and so // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE // no need to use loop - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); @@ -265,14 +273,14 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con } // sync_width should >= WARP_SIZE -template -__device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width) +template +__device__ void vectorized_process(T const* in, IdxT len, Func f, int sync_width) { - const idxT stride = blockDim.x * gridDim.x; - const idxT tid = blockIdx.x * blockDim.x + threadIdx.x; + const IdxT stride = blockDim.x * gridDim.x; + const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = tid; i < len; i += stride) + for (IdxT i = tid; i < len; i += stride) { f(in[i], i, true); } @@ -296,17 +304,17 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (idxT i = tid; i < len_cast_for_sync; i += stride) + const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (IdxT i = tid; i < len_cast_for_sync; i += stride) { bool valid = i < len_cast; if (valid) { wide.scalar = in_cast[i]; } - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { @@ -323,7 +331,7 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width T value = valid ? in[tid] : T(); f(value, tid, valid); - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; valid = remain_i < len; value = valid ? in[remain_i] : T(); f(value, remain_i, valid); @@ -1164,6 +1172,77 @@ __global__ void radix_topk_one_block_kernel(T const* in, IdxT const* in_idx, con } // namespace air_topk_stable //} +namespace moe_topk +{ +namespace cg = cooperative_groups; +static constexpr int kBLOCK_SIZE = 1024; +static constexpr int kWARP_SIZE = 32; +static constexpr int kWARPS_PER_BLOCK = kBLOCK_SIZE / kWARP_SIZE; + +template +__device__ T negativeInfinity() +{ + return -INFINITY; +} + +template <> +__device__ half negativeInfinity() +{ + return -CUDART_INF_FP16; +} + +template <> +__device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>() +{ + return -CUDART_INF_BF16; +} + +/****************TopK kernel for candidate number<= 128 and K <= 8 **************** */ +template +__global__ void moe_topk_kernel( + InputT const* in, OutputT* out, IdxT* outIdx, int32_t const batchSize, int32_t const len, int32_t const topK) +{ + + uint32_t const blockRank = blockIdx.x; + uint32_t const tIdx = kBLOCK_SIZE * blockRank + threadIdx.x; + uint32_t const warpIdx = tIdx / kWARP_SIZE; + uint32_t const laneIdx = tIdx % kWARP_SIZE; + uint32_t const warpNum = gridDim.x * kWARPS_PER_BLOCK; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + InputT minScore = negativeInfinity(); + + for (uint32_t tokenId = warpIdx; tokenId < batchSize; tokenId += warpNum) + { + auto scoreOffset = tokenId * len; + auto outputOffset = tokenId * topK; + InputT inputScore[MaxLen / kWARP_SIZE]; + IdxT inputIndex[MaxLen / kWARP_SIZE]; + + InputT warpTopKScore[MaxTopK]; + IdxT warpTopKExpertIdx[MaxTopK]; + + // Load scores and indices for this warp + for (uint32_t i = 0; i < MaxLen / kWARP_SIZE; ++i) + { + auto expertIdx = i * kWARP_SIZE + laneIdx; + inputScore[i] = expertIdx < len ? static_cast(in[scoreOffset + expertIdx]) : minScore; + inputIndex[i] = expertIdx; + } + + // Reduce topK scores and indices for this warp + tensorrt_llm::kernels::reduce_topk::reduceTopK( + warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); + + if (laneIdx < topK) + { + out[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]); + outIdx[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } // end for tokenId +} +} // namespace moe_topk /***************Runtime API****************/ @@ -1221,9 +1300,11 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx IdxT* sort_in_idx = nullptr; air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + + thrust::counting_iterator counting_iter(0); + thrust::transform_iterator, thrust::counting_iterator> transform_iter( + counting_iter, computeoffset); + cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) @@ -1348,9 +1429,9 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons const IdxT buf_len = air_topk_stable::calc_buf_len(len); air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + thrust::counting_iterator counting_iter(0); + thrust::transform_iterator, thrust::counting_iterator> transform_iter( + counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); @@ -1421,36 +1502,120 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons } } -template -void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, idxT len, idxT k, T* out, - idxT* out_idx, bool greater, cudaStream_t stream = 0) +template +void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, IdxT len, IdxT k, T* out, + IdxT* out_idx, bool greater, cudaStream_t stream = 0) { constexpr int items_per_thread = 32; constexpr int block_dim = 512; constexpr bool fused_last_filter = false; if (len <= block_dim * items_per_thread) { - standalone_stable_radix_topk_one_block_( - buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); + standalone_stable_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { int sm_cnt = tensorrt_llm::common::getMultiProcessorCount(); - unsigned grid_dim = air_topk_stable::calc_grid_dim(batch_size, len, sm_cnt); + unsigned grid_dim = air_topk_stable::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - standalone_stable_radix_topk_one_block_(buf, buf_size, in, - static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); + standalone_stable_radix_topk_one_block_(buf, buf_size, in, + static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { - standalone_stable_radix_topk_(buf, buf_size, in, static_cast(nullptr), + standalone_stable_radix_topk_(buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, fused_last_filter, grid_dim, stream, sorted); } } } +int nextPowerOfTwo(int num) +{ + if (num <= 0) + { + return 1; // Handle invalid input + } + int power = 1; + while (power < num) + { + // Check for overflow before shifting + if (power > INT_MAX / 2) + { + return power; + } + power <<= 1; + } + return power; +} + +template +void moe_reduce_topk( + T const* in, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool greater, cudaStream_t stream = 0) +{ + using InputT = T; + using OutputT = T; + const uint32_t max_num_blocks = 1024; + const uint32_t num_blocks + = std::min(static_cast((batch_size - 1) / moe_topk::kWARPS_PER_BLOCK + 1), max_num_blocks); + + uint32_t max_len = nextPowerOfTwo(len) < 32 ? 32 : nextPowerOfTwo(len); + uint32_t moe_topk = nextPowerOfTwo(k); + + auto* kernel_instance = &moe_topk::moe_topk_kernel; + + switch (max_len) + { + case 32: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 64: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 96: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 128: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + default: kernel_instance = nullptr; break; + } + + dim3 moe_topk_grid_dim(num_blocks); + dim3 moe_topk_block_dim(moe_topk::kBLOCK_SIZE); + + kernel_instance<<>>(in, out, out_idx, batch_size, len, k); +} #endif /////////////// @@ -1459,22 +1624,22 @@ template size_t invokeComputeTopkLastDimWorkspaceSize( SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest) { - using idxT = SizeType32; + using IdxT = SizeType32; size_t buf_size = 0; void* workspace = nullptr; T const* in = nullptr; T* out_val = nullptr; - idxT* out_idx = nullptr; + IdxT* out_idx = nullptr; constexpr int block_dim = 512; constexpr bool fused_last_filter = false; constexpr bool sorted = true; int sm_cnt = tensorrt_llm::common::getMultiProcessorCount(); - unsigned grid_dim = air_topk_stable::calc_grid_dim(batchSize, inputLength, sm_cnt); + unsigned grid_dim = air_topk_stable::calc_grid_dim(batchSize, inputLength, sm_cnt); - standalone_stable_radix_topk_(workspace, buf_size, in, static_cast(nullptr), + standalone_stable_radix_topk_(workspace, buf_size, in, static_cast(nullptr), batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted); return buf_size; } @@ -1504,8 +1669,17 @@ void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32 T const* in = reinterpret_cast(input); T* out_val_ = reinterpret_cast(out_val); SizeType32* out_idx_ = reinterpret_cast(out_idx); - standalone_stable_radix_11bits( - workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream); + if (inputLength <= 128 && k <= 8 && is_largest == true) + { + // This method does not require a buffer, but since the implementation may vary in different cases, + // we still allocate the buffer in case AIR TopK is used instead. + moe_reduce_topk(in, batchSize, inputLength, k, out_val_, out_idx_, !is_largest, stream); + } + else + { + standalone_stable_radix_11bits( + workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream); + } } #define INSTANTIATE_TOPK_LastDim_DATA_TYPE(T) \ diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh index 750658fad72..92d020fd19a 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh @@ -378,7 +378,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens // TODO: this is not sufficient to ensure visibility in the next kernel! -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); @@ -757,15 +757,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -// Trigger secondary kernel. -// Note: this does not guarantee the visibility of prior writes unless the consumer executes a -// dependency sync. -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 + // Trigger secondary kernel. + // Note: this does not guarantee the visibility of prior writes unless the consumer executes a + // dependency sync. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu index f1f60abdc22..5c398920390 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu @@ -227,13 +227,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // at this point, all values for offsets are ready, except the final offsets diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index a9d0d4009f9..5d9ea4e2626 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -81,7 +81,7 @@ add_library( reducescatterOp.cpp relativeAttentionBiasOp.cpp dsv3RouterGemmOp.cpp - renormMoeRoutingOp.cpp + customMoeRoutingOp.cpp selectiveScanOp.cpp userbuffersFinalizeOp.cpp userbuffersTensor.cpp diff --git a/cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp b/cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp similarity index 75% rename from cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp rename to cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp index 616cf3bb7ec..814fdf87c3b 100644 --- a/cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp +++ b/cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp @@ -15,7 +15,7 @@ */ #include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h" +#include "tensorrt_llm/kernels/customMoeRoutingKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" namespace th = torch; @@ -25,7 +25,8 @@ namespace tk = tensorrt_llm::kernels; namespace torch_ext { -std::tuple renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +template +std::tuple custom_moe_routing_op(th::Tensor const& router_logits, int64_t topk) { auto data_type = router_logits.scalar_type(); auto input_size = router_logits.sizes(); @@ -44,20 +45,22 @@ std::tuple renorm_moe_routing_op(th::Tensor const& route { case torch::kFloat32: // Handle Float32 - tk::invokeRenormMoeRouting(reinterpret_cast(router_logits.mutable_data_ptr()), + tk::invokeRenormMoeRouting( + reinterpret_cast(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; case torch::kBFloat16: // Handle BFloat16 - tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t>( + tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t, DoSoftmaxBeforeTopK>( reinterpret_cast<__nv_bfloat16*>(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; case torch::kHalf: // Handle Half - tk::invokeRenormMoeRouting(reinterpret_cast(router_logits.mutable_data_ptr()), + tk::invokeRenormMoeRouting( + reinterpret_cast(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; @@ -69,6 +72,15 @@ std::tuple renorm_moe_routing_op(th::Tensor const& route return {topk_indices, topk_values}; } +std::tuple renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +{ + return custom_moe_routing_op(router_logits, topk); +} + +std::tuple default_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +{ + return custom_moe_routing_op(router_logits, topk); +} } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -82,3 +94,15 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("renorm_moe_routing_op", &torch_ext::renorm_moe_routing_op); } + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "default_moe_routing_op(Tensor router_logits, SymInt topk" + ") -> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("default_moe_routing_op", &torch_ext::default_moe_routing_op); +} diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 5e001d9a48c..b6bde4e5134 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -531,3 +531,11 @@ def _(router_logits, topk): return router_logits.new_empty( sz, dtype=torch.int32), router_logits.new_empty(sz, dtype=torch.float32) + + @torch.library.register_fake("trtllm::default_moe_routing_op") + def _(router_logits, topk): + num_tokens = router_logits.shape[0] + sz = (num_tokens, topk) + return router_logits.new_empty( + sz, dtype=torch.int32), router_logits.new_empty(sz, + dtype=torch.float32) diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 793240c2add..fc79372e857 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -51,18 +51,28 @@ def routing_method_type(self): class DefaultMoeRoutingMethod(BaseMoeRoutingMethod): - def __init__(self, top_k: int): + def __init__(self, top_k: int, force_enable_pytorch_op: bool = False): super().__init__() self.top_k = top_k + self.force_enable_pytorch_op = force_enable_pytorch_op - def apply(self, - router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def apply_pytorch( + self, router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): topk_values, topk_indices = torch.topk(torch.nn.functional.softmax( router_logits.float(), dim=-1), k=self.top_k, dim=-1) return topk_indices.to(torch.int32), topk_values + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + num_experts = router_logits.shape[-1] + if self.force_enable_pytorch_op or num_experts > 128 or self.top_k > 8: + return self.apply_pytorch(router_logits) + else: + return torch.ops.trtllm.default_moe_routing_op( + router_logits, self.top_k) + @property def routing_method_type(self): return RoutingMethodType.Default