Skip to content

Commit 43ffb17

Browse files
committed
Support using Int4PreshuffledTensor after loading
Summary: Int4PreshuffledTensor has fasted int4 kernel for int4 weight only and fp8 act + int4 weight in fbgemm, but we can't slice the Tensor due to the preshuffling (and slice has to preserve alias) so we have to use Int4Tensor (plain format) so it can be sliced during loading, and convert the Tensor to preshuffled format after loading using `torchao.prototype.tensor_conversion.api.convert_to_packed_tensor_based_on_current_hardware` function. Test Plan: pytest tests/quantization/test_torchao.py -k test_opt_125m_int4wo_model_running_preshuffled_kernel For test we uploaded a plain int4 tensor checkpoint https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev and load it in vllm, then check the model is transformed to use Int4PreshuffledTensor before inference Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <[email protected]>
1 parent df33486 commit 43ffb17

File tree

2 files changed

+208
-4
lines changed

2 files changed

+208
-4
lines changed

tests/quantization/test_torchao.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
9999

100100

101101
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
102-
def test_on_the_fly_quant_config_dict_json(vllm_runner):
102+
def test_online_quant_config_dict_json(vllm_runner):
103103
"""Testing on the fly quantization, load_weights integration point,
104104
with config dict serialized to json string
105105
"""
@@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
133133

134134

135135
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
136-
def test_on_the_fly_quant_config_file(vllm_runner):
136+
def test_online_quant_config_file(vllm_runner):
137137
"""Testing on the fly quantization, load_weights integration point,
138138
with config file
139139
"""
@@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
252252
) as llm:
253253
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
254254

255+
assert output
256+
257+
258+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
259+
@pytest.mark.skip(
260+
reason="since torchao nightly is only compatible with torch nightly"
261+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
262+
"torchao tests that requires newer versions (0.14.0.dev+) for now"
263+
)
264+
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
265+
"""We load a model with Int4Tensor (plain format) linear weights
266+
and verify that the weight is updated to Int4PreshuffledTensor
267+
after loading in vllm
268+
"""
269+
from torchao.quantization import Int4PreshuffledTensor
270+
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
271+
272+
torch._dynamo.reset()
273+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
274+
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
275+
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
276+
# have meta kernel implemented yet, can remove this flag after that is implemented
277+
with vllm_runner(
278+
model_name=model_name,
279+
quantization="torchao",
280+
dtype="bfloat16",
281+
pt_load_map_location="cuda:0",
282+
enforce_eager=True,
283+
) as llm:
284+
285+
def has_int4_preshuffled_tensor_weight(model):
286+
return isinstance(
287+
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
288+
Int4PreshuffledTensor,
289+
)
290+
291+
def get_weight_attrs(model):
292+
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
293+
return [
294+
weight.requires_grad,
295+
weight.input_dim,
296+
weight.output_dim,
297+
hasattr(weight, "weight_loader"),
298+
]
299+
300+
llm_engine = llm.get_llm().llm_engine
301+
has_int4_preshuffled_tensor = any(
302+
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
303+
)
304+
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
305+
306+
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
307+
# fbgemm_gpu_genai
308+
# library is installed, otherwise it should be using Int4Tensor
309+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
310+
assert has_int4_preshuffled_tensor
311+
else:
312+
assert not has_int4_preshuffled_tensor
313+
314+
assert weight_attrs == [False, 1, 0, True]
315+
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
316+
317+
assert output
318+
319+
320+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
321+
@pytest.mark.skip(
322+
reason="since torchao nightly is only compatible with torch nightly"
323+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
324+
"torchao tests that requires newer versions (0.14.0.dev+) for now"
325+
)
326+
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
327+
vllm_runner, monkeypatch
328+
):
329+
"""We load a bf16 model and online quantize the model to int4, then verify that
330+
the weights are updated to Int4PreshuffledTensor after online quantization
331+
"""
332+
from torchao.quantization import Int4PreshuffledTensor
333+
from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90
334+
335+
torch._dynamo.reset()
336+
model_name = "facebook/opt-125m"
337+
338+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
339+
340+
import json
341+
342+
from torchao.core.config import config_to_dict
343+
from torchao.quantization import Int4WeightOnlyConfig
344+
345+
torchao_quant_config = Int4WeightOnlyConfig(
346+
group_size=128, int4_packing_format="plain"
347+
)
348+
hf_overrides = {
349+
"quantization_config_dict_json": json.dumps(
350+
config_to_dict(torchao_quant_config)
351+
)
352+
}
353+
354+
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
355+
# have meta kernel implemented yet, can remove this flag after that is implemented
356+
with vllm_runner(
357+
model_name=model_name,
358+
quantization="torchao",
359+
dtype="bfloat16",
360+
pt_load_map_location="cuda:0",
361+
hf_overrides=hf_overrides,
362+
enforce_eager=True,
363+
) as llm:
364+
365+
def has_int4_preshuffled_tensor_weight(model):
366+
return isinstance(
367+
model.model.decoder.layers[0].self_attn.qkv_proj.weight,
368+
Int4PreshuffledTensor,
369+
)
370+
371+
def get_weight_attrs(model):
372+
weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
373+
return [
374+
weight.requires_grad,
375+
weight.input_dim,
376+
weight.output_dim,
377+
hasattr(weight, "weight_loader"),
378+
]
379+
380+
llm_engine = llm.get_llm().llm_engine
381+
has_int4_preshuffled_tensor = any(
382+
llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
383+
)
384+
weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]
385+
386+
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
387+
# fbgemm_gpu_genai
388+
# library is installed, otherwise it should be using Int4Tensor
389+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
390+
assert has_int4_preshuffled_tensor
391+
else:
392+
assert not has_int4_preshuffled_tensor
393+
394+
assert weight_attrs == [False, 1, 0, True]
395+
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
396+
255397
assert output
256398

257399

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import importlib
44
import json
5+
import types
56
from importlib.util import find_spec
67
from typing import Any, Optional
78

@@ -27,6 +28,39 @@
2728
logger = init_logger(__name__)
2829

2930

31+
def _bond_method_to_cls(func, obj):
32+
if hasattr(func, "__self__") or not callable(func):
33+
# If the function is already bound to an instance, return it as is
34+
return func
35+
else:
36+
return types.MethodType(func, obj)
37+
38+
39+
def _get_weight_attrs(param):
40+
# record attributes attached to the weight, so we can
41+
# recover later
42+
recorded_weight_attr = {}
43+
for key in param.__dict__:
44+
if hasattr(param, key):
45+
attr = getattr(param, key)
46+
if not callable(attr):
47+
recorded_weight_attr[key] = attr
48+
elif hasattr(attr, "__self__") and param is attr.__self__:
49+
# if attr is a bonded method for an instance, and
50+
# attr.__self__ points to the instance (param)
51+
# we'll record the underlying function object
52+
recorded_weight_attr[key] = attr.__func__
53+
else:
54+
recorded_weight_attr[key] = attr
55+
return recorded_weight_attr
56+
57+
58+
def _restore_weight_attrs(param, recorded_weight_attr):
59+
for attr_name, attr in recorded_weight_attr.items():
60+
if not hasattr(param, attr_name):
61+
setattr(param, attr_name, _bond_method_to_cls(attr, param))
62+
63+
3064
def torchao_version_at_least(torchao_version: str) -> bool:
3165
if find_spec("torchao"):
3266
try:
@@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
5791
return False
5892

5993

94+
if torchao_version_at_least("0.15.0"):
95+
from torchao.prototype.tensor_conversion.api import (
96+
convert_to_packed_tensor_based_on_current_hardware,
97+
)
98+
else:
99+
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
100+
101+
60102
class TorchAOConfig(QuantizationConfig):
61103
"""Config class for torchao."""
62104

@@ -307,12 +349,32 @@ def apply(
307349

308350
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
309351
if self.quant_config.is_checkpoint_torchao_serialized:
352+
if not hasattr(layer, "weight"):
353+
return
354+
355+
# record attributes attached to the weight, so we can
356+
# recover later
357+
recorded_weight_attr = _get_weight_attrs(layer.weight)
358+
359+
layer.weight = Parameter(
360+
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
361+
requires_grad=layer.weight.requires_grad,
362+
)
363+
364+
_restore_weight_attrs(layer.weight, recorded_weight_attr)
310365
return
311366

312-
# quantize the weight on the fly if the checkpoint is not already
367+
# online quantize the weight if the checkpoint is not already
313368
# quantized by torchao
369+
recorded_weight_attr = _get_weight_attrs(layer.weight)
370+
314371
weight = torchao_quantize_param_data(
315372
layer.weight, self.quant_config.torchao_config
316373
)
317-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
374+
weight = torch.nn.Parameter(
375+
convert_to_packed_tensor_based_on_current_hardware(weight),
376+
weight.requires_grad,
377+
)
378+
379+
_restore_weight_attrs(weight, recorded_weight_attr)
318380
layer.register_parameter("weight", weight)

0 commit comments

Comments
 (0)