Skip to content

Commit 2019b80

Browse files
committed
Fix the impl for to for int4 weight only use case
Summary: Note that we can do the following right now: * initialize and quantize the model with int4_weight_only quant in cpu * move the model to cuda we'll enable this in a separate PR Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent f8789f7 commit 2019b80

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

test/quantization/test_quant_api.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def test_quantized_tensor_subclass_save_load(self):
624624

625625
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
626626
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
627-
def test_quantized_model_to_device(self):
627+
def test_int8wo_quantized_model_to_device(self):
628628
m = ToyLinearModel().eval().to(torch.bfloat16)
629629
m_copy = copy.deepcopy(m)
630630
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
@@ -637,6 +637,22 @@ def test_quantized_model_to_device(self):
637637
cuda_res = m(*example_inputs_cuda)
638638
self.assertEqual(cuda_res.cpu(), ref)
639639

640+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
641+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
642+
def test_int4wo_quantized_model_to_device(self):
643+
# TODO: change initial model to "cpu"
644+
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
645+
m_copy = copy.deepcopy(m)
646+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
647+
648+
quantize_(m, int4_weight_only())
649+
ref = m(*example_inputs)
650+
651+
example_inputs_cuda = (example_inputs[0].to("cuda"),)
652+
m.to(device="cuda")
653+
cuda_res = m(*example_inputs_cuda)
654+
self.assertEqual(cuda_res.cpu(), ref)
655+
640656
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
641657
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
642658
def test_quantized_tensor_subclass_save_load_map_location(self):

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,8 @@ def from_plain(
511511
def to(self, *args, **kwargs):
512512
kwargs = self._get_to_kwargs(*args, **kwargs)
513513
device = kwargs["device"]
514-
if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"):
515-
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device")
514+
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
515+
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
516516
return self.__class__(
517517
self.packed_weight.to(device),
518518
self.scale_and_zero.to(device),

0 commit comments

Comments
 (0)