|
| 1 | +from copy import deepcopy |
| 2 | +import pytest |
| 3 | +import torch |
| 4 | +import tempfile |
| 5 | +from torchao.quantization import quantize_ |
| 6 | +from torchao.utils import ( |
| 7 | + TORCH_VERSION_AT_LEAST_2_2, |
| 8 | + TORCH_VERSION_AT_LEAST_2_4, |
| 9 | + TORCH_VERSION_AT_LEAST_2_5, |
| 10 | +) |
| 11 | +from torchao.quantization.utils import ( |
| 12 | + dynamically_quantize_per_channel, |
| 13 | + dequantize_per_channel, |
| 14 | +) |
| 15 | +from torchao.prototype.smoothquant import ( |
| 16 | + insert_smooth_quant_observer_, |
| 17 | + smooth_quant, |
| 18 | + SmoothQuantObservedLinear, |
| 19 | + save_smooth_quant_recipe, |
| 20 | + load_smooth_quant_recipe |
| 21 | +) |
| 22 | + |
| 23 | +class ToyLinearModel(torch.nn.Module): |
| 24 | + def __init__(self, m=512, n=256, k=128): |
| 25 | + super().__init__() |
| 26 | + self.linear1 = torch.nn.Linear(m, n, bias=False) |
| 27 | + self.linear2 = torch.nn.Linear(n, k, bias=False) |
| 28 | + self.linear3 = torch.nn.Linear(k, 1, bias=False) |
| 29 | + |
| 30 | + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): |
| 31 | + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + x = self.linear1(x) |
| 35 | + x = self.linear2(x) |
| 36 | + x = self.linear3(x) |
| 37 | + return x |
| 38 | + |
| 39 | + |
| 40 | +bias_list = [True, False] |
| 41 | +alpha_list = [None, 0.5, 0.75] |
| 42 | +quant_mode_list = ["static", "dynamic"] |
| 43 | +devices = ["cpu"] |
| 44 | +if torch.cuda.is_available(): |
| 45 | + devices.append("cuda") |
| 46 | +idtypes = (torch.float, torch.bfloat16, torch.half) |
| 47 | + |
| 48 | +if TORCH_VERSION_AT_LEAST_2_5: |
| 49 | + # This test case will trigger recompilation many times, so set a large cache_size_limit here |
| 50 | + torch._dynamo.config.cache_size_limit = 128 |
| 51 | + |
| 52 | +@pytest.mark.parametrize("bias", bias_list) |
| 53 | +@pytest.mark.parametrize("alpha", alpha_list) |
| 54 | +@pytest.mark.parametrize("quant_mode", quant_mode_list) |
| 55 | +@pytest.mark.parametrize("device", devices) |
| 56 | +@pytest.mark.parametrize("idtype", idtypes) |
| 57 | +def test_compute(bias, alpha, quant_mode, device, idtype): |
| 58 | + class Linear(torch.nn.Module): |
| 59 | + def __init__(self, bias: bool): |
| 60 | + super().__init__() |
| 61 | + self.fc = torch.nn.Linear(32, 32, bias) |
| 62 | + self.fc.weight.data = torch.randn_like(self.fc.weight.data) |
| 63 | + |
| 64 | + def forward(self, x): |
| 65 | + return self.fc(x) |
| 66 | + |
| 67 | + m = Linear(bias).eval().to(idtype).to(device) |
| 68 | + m_ref = deepcopy(m) |
| 69 | + data = torch.randn(2, 32, dtype=idtype, device=device) |
| 70 | + |
| 71 | + # calibrate |
| 72 | + insert_smooth_quant_observer_(m, alpha, quant_mode) |
| 73 | + m(data) |
| 74 | + # quantize |
| 75 | + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) |
| 76 | + quantize_(m, smooth_quant(), is_observed_linear) |
| 77 | + with torch.inference_mode(): |
| 78 | + if TORCH_VERSION_AT_LEAST_2_5: |
| 79 | + m = torch.compile(m, fullgraph=True) |
| 80 | + out = m(data) |
| 81 | + |
| 82 | + # reference |
| 83 | + weight = m_ref.fc.weight.data.float() |
| 84 | + b = m_ref.fc.bias if bias else None |
| 85 | + x_abs_max_per_ic = torch.abs(data).max(dim=0).values |
| 86 | + w_abs_max_per_ic = torch.abs(weight).max(dim=0).values |
| 87 | + smoothing_factor = 1 if alpha is None else ( |
| 88 | + torch.pow(x_abs_max_per_ic, alpha) / torch.pow( |
| 89 | + w_abs_max_per_ic, 1 - alpha) |
| 90 | + ) |
| 91 | + act = data / smoothing_factor |
| 92 | + wei = weight * smoothing_factor |
| 93 | + qw, w_scales, w_zps = dynamically_quantize_per_channel( |
| 94 | + wei, -127, 127, torch.int8 |
| 95 | + ) |
| 96 | + fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) |
| 97 | + if quant_mode == "static": |
| 98 | + # activation is quantized per-tensor |
| 99 | + act_min, act_max = torch.aminmax(act.float()) |
| 100 | + max_val_pos = torch.max(-act_min, act_max) |
| 101 | + act_scale = max_val_pos / 127.0 |
| 102 | + fq_act = torch.quantize_per_tensor( |
| 103 | + act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 |
| 104 | + ).dequantize().to(idtype) |
| 105 | + out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) |
| 106 | + else: |
| 107 | + # activation is quantized per-row (batch * sequence_length) |
| 108 | + qx, x_scales, x_zps = dynamically_quantize_per_channel( |
| 109 | + act.float(), -127, 127, torch.int8 |
| 110 | + ) |
| 111 | + fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) |
| 112 | + out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) |
| 113 | + |
| 114 | + # BFloat16 and Float16 have larger errors |
| 115 | + atol = 0.1 if idtype == torch.float else ( |
| 116 | + 0.2 if idtype == torch.half else 0.3 |
| 117 | + ) |
| 118 | + assert torch.allclose(out, out_ref.to(idtype), atol=atol) |
| 119 | + |
| 120 | + |
| 121 | +@pytest.mark.parametrize("alpha", alpha_list) |
| 122 | +@pytest.mark.parametrize("quant_mode", quant_mode_list) |
| 123 | +@pytest.mark.parametrize("device", devices) |
| 124 | +@pytest.mark.parametrize("idtype", idtypes) |
| 125 | +def test_save_load_recipe(alpha, quant_mode, device, idtype): |
| 126 | + dataset_size = 20 |
| 127 | + l1, l2, l3 = 512, 256, 128 |
| 128 | + original_dtype = idtype |
| 129 | + n_calib_examples = 10 |
| 130 | + sequence_length = 5 |
| 131 | + |
| 132 | + m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) |
| 133 | + m_save_load = deepcopy(m) |
| 134 | + |
| 135 | + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) |
| 136 | + calibration_data = dataset[:n_calib_examples] |
| 137 | + |
| 138 | + # calibrate |
| 139 | + insert_smooth_quant_observer_(m, alpha, quant_mode) |
| 140 | + insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) |
| 141 | + |
| 142 | + for example in calibration_data: |
| 143 | + m(example.to(device)) |
| 144 | + m_save_load(example.to(device)) |
| 145 | + |
| 146 | + with tempfile.NamedTemporaryFile() as fp: |
| 147 | + save_path = fp.name |
| 148 | + save_smooth_quant_recipe(m_save_load, save_path) |
| 149 | + load_smooth_quant_recipe(m_save_load, save_path) |
| 150 | + |
| 151 | + # quantize |
| 152 | + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) |
| 153 | + quantize_(m, smooth_quant(), is_observed_linear) |
| 154 | + if TORCH_VERSION_AT_LEAST_2_5: |
| 155 | + # earlier versions are not compatible |
| 156 | + m = torch.compile(m, fullgraph=True) |
| 157 | + m_save_load = torch.compile(m_save_load, fullgraph=True) |
| 158 | + out_list = [m(data.squeeze(0)) for data in dataset] |
| 159 | + out = torch.cat(out_list) |
| 160 | + save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] |
| 161 | + save_load_out = torch.cat(save_load_out_list) |
| 162 | + |
| 163 | + assert out is not None |
| 164 | + assert save_load_out is not None |
| 165 | + assert torch.allclose(out, save_load_out) |
0 commit comments