Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions cpp/tensorrt_llm/thop/fp8BatchedGemmTrtllmGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,19 @@ class FP8BatchedGemmRunner : public torch::CustomClassHolder
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
std::optional<at::Tensor> const& scaleC, int64_t configIndex)
{

// If configIndex is not provided, use the default valid config index
if (configIndex == -1)
{
int64_t b = mat1.size(0);
int64_t m = mat1.size(1);
int64_t n = mat2.size(1);
int64_t k = mat1.size(2);
int32_t const numTokens = 0;
int32_t const maxNumCtasInBatchDim = 0;
std::vector<int32_t> const batchedTokens(b, m);
configIndex
= mRunner->getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, b, maxNumCtasInBatchDim);
}
return fp8_batched_gemm_sm100(mat1, mat2, mTileSize, mUseDeepSeekFp8, mLowLatencyKernel, mEpilogueTileM,
dDqSfsA, dDqSfsB, scaleC, mOutDtypeArg, *mRunner, configIndex);
}
Expand All @@ -240,17 +252,6 @@ class FP8BatchedGemmRunner : public torch::CustomClassHolder
return mRunner->getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
}

int64_t getDefaultValidConfigIndex(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
{
// numTokens and maxNumCtasInBatchDim are not used for static batching
int32_t const numTokens = 0;
int32_t const maxNumCtasInBatchDim = 0;

std::vector<int32_t> const batchedTokens(numBatches, m);

return mRunner->getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
}

private:
using RunnerType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner;
using RunnerOptionsType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions;
Expand All @@ -271,6 +272,5 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
m.class_<torch_ext::FP8BatchedGemmRunner>("FP8BatchedGemmRunner")
.def(torch::init<at::ScalarType, bool, bool, int64_t, int64_t>())
.def("get_valid_configs", &torch_ext::FP8BatchedGemmRunner::getValidConfigs)
.def("get_default_valid_config", &torch_ext::FP8BatchedGemmRunner::getDefaultValidConfigIndex)
.def("run_batched_gemm", &torch_ext::FP8BatchedGemmRunner::runBatchedGemm);
}
68 changes: 27 additions & 41 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
) -> torch.Tensor:
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
return self.fp4_gemm_runner.run_gemm(
Expand Down Expand Up @@ -348,8 +347,8 @@ def _(


class FP8BatchedGemmRunner(TunableRunner):

_runner_dict = dict()
runner_dict = dict()
tuning_config = None

def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
low_latency_kernel: bool, tile_size: int,
Expand All @@ -360,40 +359,37 @@ def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
self.low_latency_kernel = low_latency_kernel
self.tile_size = tile_size
self.epilogue_tile_m = epilogue_tile_m
self.tuning_config = self.get_tuning_config()
FP8BatchedGemmRunner.tuning_config = FP8BatchedGemmRunner.get_tuning_config(
use_deep_seek_fp8, tile_size)

instance_key = (output_dtype, use_deep_seek_fp8, low_latency_kernel,
tile_size, epilogue_tile_m)

if instance_key not in FP8BatchedGemmRunner._runner_dict:
FP8BatchedGemmRunner._runner_dict[
if instance_key not in FP8BatchedGemmRunner.runner_dict:
FP8BatchedGemmRunner.runner_dict[
instance_key] = torch.classes.trtllm.FP8BatchedGemmRunner(
output_dtype, use_deep_seek_fp8, low_latency_kernel,
tile_size, epilogue_tile_m)

self._kernel_runner = FP8BatchedGemmRunner._runner_dict[instance_key]
self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key]

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run the batched GEMM operation with the given inputs and tactic.
"""

mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c = inputs

chosen_tactic = self.get_default_valid_tactic(
inputs) if tactic == -1 else tactic

out_tensors = self._kernel_runner.run_batched_gemm(
out_tensors = self.kernel_runner.run_batched_gemm(
mat1,
mat2,
dq_sfs_a,
dq_sfs_b,
scale_c,
chosen_tactic,
tactic,
)

return out_tensors
Expand All @@ -411,28 +407,12 @@ def get_valid_tactics(
n = mat2.shape[1]
k = mat1.shape[2]

tactics = self._kernel_runner.get_valid_configs(b, m, n, k)
tactics = self.kernel_runner.get_valid_configs(b, m, n, k)

return tactics

def get_default_valid_tactic(
self,
inputs: List[torch.Tensor],
) -> int:

mat1, mat2, _, _, _ = inputs

b = mat1.shape[0]
m = mat1.shape[1]
n = mat2.shape[1]
k = mat1.shape[2]

default_tactic = self._kernel_runner.get_default_valid_config(
b, m, n, k)

return default_tactic

def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:
@classmethod
def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]:
"""Get the dynamic tensor specs for use with the AutoTuner."""

# These indices correspond to the 0th input tensor and it's first dimension
Expand All @@ -443,29 +423,31 @@ def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:

# Starting at 8 as M % tile size == 0 is required
m_values = (8, 16, 32, 64, 128, 256, 512, 1024, 2048)
round_rule = lambda x: last_positive_power_of_2(x)
round_rule = last_positive_power_of_2

specs = (DynamicTensorSpec(MAT1_IDX, TUNED_DIM, m_values, round_rule), )

return specs

def get_constraint_specs(self) -> Tuple[ConstraintSpec, ...]:
@classmethod
def get_constraint_specs(cls, use_deep_seek_fp8: bool,
tile_size: int) -> Tuple[ConstraintSpec, ...]:
"""Get the constraint specs for the dynamic tensors for use with the AutoTuner.
"""

# When using deepseek fp8, the dq_sfs_a and dq_sfs_b tensors are expected to
# have specific dimensions. As we are only tuning M, we need only constrain
# dimension 1 of dq_sfs_a
if not self.use_deep_seek_fp8:
if not use_deep_seek_fp8:
constraint_dq_sfs_a = ()
else:

def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
b = shapes[0][0]
m = shapes[0][1]

m_padded = (m + self.tile_size - 1) // self.tile_size
result = m_padded * self.tile_size * b
m_padded = (m + tile_size - 1) // tile_size
result = m_padded * tile_size * b

return result

Expand All @@ -477,11 +459,15 @@ def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:

return constraint_dq_sfs_a

def get_tuning_config(self) -> TuningConfig:
@classmethod
@lru_cache(maxsize=None)
def get_tuning_config(cls, use_deep_seek_fp8: bool,
tile_size: int) -> TuningConfig:
"""Get the tuning configuration for the AutoTuner."""

dynamic_tensor_specs = self.get_dynamic_tensor_specs()
constraint_specs = self.get_constraint_specs()
dynamic_tensor_specs = cls.get_dynamic_tensor_specs()
constraint_specs = cls.get_constraint_specs(use_deep_seek_fp8,
tile_size)

tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
constraint_specs=constraint_specs)
Expand Down Expand Up @@ -516,7 +502,7 @@ def fp8_batched_gemm_trtllmgen(
_, best_tactic = tuner.choose_one(
"trtllm::fp8_batched_gemm_trtllmgen::batched_gemm",
[kernel_runner],
kernel_runner.tuning_config,
FP8BatchedGemmRunner.tuning_config,
inputs,
)

Expand Down