Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "KernelRunner.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/envUtils.h"
#include "trtllmGen_gemm_export/GemmInterface.h"
#include "trtllmGen_gemm_export/GemmOptions.h"
#include "trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h"
Expand Down Expand Up @@ -46,9 +47,10 @@ TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& optio
auto const options = configs[i].mOptions;

// When we include low-latency kernels we can set transposeMmaOutput via constructor
if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType
if (options.mDtypeA == mOptions.eltTypeA && options.mDtypeC == mOptions.outputType
&& options.mUseDeepSeekFp8 == mOptions.deepSeekFp8
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput)
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput
&& (mOptions.eltTypeB == gemm::trtllm::gen::Dtype::Void || options.mDtypeB == mOptions.eltTypeB))
{
mPassingConfigIndices.push_back(i);
}
Expand Down Expand Up @@ -113,8 +115,8 @@ void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, fl
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));

auto const err = gemm.run(
config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount, globalTrtllmGenGemmModuleCache);
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
tensorrt_llm::common::getEnvEnablePDL(), globalTrtllmGenGemmModuleCache);

TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
}
Expand All @@ -141,12 +143,30 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)

std::vector<int32_t> sortedIndices = mPassingConfigIndices;
std::sort(sortedIndices.begin(), sortedIndices.end(),
[&configs](int32_t idx0, int32_t idx1)
[&configs, &gemmData](int32_t idx0, int32_t idx1)
{
auto const& optionsA = configs[idx0].mOptions;
auto const& optionsB = configs[idx1].mOptions;

// Sort by tileK sizes first
// Choose the tileN that is closest to the problem N. Also if one tileN is larger and the other is smaller,
// prefer the larger one. This is the batch size dimension for low latency (transposeMmaOutput) case;
if (optionsA.mTileN != optionsB.mTileN)
{
auto const N = gemmData.mProblemDimensions.mN;
auto const tileA = optionsA.mTileN;
auto const tileB = optionsB.mTileN;

// If one tile is larger than N and one is smaller, prefer the larger one
if ((tileA >= N) != (tileB >= N))
{
return tileA > tileB;
}

// Otherwise, choose the closest to N
return abs(N - tileA) < abs(N - tileB);
}

// Sort by tileK sizes
if (optionsA.mTileK != optionsB.mTileK)
{
return optionsA.mTileK > optionsB.mTileK;
Expand All @@ -158,6 +178,13 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
return optionsA.mUseUnrollLoop2xForMma;
}

// Sort by tileM sizes
// This is the batch size dimension for throughput (non-transposeMmaOutput) case;
if (optionsA.mTileM != optionsB.mTileM)
{
return optionsA.mTileM > optionsB.mTileM;
}

// Then by splitK sizes
if (optionsA.mNumSlicesForSplitK != optionsB.mNumSlicesForSplitK)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace kernels

struct TrtllmGenGemmRunnerOptions
{
gemm::trtllm::gen::Dtype eltType;
gemm::trtllm::gen::Dtype eltTypeA;
gemm::trtllm::gen::Dtype eltTypeB{gemm::trtllm::gen::Dtype::Void};
gemm::trtllm::gen::Dtype outputType;
bool deepSeekFp8{false};
bool transposeMmaOutput{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ enum class AllReduceAlgo : uint32_t

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class MatrixLayout
{
// K-major layout (default). [Mn, K]
MajorK = 0,
// M-major for A and N-major for B. [K, Mn]
MajorMn,
// Layout is blocked along the K dimension as seen in the diagram below. [K / blockK, Mn, blockK]
// where blockK is fixed at 128B
//
// ├────────────── K ──────────────┤
// ┬ ┬ ├──── K block ───┤
// │ │ │ 0 1 2 3 ║ 32 33 34 35 │
// │ CTA0 │ 4 5 6 7 ║ 36 37 38 39 │
// │ │ │ 8 9 10 11 ║ 40 41 42 43 │
// │ ┴ │ 12 13 14 15 ║ 44 45 46 47 │
// M ┬ ├────────────────║────────────────┤
// │ │ │ 16 17 18 19 ║ 48 49 50 51 │
// │ CTA1 │ 20 21 22 23 ║ 52 53 54 55 │
// │ │ │ 24 25 26 27 ║ 56 57 58 59 │
// ┴ ┴ │ 28 29 30 31 ║ 60 61 62 63 │
BlockMajorK
};

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class SplitK : uint32_t
{
// No split-k is needed. I.e. mNumSlicesForSplitK == 1.
Expand All @@ -54,6 +79,20 @@ enum class SplitK : uint32_t

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class BiasType : uint32_t
{
// No bias.
None = 0,
// One bias value per N of the output tensor.
M = 1,
// One bias value per row M of the output tensor.
N = 2,
// One bias value for each element of the output tensor.
Mn = 3,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class TileScheduler
{
// Static scheduler (Non-persistent).
Expand All @@ -80,6 +119,23 @@ SPLIT_K_FUNCTION(Dsmem)

////////////////////////////////////////////////////////////////////////////////////////////////////

// Helper functions to check the Bias type.

#define BIAS_TYPE_FUNCTION(Mode) \
inline bool isBiasType##Mode(BiasType type) \
{ \
return (type == BiasType::Mode); \
}

BIAS_TYPE_FUNCTION(None)
BIAS_TYPE_FUNCTION(N)
BIAS_TYPE_FUNCTION(M)
BIAS_TYPE_FUNCTION(Mn)

#undef BIAS_TYPE_FUNCTION

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace gemm

} // namespace gemm
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ struct GemmData
{
// The matrix A. The data type is controlled by options.mDtypeA.
//
// When transposeMatrixA is false, the shape is [M, K].
// Otherwise, the shape is [K, M].
// When layoutA is MatrixLayout::MajorK, the shape is [M, K].
// When LayoutA is MatrixLayout::MajorMn, the shape is [K, M].
// When LayoutA is MatrixLayout::BlockMajorK, the shape is [K / blockK, M, blockK] where blockK
// is 128B.
// The rightmost dimension is contiguous in memory.
void const* mPtrA{nullptr};

Expand Down Expand Up @@ -100,8 +102,10 @@ struct GemmData

// The matrix B. The data type is controlled by options.mDtypeB.
//
// When transposeMatrixB is true, the shape is [N, K].
// Otherwise, the shape is [K, N].
// When layoutB is MatrixLayout::MajorK, the shape is [N, K].
// When layoutB is MatrixLayout::MajorMn, the shape is [K, N].
// When layoutB is MatrixLayout::BlockMajorK, the shape is [K / blockK, N, blockK] where blockK
// is 128B.
// The rightmost dimension is contiguous in memory.
void const* mPtrB{nullptr};

Expand Down Expand Up @@ -142,8 +146,33 @@ struct GemmData
// The shape is [N]
void const* mPtrPerTokenSfB{nullptr};

// The output tensor scaling factor for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 quantization.
// The bias applied after the GEMM.
// The bias is applied before applying the global scaling factor. I.e.
// C' = (A * B + bias') * scaleC
// scaleC = dequantA * dequantB * quantC
// Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias.
//
// if BiasType is N, the shape is [N].
// The bias is broadcasted along the M dimension.
//
// if BiasType is M, the shape is [M].
// The bias is broadcasted along the N dimension.
//
// The dtype is float32.
void const* mPtrBias{nullptr};

// The output tensor scaling factor for Fp8 (not DeepSeek FP8) and NvFp4 quantization.
// TensorRT-LLM API requires a scaling factor on the device.
// scaleC = dequantA * dequantB * quantC,
// where dequantA is global dequantization scaling factor of A
// if dtypeA is FP8, it transforms the range from [-448, 448] to [-amaxA, amaxA]
// if dtypeA is NvFp4, it transforms the range from [-448 * 6, 448 * 6] to [-amaxA, amaxA],
// otherwise it is 1.
// dequantB is defined similarly to dequantA.
// quantC is the quantization scaling factor of C.
// if dtypeC is FP8, it transforms the range from [-amaxC, amaxC] to [-448, 448]
// if dtypeC is NvFp4, it transforms the range from [-amaxC, amaxC] to [-448 * 6, 448 * 6],
// otherwise it is 1.
// Shape is [1].
void* mPtrScaleC{nullptr};
};
Expand Down Expand Up @@ -230,7 +259,7 @@ class GemmInterface
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
// Provided config must be validated with isValidConfig before the call.
int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream,
int32_t multiProcessorCount,
int32_t multiProcessorCount, bool usePdl = true,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;

// Initializes the buffers before the world sync. Must be called before run.
Expand Down Expand Up @@ -378,7 +407,7 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data
auto options = getOptionsFromConfigAndData(config, data);

// Is Blackwell?
bool isBlackwell = config.mSm == SmVersion::Sm100a;
bool isBlackwell = isSmVersionBlackwell(config.mSm);

// Check options without modifications.
return checkAndUpdateGemmOptions(options, isBlackwell, data.mProblemDimensions.mWorldSize,
Expand All @@ -388,8 +417,11 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data
////////////////////////////////////////////////////////////////////////////////////////////////////

int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream,
int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
int32_t multiProcessorCount, bool usePdl, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
{
// Might be used.
(void) usePdl;
(void) moduleCache;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data);

Expand Down Expand Up @@ -417,15 +449,14 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
int numTilesN = gemm::divUp(options.mN, options.mTileN);

// Create kernel params.
auto kernelParams = gemm::KernelParams::setKernelParams(options, data.mInputBuffers.mPtrA,
auto kernelParams = gemm::KernelParamsSetup::setKernelParams(options, data.mInputBuffers.mPtrA,
data.mInputBuffers.mPtrSfA, data.mInputBuffers.mPtrPerTokenSfA, data.mInputBuffers.mPtrB,
data.mInputBuffers.mPtrSfB, data.mInputBuffers.mPtrPerTokenSfB, data.mOutputBuffers.mPtrC,
data.mOutputBuffers.mPtrSfC, data.mOutputBuffers.mPtrMultiMemC, (float*) data.mInputBuffers.mPtrScaleC,
dSplitKSlices, data.mAllReduceBuffers.mPtrTileBars, data.mAllReduceBuffers.mPtrMultiMemTileBars,
data.mAllReduceBuffers.mPtrCompletionBars, data.mAllReduceBuffers.mPtrMultiMemCompletionBars,
dPtrSplitKCompletionBars,
data.mInputBuffers.mPtrSfB, data.mInputBuffers.mPtrPerTokenSfB, data.mInputBuffers.mPtrBias,
data.mOutputBuffers.mPtrC, data.mOutputBuffers.mPtrSfC, data.mOutputBuffers.mPtrMultiMemC,
(float*) data.mInputBuffers.mPtrScaleC, dSplitKSlices, data.mAllReduceBuffers.mPtrTileBars,
data.mAllReduceBuffers.mPtrMultiMemTileBars, data.mAllReduceBuffers.mPtrCompletionBars,
data.mAllReduceBuffers.mPtrMultiMemCompletionBars, dPtrSplitKCompletionBars,
/* dPtrNumNonExitingCtas */ nullptr, data.mProblemDimensions.mRank, data.mProblemDimensions.mWorldSize);

// The size of the grid.
std::vector<int32_t> grid{numTilesM, numTilesN, options.mNumSlicesForSplitK};

Expand All @@ -443,26 +474,26 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
#ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule;
CUfunction cuFunction;

if (moduleCache.has_value())
{
ModuleCache& moduleCacheRef = moduleCache.value().get();

// Modules are associated with a specific context so include the ctxId in the key
// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);

// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
// representation.
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);

// Check if module exists in cache. Otherwise, load it
// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end())
{
cuFunction = std::get<1>(module->second);
Expand Down Expand Up @@ -492,17 +523,18 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
// Run the kernel.
auto result = trtllm::gen::launchKernel((void*) &kernelParams, cudaStream, config.mSharedMemSize, cuFunction,
block3, grid3, cluster3,
config.mOptions.mGridWaitForPrimaryEarlyExit | config.mOptions.mGridWaitForPrimaryA
| config.mOptions.mGridWaitForPrimaryB);
if (result != CUDA_SUCCESS)
{
return -1;
}
usePdl
&& (config.mOptions.mGridWaitForPrimaryEarlyExit | config.mOptions.mGridWaitForPrimaryA
| config.mOptions.mGridWaitForPrimaryB));
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value())
{
cuModuleUnload(cuModule);
}
if (result != CUDA_SUCCESS)
{
return -1;
}
#else
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
#endif
Expand Down
Loading