Skip to content

Commit f02db41

Browse files
committed
Support using Int4PreshuffledTensor after loading
Summary: Int4PreshuffledTensor has fasted int4 kernel for int4 weight only and fp8 act + int4 weight in fbgemm, but we can't slice the Tensor due to the preshuffling (and slice has to preserve alias) so we have to use Int4Tensor (plain format) so it can be sliced during loading, and convert the Tensor to preshuffled format after loading using `torchao.prototype.tensor_conversion.api.convert_to_packed_tensor_based_on_current_hardware` function. Test Plan: pytest tests/quantization/test_torchao.py -k test_opt_125m_int4wo_model_running_preshuffled_kernel For test we uploaded a plain int4 tensor checkpoint https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev and load it in vllm, then check the model is transformed to use Int4PreshuffledTensor before inference Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <[email protected]>
1 parent c312468 commit f02db41

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

tests/quantization/test_torchao.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,5 +211,32 @@ def test_reload_weights():
211211
# print("-" * 60)
212212

213213

214+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
215+
# @pytest.mark.skip(
216+
# reason="since torchao nightly is only compatible with torch nightly"
217+
# "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
218+
# "torchao tests that requires newer versions (0.14.0.dev+) for now")
219+
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner):
220+
"""We load a model with Int4Tensor (plain format) linear weights
221+
and verify that the weight is updated to Int4PreshuffledTensor
222+
after loading in vllm
223+
"""
224+
torch._dynamo.reset()
225+
model_name = ("torchao-testing/opt-125m-Int4WeightOnlyConfig-v2"
226+
"-0.14.0.dev")
227+
with vllm_runner(model_name=model_name,
228+
quantization="torchao",
229+
dtype="bfloat16",
230+
pt_load_map_location="cuda:0") as llm:
231+
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
232+
orig_model = model_runner.model
233+
print("orig model:", orig_model)
234+
235+
output = llm.generate_greedy(["The capital of France is"],
236+
max_tokens=32)
237+
238+
assert output
239+
240+
214241
if __name__ == "__main__":
215242
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ def apply(
260260

261261
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
262262
if self.quant_config.is_checkpoint_torchao_serialized:
263+
from torchao.prototype.tensor_conversion.api import (
264+
convert_to_packed_tensor_based_on_current_hardware)
265+
if hasattr(layer, "weight"):
266+
layer.weight = Parameter(
267+
convert_to_packed_tensor_based_on_current_hardware(
268+
layer.weight))
263269
return
264270

265271
# quantize the weight on the fly if the checkpoint is not already

0 commit comments

Comments
 (0)