Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def test_pre_quantized_model(vllm_runner):
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
"""
Test loading roberta-base model with no lm_head.
"""
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
with vllm_runner(model_name=model_name,
Expand All @@ -47,5 +44,20 @@ def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo-per-module"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output
print(output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert output, f"output is {output}"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update in next change, feel free to merge



if __name__ == "__main__":
pytest.main([__file__])
29 changes: 22 additions & 7 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs


Expand Down Expand Up @@ -55,10 +56,24 @@ def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
return cls(ao_config)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["TorchAOLinearMethod"]:
if isinstance(layer, LinearBase):
return TorchAOLinearMethod(self)
return None
prefix: str) -> Optional["QuantizeMethodBase"]:
if not isinstance(layer, LinearBase):
return None

from torchao.quantization import AOPerModuleConfig

module_fqn = prefix
if isinstance(self.torchao_config, AOPerModuleConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()

return TorchAOLinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []
Expand All @@ -75,7 +90,7 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig)
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
Expand Down