diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp index de5f1f650dd..1bceceae806 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp @@ -193,7 +193,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim, - std::optional configIndex) + int32_t configIndex) const { BatchedGemmData gemmData; gemmData.mProblemDimensions.mNumBatches = numBatches; @@ -212,14 +212,8 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, auto const configs = bmm.getBatchedGemmConfigs(); - if (!configIndex.has_value()) - { - mSelectedConfigIndex - = getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim); - configIndex = mSelectedConfigIndex; - } + auto const& config = configs[configIndex]; - auto const& config = configs[configIndex.value()]; return bmm.getWorkspaceSizeInBytes(config, gemmData); } @@ -228,7 +222,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, - void* workspace, CUstream stream, int device, std::optional configIndex) + void* workspace, CUstream stream, int device, int32_t configIndex) { auto bmm = BatchedGemmInterface(); @@ -236,14 +230,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto auto const configs = bmm.getBatchedGemmConfigs(); - if (!configIndex.has_value()) - { - TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set"); - - configIndex = mSelectedConfigIndex; - } - - auto const& config = configs[configIndex.value()]; + auto const& config = configs[configIndex]; TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0"); if (!mOptions.staticBatch) @@ -315,7 +302,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace, - CUstream stream, int device, std::optional configIndex) + CUstream stream, int device, int32_t configIndex) { // Dispatch with block scaling factors and with static batching. run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB, @@ -328,7 +315,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace, - CUstream stream, int device, std::optional configIndex) + CUstream stream, int device, int32_t configIndex) { // Dispatch with block scaling factors and with static batching. run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, @@ -415,5 +402,31 @@ int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_ return validConfigIndices[0]; } +bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, + int32_t maxNumCtasInBatchDim) const +{ + auto const bmm = BatchedGemmInterface(); + auto const configs = bmm.getBatchedGemmConfigs(); + + BatchedGemmData gemmData; + // Dims + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + + auto const& config = configs[configIndex]; + + return bmm.isValidConfig(config, gemmData); +} + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h index 7fe892511b6..6c87de22fd8 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h @@ -18,7 +18,6 @@ #include #include -#include #include #include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h" @@ -48,7 +47,7 @@ class TrtllmGenBatchedGemmRunner [[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim, - std::optional configIndex = std::nullopt); + int32_t configIndex) const; // Generic GEMM interface void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, @@ -56,17 +55,17 @@ class TrtllmGenBatchedGemmRunner void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, - void* workspace, CUstream stream, int device, std::optional configIndex = std::nullopt); + void* workspace, CUstream stream, int device, int32_t configIndex); // NVFP4 per-block scaling GEMM void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device, - std::optional configIndex = std::nullopt); + int32_t configIndex); // FP8 per-tensor scaling GEMM void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device, - std::optional configIndex = std::nullopt); + int32_t configIndex); // Get the list of configs that passed the validation based on the constructor options [[nodiscard]] std::vector getPassingConfigIndices() const @@ -85,6 +84,10 @@ class TrtllmGenBatchedGemmRunner std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim) const; + [[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, + int32_t maxNumCtasInBatchDim) const; + private: void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim); @@ -92,7 +95,6 @@ class TrtllmGenBatchedGemmRunner private: TrtllmGenBatchedGemmRunnerOptions mOptions; std::vector mPassingConfigIndices; - std::optional mSelectedConfigIndex; }; } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu index 04c494b3ecf..0020d3cbb6d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu @@ -235,22 +235,48 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* float* outputScalesScalar, float* outputScalesGateScalar, void* output, void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, - int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, int device, cudaStream_t stream) + int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, int device, cudaStream_t stream, + int32_t configIndex) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, - ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device); + ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex); } -size_t Runner::getWorkspaceSizeInBytes( - int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) +size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, + int32_t numTokens, int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getWorkspaceSizeInBytes( + numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex); +} + +int32_t Runner::getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const +{ + auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getDefaultValidConfigIndex( numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); } + +bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, + int32_t numExperts, int32_t numTokens) const +{ + auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + + auto const isValid = mRunner.isValidConfigIndex( + configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); + + return isValid; +} + +std::vector Runner::getPassingConfigIndices() const +{ + return mRunner.getPassingConfigIndices(); +} + } // namespace PermuteGemm1 namespace Gemm2 @@ -283,23 +309,49 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void float* outputScalesScalar, void* output, void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, - void* bmm2Workspace, int device, cudaStream_t stream) + void* bmm2Workspace, int device, cudaStream_t stream, int32_t configIndex) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run(numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, output, outputScale, /* permutedIdxToTokenIdx */ nullptr, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, - ptrNumNonExitingCtas, bmm2Workspace, stream, device); + ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex); } -size_t Runner::getWorkspaceSizeInBytes( - int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) +size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, + int32_t numTokens, int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getWorkspaceSizeInBytes( + numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex); +} + +int32_t Runner::getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const +{ + auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + return mRunner.getDefaultValidConfigIndex( numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); } + +bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, + int32_t numExperts, int32_t numTokens) const +{ + + auto const maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + + auto const isValid = mRunner.isValidConfigIndex( + configIndex, numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); + + return isValid; +} + +std::vector Runner::getPassingConfigIndices() const +{ + return mRunner.getPassingConfigIndices(); +} + } // namespace Gemm2 namespace MoE @@ -308,6 +360,22 @@ Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim) : mPermuteGemm1(PermuteGemm1::Runner(dtypeElt, useDeepSeekFp8, tileTokensDim)) , mGemm2(Gemm2::Runner(dtypeElt, btg::Dtype::Bfloat16, useDeepSeekFp8, tileTokensDim)) { + + auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); + auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices(); + + auto const totalPassingIndices = gemm1PassingIndices.size() * gemm2PassingIndices.size(); + mPassingConfigs.reserve(totalPassingIndices); + + for (auto const& indexGemm1 : gemm1PassingIndices) + { + for (auto const& indexGemm2 : gemm2PassingIndices) + { + mPassingConfigs.push_back(MoEConfig{indexGemm1, indexGemm2}); + } + } + + TLLM_CHECK_WITH_INFO(!mPassingConfigs.empty(), "No compatible configs found for the fp8 block scale MoE runner."); } void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, @@ -366,16 +434,48 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace } } -std::tuple Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args) +std::tuple Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args, int64_t configIndex) const { - auto workspace_size_fc1 = static_cast(mPermuteGemm1.getWorkspaceSizeInBytes( - args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens)); - auto workspace_size_fc2 = static_cast(mGemm2.getWorkspaceSizeInBytes( - args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens)); + auto const& config = mPassingConfigs[configIndex]; + + auto workspace_size_fc1 = static_cast(mPermuteGemm1.getWorkspaceSizeInBytes(args.top_k, args.hidden_size, + args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm1Config)); + auto workspace_size_fc2 = static_cast(mGemm2.getWorkspaceSizeInBytes(args.top_k, args.hidden_size, + args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm2Config)); return std::make_tuple(workspace_size_fc1, workspace_size_fc2); } -void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream) +std::vector Runner::getValidConfigIndices( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const +{ + std::vector validIndices; + + for (int i = 0; i < mPassingConfigs.size(); ++i) + { + auto const& config = mPassingConfigs[i]; + + if (mPermuteGemm1.isValidConfigIndex( + config.gemm1Config, topK, hiddenSize, intermediateSize, numLocalExperts, numTokens) + && mGemm2.isValidConfigIndex( + config.gemm2Config, topK, hiddenSize, intermediateSize, numLocalExperts, numTokens)) + { + validIndices.push_back(i); + } + } + + return validIndices; +} + +int64_t Runner::getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const +{ + auto const validIndices = getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens); + + return validIndices[0]; +} + +void Runner::run( + MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream, int64_t configIndex) { // Setup all operation data moe::dev::activation::Data activationData; @@ -386,12 +486,14 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d void* hidden_states_scale_linear{args.hidden_states_scale}; + auto const& config = mPassingConfigs[configIndex]; + mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale, workspace.expert_weights, args.output1_scales_scalar, args.output1_scales_gate_scalar, workspace.gemm1_output, workspace.gemm1_output_scale, args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens, workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, - workspace.bmm1_workspace, args.mUseRoutingScalesOnInput, device, stream); + workspace.bmm1_workspace, args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config); // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. void* gemm2_input = workspace.gemm1_output; @@ -409,7 +511,8 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar, workspace.gemm2_output, workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, - workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream); + workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, + config.gemm2Config); // Run finalize if (args.do_finalize) diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h index ce58765bee6..7799a706d0e 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h @@ -122,15 +122,23 @@ class Runner public: explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim); - size_t getWorkspaceSizeInBytes( - int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens); + size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, + int32_t numTokens, int32_t configIndex) const; + + [[nodiscard]] int32_t getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const; + + [[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const; + + [[nodiscard]] std::vector getPassingConfigIndices() const; void run(void* hiddenState, void* hiddenStateScale, void* weight, void* weightScale, void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, void* output, void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, - bool useRoutingScalesOnInput, int device, cudaStream_t stream); + bool useRoutingScalesOnInput, int device, cudaStream_t stream, int32_t configIndex); private: batchedGemm::trtllm::gen::Dtype mDtypeElt; @@ -147,14 +155,22 @@ class Runner explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype outputDtype, bool useDeepSeekFp8, int tileTokensDim); - size_t getWorkspaceSizeInBytes( - int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens); + size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, + int32_t numTokens, int32_t configIndex) const; + + [[nodiscard]] int32_t getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const; + + [[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, + int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const; + + [[nodiscard]] std::vector getPassingConfigIndices() const; void run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weight, void* weightScale, float* outputScalesScalar, void* output, void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, - void* bmm2Workspace, int device, cudaStream_t stream); + void* bmm2Workspace, int device, cudaStream_t stream, int32_t configIndex); private: batchedGemm::trtllm::gen::Dtype mDtypeElt; @@ -263,15 +279,30 @@ struct MoEWorkspace void* bmm2_workspace = nullptr; }; +// Config indices to be used with Batched GEMM runners +struct MoEConfig +{ + int64_t gemm1Config; + int64_t gemm2Config; +}; + class Runner { public: // FIXME: tileTokensDim is hardcoded for now Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8); - void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream); + void run( + MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream, int64_t configIndex); - std::tuple getWorkspaceSizeInBytes(MoERunnerArgs const& args); + [[nodiscard]] std::tuple getWorkspaceSizeInBytes( + MoERunnerArgs const& args, int64_t configIndex) const; + + [[nodiscard]] std::vector getValidConfigIndices( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const; + + [[nodiscard]] int64_t getDefaultValidConfigIndex( + int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const; private: void setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, moe::dev::convertsf::Data& convertSfData, @@ -280,6 +311,10 @@ class Runner private: PermuteGemm1::Runner mPermuteGemm1; Gemm2::Runner mGemm2; + + // This will be the cartesian product of the passing configs for gemm1 and gemm2 + // This allows us to autotune the MoE as one operation instead of tuning gemm1 and gemm2 separately + std::vector mPassingConfigs; }; } // namespace MoE diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index 25e37bb18d6..b9b232255d2 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -262,7 +262,12 @@ std::vector fp4_block_scale_moe_runner(torch::Tensor const& routi tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner moe_runner( args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim); - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args); + + auto const moeConfigIndex = moe_runner.getDefaultValidConfigIndex( + args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens); + + auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); + at::Tensor workspace_fc1 = at::detail::empty_cuda( {std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); at::Tensor workspace_fc2 = at::detail::empty_cuda( @@ -270,12 +275,13 @@ std::vector fp4_block_scale_moe_runner(torch::Tensor const& routi workspace.bmm1_workspace = workspace_fc1.data_ptr(); workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); if (!do_finalize) { return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; } + return {output}; } diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 229a1fd6307..476afa928e2 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -15,23 +15,28 @@ */ #include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h" -#include "tensorrt_llm/runtime/torchUtils.h" -#include "tensorrt_llm/thop/thUtils.h" + +#include +#include #include +#include + +#include namespace torch_ext { namespace btg = batchedGemm::trtllm::gen; using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; +using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; -torch::Tensor fp8_block_scale_moe_runner(torch::Tensor const& routing_logits, torch::Tensor const& routing_bias, - torch::Tensor const& hidden_states, torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights, - torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights, - torch::Tensor const& gemm2_weights_scale, int64_t const num_experts, int64_t const top_k, int64_t const n_group, - int64_t const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, double const routed_scaling_factor, int64_t const tile_tokens_dim, - int64_t const routing_method_type) +at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor const& routing_bias, + at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale, at::Tensor const& gemm1_weights, + at::Tensor const& gemm1_weights_scale, at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, + int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group, + int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, + double const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, + MoeRunnerType& moe_runner, int64_t moeConfigIndex) { auto const sm = tensorrt_llm::common::getSMVersion(); TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); @@ -201,9 +206,7 @@ torch::Tensor fp8_block_scale_moe_runner(torch::Tensor const& routing_logits, to args.output = output.data_ptr(); args.output_scale = nullptr; - tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner moe_runner( - args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim); - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args); + auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); at::Tensor workspace_fc1 = at::detail::empty_cuda( {std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); at::Tensor workspace_fc2 = at::detail::empty_cuda( @@ -212,36 +215,68 @@ torch::Tensor fp8_block_scale_moe_runner(torch::Tensor const& routing_logits, to workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); return output; } -} // namespace torch_ext -TORCH_LIBRARY_FRAGMENT(trtllm, m) +// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow +// use with the torch workflow autotuner class. +class FP8BlockScaleMoeRunner : public torch::CustomClassHolder { - m.def( - "fp8_block_scale_moe_runner(" - "Tensor routing_logits," - "Tensor routing_bias," - "Tensor hidden_states," - "Tensor hidden_states_scale," - "Tensor gemm1_weights," - "Tensor gemm1_weights_scale," - "Tensor gemm2_weights," - "Tensor gemm2_weights_scale," - "int num_experts," - "int top_k," - "int n_group," - "int topk_group," - "int intermediate_size," - "int local_expert_offset," - "int local_num_experts," - "float routed_scaling_factor," - "int tile_tokens_dim," - "int routing_method_type) -> Tensor"); -} -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +public: + explicit FP8BlockScaleMoeRunner(int64_t tileTokensDim) + : mTileTokensDim(tileTokensDim) + { + mRunner = std::make_unique(mDtypeElt, mUseDeepSeekFp8, mTileTokensDim); + } + + [[nodiscard]] std::vector getValidConfigs( + int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const + { + return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens); + } + + [[nodiscard]] at::Tensor run(at::Tensor const& routing_logits, at::Tensor const& routing_bias, + at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale, at::Tensor const& gemm1_weights, + at::Tensor const& gemm1_weights_scale, at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, + int64_t num_experts, int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, + int64_t routing_method_type, int64_t moeConfigIndex) + { + + // Autotuner has requested a default or 'fallback' config index + if (moeConfigIndex == -1) + { + auto const num_tokens = hidden_states.sizes()[0]; + auto const hidden_size = hidden_states.sizes()[1]; + + moeConfigIndex = mRunner->getDefaultValidConfigIndex( + top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + } + + return run_fp8_block_scale_moe(routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, + gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group, + intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, mTileTokensDim, + routing_method_type, *mRunner, moeConfigIndex); + } + +private: + using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; + + std::unique_ptr mRunner; + + btg::Dtype mDtypeElt{btg::Dtype::E4m3}; // FP8 runner so hard-coded + bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe + int64_t mTileTokensDim; +}; + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) { - m.impl("fp8_block_scale_moe_runner", &torch_ext::fp8_block_scale_moe_runner); + m.class_("FP8BlockScaleMoERunner") + .def(torch::init()) + .def("get_valid_configs", &torch_ext::FP8BlockScaleMoeRunner::getValidConfigs) + .def("run_moe", &torch_ext::FP8BlockScaleMoeRunner::run); } diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index caf9d243cc7..395c4320b2b 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -206,7 +206,11 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner moe_runner( args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim); - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args); + + auto const moeConfigIndex = moe_runner.getDefaultValidConfigIndex( + args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens); + + auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); at::Tensor workspace_fc1 = at::detail::empty_cuda( {std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); at::Tensor workspace_fc2 = at::detail::empty_cuda( @@ -214,7 +218,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit workspace.bmm1_workspace = workspace_fc1.data_ptr(); workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); return output; } } // namespace torch_ext diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 20f9c3b120c..090bcd42b2f 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -341,11 +341,7 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues. """ - # Treat None tensors as size zero - # This allows the tuner to handle TRT-LLM-Gen torch ops that have optional tensor - # arguments, such as block scaling factors. - input_shapes = tuple( - (t.shape if t is not None else torch.Size((0, ))) for t in inputs) + input_shapes = tuple(self._get_input_sizes(inputs)) # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: @@ -398,11 +394,7 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], time_measured = self._profile_single_kernel( r, tensors, tac, **kwargs) except Exception as e: - # Handle None tensors for optional inputs - shapes = [ - t.size() if t is not None else torch.Size((0, )) - for t in tensors - ] + shapes = self._get_input_sizes(tensors) logger.error( f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}" @@ -444,6 +436,16 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], return runners[runner_id], tactic + def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: + + # Handle None tensors for optional inputs and non-Tensor scalar values + sizes = [ + input.size() if isinstance(input, torch.Tensor) else torch.Size( + (0, )) for input in inputs + ] + + return sizes + def _profile_single_kernel(self, runner: TunableRunner, inputs: List[torch.Tensor], tactic: int, **kwargs) -> float: @@ -483,10 +485,7 @@ def _profile_single_kernel(self, runner: TunableRunner, avg_time = start.elapsed_time(end) / self.repeat - # Handle None tensors for optional inputs - shapes = [ - t.size() if t is not None else torch.Size((0, )) for t in inputs - ] + shapes = self._get_input_sizes(inputs) logger.debug( f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}" ) @@ -512,10 +511,10 @@ def _optimization_profiles( # every dimension created from the concrete input tensor shape # generate some dynamic dimension description based on the dynamic_tensors - # Zero handles the case where a TRTLLM op has optional inputs. + # Zero handles the case where a TRTLLM op has optional or scalar inputs. base_profile = OptimizationProfile( - [[StaticDim(x) - for x in t.size()] if t is not None else [StaticDim(0)] + [[StaticDim(x) for x in t.size()] + if isinstance(t, torch.Tensor) else [StaticDim(0)] for t in inputs]) generated_profiles: List[OptimizationProfile] = [] diff --git a/tensorrt_llm/_torch/custom_ops/__init__.py b/tensorrt_llm/_torch/custom_ops/__init__.py index 7b2a804fb89..8a81d1123a4 100644 --- a/tensorrt_llm/_torch/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/custom_ops/__init__.py @@ -1,12 +1,14 @@ from .cpp_custom_ops import _register_fake from .flashinfer_custom_ops import IS_FLASHINFER_AVAILABLE from .torch_custom_ops import bmm_out +from .trtllm_gen_custom_ops import fp8_block_scale_moe_runner from .userbuffers_custom_ops import add_to_ub, copy_to_userbuffers, matmul_to_ub __all__ = [ 'IS_FLASHINFER_AVAILABLE', '_register_fake', 'bmm_out', + 'fp8_block_scale_moe_runner', 'add_to_ub', 'copy_to_userbuffers', 'matmul_to_ub', diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py new file mode 100644 index 00000000000..b37c4e017f7 --- /dev/null +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Tuple + +import torch + +from tensorrt_llm._torch.utils import last_positive_power_of_2 + +from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, + OptimizationProfile, TunableRunner, TuningConfig) + + +@dataclass(frozen=True) +class FP8BlockScaleMoEInputs: + + routing_logits: torch.Tensor + routing_bias: torch.Tensor + hidden_states: torch.Tensor + hidden_states_scale: torch.Tensor + gemm1_weights: torch.Tensor + gemm1_weights_scale: torch.Tensor + gemm2_weights: torch.Tensor + gemm2_weights_scale: torch.Tensor + + +class FP8BlockScaleMoERunner(TunableRunner): + + runner_dict = dict() + tuning_config = None + + def __init__(self, num_experts: int, top_k: int, n_group: int, + topk_group: int, intermediate_size: int, + local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: float, tile_tokens_dim: int, + routing_method_type: int): + + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.intermediate_size = intermediate_size + self.local_expert_offset = local_expert_offset + self.local_num_experts = local_num_experts + self.routed_scaling_factor = routed_scaling_factor + self.tile_tokens_dim = tile_tokens_dim + self.routing_method_type = routing_method_type + + FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( + ) + + instance_key = ( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + ) + + if instance_key not in FP8BlockScaleMoERunner.runner_dict: + FP8BlockScaleMoERunner.runner_dict[ + instance_key] = torch.classes.trtllm.FP8BlockScaleMoERunner( + tile_tokens_dim) + + self.kernel_runner = FP8BlockScaleMoERunner.runner_dict[instance_key] + + # The hash is used by the autotuner to get the cache key, so we hash on members + # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing + # type does not matter + def __hash__(self): + return hash(( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + )) + + # __eq__ and __hash__ must agree + def __eq__(self, other): + if not isinstance(other, FP8BlockScaleMoERunner): + return False + + return (self.top_k == other.top_k + and self.intermediate_size == other.intermediate_size + and self.local_num_experts == other.local_num_experts + and self.tile_tokens_dim == other.tile_tokens_dim) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + + args = FP8BlockScaleMoEInputs(*inputs) + + return self.kernel_runner.run_moe( + args.routing_logits, args.routing_bias, args.hidden_states, + args.hidden_states_scale, args.gemm1_weights, + args.gemm1_weights_scale, args.gemm2_weights, + args.gemm2_weights_scale, self.num_experts, self.top_k, + self.n_group, self.topk_group, self.intermediate_size, + self.local_expert_offset, self.local_num_experts, + self.routed_scaling_factor, self.routing_method_type, tactic) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + + args = FP8BlockScaleMoEInputs(*inputs) + + num_tokens = args.hidden_states.shape[0] + hidden_size = args.hidden_states.shape[1] + + tactics = self.kernel_runner.get_valid_configs(self.top_k, hidden_size, + self.intermediate_size, + self.local_num_experts, + num_tokens) + + return tactics + + @classmethod + def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + HIDDEN_STATES_IDX = 2 + TUNED_DIM = 0 + + m_values = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096) + round_rule = lambda x: last_positive_power_of_2(x) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, + round_rule), ) + + return specs + + @classmethod + def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: + return () + + @classmethod + @lru_cache(maxsize=None) + def get_tuning_config(cls) -> TuningConfig: + + dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + constraint_specs = cls.get_constraint_specs() + + tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs) + + return tuning_config + + +@torch.library.custom_op("trtllm::fp8_block_scale_moe_runner", mutates_args=()) +def fp8_block_scale_moe_runner(routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, top_k: int, n_group: int, + topk_group: int, intermediate_size: int, + local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: float, + tile_tokens_dim: int, + routing_method_type: int) -> torch.Tensor: + + tuner = AutoTuner.get() + + kernel_runner = FP8BlockScaleMoERunner(num_experts, top_k, n_group, + topk_group, intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + tile_tokens_dim, routing_method_type) + + inputs = [ + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + ] + + _, best_tactic = tuner.choose_one( + "trtllm::fp8_block_scale_moe_runner", + [kernel_runner], + kernel_runner.tuning_config, + inputs, + ) + + return kernel_runner(inputs, tactic=best_tactic) diff --git a/tests/unittest/_torch/thop/test_moe.py b/tests/unittest/_torch/thop/test_moe.py index dcae98cac58..a4148b72313 100644 --- a/tests/unittest/_torch/thop/test_moe.py +++ b/tests/unittest/_torch/thop/test_moe.py @@ -23,6 +23,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import getSMVersion +from tensorrt_llm._torch.autotuner import autotune from tensorrt_llm._torch.modules.fused_moe import RoutingMethodType from tensorrt_llm.quantization.utils.fp4_utils import ( reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a) @@ -574,7 +575,11 @@ def quant_dequant_per_tensor_fp8(a): (72, 1, 1, 6), (256, 8, 4, 8)]) @pytest.mark.parametrize("hidden_size", [512]) @pytest.mark.parametrize("intermediate_size", [512]) -def test_moe_fp8(num_tokens, expert_info, hidden_size, intermediate_size): +@pytest.mark.parametrize("use_autotune", [True, False], + ids=["autotune", "no_autotune"]) +def test_moe_fp8(num_tokens, expert_info, hidden_size, intermediate_size, + use_autotune): + torch.random.manual_seed(0) # @@ -625,11 +630,12 @@ def test_moe_fp8(num_tokens, expert_info, hidden_size, intermediate_size): scores, gemm1_weights, gemm1_scales, None, gemm2_weights, gemm2_scales, None, permute_info, False) - output = torch.ops.trtllm.fp8_block_scale_moe_runner( - expert_logits, routing_bias, hidden_states, hidden_states_scale, - gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales, num_experts, - top_k, n_groups, top_k_groups, intermediate_size, 0, num_experts, - routed_scaling, tile_tokens_dim, routing_method_type) + with autotune(use_autotune): + output = torch.ops.trtllm.fp8_block_scale_moe_runner( + expert_logits, routing_bias, hidden_states, hidden_states_scale, + gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales, + num_experts, top_k, n_groups, top_k_groups, intermediate_size, 0, + num_experts, routed_scaling, tile_tokens_dim, routing_method_type) output_dequant_actual = output.to(torch.float) #