Skip to content

Conversation

@Edwardf0t1
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 commented Jun 18, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

This PR is to add Nvidia TensorRT Model Optimizer (modelopt) config adaptation with auto detection. This is part of the efforts to unify modelopt's config format and compressed-tensor's config format.

  • Update config parsing to look for quant_library instead of quant_method.
  • Maintain backward compatibility by checking both field names.
  • Add modelopt config adaptation to handle it as a quant method option.
  • Add a unit test for modelopt fp8 checkpoint test_modelopt.py

It's essentially format standardization while preserving library-specific functionality.

Test Plan

from vllm import LLM, SamplingParams
def main():

    model_id = "Llama-3.1-8B-Instruct-FP8"
    sampling_params = SamplingParams(temperature=0.7, top_p=0.9)

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    llm = LLM(model=model_id, quantization="modelopt")
    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

if __name__ == "__main__":
    main()

Llama-3.1-8B-Instruct-FP8 is produced by modelopt with per-tensor FP8 weights and activations. It can be generated by running the following command under this directory.

python hf_ptq.py --pyt_ckpt_path meta-llama/Llama-3.1-8B-Instruct --qformat fp8 --export_fmt hf --export_path Llama-3.1-8B-Instruct-FP8 --trust_remote_code

The quantization config in the exported config.json would look like the following:

    "quantization_config": {
        "config_groups": {
            "group_0": {
                "input_activations": {
                    "dynamic": false,
                    "num_bits": 8,
                    "type": "float"
                },
                "weights": {
                    "dynamic": false,
                    "num_bits": 8,
                    "type": "float"
                },
                "targets": ["Linear"]
            }
        },
        "ignore": [
            "lm_head"
        ],
        "quant_algo": "FP8",
        "kv_cache_scheme": {
            "dynamic": false,
            "num_bits": 8,
            "type": "float"
        },
        "quant_method": "modelopt",
        "producer": {
            "name": "modelopt",
            "version": "0.33.0"
        }
    }

Test Result

Prompt: 'Hello, my name is', Generated text: ' Helen and I am a 6th grade math teacher. I love teaching math'
Prompt: 'The president of the United States is', Generated text: ' the head of the federal government. He is the leader of the country and is'
Prompt: 'The capital of France is', Generated text: ' Paris, a city that has been called the most romantic city in the world.'
Prompt: 'The future of AI is', Generated text: ' not just about technology, but about the people and the world we create with it'

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Edwardf0t1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the clarity and flexibility of quantization configuration parsing within vLLM. My primary goal was to standardize the naming convention for quantization libraries while ensuring backward compatibility, and to integrate robust support for Nvidia TensorRT Model Optimizer (ModelOpt) quantized models. This allows vLLM to seamlessly load and utilize models quantized with ModelOpt's FP8 and NVFP4 schemes, adapting to different configuration formats.

Highlights

  • Quantization Config Parsing Improvement: I've updated the quantization configuration parsing logic to prioritize the quant_library field for improved semantic clarity. This change also ensures backward compatibility by falling back to the quant_method field if quant_library is not present.
  • Nvidia ModelOpt Integration: I've added comprehensive support for Nvidia TensorRT Model Optimizer (ModelOpt) quantization, specifically for both FP8 and NVFP4 formats. This includes recognizing these as valid quantization methods within vLLM.
  • Automatic ModelOpt Detection and Format Adaptation: I've implemented override_quantization_method for ModelOpt FP8 and NVFP4 configurations. This allows vLLM to automatically detect and apply the correct ModelOpt quantization method based on the hf_quant_config.json file, supporting both traditional ModelOpt nested config structures and flattened compressed-tensors style formats.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request improves the semantic clarity of quantization configurations and adds adaptation for Nvidia ModelOpt configurations, including auto-detection. The changes involve updating config parsing to prioritize quant_library over quant_method while maintaining backward compatibility, and adding modelopt config adaptation. The code has been reviewed and suggestions have been provided to improve robustness and clarity.

@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch from 785f0b0 to aca87c2 Compare June 18, 2025 18:47
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch from 5257ade to 447d23b Compare July 2, 2025 04:47
@Edwardf0t1
Copy link
Contributor Author

Hi @mgoin @robertgshaw2-redhat , it was very nice meeting you and team regarding collaborations between NV Modelopt and llm-compressor yesterday.

Could you help review this PR as discussed?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation doesn't seem like it is utilizing the structure of the CT format and instead includes duplicate/fixed information through the "quant_algo" and "kv_cache_scheme" entries.
What I mean is, I would expect your Llama FP8 config to be more like this:

    "quantization_config": {
        "config_groups": {
            "group_0": {
                "input_activations": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                },
                "weights": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                }
                "targets": ["Linear"],
            }
        },
        "ignore": [
            "lm_head"
        ],
        "kv_cache_scheme": {
            "dynamic": false,
            "strategy": "tensor",
            "num_bits": 8,
            "type": "float"
        },
        "quant_method": "compressed-tensors",
         "producer": {
            "name": "modelopt",
            "version": "0.33.0"
        }
    }

And then the vLLM modelopt backend to have matching checks for that "FP8" scheme based on the sub-configs. This is like the _is_fp4a4_nvfp4 style functions we have in CT

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (weight_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT.value)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

It would also be good to add a small config unit test to make sure vLLM parses the expected format and dispatched to the quant method correctly, similar to the tests here

def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):

@Edwardf0t1
Copy link
Contributor Author

Edwardf0t1 commented Jul 9, 2025

The current implementation doesn't seem like it is utilizing the structure of the CT format and instead includes duplicate/fixed information through the "quant_algo" and "kv_cache_scheme" entries. What I mean is, I would expect your Llama FP8 config to be more like this:

    "quantization_config": {
        "config_groups": {
            "group_0": {
                "input_activations": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                },
                "weights": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                }
                "targets": ["Linear"],
            }
        },
        "ignore": [
            "lm_head"
        ],
        "kv_cache_scheme": {
            "dynamic": false,
            "strategy": "tensor",
            "num_bits": 8,
            "type": "float"
        },
        "quant_method": "compressed-tensors",
         "producer": {
            "name": "modelopt",
            "version": "0.33.0"
        }
    }

And then the vLLM modelopt backend to have matching checks for that "FP8" scheme based on the sub-configs. This is like the _is_fp4a4_nvfp4 style functions we have in CT

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (weight_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT.value)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

It would also be good to add a small config unit test to make sure vLLM parses the expected format and dispatched to the quant method correctly, similar to the tests here

def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):

Thank you for the feedback! @mgoin

  • I think using "quant_method": "compressed-tensors" isn’t semantically meaningful in this context, since we’re quantizing the model with modelopt and the actual method is fp8. Our intention is for vLLM to accept "quant_library": "modelopt" so we can reuse the CT config logic, rather than embedding "quant_method": "compressed-tensors" directly in the quant configs exported by modelopt.

  • Regarding the required keywords in the config — aside from "strategy": "tensor" and "targets": ["Linear"] — are there any other fields that must be explicitly defined for the CT format? I’d like to add a small unit test to cover the config.

@dsikka
Copy link
Contributor

dsikka commented Jul 9, 2025

The current implementation doesn't seem like it is utilizing the structure of the CT format and instead includes duplicate/fixed information through the "quant_algo" and "kv_cache_scheme" entries. What I mean is, I would expect your Llama FP8 config to be more like this:

    "quantization_config": {
        "config_groups": {
            "group_0": {
                "input_activations": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                },
                "weights": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                }
                "targets": ["Linear"],
            }
        },
        "ignore": [
            "lm_head"
        ],
        "kv_cache_scheme": {
            "dynamic": false,
            "strategy": "tensor",
            "num_bits": 8,
            "type": "float"
        },
        "quant_method": "compressed-tensors",
         "producer": {
            "name": "modelopt",
            "version": "0.33.0"
        }
    }

And then the vLLM modelopt backend to have matching checks for that "FP8" scheme based on the sub-configs. This is like the _is_fp4a4_nvfp4 style functions we have in CT

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (weight_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT.value)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

It would also be good to add a small config unit test to make sure vLLM parses the expected format and dispatched to the quant method correctly, similar to the tests here

def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):

Thank you for the feedback! @mgoin

  • I think using "quant_method": "compressed-tensors" isn’t semantically meaningful in this context, since we’re quantizing the model with modelopt and the actual method is fp8. Our intention is for vLLM to accept "quant_library": "modelopt" so we can reuse the CT config logic, rather than embedding "quant_method": "compressed-tensors" directly in the quant configs exported by modelopt.
  • Regarding the required keywords in the config — aside from "strategy": "tensor" and "targets": ["Linear"] — are there any other fields that must be explicitly defined for the CT format? I’d like to add a small unit test to cover the config.

@Edwardf0t1

For a list of arguments that are defined in each ct structure, you can refer to the Pydantic model here:
https://github.com/neuralmagic/compressed-tensors/blob/b163bd9994c6b274f672f2846e1a64568f8ab1d5/src/compressed_tensors/quantization/quant_args.py#L146 (defined for each group of weights and input activations)

@Edwardf0t1
Copy link
Contributor Author

Edwardf0t1 commented Jul 10, 2025

The current implementation doesn't seem like it is utilizing the structure of the CT format and instead includes duplicate/fixed information through the "quant_algo" and "kv_cache_scheme" entries. What I mean is, I would expect your Llama FP8 config to be more like this:

    "quantization_config": {
        "config_groups": {
            "group_0": {
                "input_activations": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                },
                "weights": {
                    "dynamic": false,
                    "strategy": "tensor",
                    "num_bits": 8,
                    "type": "float"
                }
                "targets": ["Linear"],
            }
        },
        "ignore": [
            "lm_head"
        ],
        "kv_cache_scheme": {
            "dynamic": false,
            "strategy": "tensor",
            "num_bits": 8,
            "type": "float"
        },
        "quant_method": "compressed-tensors",
         "producer": {
            "name": "modelopt",
            "version": "0.33.0"
        }
    }

And then the vLLM modelopt backend to have matching checks for that "FP8" scheme based on the sub-configs. This is like the _is_fp4a4_nvfp4 style functions we have in CT

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (weight_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT.value)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

It would also be good to add a small config unit test to make sure vLLM parses the expected format and dispatched to the quant method correctly, similar to the tests here

def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):

Thank you for the feedback! @mgoin

  • I think using "quant_method": "compressed-tensors" isn’t semantically meaningful in this context, since we’re quantizing the model with modelopt and the actual method is fp8. Our intention is for vLLM to accept "quant_library": "modelopt" so we can reuse the CT config logic, rather than embedding "quant_method": "compressed-tensors" directly in the quant configs exported by modelopt.
  • Regarding the required keywords in the config — aside from "strategy": "tensor" and "targets": ["Linear"] — are there any other fields that must be explicitly defined for the CT format? I’d like to add a small unit test to cover the config.

@Edwardf0t1

For a list of arguments that are defined in each ct structure, you can refer to the Pydantic model here: https://github.com/neuralmagic/compressed-tensors/blob/b163bd9994c6b274f672f2846e1a64568f8ab1d5/src/compressed_tensors/quantization/quant_args.py#L146 (defined for each group of weights and input activations)

Thanks for the pointer. @dsikka

I was trying to find what arguments need to be explicitly defined in model's quant config. Looks like strategy is optional, and its value can be inferred from group_size. So it's not a must have.

@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch from 447d23b to eefe692 Compare July 10, 2025 07:34
@Edwardf0t1 Edwardf0t1 changed the title Improve quant config semantic clarity, add Nvidia ModelOpt config adaptation Add Nvidia ModelOpt config adaptation Jul 10, 2025
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch 2 times, most recently from 0d30361 to 79c5e76 Compare July 10, 2025 20:59
@Edwardf0t1
Copy link
Contributor Author

@mgoin @dsikka Would you mind take another look for this PR? Thanks!

Comment on lines +28 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to skip this test for now until you have a public checkpoint? I think this will break the quantization test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Comment on lines +15 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this require V0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I was aligning with the test here which requires V0. https://github.com/vllm-project/vllm/blob/main/tests/quantization/test_compressed_tensors.py#L44

Do you know which type of module test requires v0?

@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch 2 times, most recently from 1477e32 to 0a3c6ee Compare July 11, 2025 23:03
@mergify
Copy link

mergify bot commented Jul 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Edwardf0t1.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 12, 2025
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch from 0a3c6ee to 47e8419 Compare July 12, 2025 07:28
@mergify mergify bot removed the needs-rebase label Jul 12, 2025
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
Signed-off-by: Zhiyu Cheng <[email protected]>
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/unify-config-modelopt branch from 47e8419 to 6dcc860 Compare July 12, 2025 07:32
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 12, 2025
@mgoin mgoin merged commit 6b46c4b into vllm-project:main Jul 21, 2025
79 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
@Edwardf0t1 Edwardf0t1 mentioned this pull request Aug 12, 2025
4 tasks
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants