|
| 1 | +import torch |
| 2 | +import copy |
| 3 | +import pytest |
| 4 | + |
| 5 | +from torch import nn |
| 6 | +from torch.testing._internal.common_utils import TestCase, run_tests |
| 7 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 8 | +from torchao.dtypes import MarlinQQQLayout |
| 9 | +from torchao.quantization.quant_api import ( |
| 10 | + quantize_, |
| 11 | + int8_dynamic_activation_int4_weight, |
| 12 | +) |
| 13 | +from torchao.quantization.marlin_qqq import ( |
| 14 | + pack_to_marlin_qqq, |
| 15 | + unpack_from_marlin_qqq, |
| 16 | +) |
| 17 | +from torchao.quantization.quant_primitives import ( |
| 18 | + choose_qparams_and_quantize_affine_qqq, |
| 19 | + MappingType, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +class MarlinQQQ(TestCase): |
| 24 | + def setUp(self): |
| 25 | + super().setUp() |
| 26 | + torch.manual_seed(0) |
| 27 | + |
| 28 | + self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda") |
| 29 | + self.model = ( |
| 30 | + nn.Sequential( |
| 31 | + nn.Linear(4096, 21504), |
| 32 | + nn.Linear(21504, 4096), |
| 33 | + nn.ReLU(), |
| 34 | + nn.Linear(4096, 21504), |
| 35 | + nn.Linear(21504, 4096), |
| 36 | + ) |
| 37 | + .half() |
| 38 | + .cuda() |
| 39 | + ) |
| 40 | + |
| 41 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
| 42 | + def test_marlin_qqq(self): |
| 43 | + output_ref = self.model(self.input) |
| 44 | + for group_size in [-1, 128]: |
| 45 | + modelq = copy.deepcopy(self.model) |
| 46 | + quantize_( |
| 47 | + modelq, |
| 48 | + int8_dynamic_activation_int4_weight( |
| 49 | + group_size=group_size, |
| 50 | + mapping_type=MappingType.SYMMETRIC, |
| 51 | + input_mapping_type=MappingType.SYMMETRIC, |
| 52 | + layout=MarlinQQQLayout(), |
| 53 | + ), |
| 54 | + ) |
| 55 | + output = modelq(self.input) |
| 56 | + |
| 57 | + assert torch.allclose( |
| 58 | + output, output_ref, atol=1e-1 |
| 59 | + ), "Results are not close" |
| 60 | + |
| 61 | + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") |
| 62 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
| 63 | + def test_marlin_qqq_compile(self): |
| 64 | + model_copy = copy.deepcopy(self.model) |
| 65 | + model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) |
| 66 | + output_ref = model_copy(self.input) |
| 67 | + |
| 68 | + for group_size in [-1, 128]: |
| 69 | + modelq = copy.deepcopy(self.model) |
| 70 | + quantize_( |
| 71 | + modelq, |
| 72 | + int8_dynamic_activation_int4_weight( |
| 73 | + group_size=group_size, |
| 74 | + mapping_type=MappingType.SYMMETRIC, |
| 75 | + input_mapping_type=MappingType.SYMMETRIC, |
| 76 | + layout=MarlinQQQLayout(), |
| 77 | + ), |
| 78 | + ) |
| 79 | + modelq.forward = torch.compile(modelq.forward, fullgraph=True) |
| 80 | + output = modelq(self.input) |
| 81 | + |
| 82 | + assert torch.allclose( |
| 83 | + output, output_ref, atol=1e-1 |
| 84 | + ), "Results are not close" |
| 85 | + |
| 86 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
| 87 | + def test_pack_unpack_equivalence(self): |
| 88 | + num_bits = 4 |
| 89 | + shape = (11008, 4096) |
| 90 | + mapping_type = MappingType.SYMMETRIC |
| 91 | + |
| 92 | + w = torch.rand(shape, dtype=torch.float16, device="cuda") |
| 93 | + |
| 94 | + for group_size in [-1, 128]: |
| 95 | + # Quantize weights |
| 96 | + q_w, s_group, s_channel = choose_qparams_and_quantize_affine_qqq( |
| 97 | + w, mapping_type, num_bits, group_size |
| 98 | + ) |
| 99 | + |
| 100 | + q_w = q_w.t() |
| 101 | + s_group = s_group.t() |
| 102 | + s_channel = s_channel.t() |
| 103 | + |
| 104 | + # Test pack/unpack equivalence |
| 105 | + q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq( |
| 106 | + q_w, s_group, s_channel, num_bits, group_size |
| 107 | + ) |
| 108 | + unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq( |
| 109 | + q_w_comp, |
| 110 | + packed_s_group, |
| 111 | + packed_s_channel, |
| 112 | + q_w.shape, |
| 113 | + num_bits, |
| 114 | + group_size, |
| 115 | + ) |
| 116 | + |
| 117 | + assert torch.equal( |
| 118 | + q_w, unpacked_q_w |
| 119 | + ), "Unpacked weights do not match original weights" |
| 120 | + assert torch.equal( |
| 121 | + s_channel, unpacked_s_channel |
| 122 | + ), "Unpacked s_channel do not match original s_channel" |
| 123 | + assert torch.equal( |
| 124 | + s_group, unpacked_s_group |
| 125 | + ), "Unpacked s_group do not match original s_group" |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + run_tests() |
0 commit comments