Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
}
}

TLLM_CHECK_WITH_INFO(mPassingConfigIndices.size() != 0, "No kernel found for the given output type");
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
}

size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex)
{
BatchedGemmData gemmData;
gemmData.mProblemDimensions.mNumBatches = numBatches;
Expand All @@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;

selectGemmConfig(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);

auto bmm = BatchedGemmInterface();

auto const configs = bmm.getBatchedGemmConfigs();
TLLM_CHECK_WITH_INFO(
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
auto const& config = configs[mSelectedConfigIndex.value()];

if (!configIndex.has_value())
{
mSelectedConfigIndex
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
configIndex = mSelectedConfigIndex;
}

auto const& config = configs[configIndex.value()];
return bmm.getWorkspaceSizeInBytes(config, gemmData);
}

Expand All @@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device)
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
{
auto bmm = BatchedGemmInterface();

BatchedGemmData gemmData;

auto const configs = bmm.getBatchedGemmConfigs();
TLLM_CHECK_WITH_INFO(
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
auto const& config = configs[mSelectedConfigIndex.value()];

if (!configIndex.has_value())
{
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");

configIndex = mSelectedConfigIndex;
}

auto const& config = configs[configIndex.value()];

TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
if (!mOptions.staticBatch)
Expand Down Expand Up @@ -170,32 +182,33 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto

void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
CUstream stream, int device)
CUstream stream, int device, std::optional<int32_t> configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device);
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
}

void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
CUstream stream, int device)
CUstream stream, int device, std::optional<int32_t> configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
/* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC,
scaleGateC, c, /* outSfC */ nullptr,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device);
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
}

void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const
{
auto const bmm = BatchedGemmInterface();
auto const configs = bmm.getBatchedGemmConfigs();
Expand Down Expand Up @@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
return optionsA.mTileM > optionsB.mTileM;
});

std::vector<int64_t> validConfigIndices;
for (auto const& configIndex : sortedIndices)
{
auto const& config = configs[configIndex];
auto isValidConfig = bmm.isValidConfig(config, gemmData);
if (isValidConfig)
{
mSelectedConfigIndex = configIndex;
return;
validConfigIndices.push_back(configIndex);
}
}

TLLM_CHECK_WITH_INFO(!validConfigIndices.empty(), "No valid config found for the given problem shape");

return validConfigIndices;
}

int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const
{
auto const validConfigIndices
= getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);

return validConfigIndices[0];
}

} // namespace kernels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#pragma once

#include <cstdint>
#include <cuda.h>
#include <optional>
#include <vector>

#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"

Expand Down Expand Up @@ -45,29 +47,49 @@ class TrtllmGenBatchedGemmRunner
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);

[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim);
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex = std::nullopt);

void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device);
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);

void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device);
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);

void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device);
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);

// Get the list of configs that passed the validation based on the constructor options
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
{
return mPassingConfigIndices;
}

// Get the list of config indices that are valid for the given problem shape
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;

// Get a default config index that is valid for the given problem shape
// This will be used as the fallback config if using auto-tuning
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;

private:
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim);

private:
TrtllmGenBatchedGemmRunnerOptions mOptions;
std::optional<int> mSelectedConfigIndex;
std::vector<int32_t> mPassingConfigIndices;
std::optional<int32_t> mSelectedConfigIndex;
};
} // namespace kernels
} // namespace tensorrt_llm
Loading