Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,58 +44,115 @@
namespace torch_ext
{

W4A16GemmRunner::W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode)
finegrainedMixedDtypeGemmRunner::finegrainedMixedDtypeGemmRunner(
at::ScalarType activationDtype, at::ScalarType outputDtype, int64_t quant_mode)
: mActivationDtype(activationDtype)
, mOutputDtype(outputDtype)
{
if (quant_mode == 0)
{
if (activationDtype == at::ScalarType::Half)
{
TORCH_CHECK(
outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype);
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
}
else if (activationDtype == at::ScalarType::BFloat16)
{
TORCH_CHECK(
outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype);
mGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>();
}

else if (activationDtype == at::ScalarType::Float8_e4m3fn)
{
if (outputDtype == at::ScalarType::BFloat16)
{
mGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, __nv_bfloat16, __nv_bfloat16>>();
}
else if (outputDtype == at::ScalarType::Half)
{
mGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
}
else
{
TORCH_CHECK(false, "Unsupported output dtype for Float8_e4m3fn activation", outputDtype);
}
}
else
{
TORCH_CHECK(false, "Unsupported activation dtype", activationDtype);
}
}

else if (quant_mode == 1)
{
if (activationDtype == at::ScalarType::Half)
{
TORCH_CHECK(
outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype);
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
}
else if (activationDtype == at::ScalarType::BFloat16)
{
TORCH_CHECK(
outputDtype == activationDtype, "Activation dtype needs to match Output stype", activationDtype);
mGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, __nv_bfloat16,
__nv_bfloat16, __nv_bfloat16>>();
}
else if (activationDtype == at::ScalarType::Float8_e4m3fn)
{
if (outputDtype == at::ScalarType::BFloat16)
{
mGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, __nv_bfloat16, __nv_bfloat16>>();
}
else if (outputDtype == at::ScalarType::Half)
{
mGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
}
else
{
TORCH_CHECK(false, "Unsupported output dtype for Float8_e4m3fn activation", outputDtype);
}
}
}
else
{
TORCH_CHECK(false, "Unsupported quant mode for W4A16GemmRunner: ", quant_mode);
TORCH_CHECK(false, "Unsupported quant mode for finegrainedMixedDtypeGemmRunner: ", quant_mode);
}

TORCH_CHECK(mGemmRunner, "Failed to create W4A16 GEMM runner for activation type ", c10::toString(activationDtype));
TORCH_CHECK(mGemmRunner, "Failed to create finegrained Mixed Dtype GEMM runner for activation type ",
c10::toString(activationDtype));
mConfigs = mGemmRunner->getConfigs(); // Get configs via the interface
TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for W4A16 GEMM with activation type ",
TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for finegrainedMixedDtype GEMM with activation type ",
c10::toString(activationDtype));
}

at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
int64_t group_size_long, int64_t configIdx, std::optional<at::Tensor> bias, std::optional<at::Tensor> zeros) const
at::Tensor finegrainedMixedDtypeGemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed,
at::Tensor const& scales, int64_t group_size_long, int64_t configIdx, std::optional<at::Tensor> bias,
std::optional<at::Tensor> zeros, double alpha) const
{
TORCH_CHECK(A.is_cuda() && B_packed.is_cuda() && scales.is_cuda(), "All input tensors must be on CUDA");
TORCH_CHECK(A.scalar_type() == mActivationDtype, "Activation tensor A's dtype ", c10::toString(A.scalar_type()),
" does not match runner's expected dtype ", c10::toString(mActivationDtype));
TORCH_CHECK(B_packed.scalar_type() == torch::kQUInt4x2 || B_packed.scalar_type() == torch::kInt8
|| B_packed.scalar_type() == torch::kUInt8,
"B_packed must be quint4x2, int8, or uint8 (view of quantized data)");

TORCH_CHECK(A.is_contiguous() && B_packed.is_contiguous() && scales.is_contiguous(),
"All input tensors (A, B_packed, scales) must be contiguous");

Expand Down Expand Up @@ -156,19 +213,18 @@ at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_pac
output_shape_vec.push_back(N_orig);
}

// Set output dtype based on activation dtype
torch::ScalarType output_dtype;
if (mActivationDtype == at::ScalarType::Half)
if (mOutputDtype == at::ScalarType::Half)
{
output_dtype = torch::kFloat16;
}
else if (mActivationDtype == at::ScalarType::BFloat16)
else if (mOutputDtype == at::ScalarType::BFloat16)
{
output_dtype = torch::kBFloat16;
}
else
{
TORCH_CHECK(false, "Unsupported activation type for output dtype determination");
TORCH_CHECK(false, "Unsupported output dtype");
}

torch::Tensor C_tensor = torch::empty(output_shape_vec, A.options().dtype(output_dtype));
Expand Down Expand Up @@ -201,25 +257,24 @@ at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_pac

cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index());

mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr,
1.0f, // alpha
C_ptr, M, N_orig, K, group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream);
mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr, static_cast<float>(alpha), C_ptr, M, N_orig, K,
group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream);

return C_tensor;
}

int64_t W4A16GemmRunner::getNumConfigs() const
int64_t finegrainedMixedDtypeGemmRunner::getNumConfigs() const
{
TORCH_CHECK(mGemmRunner, "W4A16GemmRunner not initialized properly.");
TORCH_CHECK(mGemmRunner, "finegrainedMixedDtypeGemmRunner not initialized properly.");
return static_cast<int64_t>(mConfigs.size());
}

} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.class_<torch_ext::W4A16GemmRunner>("W4A16GemmRunner")
.def(torch::init<at::ScalarType, int64_t>())
.def("run_gemm", &torch_ext::W4A16GemmRunner::runGemm)
.def("get_num_configs", &torch_ext::W4A16GemmRunner::getNumConfigs);
m.class_<torch_ext::finegrainedMixedDtypeGemmRunner>("finegrainedMixedDtypeGemmRunner")
.def(torch::init<at::ScalarType, at::ScalarType, int64_t>())
.def("run_gemm", &torch_ext::finegrainedMixedDtypeGemmRunner::runGemm)
.def("get_num_configs", &torch_ext::finegrainedMixedDtypeGemmRunner::getNumConfigs);
}
8 changes: 5 additions & 3 deletions cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@
namespace torch_ext
{

class W4A16GemmRunner : public torch::CustomClassHolder
class finegrainedMixedDtypeGemmRunner : public torch::CustomClassHolder
{
public:
explicit W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode = 0);
explicit finegrainedMixedDtypeGemmRunner(
at::ScalarType activationDtype, at::ScalarType outputDtype, int64_t quant_mode = 0);

at::Tensor runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
int64_t group_size_long, int64_t configIdx = -1, std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> zeros = std::nullopt) const;
std::optional<at::Tensor> zeros = std::nullopt, double alpha = 1.0f) const;

int64_t getNumConfigs() const;

private:
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface> mGemmRunner;
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mConfigs;
at::ScalarType mActivationDtype;
at::ScalarType mOutputDtype;
};

} // namespace torch_ext
79 changes: 45 additions & 34 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,24 +675,27 @@ def _(
dtype=output_dtype)


class W4A16GemmRunner(TunableRunner):
class FinegrainedMixedDtypeGemm(TunableRunner):
_runner_dict = dict()
MAX_SUPPORTED_SM_VERSION = 90

def __init__(self, activation_dtype: torch.dtype, quant_mode: int):
instance_key = (activation_dtype, quant_mode)
if instance_key not in W4A16GemmRunner._runner_dict:
W4A16GemmRunner._runner_dict[
instance_key] = torch.classes.trtllm.W4A16GemmRunner(
activation_dtype, quant_mode)
self._w4a16_gemm_runner = W4A16GemmRunner._runner_dict[instance_key]
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
quant_mode: int):
instance_key = (activation_dtype, output_dtype, quant_mode)
if instance_key not in FinegrainedMixedDtypeGemm._runner_dict:
FinegrainedMixedDtypeGemm._runner_dict[
instance_key] = torch.classes.trtllm.finegrainedMixedDtypeGemmRunner(
activation_dtype, output_dtype, quant_mode)
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
instance_key]

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return list(range(self._w4a16_gemm_runner.get_num_configs()))
return list(
range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs()))

def forward(self,
inputs: List[torch.Tensor],
Expand All @@ -707,25 +710,25 @@ def forward(self,

activation, weights_packed, scales = inputs

return self._w4a16_gemm_runner.run_gemm(
activation,
weights_packed,
scales,
kwargs["group_size"],
tactic,
kwargs["bias"],
kwargs["zeros"],
)
alpha = 1.0 if kwargs.get("alpha") is None else kwargs["alpha"]

return self._finegrained_mixed_dtype_gemm_runner.run_gemm(
activation, weights_packed, scales, kwargs["group_size"], tactic,
kwargs["bias"], kwargs["zeros"], alpha)

@torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=())
def w4a16_gemm(input: torch.Tensor,
weight: torch.Tensor,
scales: torch.Tensor,
group_size: int,
has_zero_point: bool,
bias: Optional[torch.Tensor] = None,
zeros: Optional[torch.Tensor] = None) -> torch.Tensor:

@torch.library.custom_op("trtllm::finegrained_mixed_dtype_gemm",
mutates_args=())
def finegrained_mixed_dtype_gemm(
input: torch.Tensor,
weight: torch.Tensor,
scales: torch.Tensor,
group_size: int,
has_zero_point: bool,
output_dtype: torch.dtype,
alpha: Optional[float] = None,
bias: Optional[torch.Tensor] = None,
zeros: Optional[torch.Tensor] = None) -> torch.Tensor:

assert not has_zero_point or zeros is not None, "Expected 'zeros' tensor when has_zero_point is True"

Expand All @@ -741,16 +744,24 @@ def w4a16_gemm(input: torch.Tensor,
if quant_mode == 0:
assert zeros is None, "When quant_mode is 0 (FINEGRAINED_SCALE_ONLY), zeros must be None"

w4a16_gemm_runner = W4A16GemmRunner(input.dtype, quant_mode)
finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm(
input.dtype, output_dtype, quant_mode)

kwargs = {
"group_size": group_size,
"zeros": zeros,
"bias": bias,
"alpha": alpha
}

kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias}
_, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm",
[w4a16_gemm_runner], tuning_config,
[input, weight, scales], **kwargs)
_, best_tactic = tuner.choose_one(
"trtllm::finegrained_mixed_dtype_gemm::gemm",
[finegrained_mixed_dtype_gemm_runner], tuning_config,
[input, weight, scales], **kwargs)

return w4a16_gemm_runner(inputs=[input, weight, scales],
tactic=best_tactic,
**kwargs)
return finegrained_mixed_dtype_gemm_runner(inputs=[input, weight, scales],
tactic=best_tactic,
**kwargs)


@torch.library.custom_op("trtllm::attention", mutates_args=())
Expand Down
Loading