diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 1a20228765e8..6571fc9e471b 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -31,9 +31,6 @@ def test_pre_quantized_model(vllm_runner): ]) def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, pt_load_map_location): - """ - Test loading roberta-base model with no lm_head. - """ torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int4wo" with vllm_runner(model_name=model_name, @@ -47,5 +44,20 @@ def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): + torch._dynamo.reset() + model_name = "jerryzh168/opt-125m-int4wo-per-module" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 751002fa0945..2325b01086e0 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -5,10 +5,11 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -55,10 +56,24 @@ def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig": return cls(ao_config) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["TorchAOLinearMethod"]: - if isinstance(layer, LinearBase): - return TorchAOLinearMethod(self) - return None + prefix: str) -> Optional["QuantizeMethodBase"]: + if not isinstance(layer, LinearBase): + return None + + from torchao.quantization import AOPerModuleConfig + + module_fqn = prefix + if isinstance(self.torchao_config, AOPerModuleConfig): + module_fqn_to_config = self.torchao_config.module_fqn_to_config + c = module_fqn_to_config.get( + module_fqn) or module_fqn_to_config.get("_default", None) + if c is not None: + current_torchao_config = TorchAOConfig(c) + return TorchAOLinearMethod(current_torchao_config) + else: + return UnquantizedLinearMethod() + + return TorchAOLinearMethod(self) def get_scaled_act_names(self) -> List[str]: return [] @@ -75,7 +90,7 @@ def torchao_quantize_param_data(param: torch.Tensor, """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ - assert isinstance(torchao_config, AOBaseConfig) + assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) dummy_linear.weight = param quantize_(dummy_linear, torchao_config)