Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_dict_json(vllm_runner):
def test_online_quant_config_dict_json(vllm_runner):
"""Testing on the fly quantization, load_weights integration point,
with config dict serialized to json string
"""
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_file(vllm_runner):
def test_online_quant_config_file(vllm_runner):
"""Testing on the fly quantization, load_weights integration point,
with config file
"""
Expand Down Expand Up @@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)

assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
"""We load a model with Int4Tensor (plain format) linear weights
and verify that the weight is updated to Int4PreshuffledTensor
after loading in vllm
"""
from torchao.quantization import Int4PreshuffledTensor
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90

torch._dynamo.reset()
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with vllm_runner(
model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0",
enforce_eager=True,
) as llm:

def has_int4_preshuffled_tensor_weight(model):
return isinstance(
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
Int4PreshuffledTensor,
)

def get_weight_attrs(model):
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
return [
weight.requires_grad,
weight.input_dim,
weight.output_dim,
hasattr(weight, "weight_loader"),
]

llm_engine = llm.get_llm().llm_engine
has_int4_preshuffled_tensor = any(
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
)
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]

# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
assert has_int4_preshuffled_tensor
else:
assert not has_int4_preshuffled_tensor

assert weight_attrs == [False, 1, 0, True]
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)

assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
vllm_runner, monkeypatch
):
"""We load a bf16 model and online quantize the model to int4, then verify that
the weights are updated to Int4PreshuffledTensor after online quantization
"""
from torchao.quantization import Int4PreshuffledTensor
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90

torch._dynamo.reset()
model_name = "facebook/opt-125m"

monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

import json

from torchao.core.config import config_to_dict
from torchao.quantization import Int4WeightOnlyConfig

torchao_quant_config = Int4WeightOnlyConfig(
group_size=128, int4_packing_format="plain"
)
hf_overrides = {
"quantization_config_dict_json": json.dumps(
config_to_dict(torchao_quant_config)
)
}

# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with vllm_runner(
model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0",
hf_overrides=hf_overrides,
enforce_eager=True,
) as llm:

def has_int4_preshuffled_tensor_weight(model):
return isinstance(
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
Int4PreshuffledTensor,
)

def get_weight_attrs(model):
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
return [
weight.requires_grad,
weight.input_dim,
weight.output_dim,
hasattr(weight, "weight_loader"),
]

llm_engine = llm.get_llm().llm_engine
has_int4_preshuffled_tensor = any(
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
)
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]

# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
assert has_int4_preshuffled_tensor
else:
assert not has_int4_preshuffled_tensor

assert weight_attrs == [False, 1, 0, True]
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)

assert output


Expand Down
66 changes: 64 additions & 2 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import types
from importlib.util import find_spec
from typing import Any, Optional

Expand All @@ -27,6 +28,39 @@
logger = init_logger(__name__)


def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)


def _get_weight_attrs(param):
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
recorded_weight_attr[key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
recorded_weight_attr[key] = attr.__func__
else:
recorded_weight_attr[key] = attr
return recorded_weight_attr


def _restore_weight_attrs(param, recorded_weight_attr):
for attr_name, attr in recorded_weight_attr.items():
if not hasattr(param, attr_name):
setattr(param, attr_name, _bond_method_to_cls(attr, param))


def torchao_version_at_least(torchao_version: str) -> bool:
if find_spec("torchao"):
try:
Expand Down Expand Up @@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
return False


if torchao_version_at_least("0.15.0"):
from torchao.prototype.tensor_conversion.api import (
convert_to_packed_tensor_based_on_current_hardware,
)
else:
convert_to_packed_tensor_based_on_current_hardware = lambda t: t


class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""

Expand Down Expand Up @@ -307,12 +349,32 @@ def apply(

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.is_checkpoint_torchao_serialized:
if not hasattr(layer, "weight"):
return

# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = _get_weight_attrs(layer.weight)

layer.weight = Parameter(
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
requires_grad=layer.weight.requires_grad,
)

_restore_weight_attrs(layer.weight, recorded_weight_attr)
return

# quantize the weight on the fly if the checkpoint is not already
# online quantize the weight if the checkpoint is not already
# quantized by torchao
recorded_weight_attr = _get_weight_attrs(layer.weight)

weight = torchao_quantize_param_data(
layer.weight, self.quant_config.torchao_config
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
weight = torch.nn.Parameter(
convert_to_packed_tensor_based_on_current_hardware(weight),
weight.requires_grad,
)

_restore_weight_attrs(weight, recorded_weight_attr)
layer.register_parameter("weight", weight)