Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne

size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex)
int32_t configIndex) const
{
BatchedGemmData gemmData;
gemmData.mProblemDimensions.mNumBatches = numBatches;
Expand All @@ -212,14 +212,8 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,

auto const configs = bmm.getBatchedGemmConfigs();

if (!configIndex.has_value())
{
mSelectedConfigIndex
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
configIndex = mSelectedConfigIndex;
}
auto const& config = configs[configIndex];

auto const& config = configs[configIndex.value()];
return bmm.getWorkspaceSizeInBytes(config, gemmData);
}

Expand All @@ -228,22 +222,15 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
void* workspace, CUstream stream, int device, int32_t configIndex)
{
auto bmm = BatchedGemmInterface();

BatchedGemmData gemmData;

auto const configs = bmm.getBatchedGemmConfigs();

if (!configIndex.has_value())
{
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");

configIndex = mSelectedConfigIndex;
}

auto const& config = configs[configIndex.value()];
auto const& config = configs[configIndex];

TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
if (!mOptions.staticBatch)
Expand Down Expand Up @@ -315,7 +302,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto

void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
CUstream stream, int device, std::optional<int32_t> configIndex)
CUstream stream, int device, int32_t configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
Expand All @@ -328,7 +315,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto

void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
CUstream stream, int device, std::optional<int32_t> configIndex)
CUstream stream, int device, int32_t configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
Expand Down Expand Up @@ -415,5 +402,31 @@ int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_
return validConfigIndices[0];
}

bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const
{
auto const bmm = BatchedGemmInterface();
auto const configs = bmm.getBatchedGemmConfigs();

BatchedGemmData gemmData;
// Dims
gemmData.mProblemDimensions.mNumBatches = numBatches;
gemmData.mProblemDimensions.mNumTokens = numTokens;
gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput;
gemmData.mProblemDimensions.mBatchedM = mOptions.transposeMmaOutput ? std::vector<int32_t>{} : batchedTokens;
gemmData.mProblemDimensions.mBatchedN = mOptions.transposeMmaOutput ? batchedTokens : std::vector<int32_t>{};
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;

auto const& config = configs[configIndex];

return bmm.isValidConfig(config, gemmData);
}

} // namespace kernels
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include <cstdint>
#include <cuda.h>
#include <optional>
#include <vector>

#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
Expand Down Expand Up @@ -48,25 +47,25 @@ class TrtllmGenBatchedGemmRunner

[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex = std::nullopt);
int32_t configIndex) const;

// Generic GEMM interface
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
void* workspace, CUstream stream, int device, int32_t configIndex);

// NVFP4 per-block scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);
int32_t configIndex);

// FP8 per-tensor scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);
int32_t configIndex);

// Get the list of configs that passed the validation based on the constructor options
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const
Expand All @@ -85,14 +84,17 @@ class TrtllmGenBatchedGemmRunner
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;

[[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;

private:
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim);

private:
TrtllmGenBatchedGemmRunnerOptions mOptions;
std::vector<int64_t> mPassingConfigIndices;
std::optional<int64_t> mSelectedConfigIndex;
};
} // namespace kernels
} // namespace tensorrt_llm
135 changes: 119 additions & 16 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,48 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void*
float* outputScalesScalar, float* outputScalesGateScalar, void* output, void* outputScale, int32_t topK,
int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* permutedIdxToTokenIdx,
int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx,
int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, int device, cudaStream_t stream)
int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, int device, cudaStream_t stream,
int32_t configIndex)
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim,
hiddenState, hiddenStateScale, weights, weightsScale, expertWeights, /* perTokensSfB */ nullptr,
outputScalesScalar, outputScalesGateScalar, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens,
ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device);
ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex);
}

size_t Runner::getWorkspaceSizeInBytes(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens)
size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts,
int32_t numTokens, int32_t configIndex) const
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
return mRunner.getWorkspaceSizeInBytes(
numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex);
}

int32_t Runner::getDefaultValidConfigIndex(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
return mRunner.getDefaultValidConfigIndex(
numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim);
}

bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
int32_t numExperts, int32_t numTokens) const
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);

auto const isValid = mRunner.isValidConfigIndex(
configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim);

return isValid;
}

std::vector<int64_t> Runner::getPassingConfigIndices() const
{
return mRunner.getPassingConfigIndices();
}

} // namespace PermuteGemm1

namespace Gemm2
Expand Down Expand Up @@ -283,23 +309,49 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void
float* outputScalesScalar, void* output, void* outputScale, int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas,
int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit,
void* bmm2Workspace, int device, cudaStream_t stream)
void* bmm2Workspace, int device, cudaStream_t stream, int32_t configIndex)
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
mRunner.run(numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim,
permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr,
/* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, output, outputScale,
/* permutedIdxToTokenIdx */ nullptr, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit,
ptrNumNonExitingCtas, bmm2Workspace, stream, device);
ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex);
}

size_t Runner::getWorkspaceSizeInBytes(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens)
size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts,
int32_t numTokens, int32_t configIndex) const
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
return mRunner.getWorkspaceSizeInBytes(
numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex);
}

int32_t Runner::getDefaultValidConfigIndex(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const
{
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
return mRunner.getDefaultValidConfigIndex(
numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim);
}

bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
int32_t numExperts, int32_t numTokens) const
{

auto const maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);

auto const isValid = mRunner.isValidConfigIndex(
configIndex, numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim);

return isValid;
}

std::vector<int64_t> Runner::getPassingConfigIndices() const
{
return mRunner.getPassingConfigIndices();
}

} // namespace Gemm2

namespace MoE
Expand All @@ -308,6 +360,22 @@ Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim)
: mPermuteGemm1(PermuteGemm1::Runner(dtypeElt, useDeepSeekFp8, tileTokensDim))
, mGemm2(Gemm2::Runner(dtypeElt, btg::Dtype::Bfloat16, useDeepSeekFp8, tileTokensDim))
{

auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices();
auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices();

auto const totalPassingIndices = gemm1PassingIndices.size() * gemm2PassingIndices.size();
mPassingConfigs.reserve(totalPassingIndices);

for (auto const& indexGemm1 : gemm1PassingIndices)
{
for (auto const& indexGemm2 : gemm2PassingIndices)
{
mPassingConfigs.push_back(MoEConfig{indexGemm1, indexGemm2});
}
}

TLLM_CHECK_WITH_INFO(!mPassingConfigs.empty(), "No compatible configs found for the fp8 block scale MoE runner.");
}

void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace,
Expand Down Expand Up @@ -366,16 +434,48 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
}
}

std::tuple<int32_t, int32_t> Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args)
std::tuple<int32_t, int32_t> Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args, int64_t configIndex) const
{
auto workspace_size_fc1 = static_cast<int32_t>(mPermuteGemm1.getWorkspaceSizeInBytes(
args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens));
auto workspace_size_fc2 = static_cast<int32_t>(mGemm2.getWorkspaceSizeInBytes(
args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens));
auto const& config = mPassingConfigs[configIndex];

auto workspace_size_fc1 = static_cast<int32_t>(mPermuteGemm1.getWorkspaceSizeInBytes(args.top_k, args.hidden_size,
args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm1Config));
auto workspace_size_fc2 = static_cast<int32_t>(mGemm2.getWorkspaceSizeInBytes(args.top_k, args.hidden_size,
args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm2Config));
return std::make_tuple(workspace_size_fc1, workspace_size_fc2);
}

void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream)
std::vector<int64_t> Runner::getValidConfigIndices(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const
{
std::vector<int64_t> validIndices;

for (int i = 0; i < mPassingConfigs.size(); ++i)
{
auto const& config = mPassingConfigs[i];

if (mPermuteGemm1.isValidConfigIndex(
config.gemm1Config, topK, hiddenSize, intermediateSize, numLocalExperts, numTokens)
&& mGemm2.isValidConfigIndex(
config.gemm2Config, topK, hiddenSize, intermediateSize, numLocalExperts, numTokens))
{
validIndices.push_back(i);
}
}

return validIndices;
}

int64_t Runner::getDefaultValidConfigIndex(
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const
{
auto const validIndices = getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);

return validIndices[0];
}

void Runner::run(
MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream, int64_t configIndex)
{
// Setup all operation data
moe::dev::activation::Data activationData;
Expand All @@ -386,12 +486,14 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d

void* hidden_states_scale_linear{args.hidden_states_scale};

auto const& config = mPassingConfigs[configIndex];

mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale,
workspace.expert_weights, args.output1_scales_scalar, args.output1_scales_gate_scalar, workspace.gemm1_output,
workspace.gemm1_output_scale, args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts,
args.num_tokens, workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas,
workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit,
workspace.bmm1_workspace, args.mUseRoutingScalesOnInput, device, stream);
workspace.bmm1_workspace, args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config);

// We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint.
void* gemm2_input = workspace.gemm1_output;
Expand All @@ -409,7 +511,8 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d
mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar,
workspace.gemm2_output, workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size,
args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens,
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream);
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream,
config.gemm2Config);

// Run finalize
if (args.do_finalize)
Expand Down
Loading