diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 8eb6ddde11..379d5f7e76 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -16,30 +16,6 @@ _replace_with_custom_fn_if_matches_filter, quantize_, ) -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -) - - -def _int8wo_api(mod, **kwargs): - quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False) - - -def _int8da_int8w_api(mod, **kwargs): - quantize_( - mod, - Int8DynamicActivationInt8WeightConfig(**kwargs), - set_inductor_config=False, - ) - - -def _int4wo_api(mod, **kwargs): - kwargs_copy = kwargs.copy() - if "groupsize" in kwargs_copy: - kwargs_copy["group_size"] = kwargs_copy["groupsize"] - del kwargs_copy["groupsize"] - quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): @@ -117,38 +93,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -) -_ref_change_linear_weights_to_int4_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -) - - torch._dynamo.config.cache_size_limit = 50000 @torch.no_grad -def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): - if kwargs is None: - kwargs = {} - +def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): m = ToyLinearModel( M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" ).eval() m_bf16 = copy.deepcopy(m) - m_ref = copy.deepcopy(m) example_inputs = m.example_inputs() - api(m, **kwargs) - - # reference - ref_api(m_ref, **kwargs) - - res = m(*example_inputs) - ref = m_ref(*example_inputs) - - assert torch.equal(res, ref) + api(m, config) # Pass both model and config # perf comparison from torchao.utils import benchmark_model @@ -158,22 +114,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): RUNS = 100 torch._dynamo.reset() - m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True) - benchmark_model(m_ref, WARMUP, example_inputs) - ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) + m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) + benchmark_model(m_bf16, WARMUP, example_inputs) + bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) torch._dynamo.reset() m = torch.compile(m, mode="max-autotune", fullgraph=True) benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) - torch._dynamo.reset() - m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) - benchmark_model(m_bf16, WARMUP, example_inputs) - bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) - print( - f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" + f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" ) @@ -182,24 +133,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): (20, 2048, 2048), ] - print("_int8da_int8w_api") - + print("Int8DynamicActivationInt8WeightConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K + quantize_, + Int8DynamicActivationInt8WeightConfig(), + M, + N, + K, ) - print("_int8wo_api") - + print("Int8WeightOnlyConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K + quantize_, + Int8WeightOnlyConfig(), + M, + N, + K, ) - print("_int4wo_api") - kwargs = {"groupsize": 32, "version": 1} - + print("Int4WeightOnlyConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs + quantize_, + Int4WeightOnlyConfig(group_size=32), + M, + N, + K, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f99cf4a1b4..4dc67d0690 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -43,33 +43,21 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, - _replace_with_custom_fn_if_matches_filter, quantize_, ) from torchao.quantization.quant_primitives import ( MappingType, + choose_qparams_affine, dequantize_affine, -) -from torchao.quantization.smoothquant import ( - SmoothFakeDynamicallyQuantizedLinear, - get_scale, - smooth_fq_linear_to_inference, - swap_linear_with_smooth_fq_linear, -) -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, + quantize_affine, ) from torchao.quantization.utils import ( LoggingTensorMode, _apply_logging_hook, _fqn_to_op_to_shape_to_count, - _quant_int8_dynamic_per_token_linear, _quantize_activation_per_token_absmax, compute_error, dequantize_per_channel, - dynamically_quantize_per_channel, ) from torchao.quantization.utils import ( compute_error as SQNR, @@ -210,262 +198,6 @@ def wrapper(*args, **kwargs): return wrapper -class SmoothquantUnitTest(unittest.TestCase): - # first, let's reproduce the graphic from the paper, Figure 4, to ensure - # we are calculating the scales correctly - def test_figure_4(self): - X = torch.FloatTensor([1, -16, 2, 6, -2, 8, -1, -9]).reshape(1, 2, 4) - W = torch.FloatTensor([2, 1, -2, 1, -1, -1, 2, -1, -2, -1, -1, 1]).reshape(4, 3) - X_mul_W = torch.matmul(X, W) - - smoothquant_scale = get_scale( - torch.amax(torch.abs(X), dim=(0, 1)), - torch.amax(torch.abs(W), dim=1), - alpha=0.5, - ) - - # reproduce scaled calculation - X_scaled = X / smoothquant_scale.reshape(1, 1, -1) - W_scaled = torch.matmul(torch.diag(smoothquant_scale), W) - X_scaled_mul_scaled_W = torch.matmul(X_scaled, W_scaled) - assert torch.allclose(X_mul_W, X_scaled_mul_scaled_W), "not close!" - assert X_mul_W.shape == X_scaled_mul_scaled_W.shape - - # next, run the above test on a sample of representative inputs - def test_tensors(self): - x_shape = (1, 5, 7) - w_shape = (7, 9) - for i in range(3): - X = torch.randn(x_shape) * 10 - W = torch.randn(w_shape) - s = get_scale( - torch.amax(torch.abs(X), dim=(0, 1)), - torch.amax(torch.abs(W), dim=1), - alpha=0.5, - ) - - Y = torch.matmul(X, W) - Y_ref = torch.matmul( - X / s.reshape(1, 1, -1), - torch.matmul(torch.diag(s), W), - ) - assert torch.allclose(Y, Y_ref, atol=1e-3, rtol=1e-3), "not close!" - - def _test_smooth_linear_impl(self, x_shape, lin_shape, device): - orig_backend = torch.backends.quantized.engine - # so we can use the full range - torch.backends.quantized.engine = "qnnpack" - - x = torch.randn(*x_shape, device=device) * 9 + 10 - - lin_fp32 = nn.Linear(*lin_shape, device=device) # misc: ignore - lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( - copy.deepcopy(lin_fp32), alpha=0.25 - ) - lin_smooth_skip_scaling = SmoothFakeDynamicallyQuantizedLinear.from_float( - copy.deepcopy(lin_fp32), alpha=0.25 - ) - - lin_fp32_copy = copy.deepcopy(lin_fp32) # assignment: ignore - lin_fp32_copy.qconfig = torch.ao.quantization.QConfig( # assignment: ignore - activation=None, - weight=torch.ao.quantization.default_per_channel_weight_observer, - ) - lin_dynamic_q = torch.ao.nn.quantized.dynamic.Linear.from_float( - lin_fp32_copy.cpu() - ) - - y_ref = lin_fp32(x) - - # calibrate the smoothquant versions - y_smooth_nocalib = lin_smooth(x) - _ = lin_smooth_skip_scaling(x) - lin_smooth.to_inference() - lin_smooth_skip_scaling.debug_skip_scaling = True - lin_smooth_skip_scaling.to_inference() - - # verify that with scaling turned off, numerics match quantized version - y_smooth_fq_only = lin_smooth_skip_scaling(x) - y_smooth_fq = lin_smooth(x) - y_dynamic_q = lin_dynamic_q(x.cpu()).to(device) - - # print('y_ref', y_ref) - # print('y_smooth_nocalib', y_smooth_nocalib) - # print('y_smooth_fq', y_smooth_fq) - # print('y_smooth_fq_only', y_smooth_fq_only) - # print('y_dynamic_q', y_dynamic_q) - - sqnr_smooth_fq = compute_error(y_ref, y_smooth_fq) - sqnr_dynamic_q = compute_error(y_ref, y_dynamic_q) - sqnr_fq = compute_error(y_smooth_fq_only, y_dynamic_q) - # print('sqnr_smooth', sqnr_smooth_fq, 'sqnr_dynamic', sqnr_dynamic_q, 'sqnr_fq', sqnr_fq) - - assert torch.allclose(y_ref, y_smooth_nocalib), ( - "y_ref not close to y_smooth_nocalib" - ) - # after https://github.com/pytorch-labs/ao_benchmarks/pull/32, - # numerics do not match exactly between production c++ code - # and this Python code - # assert torch.allclose( - # y_smooth_fq_only, y_dynamic_q, - # atol=torch.max(y_smooth_fq_only).item()*0.01, - # rtol=0.00001), \ - # 'y_smooth_fq_only not close to y_dynamic_q' - - self.assertTrue(sqnr_smooth_fq.item() >= 40.0, f"got: {sqnr_smooth_fq.item()}") - self.assertTrue(sqnr_dynamic_q.item() >= 40.0, f"got: {sqnr_dynamic_q.item()}") - self.assertTrue(sqnr_fq.item() >= 40.0, f"got: {sqnr_fq.item()}") - - # Restore backend - torch.backends.quantized.engine = orig_backend - - def test_smooth_linear_cpu(self): - self._test_smooth_linear_impl((1, 5, 3), (3, 4), "cpu") - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_smooth_linear_cuda(self): - self._test_smooth_linear_impl((1, 32, 32), (32, 16), "cuda") - - def test_smooth_linear_edge_cases(self): - orig_backend = torch.backends.quantized.engine - # so we can use the full range - torch.backends.quantized.engine = "qnnpack" - lin_fp32 = nn.Linear(3, 4) - lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( - lin_fp32, alpha=0.25 - ) - - # test different ranks - x0 = torch.randn(4, 5, 3) - x1 = torch.randn(1, 8, 5, 3) - x2 = torch.randn(2, 3, 7, 5, 3) - - # calibrate - _ = lin_smooth(x0) - _ = lin_smooth(x1) - _ = lin_smooth(x2) - - # inference - lin_smooth.to_inference() - _ = lin_smooth(x0) - _ = lin_smooth(x1) - _ = lin_smooth(x2) - - # Restore backend - torch.backends.quantized.engine = orig_backend - - def test_swap(self): - m = nn.Sequential( - nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)), - nn.Linear(4, 4), - ) - m_copy = copy.deepcopy(m) - swap_linear_with_smooth_fq_linear(m_copy, skip_fqn_list=["0.2"]) - - # verify all linears are swapped - assert isinstance(m_copy[0][0], SmoothFakeDynamicallyQuantizedLinear) - assert isinstance(m_copy[0][1], nn.ReLU) - # this one was skipped - assert isinstance(m_copy[0][2], nn.Linear) - assert isinstance(m_copy[1], SmoothFakeDynamicallyQuantizedLinear) - - # verify results do not change without smoothing - x = torch.randn(4, 4) - y_ref = m(x) - y = m_copy(x) - assert torch.allclose(y_ref, y) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_weight_t_and_non_t_numerics_match(self): - # verify that numerics match whether weight is stored - # in transposed format (for cuBLAS) vs non-transposed format - # (for torch.compile) - dtype = torch.half - device = "cuda" - lin_ref = nn.Linear(32, 16, dtype=dtype, device=device) - lin_eager_t = copy.deepcopy(lin_ref) - lin_opt_t = copy.deepcopy(lin_eager_t) - lin_opt = copy.deepcopy(lin_eager_t) - lin_eager_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_eager_t) - lin_opt_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt_t) - lin_opt = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt) - lin_opt.store_w_int_repr_t = False - - x = torch.randn(32, 32, dtype=dtype, device=device) - - y_calib_eager_t = lin_eager_t(x) - y_calib_opt_t = lin_opt_t(x) - y_calib_opt = lin_opt(x) - torch.testing.assert_close(y_calib_eager_t, y_calib_opt_t) - torch.testing.assert_close(y_calib_eager_t, y_calib_opt) - - lin_eager_t.to_inference() - lin_opt_t.to_inference() - lin_opt.to_inference() - - torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt_t.W_int_repr) - torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt.W_int_repr) - - lin_opt_t = torch.compile(lin_opt_t, mode="max-autotune") - lin_opt = torch.compile(lin_opt, mode="max-autotune") - - y_ref = lin_ref(x) - y_eager = lin_eager_t(x) - y_opt_t = lin_opt_t(x) - y_opt = lin_opt(x) - - if not torch.any(torch.isinf(y_ref)) and torch.any(torch.isinf(y_eager)): - # eager mode torch._int_mm is sometimes buggy, when this happens - # we can't really compare the compiled version against it properly - print("eager mode torch._int_mm known bad, test is inconclusive") - return - - sqnr_eager_opt_t = compute_error(y_eager, y_opt_t) - sqnr_eager_opt = compute_error(y_eager, y_opt) - # since torch.compile for a torch.half model can - # change numerics significantly, we can only test for a high SQNR here - # and not for closeness - self.assertTrue(sqnr_eager_opt_t >= 45.0) - self.assertTrue(sqnr_eager_opt >= 45.0) - # y_opt_t and y_opt should be equivalent - torch.testing.assert_close(y_opt_t, y_opt) - - def test_selective_torch_compile(self): - m = nn.Sequential( - nn.Linear(4, 4), - nn.Sequential( - nn.Linear(4, 4), - nn.Linear(4, 4), - ), - nn.Linear(4, 4), - ) - x = torch.randn(4, 4) - y_ref = m(x) - - _replace_with_custom_fn_if_matches_filter( - m, - lambda mod: torch.compile(mod), - lambda mod, fqn: isinstance(mod, nn.Linear) and fqn != "1.0", - ) - - self.assertTrue(isinstance(m[0], torch._dynamo.eval_frame.OptimizedModule)) - self.assertTrue(isinstance(m[1][0], nn.Linear)) - self.assertTrue(isinstance(m[1][1], torch._dynamo.eval_frame.OptimizedModule)) - self.assertTrue(isinstance(m[2], torch._dynamo.eval_frame.OptimizedModule)) - - y = m(x) - torch.testing.assert_close(y, y_ref) - - def test_debug_x_absmax(self): - m = nn.Sequential(nn.Linear(3, 4)) - x0 = torch.randn(4, 5, 3) - m(x0) - swap_linear_with_smooth_fq_linear(m) - # no calibration, straight to inference, should not crash - smooth_fq_linear_to_inference(m, debug_skip_calibration=True) - m(x0) - - class PythonQuantUtilOpUnitTest(unittest.TestCase): def _test_dynamic_quant_per_channel_numerics_impl( self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device @@ -476,9 +208,27 @@ def _test_dynamic_quant_per_channel_numerics_impl( # torch.aminmax support half on cpu x = torch.randn(16, 32, device=device, dtype=float_dtype) - y_vals, y_scale, y_zero_point = dynamically_quantize_per_channel( - x, qmin, qmax, int_dtype + + eps = torch.finfo(torch.float32).eps + block_size = (1, x.shape[1]) + zero_point_dtype = torch.int64 + + mapping_type = MappingType.SYMMETRIC + scale, zero_point = choose_qparams_affine( + x, + mapping_type, + block_size, + target_dtype=int_dtype, + quant_min=qmin, + quant_max=qmax, + eps=eps, + zero_point_dtype=zero_point_dtype, ) + y_vals = quantize_affine( + x, block_size, scale, zero_point, int_dtype, qmin, qmax + ) + y_scale = scale + y_zero_point = zero_point min_val, max_val = torch.aminmax(x, dim=1) @@ -561,30 +311,6 @@ def test_quantize_per_token_xpu(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_quantize_per_token_impl("xpu", dtype) - def _test_per_token_linear_impl(self, device, dtype): - x = torch.randn(2, 16, 8, device=device, dtype=dtype) - w = torch.randn(16, 8, device=device, dtype=dtype) - wq, w_scales, _w_zp = dynamically_quantize_per_channel(w, -127, 127, torch.int8) - # Note: need to make the weight contiguous because we are - # testing in eager mode and cuBlas will not give correct results - # for a transposed weight - y = _quant_int8_dynamic_per_token_linear( - x, wq.t().contiguous(), w_scales, None, dtype - ) - y_ref = torch.matmul(x, w.t()) - sqnr = compute_error(y_ref, y) - self.assertTrue(sqnr >= 42.0) - - def test_per_token_linear_cpu(self): - for dtype in (torch.float32,): - self._test_per_token_linear_impl("cpu", dtype) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @skip_if_rocm("ROCm enablement in progress") - def test_per_token_linear_cuda(self): - for dtype in (torch.float32, torch.float16, torch.bfloat16): - self._test_per_token_linear_impl("cuda", dtype) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test__int_mm(self): # TODO(future): figure out what here needs to move to PT core, @@ -681,62 +407,6 @@ def _test_dequantize_impl( f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype): - self._test_dequantize_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, - device, - 35, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): - self._test_dequantize_impl( - Int8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest("Currently only supports bfloat16.") - for test_shape in [(16, 1024, 16)] + ( - [(1, 1024, 8)] if device == "cuda" else [] - ): - self._test_dequantize_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, - device, - 15, - test_shape=test_shape, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest("Currently only supports bfloat16.") - m_shapes = [16, 256] + ([1] if device == "cuda" else []) - n_shapes = [16] + ([8, 13] if device == "cuda" else []) - for groupsize in [256, 128]: - for inner_k_tiles in [8, 4, 2]: - for m in m_shapes: - for n in n_shapes: - self._test_dequantize_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( - w, groupsize, inner_k_tiles - ), - device, - 15, - test_shape=[m, 256, n], - test_dtype=dtype, - ) - @run_supported_device_dtype def _test_lin_weight_subclass_impl( self, @@ -771,22 +441,6 @@ def _test_lin_weight_subclass_impl( f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass(self, device, dtype): - self._test_lin_weight_subclass_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, - device, - 35, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_weight_only_quant_subclass(self, device, dtype): - undo_recommended_configs() - self._test_lin_weight_subclass_impl( - Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype - ) - @parameterized.expand(COMMON_DEVICE_DTYPE) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -891,46 +545,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype test_dtype=dtype, ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_int4_weight_only_quant_subclass(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest(f"Fails for {dtype}") - for test_shape in [(16, 1024, 16)] + ( - [(1, 1024, 8)] if device == "cuda" else [] - ): - self._test_lin_weight_subclass_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, - device, - 10, - test_shape=test_shape, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - @unittest.skip("Skip to fix CI until we deprecate these APIs long term") - def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): - if dtype != torch.bfloat16: - self.skipTest(f"Fails for {dtype}") - m_shapes = [16, 256] + ([1] if device == "cuda" else []) - n_shapes = [16] + ([8, 13] if device == "cuda" else []) - for groupsize in [128, 64]: - for inner_k_tiles in [8, 4, 2]: - for m in m_shapes: - for n in n_shapes: - self._test_lin_weight_subclass_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( - w, groupsize, inner_k_tiles - ), - device, - 10, - test_shape=[m, 256, n], - test_dtype=dtype, - ) - @torch.no_grad() @run_supported_device_dtype def _test_lin_weight_subclass_api_impl( @@ -1120,7 +734,6 @@ def test_dynamic_quant(self): sqnr = compute_error(y_ref, y_test) self.assertGreater(sqnr, 40.0) - # self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear)) class TestWeightOnlyInt8Quant(unittest.TestCase): @@ -1324,30 +937,6 @@ def test_save_load_int4woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype) -class TorchCompileUnitTest(unittest.TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_fullgraph(self): - lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) - lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( - lin_fp16, alpha=0.25 - ) - - x0 = torch.randn(17, 1, 32, device="cuda", dtype=torch.float16) - - # calibrate - _ = lin_smooth(x0) - - # inference - lin_smooth.to_inference() - - # torch.compile - lin_smooth_opt = torch.compile(lin_smooth, fullgraph=True) - # print(lin_smooth_opt) - - lin_smooth_opt(x0) - # print(y) - - class UtilsUnitTest(unittest.TestCase): def test_shape_logger(self): x = torch.randn(4, 4) @@ -1371,88 +960,6 @@ def test_shape_logger(self): pass -class SmoothquantIntegrationTest(unittest.TestCase): - @torch.no_grad() - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skip("Seg fault?") - def test_non_dynamically_quantizable_linear(self): - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): - self.skipTest("test requires SM capability of at least (8, 0).") - model = ( - torch.nn.Sequential( - torch.nn.modules.linear.NonDynamicallyQuantizableLinear(32, 32), - torch.nn.ReLU(), - ) - .to("cuda") - .to(torch.bfloat16) - ) - example_input = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16) - ref = model(example_input) - swap_linear_with_smooth_fq_linear(model) - model(ref) - smooth_fq_linear_to_inference(model) - model_c = torch.compile(model, mode="max-autotune") - out = model_c(example_input) - sqnr = SQNR(ref, out) - self.assertTrue(sqnr >= 25) - self.assertTrue(isinstance(model[0], SmoothFakeDynamicallyQuantizedLinear)) - - @torch.inference_mode() - @unittest.skipIf(is_fbcode(), "can't load tokenizer") - def test_on_dummy_distilbert(self): - # https://huggingface.co/distilbert-base-uncased#how-to-use - from transformers import ( # type: ignore[import-untyped] - DistilBertModel, - DistilBertTokenizer, - ) - - tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") - model = DistilBertModel.from_pretrained("distilbert-base-uncased") - # print(model) - text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors="pt") - output_ref = model(**encoded_input) - # print(output_ref) - - # - # smooth_quant - # - model_copy = copy.deepcopy(model) - swap_linear_with_smooth_fq_linear(model_copy, alpha=0.75) - # calibrate - model_copy(**encoded_input) - # inference - smooth_fq_linear_to_inference(model_copy) - output_1_2 = model_copy(**encoded_input) - # print(output_1_1) - # print(output_1_2) - sqnr_sq = compute_error( - output_ref.last_hidden_state, output_1_2.last_hidden_state - ) - print("sqnr_sq", sqnr_sq) - self.assertTrue(sqnr_sq >= 20.0) - - # - # reference - dynamic linear quant - # - model_copy2 = copy.deepcopy(model) - qconfig = torch.ao.quantization.QConfig( - activation=None, - weight=torch.ao.quantization.default_per_channel_weight_observer, - ) - model_copy2 = torch.ao.quantization.quantize_dynamic( - model_copy2, - {torch.nn.Linear: qconfig}, - ) - output_2_2 = model_copy2(**encoded_input) - # print(output_2_2) - sqnr_pt_quant = compute_error( - output_ref.last_hidden_state, output_2_2.last_hidden_state - ) - print("sqnr_pt_quant", sqnr_pt_quant) - self.assertTrue(sqnr_sq >= 8.0) - - class TestAutoQuant(unittest.TestCase): @parameterized.expand( combine_parameters( diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b5ea7bf09a..13d1eaad08 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -60,10 +60,6 @@ from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( IntxUnpackedToInt8Tensor, ) -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -) from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -168,14 +164,6 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -) -_ref_change_linear_weights_to_int4_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -) - - class TestQuantFlow(TestCase): GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + ( ["xpu"] if torch.xpu.is_available() else [] @@ -447,54 +435,6 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") - def test_quantized_tensor_subclass_int4(self): - for device in self.GPU_DEVICES: - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - - group_size = 32 - if device == "xpu": - quantize_( - m, - Int4WeightOnlyConfig( - group_size=group_size, layout=Int4XPULayout(), version=1 - ), - ) - else: - quantize_(m, Int4WeightOnlyConfig(group_size=group_size, version=1)) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_tensor_subclass_int8_wo(self): - m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - - quantize_(m, Int8WeightOnlyConfig()) - - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - _ref_change_linear_weights_to_int8_woqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 1240bbacd0..2edbe33774 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -40,9 +40,7 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.subclass import ( # noqa - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, +from torchao.quantization.subclass import ( QuantizedLinearWeightBase, ) from torchao.quantization.utils import _quantize_activation_per_token_absmax diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..bdb1d90c04 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -100,21 +100,11 @@ IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) -from .smoothquant import ( - SmoothFakeDynamicallyQuantizedLinear, - SmoothFakeDynQuantMixin, - get_scale, - set_smooth_fq_attribute, - smooth_fq_linear_to_inference, - swap_linear_with_smooth_fq_linear, -) -from .subclass import * # noqa: F403 from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer from .utils import ( compute_error, ) -from .weight_only import WeightOnlyInt8QuantLinear # TODO: remove after migration of APIs are done AOPerModuleConfig = ModuleFqnToConfig @@ -172,13 +162,6 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", - # smooth quant - subject to change - "get_scale", - "SmoothFakeDynQuantMixin", - "SmoothFakeDynamicallyQuantizedLinear", - "swap_linear_with_smooth_fq_linear", - "smooth_fq_linear_to_inference", - "set_smooth_fq_attribute", "compute_error", # building blocks "to_linear_activation_quantized", @@ -210,7 +193,6 @@ "Int4WeightOnlyQuantizer", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightLinear", - "WeightOnlyInt8QuantLinear", "TwoStepQuantizer", "Quantizer", # Layouts for quant_api diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index eb19a00923..be9d546c66 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -41,11 +41,6 @@ PerRow, PerTensor, ) -from .subclass import ( # noqa - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) __all__ = [ "AutoQuantizableLinearWeight", diff --git a/torchao/quantization/dynamic_quant.py b/torchao/quantization/dynamic_quant.py deleted file mode 100644 index 5c6ee9c8f9..0000000000 --- a/torchao/quantization/dynamic_quant.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from .utils import ( - _quant_int8_dynamic_per_token_linear, - dynamically_quantize_per_channel, -) - -__all__ = ["DynamicallyPerAxisQuantizedLinear"] - - -class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear): - """ - This class is a replacement for `torch.nn.Linear`. It implements a - quantized matmul using int8 dynamic symmetric per-token activation, - and int8 symmetric per-channel weight quantization - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - ) -> None: - super().__init__(in_features, out_features, bias) - - def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Performs the forward pass of the quantized linear layer which consists - of int8 dynamic symmetric per-token activation and int8 symmetric per-channel weight - quantization - - Args: - X (torch.Tensor): The input floating point tensor to the quantized linear layer. - - Returns: - torch.Tensor: The output floating point tensor after the quantized matmul and rescale. - - """ - - Y = _quant_int8_dynamic_per_token_linear( - X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype - ) - return Y - - @classmethod - def from_float(cls, mod: torch.nn.Linear) -> "DynamicallyPerAxisQuantizedLinear": - """ - Converts a `mod` of class `torch.nn.Linear` to the - `DynamicallyPerAxisQuantizedLinear` class - - Args: - mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. - - Returns: - DynamicallyPerAxisQuantizedLinear: The converted quantized linear module. - - """ - - # create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, - fake_out_features, - bias=mod.bias is not None, - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( - mod.weight, -128, 127, torch.int8 - ) - new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t()) - new_mod.W_scales = nn.Parameter(W_scales) - new_mod.bias = mod.bias - del new_mod.weight - - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15caddcadc..44686c5171 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -126,9 +126,6 @@ ZeroPointDomain, quantize_affine, ) -from .subclass import ( - QuantizedLinearWeightBase, -) from .unified import Quantizer, TwoStepQuantizer from .utils import _get_per_token_block_size @@ -285,7 +282,6 @@ def _is_linear(mod, *args): return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") - and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) diff --git a/torchao/quantization/smoothquant.py b/torchao/quantization/smoothquant.py deleted file mode 100644 index 3420f3c8b2..0000000000 --- a/torchao/quantization/smoothquant.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Testing out accuracy-only implementation of SmoothQuant -(https://arxiv.org/pdf/2211.10438.pdf) -Note: this is an application of input-weight equalization, with the addition that the -multiplication by scale is fused into the preceding layer, specifically for relevant -parts of transformer blocks. -""" - -import torch -import torch.nn.functional as F - -from .utils import ( - _quant_int8_dynamic_per_token_linear, - dynamically_quantize_per_channel, -) - -__all__ = [ - "get_scale", - "SmoothFakeDynQuantMixin", - "SmoothFakeDynamicallyQuantizedLinear", - "swap_linear_with_smooth_fq_linear", - "smooth_fq_linear_to_inference", - "set_smooth_fq_attribute", -] - - -def get_scale(X_absmax, W_absmax, alpha=0.5): - """ - Calculate the scale based on abs(max(X)), abs(max(W)), and alpha. - - Args: - X_absmax (torch.Tensor): Absolute maximum values of the input tensor X. - W_absmax (torch.Tensor): Absolute maximum values of the weight tensor W. - alpha (float, optional): Scaling factor. Defaults to 0.5. - - Returns: - torch.Tensor: The calculated scale of dimension `k` if X is of dimension `b*n*k` and W is of dimension `k*m`. - """ - X_pow = torch.pow(X_absmax, alpha) - W_pow = torch.pow(W_absmax, 1.0 - alpha) - div = X_pow / W_pow - return div.reshape(-1) - - -class SmoothFakeDynQuantMixin(torch.nn.Module): - def init_smoothquant_variables(self, alpha): - self.calibrating = True - self.x_running_abs_max = None - self.register_buffer("smooth_scale", None) - self.alpha = alpha - # debug only - self.debug_skip_scaling = False - # self.debug_skip_scaling = True - - # Currently torch._int_mm cuBLAS underlying kernel does not work with - # non-contiguous weight. However, torch.compil'ing through - # torch._int_mm leads to triton code which is ~2x faster if the weight - # is transposed. So, for now we have a debug flag to toggle whether - # we store the quantized weight transposed, so that we can get correct - # numerics both in eager mode and after torch.compile. - # The default is True for cuBLAS / eager mode, set to False for - # torch.compile. - # self.store_w_int_repr_t = True - self.store_w_int_repr_t = False - - def update_x_running_abs_max(self, X): - # update the running max of incoming activations - all_dims_except_last = tuple(range(len(X.shape) - 1)) - cur_abs_max = torch.amax(torch.abs(X), dim=all_dims_except_last) - if self.x_running_abs_max is None: - self.x_running_abs_max = cur_abs_max - else: - self.x_running_abs_max = torch.max(cur_abs_max, self.x_running_abs_max) - - def get_scaled_quantized_w(self): - # inference - assert self.smooth_scale is not None, ( - "self.smooth_scale is None, did you turn on inference?" - ) - W = self.weight - - # scale weight - # in the future, this can be done ahead of time instead of - # during inference - if not self.debug_skip_scaling: - # TODO(future): do below in `to_inference` instead of here - W = torch.matmul( - torch.diag(self.smooth_scale), W.transpose(0, 1) - ).transpose(0, 1) - - # fake quantize input and weight, and then do matmul in fp32/fp16 - # in the future, this should be replaced with quantized kernels which - # work on NVIDIA GPUs (such as protoquant's implementation) - W_int_repr, W_scales, W_zps = dynamically_quantize_per_channel( - W, -128, 127, torch.int8 - ) - W_int_repr = W_int_repr.contiguous() - return W_int_repr, W_scales, W_zps - - def to_inference(self): - raise NotImplementedError() - - def fold_weight(self): - # note: _W_zps are zeroes and they are ignored - # TODO(future PR): set up serialization for this - W_int_repr, self.W_scales, _W_zps = self.get_scaled_quantized_w() - # need to store transposed weights to make eager mode matmul - # op work in cuBlas, or non-transposed to make it fast in torch.compile - if self.store_w_int_repr_t: - self.register_buffer("W_int_repr", W_int_repr.transpose(0, 1).contiguous()) - else: - self.register_buffer("W_int_repr", W_int_repr.contiguous()) - del self.weight - - def set_debug_x_absmax(self): - """ - Sets `self.x_running_abs_max` to a value which will lead to smooth scale - of all ones if `alpha=0.5`, to enable performance benchmarking without - calibration. - """ - raise NotImplementedError() - - -class SmoothFakeDynamicallyQuantizedLinear(SmoothFakeDynQuantMixin, torch.nn.Linear): - """ - This is a replacement for `torch.nn.Linear` which implements dynamic per-token - activation quantization and dynamic per-channel weight quantization based on - Smoothquant scaling. - """ - - def __init__(self, *args, **kwargs): - alpha = kwargs.pop("alpha") - super().__init__(*args, **kwargs) - self.init_smoothquant_variables(alpha) - - def forward(self, X, *args, **kwargs): - if self.calibrating: - self.update_x_running_abs_max(X) - Y = F.linear(X, self.weight, self.bias) - else: - if not self.debug_skip_scaling: - # Ideally this would be fused into preceding layers - # but in practice torch.compile fuses it with other - # ops so the slowdown is minimal - X = X / self.smooth_scale - W_int_repr_t = ( - self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t() - ) - Y = _quant_int8_dynamic_per_token_linear( - X, W_int_repr_t, self.W_scales, self.bias, X.dtype - ) - return Y - - @classmethod - def from_float(cls, mod, alpha=0.5): - """ - Converts a `mod` of class `torch.nn.Linear` to the smooth fake quantized - version of it. Note: requires calibration. - """ - # create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, fake_out_features, bias=mod.bias is not None, alpha=alpha - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - new_mod.weight = mod.weight - new_mod.bias = mod.bias - # TODO: test when creation is on cuda - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod - - def to_inference(self): - """ - Calculates the smoothquant scale based on calibration - in preparation for inference - """ - assert self.x_running_abs_max is not None, "no calibration data found" - self.calibrating = False - self.smooth_scale = get_scale( - self.x_running_abs_max, - torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values, - alpha=self.alpha, - ) - self.fold_weight() - - def set_debug_x_absmax(self): - w_absmax = torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values - self.x_running_abs_max = w_absmax - - -# -# utils to use the smooth linear on real models -# - -source_cls_to_target_cls = { - torch.nn.Linear: SmoothFakeDynamicallyQuantizedLinear, - torch.nn.modules.linear.NonDynamicallyQuantizableLinear: SmoothFakeDynamicallyQuantizedLinear, -} - - -def swap_linear_with_smooth_fq_linear( - model, skip_fqn_list=None, cur_fqn="", alpha=0.5 -) -> None: - """ - Replaces linear layers in the model with their SmoothFakeDynamicallyQuantizedLinear equivalents. - - Args: - model (torch.nn.Module): The model containing linear layers to be replaced. - skip_fqn_list (list of str, optional): List of fully qualified names to skip during replacement. Defaults to None. - cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". - alpha (float, optional): The scaling factor for SmoothQuant. Defaults to 0.5. - - Returns: - None - """ - - name_to_child = dict(model.named_children()) - for name, child in name_to_child.items(): - if cur_fqn == "": - new_fqn = name - else: - new_fqn = f"{cur_fqn}.{name}" - if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and ( - type(child) in source_cls_to_target_cls.keys() - ): - target_cls = source_cls_to_target_cls[type(child)] - new_child = target_cls.from_float(child, alpha=alpha) - setattr(model, name, new_child) - else: - swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha) - - -def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: - """ - Prepares the model for inference by calculating the smoothquant scale for each SmoothFakeDynamicallyQuantizedLinear layer. - - Args: - model (torch.nn.Module): The model containing SmoothFakeDynamicallyQuantizedLinear layers. - debug_skip_calibration (bool, optional): If True, sets the running maximum of activations to a debug value for performance benchmarking. - Defaults to False. - - Returns: - None - """ - for _, mod in model.named_modules(): - if isinstance(mod, tuple(source_cls_to_target_cls.values())): - if debug_skip_calibration: - mod.set_debug_x_absmax() - mod.to_inference() - - -# useful for quickly toggling smoothquant debug settings on all smoothquant -# modules in a model -def set_smooth_fq_attribute(model, attribute_name, new_attribute_val): - for _, mod in model.named_modules(): - if isinstance(mod, tuple(source_cls_to_target_cls.values())): - if hasattr(mod, attribute_name): - setattr(mod, attribute_name, new_attribute_val) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py deleted file mode 100644 index caffef7b58..0000000000 --- a/torchao/quantization/subclass.py +++ /dev/null @@ -1,702 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.quantization.utils import ( - _quant_int8_dynamic_per_token_linear, - dequantize_per_channel, - dynamically_quantize_per_channel, - groupwise_affine_quantize_tensor, - unpack_tinygemm_scales_and_zeros, -) -from torchao.utils import ( - check_cpu_version, - check_xpu_version, - find_multiple, -) - -from .quant_primitives import ( - ZeroPointDomain, -) - -__all__ = [ - "Int8DynamicallyQuantizedLinearWeight", - "Int8WeightOnlyQuantizedLinearWeight", - "Int4WeightOnlyQuantizedLinearWeight", -] - - -aten = torch.ops.aten - - -class QuantizedLinearWeightBase(torch.Tensor): - """ - Base quantized tensor subclass for quantized linear weights. When the from_float method is used, - to create an instance of any QuantizedLinearWeightBase, we assume the input - weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. - - The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, - regardless of the internal representation's type or orientation. - """ - - @staticmethod - def __new__(cls, int_data, transposed, shape, *args, **kwargs): - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - assert "dtype" in kwargs - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, transposed, *args, **kwargs): - self.int_data = int_data - - self.transposed = transposed - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - pass - - def __repr__(self): - return ( - f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def dequantize(self): - pass - - def int_repr(self): - pass - - def q_params(self): - pass - - def half(self): - return self.to(torch.float16) - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def _apply_fn_to_data(self, fn): - pass - - def _change_shape(self): - pass - - def __tensor_flatten__(self): - pass - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - pass - - @classmethod - def from_float(cls, input_float): - pass - - # __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func is torch.nn.functional.linear: - mat1, w_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - assert not w_qtensor.transposed - return cls._quantized_op(mat1, w_qtensor, bias) - - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except Exception: - print(f"ERR: subclass doesn't implement {func}") - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation - # for consistency and to allow people to test - # 2 - we're given non-floats - quantizing long to int8 is crazy - if ( - func in [aten.mm.default, aten.addmm.default] - and args[0].is_floating_point() - and args[0].is_cuda - ): - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - mat1, w_qtensor, bias = ( - args[1], - args[2], - args[0], - ) - else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - mat1, w_qtensor, bias = ( - args[0], - args[1], - None if len(args) == 2 else args[2], - ) - # call the quantized op for the specific type - # of quantized tensor subclass - return cls._quantized_op(mat1, w_qtensor, bias) - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.t.default: - args[0].transposed = not args[0].transposed - new = args[0]._change_shape(args[0].shape[::-1]) - return return_and_correct_aliasing(func, args, kwargs, new) - - if func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - - -class ConstructTensorSubclass(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.args = args - self.kwargs = kwargs - - def forward(self, x): - pass - - def right_inverse(self, tensor_subclass_instance): - fields, _ = tensor_subclass_instance.__tensor_flatten__() - return [getattr(tensor_subclass_instance, field) for field in fields] - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int8dyn(*args, **kwargs): - return Int8DynamicallyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt8Dyn(ConstructTensorSubclass): - def forward(self, int_data, q_scales): - return from_qtensor_components_int8dyn( - int_data, q_scales, *self.args, **self.kwargs - ) - - -class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, changes the - linear op to a dynamically quantized linear op with symmetric per-token and per-channel - quantization on the activation and weight respectively. - """ - - subclass_constructor = ConstructTensorSubclassInt8Dyn - - @staticmethod - def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - if dtype is None: - dtype = q_scales.dtype - kwargs["dtype"] = dtype - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - self.q_scales = q_scales - super().__init__(int_data, transposed) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return _quant_int8_dynamic_per_token_linear( - act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype - ) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - zero_points = torch.zeros( - self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype - ) - # zero_points = 0 - # TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg? - dq_t = dequantize_per_channel( - self.int_data.t(), - self.q_scales, - zero_points, - self.dtype if dtype is None else dtype, - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype, - ) - - # `QuantizedLinearWeightBase` inconsistently. - - def _change_shape(self, shape): - return self.__class__( - self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype - ) - - def __tensor_flatten__(self): - # note: the order of args must match the order of args in __init__ - return ["int_data", "q_scales"], [self.transposed, self.shape, self.dtype] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None - ): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - transposed, shape, dtype = tensor_attributes - return cls( - int_data, - q_scales, - transposed, - shape if outer_size is None else outer_size, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127, dtype=None): - """ - Method used to convert a linear weight tensor to an instance of the - Int8DynamicallyQuantizedLinearWeight subclass. - - Example usage:: - - model.lin_mod.weight = ( - Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) - ) - """ - if dtype is None: - dtype = input_float.dtype - - # because we call transpose in dequantization - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - # the desired representation shape for fast quantized matmul is - # transposed compared to how it's stored as a linear weight, - # i.e. we want in_channels as dim=0 and out_channels (and quantized axis) as dim=1 - # however the external representation of our tensor will maintain the correct - # shape attribute which needs to be tracked directly. - int_data = w_int_repr.contiguous().t() - if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): - int_data = int_data.contiguous() - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=dtype, - ) - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int8wo(*args, **kwargs): - return Int8WeightOnlyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt8wo(ConstructTensorSubclass): - def forward(self, int_data, q_scales): - return from_qtensor_components_int8wo( - int_data, q_scales, *self.args, **self.kwargs - ) - - -class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, - changes the linear op to a weight-only quantized linear op with symmetric - per-channel quantization on the weight. - """ - - subclass_constructor = ConstructTensorSubclassInt8wo - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - orig_dtype = act_mat.dtype - y = ( - torch.mm( - act_mat.reshape(-1, act_mat.shape[-1]), - w_qtensor.int_data.to(act_mat.dtype), - ) - * w_qtensor.q_scales - ) - y = y.reshape(*act_mat.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias - return y.to(orig_dtype) - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int4wo(*args, **kwargs): - return Int4WeightOnlyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt4wo(ConstructTensorSubclass): - def forward(self, int_data, scales_and_zeros): - return from_qtensor_components_int4wo( - int_data, scales_and_zeros, *self.args, **self.kwargs - ) - - -class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, - changes that linear op to a weight-only int4 quantized linear op with groupwise - affine quantization on the weight. - """ - - subclass_constructor = ConstructTensorSubclassInt4wo - - @staticmethod - def __new__( - cls, - int_data, - scales_and_zeros, - transposed, - shape, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - dtype=None, - **kwargs, - ): - if dtype is None: - dtype = scales_and_zeros.dtype - kwargs["dtype"] = dtype - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data, - scales_and_zeros, - transposed, - shape, - groupsize, - inner_k_tiles, - zero_point_domain, - preserve_zero, - dtype, - **kwargs, - ): - # the transposed flag tracks whether the tensor subclass has been transposed relative - # to how a weight is normally stored in a linear i.e. [out_features, in_features]. - # tracking both transposed and shape is slightly redundant but corner cases like - # square matrices can cause issues otherwise - - self.scales_and_zeros = scales_and_zeros - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.zero_point_domain = zero_point_domain - self.preserve_zero = preserve_zero - super().__init__(int_data, transposed) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape and pad activation - act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) - pad_size = find_multiple(act_mat.shape[-1], 1024) - act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) - - # matmul - if check_cpu_version(act_mat.device): - y = aten._weight_int4pack_mm_for_cpu( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - elif check_xpu_version(act_mat.device): - if not w_qtensor.zero_point_domain == ZeroPointDomain.INT: - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - else: - y = aten._weight_int4pack_mm_with_scales_and_zeros( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros[0], - w_qtensor.scales_and_zeros[1], - ) - else: - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - - # remove out_feature padding - orig_out_features = ( - w_qtensor.shape[-1] if w_qtensor.transposed else w_qtensor.shape[-2] - ) - y = y[:, :orig_out_features] - - y = y.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - y += bias - return y.to(orig_dtype) - - def dequantize(self): - eye_shape = self.shape[1] if not self.transposed else self.shape[0] - w_dq = self._quantized_op( - torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None - ) - # we dequantized using linear with the identity matrix, output has shape [in_channels, out_channels] - # so we need to transpose back to get the original shape unless self.transposed is set. - w_dq = w_dq if self.transposed else w_dq.t() - return w_dq.to(self.dtype) - - def int_repr(self): - return self.int_data - - def q_params(self): - scales, zero_points = unpack_tinygemm_scales_and_zeros( - self.scales_and_zeros, - ) - return {"q_scales": scales, "q_zero_points": zero_points} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scales_and_zeros.to(kwargs["device"]), - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scales_and_zeros), - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - dtype=self.dtype, - ) - - # `QuantizedLinearWeightBase` inconsistently. - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.scales_and_zeros, - self.transposed, - shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - dtype=self.dtype, - ) - - def __tensor_flatten__(self): - return ["int_data", "scales_and_zeros"], ( - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - self.dtype, - ) - - @classmethod - - # `QuantizedLinearWeightBase` inconsistently. - - def __tensor_unflatten__( - cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None - ): - int_data, scales_and_zeros = ( - tensor_data_dict["int_data"], - tensor_data_dict["scales_and_zeros"], - ) - ( - transposed, - shape, - groupsize, - inner_k_tiles, - zero_point_domain, - preserve_zero, - dtype, - ) = attributes - return cls( - int_data, - scales_and_zeros, - transposed, - shape if outer_size is None else outer_size, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float( - cls, - input_float, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - dtype=None, - ): - """ - Method used to convert a linear weight tensor to an instance of the - Int4WeightOnlyQuantizedLinearWeight subclass. - - Example usage:: - - model.lin_mod.weight = ( - Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight) - ) - """ - if dtype is None: - dtype = input_float.dtype - - int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = ( - cls.to_qtensor_components( - input_float, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ) - ) - return cls( - int_data, - scales_and_zeros, - transposed, - input_float.shape, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - dtype=dtype, - ) - - @classmethod - def to_qtensor_components( - cls, - input_float, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - ): - assert groupsize in [256, 128, 64, 32] - assert inner_k_tiles in [8, 4, 2] - orig_out_features, orig_in_features = input_float.shape - - # padding - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input_float = torch.nn.functional.pad( - input_float, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - - # quantization and packing - input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( - input_float, - 4, - groupsize, - dtype=input_float.dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ) - if check_cpu_version(input_float.device): - int_data = aten._convert_weight_to_int4pack_for_cpu( - input_int4x8, inner_k_tiles - ) - else: - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) - return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index b4b1a1087d..c5e2a8e704 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -42,8 +42,6 @@ __all__ = [ "compute_error", "_quantize_activation_per_token_absmax", - "_quant_int8_dynamic_per_token_linear", - "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", "get_groupwise_affine_qparams", @@ -192,26 +190,6 @@ def _quantize_activation_per_token_absmax(t): return quantized, scale -def _quant_int8_dynamic_per_token_linear( - x, - w_vals_int8_t, - w_scales, - bias, - out_dtype, -): - """ - like F.linear, but with int8 dynamic quantization of activation, - and a quantized weight - """ - x_vals_int8, x_scales = _quantize_activation_per_token_absmax(x) - mm_out = _quant_int8_per_token_matmul( - x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype - ) - if bias is not None: - mm_out = mm_out + bias - return mm_out - - def _quant_int8_per_token_matmul( x_vals_int8, x_scales, @@ -272,37 +250,6 @@ def _quant_int8_per_token_matmul( return y -def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - """ - assumes symmetric quantization - assumes axis == 0 - assumes dense memory format - TODO(future): relax ^ as needed - """ - - assert x.dim() == 2, "only support 2d Tensors" - - eps = torch.finfo(torch.float32).eps - block_size = (1, x.shape[1]) - zero_point_dtype = torch.int64 - - mapping_type = MappingType.SYMMETRIC - scale, zero_point = choose_qparams_affine( - x, - mapping_type, - block_size, - target_dtype=target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) - quant = quantize_affine( - x, block_size, scale, zero_point, target_dtype, quant_min, quant_max - ) - return quant, scale, zero_point - - # reference: https://fburl.com/code/vfsygwd0 def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32): block_size = int_repr.shape diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py deleted file mode 100644 index fb30c14936..0000000000 --- a/torchao/quantization/weight_only.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from .utils import dynamically_quantize_per_channel - -__all__ = ["WeightOnlyInt8QuantLinear"] - - -class WeightOnlyInt8QuantLinear(torch.nn.Linear): - """ - This class is a replacement for `torch.nn.Linear`. It implements a - mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. - - The primary goal of this class is to leverage int8 quantization for weights to reduce the - memory footprint and computational requirements while performing linear transformations. - This can be particularly beneficial for deploying models in low latency environments - - Attributes: - w_int8 (torch.Tensor): The quantized weights in int8 format. - scales (torch.Tensor): The scaling factors for each channel to convert the quantized - weights back to floating point format during the forward pass. - """ - - def __init__(self, *args, **kwargs): - """ - Initializes the WeightOnlyInt8QuantLinear module. - - Args: - *args: Variable length argument list for `torch.nn.Linear`. - **kwargs: Arbitrary keyword arguments. - Must include 'w_int8' (int8 quantized weights) and 'scales' (scaling factors). - """ - w_int8 = kwargs.pop("w_int8") - scales = kwargs.pop("scales") - super().__init__(*args, **kwargs) - - self.register_buffer("w_int8", w_int8) - self.register_buffer("scales", scales) - - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Performs the forward pass of the quantized linear layer, which consists of - mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. - - Args: - x (torch.Tensor): The input floating point tensor to the quantized linear layer. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - torch.Tensor: The output floating point tensor after the quantized matrix multiplication - and rescale. - """ - x_view = x.view(-1, x.shape[-1]) - y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales - y = y.reshape(*x.shape[:-1], -1) - if self.bias is not None: - y += self.bias - return y - - @classmethod - def from_float(cls, mod: torch.nn.Linear): - """ - Converts a `torch.nn.Linear` module to a `WeightOnlyInt8QuantLinear` module. - - This method performs the conversion by dynamically quantizing the weights of the original - floating point linear layer to int8 format and creating a new `WeightOnlyInt8QuantLinear` - instance with these quantized weights and the corresponding scaling factors. - - Args: - mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. - - Returns: - WeightOnlyInt8QuantLinear: The converted quantized linear module with int8 weights. - """ - w_fp32 = mod.weight - w_int8, scales, _zp = dynamically_quantize_per_channel( - w_fp32, -128, 127, torch.int8 - ) - # Create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, - fake_out_features, - bias=mod.bias is not None, - w_int8=w_int8.t().contiguous(), - scales=scales, - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - del new_mod.weight - new_mod.bias = mod.bias - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod