diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index cab198a2a15e..82413f36e997 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") -def test_on_the_fly_quant_config_dict_json(vllm_runner): +def test_online_quant_config_dict_json(vllm_runner): """Testing on the fly quantization, load_weights integration point, with config dict serialized to json string """ @@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner): @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") -def test_on_the_fly_quant_config_file(vllm_runner): +def test_online_quant_config_file(vllm_runner): """Testing on the fly quantization, load_weights integration point, with config file """ @@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner): ) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=4) + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch): + """We load a model with Int4Tensor (plain format) linear weights + and verify that the weight is updated to Int4PreshuffledTensor + after loading in vllm + """ + from torchao.quantization import Int4PreshuffledTensor + from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90 + + torch._dynamo.reset() + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev" + # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't + # have meta kernel implemented yet, can remove this flag after that is implemented + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + enforce_eager=True, + ) as llm: + + def has_int4_preshuffled_tensor_weight(model): + return isinstance( + model.model.decoder.layers[0].self_attn.qkv_proj.weight, + Int4PreshuffledTensor, + ) + + def get_weight_attrs(model): + weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight + return [ + weight.requires_grad, + weight.input_dim, + weight.output_dim, + hasattr(weight, "weight_loader"), + ] + + llm_engine = llm.get_llm().llm_engine + has_int4_preshuffled_tensor = any( + llm_engine.apply_model(has_int4_preshuffled_tensor_weight) + ) + weight_attrs = llm_engine.apply_model(get_weight_attrs)[0] + + # making sure we are using Int4PreshuffledTensor on H100 GPU, when + # fbgemm_gpu_genai + # library is installed, otherwise it should be using Int4Tensor + if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + assert has_int4_preshuffled_tensor + else: + assert not has_int4_preshuffled_tensor + + assert weight_attrs == [False, 1, 0, True] + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant( + vllm_runner, monkeypatch +): + """We load a bf16 model and online quantize the model to int4, then verify that + the weights are updated to Int4PreshuffledTensor after online quantization + """ + from torchao.quantization import Int4PreshuffledTensor + from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90 + + torch._dynamo.reset() + model_name = "facebook/opt-125m" + + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import Int4WeightOnlyConfig + + torchao_quant_config = Int4WeightOnlyConfig( + group_size=128, int4_packing_format="plain" + ) + hf_overrides = { + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) + } + + # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't + # have meta kernel implemented yet, can remove this flag after that is implemented + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + hf_overrides=hf_overrides, + enforce_eager=True, + ) as llm: + + def has_int4_preshuffled_tensor_weight(model): + return isinstance( + model.model.decoder.layers[0].self_attn.qkv_proj.weight, + Int4PreshuffledTensor, + ) + + def get_weight_attrs(model): + weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight + return [ + weight.requires_grad, + weight.input_dim, + weight.output_dim, + hasattr(weight, "weight_loader"), + ] + + llm_engine = llm.get_llm().llm_engine + has_int4_preshuffled_tensor = any( + llm_engine.apply_model(has_int4_preshuffled_tensor_weight) + ) + weight_attrs = llm_engine.apply_model(get_weight_attrs)[0] + + # making sure we are using Int4PreshuffledTensor on H100 GPU, when + # fbgemm_gpu_genai + # library is installed, otherwise it should be using Int4Tensor + if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + assert has_int4_preshuffled_tensor + else: + assert not has_int4_preshuffled_tensor + + assert weight_attrs == [False, 1, 0, True] + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + assert output diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index f42c45dae76d..3fee71e193db 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib import json +import types from importlib.util import find_spec from typing import Any, Optional @@ -27,6 +28,39 @@ logger = init_logger(__name__) +def _bond_method_to_cls(func, obj): + if hasattr(func, "__self__") or not callable(func): + # If the function is already bound to an instance, return it as is + return func + else: + return types.MethodType(func, obj) + + +def _get_weight_attrs(param): + # record attributes attached to the weight, so we can + # recover later + recorded_weight_attr = {} + for key in param.__dict__: + if hasattr(param, key): + attr = getattr(param, key) + if not callable(attr): + recorded_weight_attr[key] = attr + elif hasattr(attr, "__self__") and param is attr.__self__: + # if attr is a bonded method for an instance, and + # attr.__self__ points to the instance (param) + # we'll record the underlying function object + recorded_weight_attr[key] = attr.__func__ + else: + recorded_weight_attr[key] = attr + return recorded_weight_attr + + +def _restore_weight_attrs(param, recorded_weight_attr): + for attr_name, attr in recorded_weight_attr.items(): + if not hasattr(param, attr_name): + setattr(param, attr_name, _bond_method_to_cls(attr, param)) + + def torchao_version_at_least(torchao_version: str) -> bool: if find_spec("torchao"): try: @@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool: return False +if torchao_version_at_least("0.15.0"): + from torchao.prototype.tensor_conversion.api import ( + convert_to_packed_tensor_based_on_current_hardware, + ) +else: + convert_to_packed_tensor_based_on_current_hardware = lambda t: t + + class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" @@ -307,12 +349,32 @@ def apply( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.quant_config.is_checkpoint_torchao_serialized: + if not hasattr(layer, "weight"): + return + + # record attributes attached to the weight, so we can + # recover later + recorded_weight_attr = _get_weight_attrs(layer.weight) + + layer.weight = Parameter( + convert_to_packed_tensor_based_on_current_hardware(layer.weight), + requires_grad=layer.weight.requires_grad, + ) + + _restore_weight_attrs(layer.weight, recorded_weight_attr) return - # quantize the weight on the fly if the checkpoint is not already + # online quantize the weight if the checkpoint is not already # quantized by torchao + recorded_weight_attr = _get_weight_attrs(layer.weight) + weight = torchao_quantize_param_data( layer.weight, self.quant_config.torchao_config ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + weight = torch.nn.Parameter( + convert_to_packed_tensor_based_on_current_hardware(weight), + weight.requires_grad, + ) + + _restore_weight_attrs(weight, recorded_weight_attr) layer.register_parameter("weight", weight)