Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions cpp/include/tensorrt_llm/common/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ class QuantMode
return QuantMode(BaseType(1u) << 13);
}

static constexpr QuantMode w4a8Mxfp4Fp8() noexcept
{
return QuantMode(BaseType(1u) << 14);
}

constexpr BaseType value() const noexcept
{
return mValue;
Expand Down Expand Up @@ -192,15 +197,19 @@ class QuantMode
return isSet(nvfp4());
}

constexpr bool hasW4a8Mxfp4Fp8() const noexcept
{
return isSet(w4a8Mxfp4Fp8());
}

constexpr bool hasKvCacheQuant() const noexcept
{
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
}

static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
bool useW4a8QServe = false, bool useFp4Quant = false, bool useFp8BlockScales = false)
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
{
QuantMode quantMode{};
if (quantizeWeights)
Expand Down Expand Up @@ -264,22 +273,30 @@ class QuantMode
quantMode += nvfp4();
}

if (useW4a8Mxfp4Fp8)
{
quantMode += w4a8Mxfp4Fp8();
}

return quantMode;
}

static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
{
return fromDescription(true, true, perToken, perChannel);
return fromDescription(
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
}

static constexpr QuantMode useQServe(bool perGroup)
{
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
return fromDescription(
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
}

static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
{
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
false, false, false);
}

static QuantMode const fromQuantAlgo(
Expand Down Expand Up @@ -336,21 +353,28 @@ class QuantMode
}
else if (quantAlgo == "FP8")
{
quantMode = fromDescription(false, false, false, false, false, false, false, false, true);
quantMode = fromDescription(
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
}
else if (quantAlgo == "FP8_ROWWISE")
{
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true);
quantMode = fromDescription(
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
}
else if (quantAlgo == "FP4")
{
quantMode
= fromDescription(false, false, false, false, false, false, false, false, false, false, false, true);
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
}
else if (quantAlgo == "FP8_BLOCK_SCALES")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, false, true);
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
}
else if (quantAlgo == "W4A8_MXFP4_FP8")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
}

if (kvCacheQuantAlgo == "INT8")
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
int const max_distance = mMaxDistance;
bool const* finished = nullptr;

auto const quant_option = tc::QuantMode::fromDescription();
auto const quant_option = tc::QuantMode{};
float const* qkv_scale_out = nullptr;

int const* ia3_tasks = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ IPluginV2* Fp8RowwiseGemmPluginCreator::createPlugin(char const* name, PluginFie
// Fp8RowwiseGemmPluginCreator is unique and shared for an engine generation
// Create plugin profiler with shared tactics map
auto pluginProfiler = mGemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false);
QuantMode quantMode = QuantMode::fromDescription();
QuantMode quantMode = QuantMode{};
auto* obj = new Fp8RowwiseGemmPlugin(quantMode, type, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ IPluginV2* GemmSwigluPluginCreator::createPlugin(char const* name, PluginFieldCo
// GemmSwigluPluginCreator is unique and shared for an engine generation
// Create plugin profiler with shared tactics map
auto pluginProfiler = mGemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false);
QuantMode quantMode = QuantMode::fromDescription();
QuantMode quantMode = QuantMode{};
auto* obj = new GemmSwigluPlugin(quantMode, type, hasBias, scale_d0, scale_d1, scale_output, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(char const* name, PluginFi
// SmoothQuantGemmPluginCreator is unique and shared for an engine generation
// Create plugin profiler with shared tactics map
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false);
QuantMode quantMode = QuantMode::fromDescription(true, true, perTokenScaling, perChannelScaling);
QuantMode quantMode = QuantMode::fromDescription(true, true, perTokenScaling, perChannelScaling, false, false,
false, false, false, false, false, false, false, false);
auto* obj = new SmoothQuantGemmPlugin(quantMode, type, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
Expand Down
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/pybind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,13 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache)
.def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq)
.def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4)
.def_property_readonly("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8)
.def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant)
.def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights") = false,
py::arg("quantize_activations") = false, py::arg("per_token") = false, py::arg("per_channel") = false,
py::arg("per_group") = false, py::arg("use_int4_weights") = false, py::arg("use_int8_kv_cache") = false,
py::arg("use_fp8_kv_kache") = false, py::arg("use_fp8_qdq") = false, py::arg("use_fp8_rowwise") = false,
py::arg("use_w4a8_qserve") = false, py::arg("use_nvfp4") = false, py::arg("use_fp8_block_scales") = false)
.def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights"),
py::arg("quantize_activations"), py::arg("per_token"), py::arg("per_channel"), py::arg("per_group"),
py::arg("use_int4_weights"), py::arg("use_int8_kv_cache"), py::arg("use_fp8_kv_kache"),
py::arg("use_fp8_qdq"), py::arg("use_fp8_rowwise"), py::arg("use_w4a8_qserve"), py::arg("use_nvfp4"),
py::arg("use_fp8_block_scales"), py::arg("use_w4a8_mxfp4_fp8"))
.def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, py::arg("per_token") = false,
py::arg("per_channel") = false)
.def_static("use_weight_only", &tc::QuantMode::useWeightOnly, py::arg("use_int4_weights") = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,14 @@ TEST(Kernel, WeightOnly)
std::vector<int> ns{2048, 4096};
std::vector<int> ks{2048, 4096};
std::vector<tensorrt_llm::common::QuantMode> quant_modes(4);
quant_modes[0] = tensorrt_llm::common::QuantMode::fromDescription(false, false, false, false);
quant_modes[1] = tensorrt_llm::common::QuantMode::fromDescription(false, false, true, false);
quant_modes[2] = tensorrt_llm::common::QuantMode::fromDescription(false, false, false, true);
quant_modes[3] = tensorrt_llm::common::QuantMode::fromDescription(false, false, true, true);
quant_modes[0] = tensorrt_llm::common::QuantMode::fromDescription(
false, false, false, false, false, false, false, false, false, false, false, false, false, false);
quant_modes[1] = tensorrt_llm::common::QuantMode::fromDescription(
false, false, true, false, false, false, false, false, false, false, false, false, false, false);
quant_modes[2] = tensorrt_llm::common::QuantMode::fromDescription(
false, false, false, true, false, false, false, false, false, false, false, false, false, false);
quant_modes[3] = tensorrt_llm::common::QuantMode::fromDescription(
false, false, true, true, false, false, false, false, false, false, false, false, false, false);
for (auto m : ms)
{
for (auto n : ns)
Expand Down
22 changes: 19 additions & 3 deletions tensorrt_llm/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
INT8 = auto()
MIXED_PRECISION = auto()
NVFP4 = auto()
W4A8_MXFP4_FP8 = auto()
NO_QUANT = auto()


Expand Down Expand Up @@ -87,6 +88,8 @@ class QuantMode(IntFlag):
# FP4
NVFP4 = auto()
NVFP4_KV_CACHE = auto()
# W4A8 MXFP4
W4A8_MXFP4_FP8 = auto()

# The smallest power-of-two that is not used by a flag. Do not call auto() after that line.
COUNT = auto()
Expand Down Expand Up @@ -172,6 +175,9 @@ def has_fp8_rowwise(self):
def has_nvfp4(self):
return self._any(self.NVFP4)

def has_w4a8_mxfp4_fp8(self):
return self._any(self.W4A8_MXFP4_FP8)

def has_weight_quant(self):
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS)

Expand All @@ -182,7 +188,8 @@ def has_any_quant(self, exclude_kv_cache: bool = False):
| self.FP8_QDQ | self.FP8_ROWWISE
| self.W4A8_QSERVE
| self.FP8_1x128_128x128
| self.NVFP4)
| self.NVFP4
| self.W4A8_MXFP4_FP8)
if exclude_kv_cache:
return has_quant

Expand Down Expand Up @@ -217,7 +224,8 @@ def from_description(quantize_weights=False,
use_fp8_block_scales=False,
use_fp8_rowwise=False,
use_nvfp4=False,
use_w4a8_qserve=False):
use_w4a8_qserve=False,
use_w4a8_mxfp4_fp8=False):

def raise_error():
raise ValueError(f"Unsupported combination of QuantMode args: "
Expand All @@ -233,7 +241,8 @@ def raise_error():
f"{use_fp8_block_scales=}, "
f"{use_fp8_rowwise=}, "
f"{use_nvfp4=}, "
f"{use_w4a8_qserve=}")
f"{use_w4a8_qserve=}, "
f"{use_w4a8_mxfp4_fp8=}")

# We must quantize weights when we quantize activations.
if quantize_activations and not quantize_weights:
Expand Down Expand Up @@ -288,6 +297,9 @@ def raise_error():
if use_w4a8_qserve:
mode = mode | QuantMode.W4A8_QSERVE

if use_w4a8_mxfp4_fp8:
mode = mode | QuantMode.W4A8_MXFP4_FP8

return mode

@staticmethod
Expand Down Expand Up @@ -361,6 +373,8 @@ def from_quant_algo(
quant_mode = QuantMode.from_description(use_fp8_block_scales=True)
elif quant_algo == QuantAlgo.NVFP4:
quant_mode = QuantMode.from_description(use_nvfp4=True)
elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8:
quant_mode = QuantMode.from_description(use_w4a8_mxfp4_fp8=True)
else:
quant_mode = QuantMode(0)

Expand Down Expand Up @@ -393,6 +407,8 @@ def to_dict(self):
self.has_fp8_block_scales(),
'enable_nvfp4':
self.has_nvfp4(),
'enable_w4a8_mxfp4_fp8':
self.has_w4a8_mxfp4_fp8(),
'fp8_kv_cache':
self.has_fp8_kv_cache(),
'use_weight_only':
Expand Down
3 changes: 2 additions & 1 deletion tests/unittest/bindings/test_bindings_ut.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_quant_mode():
assert _tb.QuantMode.fp8_qdq().has_fp8_qdq

quant_mode = _tb.QuantMode.from_description(True, True, True, True, True,
True, True, True)
True, True, True, False, False,
False, False, False, False)
assert quant_mode.has_int4_weights
quant_mode -= _tb.QuantMode.int4_weights()
assert not quant_mode.has_int4_weights
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/trt/quantization/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_any(self):

def test_count(self):
# Make sure the COUNT value is as expected - change that test if you add a new flag.
self.assertEqual(QuantMode.COUNT.value, 1 << 14)
self.assertEqual(QuantMode.COUNT.value, 1 << 15)

def test_from_description(self):
# Test weight only.
Expand Down