Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
310d791
extend quark to support mixed-precision quantization model
xuebwang-amd Sep 4, 2025
cb060fc
use an environment variable to support mxfp4 quantize regardless of p…
xuebwang-amd Sep 4, 2025
4f0949e
add a test for quark mixed precision models
xuebwang-amd Sep 4, 2025
adfd0f9
fix pre-commit issues
xuebwang-amd Sep 4, 2025
5b89354
add one section about mixed-precision usage in the Quark document
xuebwang-amd Sep 4, 2025
5fb4031
tiny update AMP document
xuebwang-amd Sep 5, 2025
cc0352e
refactor test script and add a new model
xuebwang-amd Sep 5, 2025
c9c7567
simplify layer_quant_configs matching
xuebwang-amd Sep 5, 2025
1a9699e
update AMP section in the Quark document
xuebwang-amd Sep 5, 2025
6e97f5e
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
fe7fe79
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
f6ae690
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
a3d48e2
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
752a6ea
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
b154d68
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
cf810e3
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
6ced976
remove VLLM_QUARK_EMU_MEM_OPT and use_v0
xuebwang-amd Sep 17, 2025
5f5eb76
correct and simplify layer_quant_config parsing in QuarkConfig
xuebwang-amd Sep 17, 2025
2917483
update one pre-commit issue
xuebwang-amd Sep 17, 2025
28c4e43
fix markdownlint issue
xuebwang-amd Sep 18, 2025
91308a0
update excepted accuracy numbers since the amp models in hf have been…
xuebwang-amd Sep 25, 2025
9760017
reduce test models to be one
xuebwang-amd Sep 26, 2025
0e5eebb
remove HF_HUB_AMD_ORG_ACCESS since model is public
xuebwang-amd Sep 28, 2025
56c0565
fix pre-commit issue in test_mixed_precision.py
xuebwang-amd Sep 29, 2025
d7e47ef
update with fixing conflictions
xuebwang-amd Oct 9, 2025
2bb8520
add a non-mixed-precision (PTQ) model as a reference for pipeline com…
xuebwang-amd Oct 17, 2025
5d6fb0a
use a quark_format model as reference
xuebwang-amd Oct 18, 2025
17e68bd
keep both glob-style wildcard matching and simple substring containme…
xuebwang-amd Oct 24, 2025
4b60f95
reset file mode for vllm/model_executor/layers/fused_moe/utils.py
xuebwang-amd Oct 24, 2025
69c9be9
reset file mode for docs/features/quantization/quark.md
xuebwang-amd Oct 24, 2025
b738f39
reset quark.py to 644
xuebwang-amd Oct 27, 2025
db962b6
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 27, 2025
38b4a1e
make layer_quant_config parsing more efficient
xuebwang-amd Oct 27, 2025
36fca3c
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 27, 2025
bb83007
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 27, 2025
25ec47d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
03cde7c
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
7d98368
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
66ba5a8
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
e0d618b
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
d7a584d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
bb2d70b
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
cdc2c54
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 28, 2025
c2e6de3
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 29, 2025
b0fe61d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 29, 2025
7c83065
Merge branch 'main' into xuebin/upstream_amd_quark_layerwise_mixed_pr…
DarkLight1337 Oct 29, 2025
51e9dfc
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 29, 2025
6f9358f
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 29, 2025
ef3fde7
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 30, 2025
49fbe40
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 30, 2025
b4ab91a
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 31, 2025
cf093b1
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 31, 2025
905633d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 2, 2025
536a369
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 3, 2025
4d78818
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 4, 2025
800b50b
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 4, 2025
4fb82ea
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 4, 2025
1488e0c
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 5, 2025
9859b96
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 5, 2025
bf8a88d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion docs/features/quantization/quark.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--group_size 32
```

The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations.

## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models

vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including

- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16}
- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}

Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput.

There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below.

### 1. Quantize a model using mixed precision in AMD Quark

Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.

As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:

- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8

### 2. inference the quantized mixed precision model in vLLM

Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow:

```bash
lm_eval --model vllm \
--model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \
--tasks mmlu \
--batch_size auto
```
69 changes: 69 additions & 0 deletions tests/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test quark-quantized {MXFP4, FP8} mixed precision models.

Run `pytest tests/quantization/test_mixed_precision.py`.

"""

import importlib
import importlib.metadata
from dataclasses import dataclass

import lm_eval
import pytest
from packaging import version

QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")


@dataclass
class ModelCase:
model_id: str
tp: int


@dataclass
class EvaluationConfig:
model_name: str

def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False"
)


TEST_CONFIGS = {
# Mixed-precision (AMP) model
# - Demonstrates end-to-end pipeline functionality
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72},
# Non-mixed-precision (PTQ) model
# - Reference for pipeline compatibility verification -> No conflicts or breakings
"amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": {
"arc_challenge": 0.53,
"mmlu": 0.61,
},
}


@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items())
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(),
tasks=list(accuracy_numbers.keys()),
batch_size=8,
)

rtol = 0.05

for task, expect_accuracy in accuracy_numbers.items():
measured_accuracy = results["results"][task]["acc,none"]
assert (
measured_accuracy - rtol < expect_accuracy
and measured_accuracy + rtol > expect_accuracy
), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}"
32 changes: 25 additions & 7 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,14 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)

if not kv_cache_set.issubset(layer_quant_set):
if not (
kv_cache_set.issubset(layer_quant_set)
or any(
fnmatch.fnmatchcase(layer_quant, pat)
for layer_quant in list(layer_quant_set)
for pat in list(kv_cache_set)
)
):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
Expand All @@ -124,10 +131,15 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
)

q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
quant_cfg
for name, quant_cfg in layer_quant_config.items()
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):

if not all(
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
for q_config in q_configs
):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
Expand Down Expand Up @@ -312,9 +324,15 @@ def _find_matched_config(
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]

def _matches_pattern(layer_name, pattern):
if "*" not in pattern:
return layer_name in pattern
return fnmatch.fnmatch(layer_name, pattern)

for name_pattern, config in layer_quant_config.items():
if _matches_pattern(layer_name, name_pattern):
return config

layer_type = cast(str, type(module))
layer_type_quant_config = cast(
Expand Down