From b2e99eca18f6284922eb2aa5bdd86ec52ed70e4c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 10:58:59 +0800 Subject: [PATCH 01/45] initial commit --- torchao/prototype/quantized_training/int8.py | 64 ++++---- .../int8_mixed_precision.py | 125 +++++++++++++++ .../prototype/quantized_training/int8_mm.py | 142 ++++++++++++++++++ 3 files changed, 298 insertions(+), 33 deletions(-) create mode 100644 torchao/prototype/quantized_training/int8_mixed_precision.py create mode 100644 torchao/prototype/quantized_training/int8_mm.py diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index c301f011c2..e4cb8b705f 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -12,13 +12,41 @@ _c10d_functional = torch.ops._c10d_functional +@torch.no_grad() +def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): + """Normal rounding will always round down small changes in weight update. To tackle this problem, + stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The + probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next + integer value. Thus, stochastic rounding also approximates the floating point value exactly. + + Currently this function differs from AQT's `int8_weight_only()` in the following way: + 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input + to FP32 before quantization. Output scale maintains the original input dtype. + 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is + done here. + 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. + """ + # absmax symmetric quantization + scale = tensor.abs().amax(1) / 127 # same dtype as tensor + inv_scale = 1.0 / scale.float().clip(1e-12) + tensor = tensor.float() * inv_scale.view(-1, 1) # slightly faster than divide directly + + if stochastic_rounding: + tensor = (tensor + torch.rand_like(tensor)).floor() + else: + tensor = tensor.round() + + tensor = tensor.clip(-128, 127).to(torch.int8) + return tensor, scale + + class Int8QTLinearWeight(Tensor): """INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference of this tensor subclass from AffineQuantizedTensor: 1. `F.linear` is differentiable i.e. backward is defined. 2. All in-place ops, such as `aten.copy_`, will perform stochastic rounding. `Int8QTLinearWeight.from_float()` does not perform stochastic rounding. - 3. The numerics for quantization is slightly different. See `Int8QTLinearWeight.quantize()` + 3. The numerics for quantization is slightly different. See `quantize_int8_rowwise()` for more details. """ @@ -55,42 +83,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) - @staticmethod - @torch.no_grad() - def quantize(tensor: Tensor, stochastic_rounding: bool = False): - """Normal rounding will always round down small changes in weight update. To tackle this problem, - stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The - probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next - integer value. Thus, stochastic rounding also approximates the floating point value exactly. - - Currently this function differs from AQT's `int8_weight_only()` in the following way: - 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input - to FP32 before quantization, and downcast scale to original dtype. - 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is - done here. - 3. Apply scale: AQT uses `input * (1 / scale)`, while this function performs `input / scale`. - """ - original_dtype = tensor.dtype - tensor = tensor.float() - - # absmax symmetric quantization - scale = tensor.abs().amax(-1) / 127 - tensor = tensor / scale.clip(1e-12).view(-1, 1) - - if stochastic_rounding: - tensor = (tensor + torch.rand_like(tensor)).floor() - else: - tensor = tensor.round() - - tensor = tensor.clip(-128, 127).to(torch.int8) - return tensor, scale.to(original_dtype) - @classmethod def from_float(cls, tensor: Tensor): """Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed. This function is not differentiable. """ - int_data, scale = cls.quantize(tensor.detach()) + int_data, scale = quantize_int8_rowwise(tensor.detach()) out = cls(int_data, scale) out.requires_grad_(tensor.requires_grad) return out @@ -201,7 +199,7 @@ def _(func, types, args, kwargs): args[0].scale.copy_(args[1].scale, **kwargs) elif isinstance(args[0], Int8QTLinearWeight): - int_data, scale = Int8QTLinearWeight.quantize(args[1], stochastic_rounding=True) + int_data, scale = quantize_int8_rowwise(args[1], stochastic_rounding=True) args[0].int_data.copy_(int_data, **kwargs) args[0].scale.copy_(scale, **kwargs) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py new file mode 100644 index 0000000000..a3020c4974 --- /dev/null +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -0,0 +1,125 @@ +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.utils._triton import has_triton + +from .int8 import quantize_int8_rowwise + +if has_triton(): + from .int8_mm import int8_mm_dequant + +else: + + def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: + return (A * A_scale_rowwise.view(-1, 1)) @ (B * B_scale_colwise.view(1, -1)) + + +aten = torch.ops.aten + + +class Int8MixedPrecisionConfig(NamedTuple): + forward: bool = False + backward_grad_input: bool = False + backward_grad_weight: bool = False + + +class Int8MixedPrecisionLinearWeight(Tensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): + return Tensor._make_wrapper_subclass( + cls, + data.shape, + dtype=data.dtype, + device=data.device, + ) + + @torch._dynamo.disable + def __init__(self, data: Tensor, config: Int8MixedPrecisionConfig): + self._data = data + self.config = config + + def __tensor_flatten__(self): + return ["_data"], [self.config] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["_data"], *tensor_attributes) + + def __repr__(self): + return self._data.__repr__() + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or dict() + + if func is F.linear: + return _Int8MixedPrecisionLinear.apply(*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func in (aten.detach.default, aten.clone.default, aten._to_copy.default): + return cls(func(args[0]._data, *args[1:], **kwargs), args[0].config) + + # TODO: some ops should return the original class i.e. in-place ops + args = [x._data if isinstance(x, cls) else x for x in args] + return func(*args, **kwargs) + + +class _Int8MixedPrecisionLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Tensor | None = None): + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + + if weight.config.forward: + batch_dims = input.shape[:-1] + input = input.view(-1, weight.shape[1]) + input_i8, input_scale = quantize_int8_rowwise(input) + weight_i8, weight_scale = quantize_int8_rowwise(weight) + out = int8_mm_dequant(input_i8, weight_i8.T, input_scale, weight_scale) + out = out.view(*batch_dims, weight.shape[0]) + else: + out = input @ weight.T + + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + weight: Int8MixedPrecisionLinearWeight + sr = weight.config.stochastic_rounding + + batch_dims = grad_output.shape[:-1] + grad_output = grad_output.view(-1, weight.shape[0]) + input = input.view(-1, weight.shape[1]) + + if ctx.needs_input_grad[0]: + if weight.config.backward_grad_input: + grad_output_i8, grad_output_scale = quantize_int8_rowwise(grad_output) + weight_i8_t, weight_scale = quantize_int8_rowwise(weight.T) + grad_input = int8_mm_dequant(grad_output_i8, weight_i8_t.T, grad_output_scale, weight_scale) + else: + grad_input = grad_output @ weight + grad_input = grad_input.view(*batch_dims, weight.shape[1]) + + if ctx.needs_input_grad[1]: + if weight.config.backward_grad_weight: + grad_output_i8_t, grad_output_scale = quantize_int8_rowwise(grad_output.T) + input_i8_t, input_scale = quantize_int8_rowwise(input.T) + grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) + else: + grad_weight = grad_output.T @ input + + if ctx.needs_input_grad[2] and ctx.bias: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py new file mode 100644 index 0000000000..1f8984e66a --- /dev/null +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -0,0 +1,142 @@ +# TODO: might merge this with torchao/kernel/intmm_triton.py + +import torch +import triton +import triton.language as tl +from torch import Tensor + +lib = torch.library.Library("torchao", "FRAGMENT") + + +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) +configs = [ + (128, 256, 64, 3, 8), + (64, 256, 32, 4, 4), + (128, 128, 32, 4, 4), + (128, 64, 32, 4, 4), + (64, 128, 32, 4, 4), + (128, 32, 32, 4, 4), + (64, 32, 32, 5, 2), + (32, 64, 32, 5, 2), + # Good config for fp8 inputs + (128, 256, 128, 3, 8), + (256, 128, 128, 3, 8), + (256, 64, 128, 4, 4), + (64, 256, 128, 4, 4), + (128, 128, 128, 4, 4), + (128, 64, 64, 4, 4), + (64, 128, 64, 4, 4), + (128, 32, 64, 4, 4), + # https://github.com/pytorch/pytorch/blob/7868b65c4d4f34133607b0166f08e9fbf3b257c4/torch/_inductor/kernel/mm_common.py#L172 + (64, 64, 32, 2, 4), + (64, 128, 32, 3, 4), + (128, 64, 32, 3, 4), + (64, 128, 32, 4, 8), + (128, 64, 32, 4, 8), + (64, 32, 32, 5, 8), + (32, 64, 32, 5, 8), + (128, 128, 32, 2, 8), + (64, 64, 64, 3, 8), + (128, 256, 128, 3, 8), + (256, 128, 128, 3, 8), +] + +configs = [ + triton.Config(dict(BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K), num_stages=num_stages, num_warps=num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps in configs +] + + +@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) +@triton.jit +def _int8_mm_dequant_kernel( + # fmt: off + A_ptr, B_ptr, C_ptr, + A_scale_rowwise_ptr, + B_scale_colwise_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr = 8, + EVEN_K: tl.constexpr = True, + # fmt: on +): + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A_ptr + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B_ptr + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32) + b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32) + acc = acc.to(tl.float32) * a_scale * b_scale + + # inductor generates a suffix + xindex = idx_m * stride_cm + idx_n * stride_cn + tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask) + + +lib.define("int8_mm_dequant(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") + + +def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: + return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) + + +@torch.library.impl(lib, "int8_mm_dequant", "Meta") +def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): + return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=A_scale_rowwise.dtype) + + +@torch.library.impl(lib, "int8_mm_dequant", "CUDA") +def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): + assert A.dtype is torch.int8 and B.dtype is torch.int8 + assert A.shape[1] == B.shape[0] + M, K = A.shape + _, N = B.shape + assert A_scale_rowwise.squeeze().shape == (M,) + assert B_scale_colwise.squeeze().shape == (N,) + C = torch.empty(M, N, device=A.device, dtype=A_scale_rowwise.dtype) + grid = lambda meta: (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + _int8_mm_dequant_kernel[grid]( + A, B, C, A_scale_rowwise, B_scale_colwise, M, N, K, *A.stride(), *B.stride(), *C.stride(), EVEN_K=K % 2 == 0 + ) + return C From 255abe9c9881f6ded22a8d7a7ceb1686c6b93c08 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 11:09:44 +0800 Subject: [PATCH 02/45] expose some UX. update test --- test/prototype/test_quantized_training.py | 13 +++++++----- .../prototype/quantized_training/__init__.py | 11 +++++++++- .../int8_mixed_precision.py | 21 ++++++++++++++----- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 6b4b6a6be9..51543a3901 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -9,7 +9,10 @@ from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests from torchao.prototype.low_bit_optim import _AdamW -from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training +from torchao.prototype.quantized_training import ( + int8_weight_only_quantized_training, + quantize_int8_rowwise, +) from torchao.quantization.quant_api import quantize_ from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 @@ -35,7 +38,7 @@ def test_int8_stochastic_rounding(self, device): x = torch.randn(32, device=device) x_samples = x.view(1, -1).repeat(100_000, 1) - x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) + x_int8, x_scale = quantize_int8_rowwise(x_samples, stochastic_rounding=True) x_dequant_samples = x_int8 * x_scale.view(-1, 1) x_dequant_mean = x_dequant_samples.mean(0) @@ -46,7 +49,7 @@ def test_int8_stochastic_rounding(self, device): @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear(self, leading_dims, bias, device): + def test_int8_weight_only_linear(self, leading_dims, bias, device): _reset() embed_dim = 32 @@ -77,7 +80,7 @@ def test_int8_linear(self, leading_dims, bias, device): @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_compile(self, leading_dims, bias, device): + def test_int8_weight_only_linear_compile(self, leading_dims, bias, device): _reset() embed_dim = 128 @@ -105,7 +108,7 @@ def test_int8_linear_compile(self, leading_dims, bias, device): @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) - def test_int8_linear_training(self, compile, device): + def test_int8_weight_only_linear_training(self, compile, device): _reset() bsize = 4 embed_dim = 32 diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index 6c7f8eb9b1..99e5d3a876 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1 +1,10 @@ -from .int8 import Int8QTLinearWeight, int8_weight_only_quantized_training +from .int8 import ( + Int8QTLinearWeight, + int8_weight_only_quantized_training, + quantize_int8_rowwise, +) +from .int8_mixed_precision import ( + Int8MixedPrecisionConfig, + Int8MixedPrecisionLinearWeight, + int8_mixed_precision_training, +) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index a3020c4974..ad45eebca3 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -1,8 +1,7 @@ from typing import NamedTuple import torch -import torch.nn.functional as F -from torch import Tensor +from torch import Tensor, nn from torch.utils._triton import has_triton from .int8 import quantize_int8_rowwise @@ -55,7 +54,7 @@ def __repr__(self): def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or dict() - if func is F.linear: + if func is torch.nn.functional.linear: return _Int8MixedPrecisionLinear.apply(*args, **kwargs) with torch._C.DisableTorchFunctionSubclass(): @@ -94,9 +93,7 @@ def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Te def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - weight: Int8MixedPrecisionLinearWeight - sr = weight.config.stochastic_rounding batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight.shape[0]) @@ -123,3 +120,17 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(0) return grad_input, grad_weight, grad_bias + + +def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = Int8MixedPrecisionConfig()): + # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` + # when we have this out of prototype (or there are stable trainable tensor subclasses), + # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. + def apply_int8_linear_weight(linear: nn.Linear): + linear.weight = nn.Parameter( + Int8MixedPrecisionLinearWeight(linear.weight.detach(), config), + requires_grad=linear.weight.requires_grad, + ) + return linear + + return apply_int8_linear_weight From efb53bff39cae3aff3bc54cea5f2fe37fa0a8427 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 03:56:03 +0000 Subject: [PATCH 03/45] add test. update bench --- .../quantized_training/pretrain_llama2.py | 9 +- test/prototype/test_quantized_training.py | 91 +++++++++++++------ .../int8_mixed_precision.py | 2 + 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index de3ed04e8f..0b2692d6f9 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -20,7 +20,11 @@ from torchao._models.llama.model import ModelArgs, Transformer from torchao.prototype import low_bit_optim -from torchao.prototype.quantized_training import int8_weight_only_quantized_training +from torchao.prototype.quantized_training import ( + Int8MixedPrecisionConfig, + int8_mixed_precision_training, + int8_weight_only_quantized_training, +) from torchao.quantization.quant_api import quantize_ @@ -118,6 +122,9 @@ def get_tinystories(): enable_activation_checkpointing(layer) if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + elif args.quantize == "int8_mixed_precision": + cfg = Int8MixedPrecisionConfig(True, True, True) + quantize_(model, int8_mixed_precision_training(), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 51543a3901..34cfa7c5ca 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -11,7 +11,9 @@ from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import ( int8_weight_only_quantized_training, + int8_mixed_precision_training, quantize_int8_rowwise, + Int8MixedPrecisionConfig, ) from torchao.quantization.quant_api import quantize_ from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 @@ -46,10 +48,18 @@ def test_int8_stochastic_rounding(self, device): # due to the statistical nature, this assertion may still fail, though very rarely. torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) + @staticmethod + def _forward_and_backward(module, input, grad): + # clone input, since we want to inspect its gradient later + input = input.detach().clone().requires_grad_(True) + output = module(input) + output.backward(grad) + return input, output + @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_weight_only_linear(self, leading_dims, bias, device): + def test_int8_weight_only_correctness(self, leading_dims, bias, device): _reset() embed_dim = 32 @@ -58,20 +68,13 @@ def test_int8_weight_only_linear(self, leading_dims, bias, device): quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) linear_fp32.weight.data = linear_int8.weight.data.dequantize() - input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) - input_int8 = input_fp32.clone() - input_fp32.requires_grad_(True) - input_int8.requires_grad_(True) + input = torch.randn(leading_dims + (embed_dim,), device=device) + grad = torch.randn(leading_dims + (embed_dim,), device=device) - # test forward - out_fp32 = linear_fp32(input_fp32) - out_int8 = linear_int8(input_int8) - torch.testing.assert_close(out_fp32, out_int8) + input_fp32, out_fp32 = self._forward_and_backward(linear_fp32, input, grad) + input_int8, out_int8 = self._forward_and_backward(linear_int8, input, grad) - # test backward - grad = torch.randn(leading_dims + (embed_dim,), device=device) - out_fp32.backward(grad) - out_int8.backward(grad) + torch.testing.assert_close(out_fp32, out_int8) torch.testing.assert_close(input_fp32.grad, input_int8.grad) torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) if bias: @@ -80,7 +83,7 @@ def test_int8_weight_only_linear(self, leading_dims, bias, device): @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) @parametrize("device", _DEVICES) - def test_int8_weight_only_linear_compile(self, leading_dims, bias, device): + def test_int8_weight_only_compile(self, leading_dims, bias, device): _reset() embed_dim = 128 @@ -89,18 +92,13 @@ def test_int8_weight_only_linear_compile(self, leading_dims, bias, device): linear_compiled = copy.deepcopy(linear_eager) linear_compiled.compile() - input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 - input_compiled = input_eager.clone() - input_eager.requires_grad_(True) - input_compiled.requires_grad_(True) + input = torch.randn(leading_dims + (embed_dim,), device=device) * 10 + grad = torch.randn(leading_dims + (embed_dim,), device=device) - out_eager = linear_eager(input_eager) - out_compiled = linear_compiled(input_compiled) - torch.testing.assert_close(out_eager, out_compiled) + input_eager, out_eager = self._forward_and_backward(linear_eager, input, grad) + input_compiled, out_compiled = self._forward_and_backward(linear_compiled, input, grad) - grad = torch.randn(leading_dims + (embed_dim,), device=device) - out_eager.backward(grad) - out_compiled.backward(grad) + torch.testing.assert_close(out_eager, out_compiled) torch.testing.assert_close(input_eager.grad, input_compiled.grad) torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) if bias: @@ -108,7 +106,7 @@ def test_int8_weight_only_linear_compile(self, leading_dims, bias, device): @parametrize("compile", [False, True]) @parametrize("device", _DEVICES) - def test_int8_weight_only_linear_training(self, compile, device): + def test_int8_weight_only_training(self, compile, device): _reset() bsize = 4 embed_dim = 32 @@ -120,7 +118,6 @@ def test_int8_weight_only_linear_training(self, compile, device): nn.Linear(embed_dim * 2, n_classes), ).to(device) model_int8 = copy.deepcopy(model_fp32) - # don't set inductor flags to speed up CI time quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) if compile: @@ -147,6 +144,48 @@ def test_int8_weight_only_linear_training(self, compile, device): optim_int8.step() optim_int8.zero_grad() + @parametrize("compile", [False, True]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_int8_mixed_precision_training(self, compile): + _reset() + bsize = 4 + embed_dim = 32 + device = "cuda" + config = Int8MixedPrecisionConfig(True, True, True) + + # only use 1 matmul shape to reduce triton autotune time + model_ref = nn.Sequential( + nn.Linear(embed_dim, embed_dim, bias=False), + nn.GELU(), + nn.Linear(embed_dim, embed_dim), + ).to(device) + model_int8mp = copy.deepcopy(model_ref) + quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + + if compile: + model_ref.compile() + model_int8mp.compile() + + optim_ref = torch.optim.AdamW(model_ref.parameters()) + optim_int8mp = torch.optim.AdamW(model_int8mp.parameters()) + + for i in range(5): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(embed_dim, size=(bsize,), device=device) + loss_ref = F.cross_entropy(model_ref(inputs), labels) + loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels) + + rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item()) + assert rel_error < 3e-2, (i, rel_error) + + loss_ref.backward() + optim_ref.step() + optim_ref.zero_grad() + + loss_int8mp.backward() + optim_int8mp.step() + optim_int8mp.zero_grad() + class TestFSDP2(FSDPTest): @property diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index ad45eebca3..784c245555 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -122,6 +122,8 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias +# NOTE: should default config set all to True instead? -> speedup out-of-the-box. +# only if there are convergence issues, turn off some INT8 matmuls in backward. def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = Int8MixedPrecisionConfig()): # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` # when we have this out of prototype (or there are stable trainable tensor subclasses), From 0a510f511e56004b86735bf41bfe935aafdd9673 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 05:29:04 +0000 Subject: [PATCH 04/45] update test. add doc --- .../quantized_training/pretrain_llama2.py | 5 ++- test/prototype/test_quantized_training.py | 30 +++++++++---- .../prototype/quantized_training/README.md | 44 +++++++++++++++++-- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 0b2692d6f9..9c81a586fd 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -120,11 +120,12 @@ def get_tinystories(): if args.activation_checkpointing: for layer in model.layers: enable_activation_checkpointing(layer) + # NOTE: don't apply to LM head since there are memory issues. if args.quantize == "int8_weight_only": - quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + quantize_(model.layers, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize == "int8_mixed_precision": cfg = Int8MixedPrecisionConfig(True, True, True) - quantize_(model, int8_mixed_precision_training(), set_inductor_config=False) + quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 34cfa7c5ca..c0e93e2794 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -16,9 +16,9 @@ Int8MixedPrecisionConfig, ) from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 -if not TORCH_VERSION_AFTER_2_3: +if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Requires torch>=2.4", allow_module_level=True) @@ -190,21 +190,29 @@ def test_int8_mixed_precision_training(self, compile): class TestFSDP2(FSDPTest): @property def world_size(self) -> int: - return 2 + return 1 - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(1) def test_fsdp2(self): # FSDP2 + compiled quantized training fails with PyTorch 2.4 compile_layer_choices = [False] - if TORCH_VERSION_AFTER_2_4: + if TORCH_VERSION_AT_LEAST_2_5: compile_layer_choices.append(True) + # need to run separately since they will timeout self.run_subtests( {"compile_layer": compile_layer_choices}, self._test_fsdp2, + quantize_fn=int8_weight_only_quantized_training(), ) - - def _test_fsdp2(self, compile_layer): + # TODO: fix FSDP ops. when sharding, need to return the original class + # self.run_subtests( + # {"compile_layer": compile_layer_choices}, + # self._test_fsdp2, + # quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, True, True)), + # ) + + def _test_fsdp2(self, quantize_fn, compile_layer): import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer @@ -223,7 +231,7 @@ def _test_fsdp2(self, compile_layer): ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() - quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) + quantize_(base_model, quantize_fn, set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) if compile_layer: @@ -236,6 +244,11 @@ def _test_fsdp2(self, compile_layer): fully_shard(layer) fully_shard(fsdp_model) + for m in fsdp_model.modules(): + if isinstance(m, nn.Linear): + print(m.weight) + print(m.weight._local_tensor) + base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) @@ -256,6 +269,7 @@ def _test_fsdp2(self, compile_layer): base_optim.step() # due to stochastic rounding, use a pretty large tolerance here + # TODO: might want to use difference tolerance for different quantize_fn rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() assert rel_error < 0.05, rel_error diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 9b2980aa2b..78e470e524 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -2,7 +2,7 @@ This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from: - Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] -- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] +- JetFire: [[paper](https://arxiv.org/abs/2403.12422)] [[code](https://github.com/thu-ml/Jetfire-INT8Training)] Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. @@ -23,7 +23,7 @@ Usage ```python from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.prototype.low_bit_optim import _AdamW -from torchao.quantization.quant_api import quantize_ +from torchao.quantization import quantize_ model = ... quantize_(model, int8_weight_only_quantized_training()) @@ -46,8 +46,46 @@ BF16 compile | 10.16915 INT8 QT eager | 10.11437 INT8 QT compile | 10.03365 +## INT8 mixed-precision + +On NVIDIA GPUs, INT8 Tensor Cores can be up to 3x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired from prior works: + +- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] +- SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] + +Usage + +```python +from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionConfig +from torchao.quantization import quantize_ + +model = ... +config = Int8MixedPrecisionConfig( + forward=True, + backward_grad_input=True, + backward_grad_weight=True, +) +quantize_(model, int8_mixed_precision_training(config)) + +# train model as usual +``` + +During training, there are 3 matmuls involved in each `nn.Linear` layer: +- 1 in forward: `output = input @ weight.T` +- 2 in backward: + - `grad_input = grad_output @ weight` + - `grad_weight = grad_output.T @ input` + +You can configure which matmul to be applied with INT8 mixed-precision using `Int8MixedPrecisionConfig` shown above. If convergence is an issue, we recommend leaving `backward_grad_weight` in original matmul precision, and also `backward_grad_input` if the issue still persists. + +Note: +- When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT. +- When we only apply INT8 mixed-precision to `forward` and `backward_grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. + +TODO: add some benchmarks + ## Future ideas -- INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores. +- Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). - INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). - FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. From f80ea8ce861cd9d52077191682d14f7f98198b68 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 05:35:18 +0000 Subject: [PATCH 05/45] fix ngpu --- test/prototype/test_quantized_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index c0e93e2794..f6fef27b03 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -190,9 +190,9 @@ def test_int8_mixed_precision_training(self, compile): class TestFSDP2(FSDPTest): @property def world_size(self) -> int: - return 1 + return 2 - @skip_if_lt_x_gpu(1) + @skip_if_lt_x_gpu(2) def test_fsdp2(self): # FSDP2 + compiled quantized training fails with PyTorch 2.4 compile_layer_choices = [False] From 4a404ce1548ab05546ffa99a33606c4480bdd803 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 06:21:52 +0000 Subject: [PATCH 06/45] fix FSDP --- test/prototype/test_quantized_training.py | 24 ++-- .../int8_mixed_precision.py | 112 ++++++++++++++---- 2 files changed, 100 insertions(+), 36 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index f6fef27b03..6956228020 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -200,19 +200,21 @@ def test_fsdp2(self): compile_layer_choices.append(True) # need to run separately since they will timeout + # due to stochastic rounding, use a pretty large tolerance here self.run_subtests( {"compile_layer": compile_layer_choices}, self._test_fsdp2, quantize_fn=int8_weight_only_quantized_training(), + tolerance=0.05, ) - # TODO: fix FSDP ops. when sharding, need to return the original class - # self.run_subtests( - # {"compile_layer": compile_layer_choices}, - # self._test_fsdp2, - # quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, True, True)), - # ) - - def _test_fsdp2(self, quantize_fn, compile_layer): + self.run_subtests( + {"compile_layer": compile_layer_choices}, + self._test_fsdp2, + quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, True, True)), + tolerance=1e-6 + ) + + def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer @@ -267,11 +269,9 @@ def _test_fsdp2(self, quantize_fn, compile_layer): if param.grad is not None: dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() - - # due to stochastic rounding, use a pretty large tolerance here - # TODO: might want to use difference tolerance for different quantize_fn + rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < 0.05, rel_error + assert rel_error < tolerance, rel_error instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 784c245555..4d40965bc0 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -1,9 +1,12 @@ -from typing import NamedTuple +from typing import Any, NamedTuple, Optional, Tuple import torch from torch import Tensor, nn +from torch.utils._python_dispatch import return_and_correct_aliasing from torch.utils._triton import has_triton +from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements + from .int8 import quantize_int8_rowwise if has_triton(): @@ -16,6 +19,8 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional class Int8MixedPrecisionConfig(NamedTuple): @@ -25,6 +30,10 @@ class Int8MixedPrecisionConfig(NamedTuple): class Int8MixedPrecisionLinearWeight(Tensor): + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + @staticmethod @torch._dynamo.disable def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): @@ -50,33 +59,89 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No def __repr__(self): return self._data.__repr__() - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or dict() - - if func is torch.nn.functional.linear: - return _Int8MixedPrecisionLinear.apply(*args, **kwargs) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in (aten.detach.default, aten.clone.default, aten._to_copy.default): - return cls(func(args[0]._data, *args[1:], **kwargs), args[0].config) - - # TODO: some ops should return the original class i.e. in-place ops - args = [x._data if isinstance(x, cls) else x for x in args] - return func(*args, **kwargs) + def to_original(self): + return self._data.clone() + + def fsdp_pre_all_gather(self, mesh): + return (self._data,), (self.config,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + (data,) = all_gather_outputs + (config,) = metadata + return Int8MixedPrecisionLinearWeight(data, config), all_gather_outputs + + +implements = Int8MixedPrecisionLinearWeight.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + return _Int8MixedPrecisionLinear.apply(*args, **kwargs) + + +@implements( + [ + aten.detach.default, + aten.clone.default, + aten._to_copy.default, + # FSDP ops + aten.slice.Tensor, + aten.new_zeros.default, + aten.view.default, + aten.as_strided.default, + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) +def _(func, types, args, kwargs): + out = Int8MixedPrecisionLinearWeight(func(args[0]._data, *args[1:], **kwargs), args[0].config) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@implements( + [ + aten.copy_.default, + aten.addcdiv_.default, + aten.add_.Tensor, + ] +) +def _(func, types, args, kwargs): + unpacked_args = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args] + func(*unpacked_args, **kwargs) + return args[0] + + +# called by optimizers. return a normal tensor +@implements(aten.zeros_like.default) +def _(func, types, args, kwargs): + return func(args[0]._data, *args[1:], **kwargs) + + +# FSDP op +@implements(aten.split.Tensor) +def _(func, types, args, kwargs): + data_list = func(args[0]._data, *args[1:], **kwargs) + return [Int8MixedPrecisionLinearWeight(x, args[0].config) for x in data_list] class _Int8MixedPrecisionLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Tensor | None = None): + def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Optional[Tensor] = None): + ctx.config = weight.config + weight = weight._data ctx.save_for_backward(input, weight) ctx.bias = bias is not None - if weight.config.forward: + if ctx.config.forward: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) input_i8, input_scale = quantize_int8_rowwise(input) @@ -93,14 +158,13 @@ def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Te def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - weight: Int8MixedPrecisionLinearWeight batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight.shape[0]) input = input.view(-1, weight.shape[1]) if ctx.needs_input_grad[0]: - if weight.config.backward_grad_input: + if ctx.config.backward_grad_input: grad_output_i8, grad_output_scale = quantize_int8_rowwise(grad_output) weight_i8_t, weight_scale = quantize_int8_rowwise(weight.T) grad_input = int8_mm_dequant(grad_output_i8, weight_i8_t.T, grad_output_scale, weight_scale) @@ -109,7 +173,7 @@ def backward(ctx, grad_output): grad_input = grad_input.view(*batch_dims, weight.shape[1]) if ctx.needs_input_grad[1]: - if weight.config.backward_grad_weight: + if ctx.config.backward_grad_weight: grad_output_i8_t, grad_output_scale = quantize_int8_rowwise(grad_output.T) input_i8_t, input_scale = quantize_int8_rowwise(input.T) grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) From 42abc152da0f726a645694d23db786474c2bc4e6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 07:36:33 +0000 Subject: [PATCH 07/45] fix --- test/prototype/test_quantized_training.py | 11 ++++------- .../quantized_training/int8_mixed_precision.py | 1 + 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 6956228020..6d26ffc255 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -207,10 +207,12 @@ def test_fsdp2(self): quantize_fn=int8_weight_only_quantized_training(), tolerance=0.05, ) + # triton autotune takes too long. only test with compile_layer=False self.run_subtests( - {"compile_layer": compile_layer_choices}, + dict(), self._test_fsdp2, quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, True, True)), + compile_layer=False, tolerance=1e-6 ) @@ -246,11 +248,6 @@ def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): fully_shard(layer) fully_shard(fsdp_model) - for m in fsdp_model.modules(): - if isinstance(m, nn.Linear): - print(m.weight) - print(m.weight._local_tensor) - base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) @@ -269,7 +266,7 @@ def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): if param.grad is not None: dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() - + rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() assert rel_error < tolerance, rel_error diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 4d40965bc0..fa067e4513 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -112,6 +112,7 @@ def _(func, types, args, kwargs): aten.copy_.default, aten.addcdiv_.default, aten.add_.Tensor, + aten.mul_.Tensor, ] ) def _(func, types, args, kwargs): From e826d487de0a6399e43011d7b089207c63bba7e7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 08:40:35 +0000 Subject: [PATCH 08/45] fix fsdp test --- test/prototype/test_quantized_training.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 6d26ffc255..fc539bf970 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -193,27 +193,24 @@ def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(2) + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, "requires PyTorch>=2.5") def test_fsdp2(self): - # FSDP2 + compiled quantized training fails with PyTorch 2.4 - compile_layer_choices = [False] - if TORCH_VERSION_AT_LEAST_2_5: - compile_layer_choices.append(True) - - # need to run separately since they will timeout # due to stochastic rounding, use a pretty large tolerance here self.run_subtests( - {"compile_layer": compile_layer_choices}, + {"compile_layer": [False, True]}, self._test_fsdp2, quantize_fn=int8_weight_only_quantized_training(), tolerance=0.05, ) + # triton autotune takes too long. only test with compile_layer=False + # and apply INT8 matmul to forward pass only. self.run_subtests( dict(), self._test_fsdp2, - quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, True, True)), + quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, False, False)), compile_layer=False, - tolerance=1e-6 + tolerance=1e-6, ) def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): From 2ab9df307dab93eba532a49f551e0fe78315186b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 09:07:42 +0000 Subject: [PATCH 09/45] fix --- test/prototype/test_quantized_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index fc539bf970..a39c6640c0 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -193,7 +193,7 @@ def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(2) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, "requires PyTorch>=2.5") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch>=2.5") def test_fsdp2(self): # due to stochastic rounding, use a pretty large tolerance here self.run_subtests( From c89b95037d8ec58636a150af12632f18f5ea7210 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 20:33:11 +0800 Subject: [PATCH 10/45] grammar --- torchao/prototype/quantized_training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 78e470e524..ad63b87aee 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -48,7 +48,7 @@ INT8 QT compile | 10.03365 ## INT8 mixed-precision -On NVIDIA GPUs, INT8 Tensor Cores can be up to 3x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired from prior works: +On NVIDIA GPUs, INT8 Tensor Cores can be up to 3x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired by prior works: - AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] - SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] From cde7e8f42d614e7630564681301be758282f1f55 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 26 Aug 2024 21:18:23 +0800 Subject: [PATCH 11/45] simplify fsdp test --- .../quantized_training/pretrain_llama2.py | 2 ++ test/prototype/test_quantized_training.py | 17 +++++------------ .../quantized_training/int8_mixed_precision.py | 2 +- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 9c81a586fd..697d523364 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -120,6 +120,7 @@ def get_tinystories(): if args.activation_checkpointing: for layer in model.layers: enable_activation_checkpointing(layer) + # NOTE: don't apply to LM head since there are memory issues. if args.quantize == "int8_weight_only": quantize_(model.layers, int8_weight_only_quantized_training(), set_inductor_config=False) @@ -128,6 +129,7 @@ def get_tinystories(): quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") + print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index a39c6640c0..22cd6db241 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -16,7 +16,7 @@ Int8MixedPrecisionConfig, ) from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Requires torch>=2.4", allow_module_level=True) @@ -193,11 +193,10 @@ def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(2) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires PyTorch>=2.5") def test_fsdp2(self): # due to stochastic rounding, use a pretty large tolerance here self.run_subtests( - {"compile_layer": [False, True]}, + dict(), self._test_fsdp2, quantize_fn=int8_weight_only_quantized_training(), tolerance=0.05, @@ -209,11 +208,10 @@ def test_fsdp2(self): dict(), self._test_fsdp2, quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, False, False)), - compile_layer=False, tolerance=1e-6, ) - def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): + def _test_fsdp2(self, quantize_fn, tolerance): import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer @@ -229,19 +227,14 @@ def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): vocab_size=vocab_size, max_seq_len=seq_len, dropout_p=0, + weight_tying=False, # INT8 mixed-precision will fail if weight_tying=True ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() quantize_(base_model, quantize_fn, set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) - if compile_layer: - for layer in base_model.layers: - layer.compile() - for layer in fsdp_model.layers: - if compile_layer: - layer.compile() fully_shard(layer) fully_shard(fsdp_model) @@ -265,7 +258,7 @@ def _test_fsdp2(self, quantize_fn, compile_layer, tolerance): base_optim.step() rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < tolerance, rel_error + assert rel_error < tolerance, (iter_idx, rel_error) instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index fa067e4513..12ad8e4e2d 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -57,7 +57,7 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No return cls(tensor_data_dict["_data"], *tensor_attributes) def __repr__(self): - return self._data.__repr__() + return f"{self.__class__.__name__}(data={self._data}, config={self.config})" def to_original(self): return self._data.clone() From 691da9dfe08228a8dc5becec2b4656c0422788fb Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 11:42:56 +0800 Subject: [PATCH 12/45] update benchmark script --- .../quantized_training/pretrain_llama2.py | 19 +++++++++++++------ test/prototype/test_quantized_training.py | 3 +-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 697d523364..43bff73761 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -9,6 +9,7 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import argparse +import time from functools import partial from pathlib import Path @@ -33,11 +34,11 @@ def enable_activation_checkpointing(m: torch.nn.Module): assert not hasattr(m, "_forward") m._forward = m.forward - m.forward = partial(checkpoint, m.forward) + m.forward = partial(checkpoint, m.forward, use_reentrant=False) def get_loss(model: Transformer, batch: torch.Tensor): - logits = model(batch)[:, :-1].flatten(0, 1) + logits = model(batch)[:, :-1].float().flatten(0, 1) labels = batch[:, 1:].flatten() return torch.nn.functional.cross_entropy(logits, labels) @@ -88,6 +89,7 @@ def get_tinystories(): parser.add_argument("--head_dim", type=int, default=64) parser.add_argument("--quantize") + parser.add_argument("--quantize_lm_head", action="store_true") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") @@ -121,17 +123,18 @@ def get_tinystories(): for layer in model.layers: enable_activation_checkpointing(layer) - # NOTE: don't apply to LM head since there are memory issues. + module_to_quantize = model if args.quantize_lm_head else model.layers if args.quantize == "int8_weight_only": - quantize_(model.layers, int8_weight_only_quantized_training(), set_inductor_config=False) + quantize_(module_to_quantize, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize == "int8_mixed_precision": cfg = Int8MixedPrecisionConfig(True, True, True) - quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False) + quantize_(module_to_quantize, int8_mixed_precision_training(cfg), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") + torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights # only use optimizers from torchao.prototype.low_bit_optim to support quantized training if args.optim == "AdamW": @@ -146,6 +149,7 @@ def get_tinystories(): pbar = tqdm(total=args.n_steps, dynamic_ncols=True) model.train() _get_loss = torch.compile(get_loss) if args.compile else get_loss + time0 = time.time() while step < args.n_steps: # randomly select a continuous chunk, then reshape it @@ -160,8 +164,11 @@ def get_tinystories(): loss=loss.item(), lr=optim.param_groups[0]["lr"], max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, - max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9, ) + if step > 0: + time1 = time.time() + log_dict["tokens_per_second"] = (log_interval * args.batch_size * args.seq_len) / (time1 - time0) + time0 = time1 run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 22cd6db241..4a1c79a615 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -202,8 +202,7 @@ def test_fsdp2(self): tolerance=0.05, ) - # triton autotune takes too long. only test with compile_layer=False - # and apply INT8 matmul to forward pass only. + # triton autotune takes too long. apply INT8 matmul to forward pass only. self.run_subtests( dict(), self._test_fsdp2, From 3540e79e970f30646ad6676597548c6b4027c9ce Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 13:35:16 +0800 Subject: [PATCH 13/45] update --- .../quantized_training/benchmark_int8mm.py | 45 +++++++++++++++++++ .../quantized_training/pretrain_llama2.py | 6 +-- .../prototype/quantized_training/README.md | 33 ++++++++------ .../int8_mixed_precision.py | 29 +++++++----- .../prototype/quantized_training/int8_mm.py | 1 + 5 files changed, 87 insertions(+), 27 deletions(-) create mode 100644 benchmarks/quantized_training/benchmark_int8mm.py diff --git a/benchmarks/quantized_training/benchmark_int8mm.py b/benchmarks/quantized_training/benchmark_int8mm.py new file mode 100644 index 0000000000..85892afa85 --- /dev/null +++ b/benchmarks/quantized_training/benchmark_int8mm.py @@ -0,0 +1,45 @@ +import pandas as pd +import torch +from triton.testing import do_bench + +from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant + + +def bench_f(f, *args): + return do_bench(lambda: f(*args), fast_flush=False, return_mode="median") + + +shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]] + +# Llama-8B shapes +shapes += [ + # linear in attention + (32_768, 4096, 4096), + (4096, 4096, 32_768), + # linear in feed-forward + (32_768, 14_336, 4096), + (32_768, 4096, 14_336), + (14_336, 4096, 32_768), +] + +data = [] +for M, N, K in shapes: + print(f"{M=}, {N=}, {K=}") + + A_bf16 = torch.randn(M, K).bfloat16().cuda() + B_bf16 = torch.randn(N, K).bfloat16().cuda() + A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda() + B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda() + A_scale = torch.randn(M).bfloat16().cuda() + B_scale = torch.randn(N).bfloat16().cuda() + + # benchmark F.linear() i.e. A @ B.T + bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T) + i8_time = bench_f(torch._int_mm, A_i8, B_i8.T) + i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale) + + sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time] + data.append(sample) + +df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"]) +print(df.to_markdown()) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 43bff73761..f9f5a6eb45 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -89,7 +89,6 @@ def get_tinystories(): parser.add_argument("--head_dim", type=int, default=64) parser.add_argument("--quantize") - parser.add_argument("--quantize_lm_head", action="store_true") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") @@ -123,12 +122,11 @@ def get_tinystories(): for layer in model.layers: enable_activation_checkpointing(layer) - module_to_quantize = model if args.quantize_lm_head else model.layers if args.quantize == "int8_weight_only": - quantize_(module_to_quantize, int8_weight_only_quantized_training(), set_inductor_config=False) + quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize == "int8_mixed_precision": cfg = Int8MixedPrecisionConfig(True, True, True) - quantize_(module_to_quantize, int8_mixed_precision_training(cfg), set_inductor_config=False) + quantize_(model, int8_mixed_precision_training(cfg), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index ad63b87aee..5170ee2b73 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -37,14 +37,14 @@ Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` ar See [#644](https://github.com/pytorch/ao/pull/644) for some early results. -TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Memory benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing. +TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing, 4070Ti SUPER. -Model | Peak memory (GB) -----------------|----------------- -BF16 eager | 11.06847 -BF16 compile | 10.16915 -INT8 QT eager | 10.11437 -INT8 QT compile | 10.03365 +Model | Peak memory (GB) | toks/s +----------------|------------------|------- +BF16 eager | 11.07 | 6200 +BF16 compile | 10.25 | 9000 +INT8 QT eager | 10.12 | 5600 +INT8 QT compile | 9.84 | 8700 ## INT8 mixed-precision @@ -61,9 +61,9 @@ from torchao.quantization import quantize_ model = ... config = Int8MixedPrecisionConfig( - forward=True, - backward_grad_input=True, - backward_grad_weight=True, + output=True, + grad_input=True, + grad_weight=True, ) quantize_(model, int8_mixed_precision_training(config)) @@ -76,13 +76,20 @@ During training, there are 3 matmuls involved in each `nn.Linear` layer: - `grad_input = grad_output @ weight` - `grad_weight = grad_output.T @ input` -You can configure which matmul to be applied with INT8 mixed-precision using `Int8MixedPrecisionConfig` shown above. If convergence is an issue, we recommend leaving `backward_grad_weight` in original matmul precision, and also `backward_grad_input` if the issue still persists. +You can configure which matmul to be applied with INT8 mixed-precision using `Int8MixedPrecisionConfig` shown above. If convergence is an issue, we recommend leaving `grad_weight` in original matmul precision, and also `grad_input` if the issue still persists. Note: - When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT. -- When we only apply INT8 mixed-precision to `forward` and `backward_grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. +- When we only apply INT8 mixed-precision to `output` and `grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. -TODO: add some benchmarks +Pre-train Llama2-1B on C4 realnewslike subset. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). Using 4090. INT8 mixed precision is not applied to LM head. + +Config | Tok/s | Peak mem (GB) | Val loss +---------------------|-------|---------------|--------- +BF16 (baseline) | ~17k | 19.47 | 2.97 +INT8 mixed-precision | ~29k | 19.47 | 2.90 + +See [#748](https://github.com/pytorch/ao/pull/748) for more results. ## Future ideas diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 12ad8e4e2d..5b74e9597c 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -15,7 +15,9 @@ else: def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: - return (A * A_scale_rowwise.view(-1, 1)) @ (B * B_scale_colwise.view(1, -1)) + A_scaled = A * A_scale_rowwise.view(-1, 1) + B_scaled = B * B_scale_colwise.view(1, -1) + return A_scaled @ B_scaled aten = torch.ops.aten @@ -24,9 +26,9 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi class Int8MixedPrecisionConfig(NamedTuple): - forward: bool = False - backward_grad_input: bool = False - backward_grad_weight: bool = False + output: bool = False + grad_input: bool = False + grad_weight: bool = False class Int8MixedPrecisionLinearWeight(Tensor): @@ -121,10 +123,17 @@ def _(func, types, args, kwargs): return args[0] -# called by optimizers. return a normal tensor -@implements(aten.zeros_like.default) +# return normal tensor +@implements( + [ + aten.zeros_like.default, # called by optimizers + aten.add.Tensor, + aten.mul.Tensor, + ] +) def _(func, types, args, kwargs): - return func(args[0]._data, *args[1:], **kwargs) + unpacked_args = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args] + return func(*unpacked_args, **kwargs) # FSDP op @@ -142,7 +151,7 @@ def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Op ctx.save_for_backward(input, weight) ctx.bias = bias is not None - if ctx.config.forward: + if ctx.config.output: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) input_i8, input_scale = quantize_int8_rowwise(input) @@ -165,7 +174,7 @@ def backward(ctx, grad_output): input = input.view(-1, weight.shape[1]) if ctx.needs_input_grad[0]: - if ctx.config.backward_grad_input: + if ctx.config.grad_input: grad_output_i8, grad_output_scale = quantize_int8_rowwise(grad_output) weight_i8_t, weight_scale = quantize_int8_rowwise(weight.T) grad_input = int8_mm_dequant(grad_output_i8, weight_i8_t.T, grad_output_scale, weight_scale) @@ -174,7 +183,7 @@ def backward(ctx, grad_output): grad_input = grad_input.view(*batch_dims, weight.shape[1]) if ctx.needs_input_grad[1]: - if ctx.config.backward_grad_weight: + if ctx.config.grad_weight: grad_output_i8_t, grad_output_scale = quantize_int8_rowwise(grad_output.T) input_i8_t, input_scale = quantize_int8_rowwise(input.T) grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 1f8984e66a..927c2d48b2 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -8,6 +8,7 @@ lib = torch.library.Library("torchao", "FRAGMENT") +# TODO: prune configs to speedup triton autotune # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) configs = [ From f9d4e2a9fc3eeab0b2ccc4a22970fc5050b48a7b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 27 Aug 2024 14:07:21 +0800 Subject: [PATCH 14/45] make claim more conservative --- torchao/prototype/quantized_training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 5170ee2b73..b537354f70 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -48,7 +48,7 @@ INT8 QT compile | 9.84 | 8700 ## INT8 mixed-precision -On NVIDIA GPUs, INT8 Tensor Cores can be up to 3x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired by prior works: +On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired by prior works: - AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] - SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] From 64f707a8d114b591537100336e1ebe7ab7120348 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 28 Aug 2024 12:22:20 +0800 Subject: [PATCH 15/45] register fused adam --- .../prototype/quantized_training/int8_mixed_precision.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 5b74e9597c..a4eaad8d84 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -123,6 +123,12 @@ def _(func, types, args, kwargs): return args[0] +@implements([aten._fused_adam_.default, aten._fused_adamw_.default]) +def _(func, types, args, kwargs): + params = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args[0]] + func(params, *args[1:], **kwargs) + + # return normal tensor @implements( [ From b3770d308046d90b2369ac1a2e3e4359d8d6b6cc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 3 Sep 2024 09:23:37 +0000 Subject: [PATCH 16/45] update benchmark script --- benchmarks/quantized_training/pretrain_llama2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index f9f5a6eb45..b5f11e1155 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -122,11 +122,13 @@ def get_tinystories(): for layer in model.layers: enable_activation_checkpointing(layer) + # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. + # TODO: might want to do the same for int8_weight_only to standardize. if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize == "int8_mixed_precision": cfg = Int8MixedPrecisionConfig(True, True, True) - quantize_(model, int8_mixed_precision_training(cfg), set_inductor_config=False) + quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") From dd33823ef749f2ff28b4d0745a8554ca12b473c4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 03:01:27 +0000 Subject: [PATCH 17/45] add more ops --- torchao/prototype/quantized_training/int8_mixed_precision.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index a4eaad8d84..da47facf5f 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -93,6 +93,7 @@ def _(func, types, args, kwargs): aten.detach.default, aten.clone.default, aten._to_copy.default, + aten.empty_like.default, # FSDP ops aten.slice.Tensor, aten.new_zeros.default, @@ -115,6 +116,10 @@ def _(func, types, args, kwargs): aten.addcdiv_.default, aten.add_.Tensor, aten.mul_.Tensor, + # param init functions + aten.uniform_.default, + aten.erfinv_.default, + aten.clamp_.default, ] ) def _(func, types, args, kwargs): From b96769a7974d24faff134c09142978131af744bc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 03:03:24 +0000 Subject: [PATCH 18/45] update default --- torchao/prototype/quantized_training/README.md | 1 + .../prototype/quantized_training/int8_mixed_precision.py | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index b537354f70..396a96e337 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -60,6 +60,7 @@ from torchao.prototype.quantized_training import int8_mixed_precision_training, from torchao.quantization import quantize_ model = ... +# by default, apply INT8 matmul to all 3 matmuls config = Int8MixedPrecisionConfig( output=True, grad_input=True, diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index da47facf5f..45cb4df27f 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -26,9 +26,9 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi class Int8MixedPrecisionConfig(NamedTuple): - output: bool = False - grad_input: bool = False - grad_weight: bool = False + output: bool = True + grad_input: bool = True + grad_weight: bool = True class Int8MixedPrecisionLinearWeight(Tensor): @@ -207,8 +207,6 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias -# NOTE: should default config set all to True instead? -> speedup out-of-the-box. -# only if there are convergence issues, turn off some INT8 matmuls in backward. def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = Int8MixedPrecisionConfig()): # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` # when we have this out of prototype (or there are stable trainable tensor subclasses), From 2b16ebbf6317a3dabe41aaac42cffb14d113786e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 18:50:08 +0800 Subject: [PATCH 19/45] use TorchAOBaseTensor --- .../prototype/quantized_training/int8_mixed_precision.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 45cb4df27f..fb02937ab6 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -5,7 +5,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torch.utils._triton import has_triton -from torchao.dtypes.utils import _dispatch__torch_dispatch__, _dispatch__torch_function__, _implements +from torchao.utils import TorchAOBaseTensor from .int8 import quantize_int8_rowwise @@ -31,11 +31,7 @@ class Int8MixedPrecisionConfig(NamedTuple): grad_weight: bool = True -class Int8MixedPrecisionLinearWeight(Tensor): - implements = classmethod(_implements) - __torch_function__ = classmethod(_dispatch__torch_function__) - __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) - +class Int8MixedPrecisionLinearWeight(TorchAOBaseTensor): @staticmethod @torch._dynamo.disable def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): From 117cc60036bba5cdab2fa33ac1cf12be5a7162a0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:04:27 +0800 Subject: [PATCH 20/45] fix fsdp param_dtype --- torchao/prototype/quantized_training/int8_mixed_precision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index fb02937ab6..3ba3802e81 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -61,6 +61,8 @@ def to_original(self): return self._data.clone() def fsdp_pre_all_gather(self, mesh): + # TODO: pre-quantize weight here -> reduce comm bandwidth. + # we will need another tensor subclass to hold the quantized weight. return (self._data,), (self.config,) def fsdp_post_all_gather( @@ -73,7 +75,7 @@ def fsdp_post_all_gather( ): (data,) = all_gather_outputs (config,) = metadata - return Int8MixedPrecisionLinearWeight(data, config), all_gather_outputs + return Int8MixedPrecisionLinearWeight(data.to(param_dtype), config), all_gather_outputs implements = Int8MixedPrecisionLinearWeight.implements From ae370585a9f6bbe14e62c7bb2d3bf8f60258bf70 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:05:35 +0800 Subject: [PATCH 21/45] fix param_dtype --- torchao/prototype/quantized_training/int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index a50813a166..3fe6e0d34c 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -110,7 +110,7 @@ def fsdp_post_all_gather( out: Optional[Tensor] = None, ): int_data, scale = all_gather_outputs - return Int8QTLinearWeight(int_data, scale), all_gather_outputs + return Int8QTLinearWeight(int_data, scale.to(param_dtype)), all_gather_outputs class _Int8WeightOnlyLinear(torch.autograd.Function): From ae4eb213e20d12d91502ec5188efa82fa9523a88 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:09:27 +0800 Subject: [PATCH 22/45] dtype check to prevent unnecessary errors --- torchao/prototype/quantized_training/int8_mm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 927c2d48b2..6c96deaf8a 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -130,6 +130,7 @@ def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): @torch.library.impl(lib, "int8_mm_dequant", "CUDA") def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): assert A.dtype is torch.int8 and B.dtype is torch.int8 + assert A_scale_rowwise.dtype is B_scale_colwise.dtype assert A.shape[1] == B.shape[0] M, K = A.shape _, N = B.shape From 730c90cc5e0a34543c7b8663d73c783369267ee1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:46:20 +0800 Subject: [PATCH 23/45] move checks --- torchao/prototype/quantized_training/int8_mm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 6c96deaf8a..3412d6b40b 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -119,6 +119,11 @@ def _int8_mm_dequant_kernel( def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: + assert A.dtype is torch.int8 and B.dtype is torch.int8 + assert A_scale_rowwise.dtype is B_scale_colwise.dtype + assert A.shape[1] == B.shape[0] + assert A_scale_rowwise.squeeze().shape == A.shape[0] + assert B_scale_colwise.squeeze().shape == B.shape[0] return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) @@ -129,13 +134,8 @@ def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): @torch.library.impl(lib, "int8_mm_dequant", "CUDA") def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): - assert A.dtype is torch.int8 and B.dtype is torch.int8 - assert A_scale_rowwise.dtype is B_scale_colwise.dtype - assert A.shape[1] == B.shape[0] M, K = A.shape _, N = B.shape - assert A_scale_rowwise.squeeze().shape == (M,) - assert B_scale_colwise.squeeze().shape == (N,) C = torch.empty(M, N, device=A.device, dtype=A_scale_rowwise.dtype) grid = lambda meta: (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) _int8_mm_dequant_kernel[grid]( From c470a24bd79b8f65cab010f3f6e3377dc3b5db3d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:54:28 +0800 Subject: [PATCH 24/45] add note --- torchao/prototype/quantized_training/int8_mixed_precision.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 3ba3802e81..d64fddeba4 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -193,9 +193,12 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[1]: if ctx.config.grad_weight: + # TODO: check if transpose+quantize are fused grad_output_i8_t, grad_output_scale = quantize_int8_rowwise(grad_output.T) input_i8_t, input_scale = quantize_int8_rowwise(input.T) - grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) + # grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) + # this is slightly faster + grad_weight = int8_mm_dequant(input_i8_t, grad_output_i8_t.T, input_scale, grad_output_scale).T else: grad_weight = grad_output.T @ input From 7c1d760ebc1b9a3a29bc11f18adf361c13342b14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 20:58:49 +0800 Subject: [PATCH 25/45] fix --- torchao/prototype/quantized_training/int8_mm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 3412d6b40b..e22bb85732 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -122,8 +122,8 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi assert A.dtype is torch.int8 and B.dtype is torch.int8 assert A_scale_rowwise.dtype is B_scale_colwise.dtype assert A.shape[1] == B.shape[0] - assert A_scale_rowwise.squeeze().shape == A.shape[0] - assert B_scale_colwise.squeeze().shape == B.shape[0] + assert A_scale_rowwise.squeeze().shape == (A.shape[0],) + assert B_scale_colwise.squeeze().shape == (B.shape[1],) return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) From 0e15e2dbd9cd782c70641a8827b5a744a9c23e4c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Sep 2024 22:26:22 +0800 Subject: [PATCH 26/45] simplify script --- .../quantized_training/pretrain_llama2.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index b5f11e1155..8b7b1ab2d4 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -3,6 +3,7 @@ # # BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile # INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only +# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision import os @@ -19,7 +20,7 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from torchao._models.llama.model import ModelArgs, Transformer +from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( Int8MixedPrecisionConfig, @@ -29,6 +30,15 @@ from torchao.quantization.quant_api import quantize_ +# not official models +transformer_configs.update( + ( + ("470M", dict(n_layer=24, n_head=16, dim=1024, intermediate_size=4096)), + ("1B", dict(n_layer=24, n_head=24, dim=1536, intermediate_size=6144)), + ) +) + + # hack from fairseq # https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py def enable_activation_checkpointing(m: torch.nn.Module): @@ -82,12 +92,7 @@ def get_tinystories(): if __name__ == "__main__": parser = argparse.ArgumentParser() - # default config is 470M - parser.add_argument("--d_model", type=int, default=1024) - parser.add_argument("--depth", type=int, default=24) - parser.add_argument("--ffn_size", type=int, default=4096) - parser.add_argument("--head_dim", type=int, default=64) - + parser.add_argument("--model", default="470M", choices=transformer_configs.keys()) parser.add_argument("--quantize") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") @@ -108,13 +113,8 @@ def get_tinystories(): if args.seed is not None: torch.manual_seed(args.seed) - config = ModelArgs( - block_size=args.seq_len, - n_layer=args.depth, - n_head=args.d_model // args.head_dim, - dim=args.d_model, - intermediate_size=args.ffn_size, - ) + config = ModelArgs.from_name(args.model) + config.block_size = args.seq_len model = Transformer(config).bfloat16().cuda() with torch.device("cuda"): model.setup_caches(args.batch_size, args.seq_len, training=True) @@ -164,6 +164,7 @@ def get_tinystories(): loss=loss.item(), lr=optim.param_groups[0]["lr"], max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, + max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9, ) if step > 0: time1 = time.time() From 208188ca382ebd29e7121fa4ca07d5bd19d457bd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 5 Sep 2024 14:22:24 +0800 Subject: [PATCH 27/45] add module-based UX --- .../prototype/quantized_training/__init__.py | 1 + .../int8_mixed_precision.py | 40 +++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index 99e5d3a876..4c0fe90c98 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -5,6 +5,7 @@ ) from .int8_mixed_precision import ( Int8MixedPrecisionConfig, + Int8MixedPrecisionLinear, Int8MixedPrecisionLinearWeight, int8_mixed_precision_training, ) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index d64fddeba4..fcd480bb13 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -31,6 +31,9 @@ class Int8MixedPrecisionConfig(NamedTuple): grad_weight: bool = True +_DEFAULT_CONFIG = Int8MixedPrecisionConfig() + + class Int8MixedPrecisionLinearWeight(TorchAOBaseTensor): @staticmethod @torch._dynamo.disable @@ -83,7 +86,9 @@ def fsdp_post_all_gather( @implements(torch.nn.functional.linear) def _(func, types, args, kwargs): - return _Int8MixedPrecisionLinear.apply(*args, **kwargs) + act, weight = args[:2] + bias = args[2] if len(args) > 2 else None + return _Int8MixedPrecisionLinear.apply(act, weight._data, bias, weight.config) @implements( @@ -152,11 +157,32 @@ def _(func, types, args, kwargs): return [Int8MixedPrecisionLinearWeight(x, args[0].config) for x in data_list] +# alternative UX +class Int8MixedPrecisionLinear(nn.Linear): + def __init__(self, *args, config: Int8MixedPrecisionConfig, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.config = config + + def forward(self, input: Tensor) -> Tensor: + return _Int8MixedPrecisionLinear.apply(input, self.weight, self.bias, self.config) + + def extra_repr(self): + return f"{super().extra_repr()}, config={self.config}" + + @classmethod + def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionConfig = _DEFAULT_CONFIG): + if module.__class__ is nn.Linear: # exact match, don't swap nn.Linear subclasses + module.__class__ = cls + module.config = config + return + for child in module.children(): + cls.convert_linear(child) + + class _Int8MixedPrecisionLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Optional[Tensor] = None): - ctx.config = weight.config - weight = weight._data + def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionConfig): + ctx.config = config ctx.save_for_backward(input, weight) ctx.bias = bias is not None @@ -176,7 +202,7 @@ def forward(ctx, input: Tensor, weight: Int8MixedPrecisionLinearWeight, bias: Op @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None + grad_input = grad_weight = grad_bias = grad_config = None batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight.shape[0]) @@ -205,10 +231,10 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[2] and ctx.bias: grad_bias = grad_output.sum(0) - return grad_input, grad_weight, grad_bias + return grad_input, grad_weight, grad_bias, grad_config -def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = Int8MixedPrecisionConfig()): +def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = _DEFAULT_CONFIG): # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` # when we have this out of prototype (or there are stable trainable tensor subclasses), # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. From 77aafdba3bbe6586c4776501d4b1c7bfa0f74a5d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 5 Sep 2024 15:35:06 +0800 Subject: [PATCH 28/45] fix --- .../int8_mixed_precision.py | 23 +++++++++---------- .../prototype/quantized_training/int8_mm.py | 5 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index fcd480bb13..c4d3ed81a9 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -176,7 +176,13 @@ def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionConfig = _D module.config = config return for child in module.children(): - cls.convert_linear(child) + cls.convert_linear(child, config) + + +def _dynamic_int8_linear(input: Tensor, weight: Tensor) -> Tensor: + input_i8, input_scale = quantize_int8_rowwise(input) + weight_i8, weight_scale = quantize_int8_rowwise(weight) + return int8_mm_dequant(input_i8, weight_i8.T, input_scale, weight_scale) class _Int8MixedPrecisionLinear(torch.autograd.Function): @@ -189,9 +195,7 @@ def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: if ctx.config.output: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) - input_i8, input_scale = quantize_int8_rowwise(input) - weight_i8, weight_scale = quantize_int8_rowwise(weight) - out = int8_mm_dequant(input_i8, weight_i8.T, input_scale, weight_scale) + out = _dynamic_int8_linear(input, weight) out = out.view(*batch_dims, weight.shape[0]) else: out = input @ weight.T @@ -210,9 +214,7 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: if ctx.config.grad_input: - grad_output_i8, grad_output_scale = quantize_int8_rowwise(grad_output) - weight_i8_t, weight_scale = quantize_int8_rowwise(weight.T) - grad_input = int8_mm_dequant(grad_output_i8, weight_i8_t.T, grad_output_scale, weight_scale) + grad_input = _dynamic_int8_linear(grad_output, weight.T) else: grad_input = grad_output @ weight grad_input = grad_input.view(*batch_dims, weight.shape[1]) @@ -220,11 +222,8 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[1]: if ctx.config.grad_weight: # TODO: check if transpose+quantize are fused - grad_output_i8_t, grad_output_scale = quantize_int8_rowwise(grad_output.T) - input_i8_t, input_scale = quantize_int8_rowwise(input.T) - # grad_weight = int8_mm_dequant(grad_output_i8_t, input_i8_t.T, grad_output_scale, input_scale) - # this is slightly faster - grad_weight = int8_mm_dequant(input_i8_t, grad_output_i8_t.T, input_scale, grad_output_scale).T + # grad_weight = _dynamic_int8_linear(grad_output.T, input.T) + grad_weight = _dynamic_int8_linear(input.T, grad_output.T).T # this is slightly faster else: grad_weight = grad_output.T @ input diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index e22bb85732..8a345edcc4 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -52,7 +52,6 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) @triton.jit def _int8_mm_dequant_kernel( - # fmt: off A_ptr, B_ptr, C_ptr, A_scale_rowwise_ptr, B_scale_colwise_ptr, @@ -65,7 +64,6 @@ def _int8_mm_dequant_kernel( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr = 8, EVEN_K: tl.constexpr = True, - # fmt: on ): # based on triton.ops.matmul pid = tl.program_id(0) @@ -124,6 +122,9 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi assert A.shape[1] == B.shape[0] assert A_scale_rowwise.squeeze().shape == (A.shape[0],) assert B_scale_colwise.squeeze().shape == (B.shape[1],) + # TODO: handle this inside triton kernel + A_scale_rowwise = A_scale_rowwise.contiguous() + B_scale_colwise = B_scale_colwise.contiguous() return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) From d367f772c99cc7915669f8553809fab55cb98b8b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 08:58:57 +0800 Subject: [PATCH 29/45] use FP8 impl of __torch_dispatch__ --- .../int8_mixed_precision.py | 140 +++++++----------- 1 file changed, 55 insertions(+), 85 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index c4d3ed81a9..b81f29cf24 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -1,12 +1,10 @@ from typing import Any, NamedTuple, Optional, Tuple import torch +import torch.utils._pytree as pytree from torch import Tensor, nn -from torch.utils._python_dispatch import return_and_correct_aliasing from torch.utils._triton import has_triton -from torchao.utils import TorchAOBaseTensor - from .int8 import quantize_int8_rowwise if has_triton(): @@ -20,11 +18,6 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi return A_scaled @ B_scaled -aten = torch.ops.aten -c10d_functional = torch.ops.c10d_functional -_c10d_functional = torch.ops._c10d_functional - - class Int8MixedPrecisionConfig(NamedTuple): output: bool = True grad_input: bool = True @@ -34,7 +27,24 @@ class Int8MixedPrecisionConfig(NamedTuple): _DEFAULT_CONFIG = Int8MixedPrecisionConfig() -class Int8MixedPrecisionLinearWeight(TorchAOBaseTensor): +# adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor +aten = torch.ops.aten +_ops_to_preserve_subclass = { + aten.detach.default, + aten.empty_like.default, + aten.new_zeros.default, + aten.slice.Tensor, + aten.copy_.default, + aten.view.default, + aten.as_strided.default, + aten._to_copy.default, + aten._pin_memory.default, + aten.split.Tensor, + aten.clone.default, +} + + +class Int8MixedPrecisionLinearWeight(Tensor): @staticmethod @torch._dynamo.disable def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): @@ -43,6 +53,9 @@ def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): data.shape, dtype=data.dtype, device=data.device, + strides=data.stride(), + storage_offset=data.storage_offset(), + pin_memory=data.is_pinned(), ) @torch._dynamo.disable @@ -63,6 +76,38 @@ def __repr__(self): def to_original(self): return self._data.clone() + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = dict() + + if func is torch.nn.functional.linear: + act, weight = args[:2] + bias = args[2] if len(args) > 2 else None + return _Int8MixedPrecisionLinear.apply(act, weight._data, bias, weight.config) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + config = None + + def unwrap(x: cls): + nonlocal config + if config is None: + config = x.config + else: + assert x.config == config + return x._data + + args, kwargs = pytree.tree_map_only(cls, unwrap, (args, kwargs)) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + else: + return pytree.tree_map_only(Tensor, lambda x: cls(x, config), out) + def fsdp_pre_all_gather(self, mesh): # TODO: pre-quantize weight here -> reduce comm bandwidth. # we will need another tensor subclass to hold the quantized weight. @@ -81,82 +126,6 @@ def fsdp_post_all_gather( return Int8MixedPrecisionLinearWeight(data.to(param_dtype), config), all_gather_outputs -implements = Int8MixedPrecisionLinearWeight.implements - - -@implements(torch.nn.functional.linear) -def _(func, types, args, kwargs): - act, weight = args[:2] - bias = args[2] if len(args) > 2 else None - return _Int8MixedPrecisionLinear.apply(act, weight._data, bias, weight.config) - - -@implements( - [ - aten.detach.default, - aten.clone.default, - aten._to_copy.default, - aten.empty_like.default, - # FSDP ops - aten.slice.Tensor, - aten.new_zeros.default, - aten.view.default, - aten.as_strided.default, - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, - ] -) -def _(func, types, args, kwargs): - out = Int8MixedPrecisionLinearWeight(func(args[0]._data, *args[1:], **kwargs), args[0].config) - return return_and_correct_aliasing(func, args, kwargs, out) - - -@implements( - [ - aten.copy_.default, - aten.addcdiv_.default, - aten.add_.Tensor, - aten.mul_.Tensor, - # param init functions - aten.uniform_.default, - aten.erfinv_.default, - aten.clamp_.default, - ] -) -def _(func, types, args, kwargs): - unpacked_args = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args] - func(*unpacked_args, **kwargs) - return args[0] - - -@implements([aten._fused_adam_.default, aten._fused_adamw_.default]) -def _(func, types, args, kwargs): - params = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args[0]] - func(params, *args[1:], **kwargs) - - -# return normal tensor -@implements( - [ - aten.zeros_like.default, # called by optimizers - aten.add.Tensor, - aten.mul.Tensor, - ] -) -def _(func, types, args, kwargs): - unpacked_args = [x._data if isinstance(x, Int8MixedPrecisionLinearWeight) else x for x in args] - return func(*unpacked_args, **kwargs) - - -# FSDP op -@implements(aten.split.Tensor) -def _(func, types, args, kwargs): - data_list = func(args[0]._data, *args[1:], **kwargs) - return [Int8MixedPrecisionLinearWeight(x, args[0].config) for x in data_list] - - # alternative UX class Int8MixedPrecisionLinear(nn.Linear): def __init__(self, *args, config: Int8MixedPrecisionConfig, **kwargs) -> None: @@ -180,6 +149,7 @@ def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionConfig = _D def _dynamic_int8_linear(input: Tensor, weight: Tensor) -> Tensor: + # TODO: check if we need to enforce .contiguous() for input_i8 and weight_i8 input_i8, input_scale = quantize_int8_rowwise(input) weight_i8, weight_scale = quantize_int8_rowwise(weight) return int8_mm_dequant(input_i8, weight_i8.T, input_scale, weight_scale) From d24a894b3758bcb0b9f3da928e555e5e7628cfdf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 09:18:28 +0800 Subject: [PATCH 30/45] rename _dynamice interface --- .../int8_mixed_precision.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index b81f29cf24..4cdda43baf 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -80,7 +80,7 @@ def to_original(self): def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = dict() - + if func is torch.nn.functional.linear: act, weight = args[:2] bias = args[2] if len(args) > 2 else None @@ -148,11 +148,24 @@ def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionConfig = _D cls.convert_linear(child, config) -def _dynamic_int8_linear(input: Tensor, weight: Tensor) -> Tensor: - # TODO: check if we need to enforce .contiguous() for input_i8 and weight_i8 - input_i8, input_scale = quantize_int8_rowwise(input) - weight_i8, weight_scale = quantize_int8_rowwise(weight) - return int8_mm_dequant(input_i8, weight_i8.T, input_scale, weight_scale) +def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: + # INT8 matmul is the most performant when A is row-major and B is column-major. + # thus, we transpose B before quantization. + + # A and B (before quantization) might not be contiguous, especially in backward. + # it's also not guaranteed that A_i8 and B_t_i8 (after quantization) are contiguous, + # thus we have to call .contiguous() on them. + # hope that the .contiguous() calls will be fused into quantize op by torch.compile() + # TODO: investigate if calling .contiguous() before quantization is better. + # TODO: check if transpose+quantize are fused. + A_i8, A_scale_rowwise = quantize_int8_rowwise(A) + B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) + return int8_mm_dequant( + A_i8.contiguous(), + B_t_i8.contiguous().T, + A_scale_rowwise, + B_scale_colwise, + ) class _Int8MixedPrecisionLinear(torch.autograd.Function): @@ -165,7 +178,7 @@ def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: if ctx.config.output: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) - out = _dynamic_int8_linear(input, weight) + out = _dynamic_int8_mm(input, weight.T) out = out.view(*batch_dims, weight.shape[0]) else: out = input @ weight.T @@ -184,16 +197,15 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: if ctx.config.grad_input: - grad_input = _dynamic_int8_linear(grad_output, weight.T) + grad_input = _dynamic_int8_mm(grad_output, weight) else: grad_input = grad_output @ weight grad_input = grad_input.view(*batch_dims, weight.shape[1]) if ctx.needs_input_grad[1]: if ctx.config.grad_weight: - # TODO: check if transpose+quantize are fused - # grad_weight = _dynamic_int8_linear(grad_output.T, input.T) - grad_weight = _dynamic_int8_linear(input.T, grad_output.T).T # this is slightly faster + # grad_weight = _dynamic_int8_mm(grad_output.T, input) + grad_weight = _dynamic_int8_mm(input.T, grad_output).T # this is slightly faster else: grad_weight = grad_output.T @ input From fb09b248869cc45a9438aa4be42e368b09cc68ce Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 09:26:12 +0800 Subject: [PATCH 31/45] update test --- test/prototype/test_quantized_training.py | 12 +++++------- .../quantized_training/int8_mixed_precision.py | 4 ++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 4a1c79a615..da5a45ef24 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -13,7 +13,6 @@ int8_weight_only_quantized_training, int8_mixed_precision_training, quantize_int8_rowwise, - Int8MixedPrecisionConfig, ) from torchao.quantization.quant_api import quantize_ from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -151,7 +150,6 @@ def test_int8_mixed_precision_training(self, compile): bsize = 4 embed_dim = 32 device = "cuda" - config = Int8MixedPrecisionConfig(True, True, True) # only use 1 matmul shape to reduce triton autotune time model_ref = nn.Sequential( @@ -160,7 +158,7 @@ def test_int8_mixed_precision_training(self, compile): nn.Linear(embed_dim, embed_dim), ).to(device) model_int8mp = copy.deepcopy(model_ref) - quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + quantize_(model_int8mp, int8_mixed_precision_training(), set_inductor_config=False) if compile: model_ref.compile() @@ -202,11 +200,10 @@ def test_fsdp2(self): tolerance=0.05, ) - # triton autotune takes too long. apply INT8 matmul to forward pass only. self.run_subtests( dict(), self._test_fsdp2, - quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, False, False)), + quantize_fn=int8_mixed_precision_training(), tolerance=1e-6, ) @@ -219,6 +216,8 @@ def _test_fsdp2(self, quantize_fn, tolerance): batch_size = 3 vocab_size = 32 seq_len = 64 + + # NOTE: if weight_tying=True and we also quantize LM head, INT8 mixed-precision will fail. model_args = ModelArgs( n_layers=2, n_heads=2, @@ -226,11 +225,10 @@ def _test_fsdp2(self, quantize_fn, tolerance): vocab_size=vocab_size, max_seq_len=seq_len, dropout_p=0, - weight_tying=False, # INT8 mixed-precision will fail if weight_tying=True ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() - quantize_(base_model, quantize_fn, set_inductor_config=False) + quantize_(base_model.layers, quantize_fn, set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) for layer in fsdp_model.layers: diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 4cdda43baf..be76888a0a 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -123,6 +123,10 @@ def fsdp_post_all_gather( ): (data,) = all_gather_outputs (config,) = metadata + if out is not None: + assert isinstance(out, Int8MixedPrecisionLinearWeight) + assert out.config == config + return return Int8MixedPrecisionLinearWeight(data.to(param_dtype), config), all_gather_outputs From 3372644fd71bfa58888f9357085e8b30dc2fb0f4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 10:08:11 +0800 Subject: [PATCH 32/45] fix compile on 2.4 --- torchao/prototype/quantized_training/int8_mixed_precision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index be76888a0a..3892d01151 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -55,7 +55,7 @@ def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): device=data.device, strides=data.stride(), storage_offset=data.storage_offset(), - pin_memory=data.is_pinned(), + # pin_memory=data.is_pinned(), # this will fail compile on 2.4 ) @torch._dynamo.disable @@ -82,7 +82,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = dict() if func is torch.nn.functional.linear: - act, weight = args[:2] + act = args[0] + weight: cls = args[1] bias = args[2] if len(args) > 2 else None return _Int8MixedPrecisionLinear.apply(act, weight._data, bias, weight.config) From 9e05b5c2959de22d1003f3ae17c4bcf61cd78fe7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 11:26:14 +0800 Subject: [PATCH 33/45] log torch version --- benchmarks/quantized_training/pretrain_llama2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 8b7b1ab2d4..9ec6b14865 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -142,6 +142,7 @@ def get_tinystories(): optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) data = get_tinystories().cuda() + args.torch_version = torch.__version__ run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) step = 0 From 6e4e6845b69a97b3b13934bb407f5fc9369936ae Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 11:32:41 +0800 Subject: [PATCH 34/45] make log interval customizable --- benchmarks/quantized_training/pretrain_llama2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 9ec6b14865..4496ec8b42 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -108,6 +108,7 @@ def get_tinystories(): parser.add_argument("--project", default="int8_quantized_training") parser.add_argument("--run_name") parser.add_argument("--seed", type=int) + parser.add_argument("--log_interval", type=int, default=10) args = parser.parse_args() if args.seed is not None: @@ -146,7 +147,6 @@ def get_tinystories(): run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) step = 0 - log_interval = 50 pbar = tqdm(total=args.n_steps, dynamic_ncols=True) model.train() _get_loss = torch.compile(get_loss) if args.compile else get_loss @@ -160,7 +160,7 @@ def get_tinystories(): loss = _get_loss(model, batch) loss.backward() - if step % log_interval == 0: + if step % args.log_interval == 0: log_dict = dict( loss=loss.item(), lr=optim.param_groups[0]["lr"], @@ -169,7 +169,7 @@ def get_tinystories(): ) if step > 0: time1 = time.time() - log_dict["tokens_per_second"] = (log_interval * args.batch_size * args.seq_len) / (time1 - time0) + log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0) time0 = time1 run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) From b395858cb385b394be80594109bb7b5abfbe9591 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 11:46:23 +0800 Subject: [PATCH 35/45] make naming for explicit --- .../quantized_training/pretrain_llama2.py | 4 +- .../prototype/quantized_training/README.md | 16 +++--- .../prototype/quantized_training/__init__.py | 6 +-- torchao/prototype/quantized_training/int8.py | 49 ++++++++++--------- .../int8_mixed_precision.py | 26 +++++----- 5 files changed, 53 insertions(+), 48 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 4496ec8b42..fc87c2cd6e 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -23,7 +23,6 @@ from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( - Int8MixedPrecisionConfig, int8_mixed_precision_training, int8_weight_only_quantized_training, ) @@ -128,8 +127,7 @@ def get_tinystories(): if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize == "int8_mixed_precision": - cfg = Int8MixedPrecisionConfig(True, True, True) - quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False) + quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 396a96e337..456dbf30e1 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -56,15 +56,19 @@ On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP1 Usage ```python -from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionConfig +from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig from torchao.quantization import quantize_ model = ... -# by default, apply INT8 matmul to all 3 matmuls -config = Int8MixedPrecisionConfig( + +# apply INT8 matmul to all 3 matmuls +quantize_(model, int8_mixed_precision_training()) + +# customize which matmul is left in original precision. +config = Int8MixedPrecisionTrainingConfig( output=True, grad_input=True, - grad_weight=True, + grad_weight=False, ) quantize_(model, int8_mixed_precision_training(config)) @@ -77,10 +81,10 @@ During training, there are 3 matmuls involved in each `nn.Linear` layer: - `grad_input = grad_output @ weight` - `grad_weight = grad_output.T @ input` -You can configure which matmul to be applied with INT8 mixed-precision using `Int8MixedPrecisionConfig` shown above. If convergence is an issue, we recommend leaving `grad_weight` in original matmul precision, and also `grad_input` if the issue still persists. +You can configure which matmul to be applied with INT8 mixed-precision (shown above). If convergence is an issue, we recommend leaving `grad_weight` in original matmul precision, and also `grad_input` if the issue still persists. Note: -- When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT. +- When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT for INT8 dynamic activations + INT8 weight quantization (A8W8). - When we only apply INT8 mixed-precision to `output` and `grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. Pre-train Llama2-1B on C4 realnewslike subset. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). Using 4090. INT8 mixed precision is not applied to LM head. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index 4c0fe90c98..9f697e1362 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1,11 +1,11 @@ from .int8 import ( - Int8QTLinearWeight, + Int8QuantizedTrainingLinearWeight, int8_weight_only_quantized_training, quantize_int8_rowwise, ) from .int8_mixed_precision import ( - Int8MixedPrecisionConfig, + Int8MixedPrecisionTrainingConfig, Int8MixedPrecisionLinear, - Int8MixedPrecisionLinearWeight, + Int8MixedPrecisionTrainingLinearWeight, int8_mixed_precision_training, ) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 3fe6e0d34c..36f8dc03ab 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -40,7 +40,7 @@ def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): return tensor, scale -class Int8QTLinearWeight(TorchAOBaseTensor): +class Int8QuantizedTrainingLinearWeight(TorchAOBaseTensor): """INT8 symmetric quantization weight, with absmax scaling [-127, 127]. The main difference of this tensor subclass from AffineQuantizedTensor: 1. `F.linear` is differentiable i.e. backward is defined. @@ -110,12 +110,12 @@ def fsdp_post_all_gather( out: Optional[Tensor] = None, ): int_data, scale = all_gather_outputs - return Int8QTLinearWeight(int_data, scale.to(param_dtype)), all_gather_outputs + return Int8QuantizedTrainingLinearWeight(int_data, scale.to(param_dtype)), all_gather_outputs class _Int8WeightOnlyLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Int8QTLinearWeight, bias: Optional[Tensor] = None): + def forward(ctx, input: Tensor, weight: Int8QuantizedTrainingLinearWeight, bias: Optional[Tensor] = None): ctx.save_for_backward(input, weight) ctx.bias = bias is not None @@ -134,12 +134,15 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias -@Int8QTLinearWeight.implements(torch.nn.functional.linear) +implements = Int8QuantizedTrainingLinearWeight.implements + + +@implements(torch.nn.functional.linear) def _(func, types, args, kwargs): return _Int8WeightOnlyLinear.apply(*args, **kwargs) -@Int8QTLinearWeight.implements( +@implements( [ aten.detach.default, aten.clone.default, @@ -153,20 +156,20 @@ def _(func, types, args, kwargs): ) def _(func, types, args, kwargs): # will error out if try to slice 2nd dim - out = Int8QTLinearWeight( + out = Int8QuantizedTrainingLinearWeight( func(args[0].int_data, *args[1:], **kwargs), func(args[0].scale, *args[1:], **kwargs), ) return return_and_correct_aliasing(func, args, kwargs, out) -@Int8QTLinearWeight.implements(aten._to_copy.default) +@implements(aten._to_copy.default) def _(func, types, args, kwargs): # only perform dtype casting on scale, which determines the appearance dtype # TODO: handle non_blocking kwarg? device = kwargs.get("device", None) dtype = kwargs.get("dtype", None) - out = Int8QTLinearWeight( + out = Int8QuantizedTrainingLinearWeight( args[0].int_data.to(device=device), args[0].scale.to(device=device, dtype=dtype), ) @@ -174,7 +177,7 @@ def _(func, types, args, kwargs): # to make training work with existing PyTorch optimizers, we return a normal tensor -@Int8QTLinearWeight.implements(aten.zeros_like.default) +@implements(aten.zeros_like.default) def _(func, types, args, kwargs): dtype = kwargs.get("dtype", args[0].dtype) device = kwargs.get("device", args[0].device) @@ -182,19 +185,19 @@ def _(func, types, args, kwargs): # out-of-place math ops always return plain tensor -@Int8QTLinearWeight.implements([aten.sub.Tensor, aten.mul.Tensor]) +@implements([aten.sub.Tensor, aten.mul.Tensor]) def _(func, types, args, kwargs): - args = [x.dequantize() if isinstance(x, Int8QTLinearWeight) else x for x in args] + args = [x.dequantize() if isinstance(x, Int8QuantizedTrainingLinearWeight) else x for x in args] return func(*args, **kwargs) -@Int8QTLinearWeight.implements(aten.copy_.default) +@implements(aten.copy_.default) def _(func, types, args, kwargs): - if isinstance(args[0], Int8QTLinearWeight) and isinstance(args[1], Int8QTLinearWeight): + if isinstance(args[0], Int8QuantizedTrainingLinearWeight) and isinstance(args[1], Int8QuantizedTrainingLinearWeight): args[0].int_data.copy_(args[1].int_data, **kwargs) args[0].scale.copy_(args[1].scale, **kwargs) - elif isinstance(args[0], Int8QTLinearWeight): + elif isinstance(args[0], Int8QuantizedTrainingLinearWeight): int_data, scale = quantize_int8_rowwise(args[1], stochastic_rounding=True) args[0].int_data.copy_(int_data, **kwargs) args[0].scale.copy_(scale, **kwargs) @@ -205,7 +208,7 @@ def _(func, types, args, kwargs): return args[0] -@Int8QTLinearWeight.implements([aten.addcdiv_.default, aten.add_.Tensor]) +@implements([aten.addcdiv_.default, aten.add_.Tensor]) def _(func, types, args, kwargs): original = args[0] out = func(args[0].dequantize(), *args[1:], **kwargs) @@ -213,20 +216,20 @@ def _(func, types, args, kwargs): # FSDP ops -@Int8QTLinearWeight.implements(aten.split.Tensor) +@implements(aten.split.Tensor) def _(func, types, args, kwargs): if len(args) == 3 and args[2] != 0: raise NotImplementedError("Int8QTLinearWeight only supports split at dim=0") - int8_weight: Int8QTLinearWeight = args[0] + int8_weight: Int8QuantizedTrainingLinearWeight = args[0] int_data_list = func(int8_weight.int_data, *args[1:], **kwargs) scale_list = func(int8_weight.scale, *args[1:], **kwargs) - out = [Int8QTLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] + out = [Int8QuantizedTrainingLinearWeight(int_data, scale) for int_data, scale in zip(int_data_list, scale_list)] return out -@Int8QTLinearWeight.implements(aten.new_zeros.default) +@implements(aten.new_zeros.default) def _(func, types, args, kwargs): size = args[1] if len(size) != 2: @@ -237,7 +240,7 @@ def _(func, types, args, kwargs): dtype = kwargs.get("dtype", args[0].dtype) int_data = torch.zeros(size, device=device, dtype=torch.int8) scale = torch.zeros(size[0], device=device, dtype=dtype) - return Int8QTLinearWeight(int_data, scale) + return Int8QuantizedTrainingLinearWeight(int_data, scale) # FSDP2 will call these two ops, expecting a view, not a copy. It doesn't make sense to @@ -245,9 +248,9 @@ def _(func, types, args, kwargs): # since this is channel-wise quantization. # Thus, this is a workaround for FSDP2. Users SHOULD NOT call these ops directly, since # they will produce unexpected or wrong results. -@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) +@implements([aten.view.default, aten.as_strided.default]) def _(func, types, args, kwargs): - out = Int8QTLinearWeight(args[0].int_data, args[0].scale) + out = Int8QuantizedTrainingLinearWeight(args[0].int_data, args[0].scale) return return_and_correct_aliasing(func, args, kwargs, out) @@ -257,7 +260,7 @@ def int8_weight_only_quantized_training(): # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. def apply_int8_linear_weight(linear: nn.Linear): linear.weight = nn.Parameter( - Int8QTLinearWeight.from_float(linear.weight), + Int8QuantizedTrainingLinearWeight.from_float(linear.weight), requires_grad=linear.weight.requires_grad, ) return linear diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 3892d01151..e56b634481 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -18,13 +18,13 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi return A_scaled @ B_scaled -class Int8MixedPrecisionConfig(NamedTuple): +class Int8MixedPrecisionTrainingConfig(NamedTuple): output: bool = True grad_input: bool = True grad_weight: bool = True -_DEFAULT_CONFIG = Int8MixedPrecisionConfig() +_DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor @@ -44,10 +44,10 @@ class Int8MixedPrecisionConfig(NamedTuple): } -class Int8MixedPrecisionLinearWeight(Tensor): +class Int8MixedPrecisionTrainingLinearWeight(Tensor): @staticmethod @torch._dynamo.disable - def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): + def __new__(cls, data: Tensor, config: Int8MixedPrecisionTrainingConfig): return Tensor._make_wrapper_subclass( cls, data.shape, @@ -59,7 +59,7 @@ def __new__(cls, data: Tensor, config: Int8MixedPrecisionConfig): ) @torch._dynamo.disable - def __init__(self, data: Tensor, config: Int8MixedPrecisionConfig): + def __init__(self, data: Tensor, config: Int8MixedPrecisionTrainingConfig): self._data = data self.config = config @@ -125,15 +125,15 @@ def fsdp_post_all_gather( (data,) = all_gather_outputs (config,) = metadata if out is not None: - assert isinstance(out, Int8MixedPrecisionLinearWeight) + assert isinstance(out, Int8MixedPrecisionTrainingLinearWeight) assert out.config == config return - return Int8MixedPrecisionLinearWeight(data.to(param_dtype), config), all_gather_outputs + return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs -# alternative UX +# alternative UX. to be deleted class Int8MixedPrecisionLinear(nn.Linear): - def __init__(self, *args, config: Int8MixedPrecisionConfig, **kwargs) -> None: + def __init__(self, *args, config: Int8MixedPrecisionTrainingConfig, **kwargs) -> None: super().__init__(*args, **kwargs) self.config = config @@ -144,7 +144,7 @@ def extra_repr(self): return f"{super().extra_repr()}, config={self.config}" @classmethod - def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionConfig = _DEFAULT_CONFIG): + def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): if module.__class__ is nn.Linear: # exact match, don't swap nn.Linear subclasses module.__class__ = cls module.config = config @@ -175,7 +175,7 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: class _Int8MixedPrecisionLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionConfig): + def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionTrainingConfig): ctx.config = config ctx.save_for_backward(input, weight) ctx.bias = bias is not None @@ -220,13 +220,13 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, grad_config -def int8_mixed_precision_training(config: Int8MixedPrecisionConfig = _DEFAULT_CONFIG): +def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` # when we have this out of prototype (or there are stable trainable tensor subclasses), # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. def apply_int8_linear_weight(linear: nn.Linear): linear.weight = nn.Parameter( - Int8MixedPrecisionLinearWeight(linear.weight.detach(), config), + Int8MixedPrecisionTrainingLinearWeight(linear.weight.detach(), config), requires_grad=linear.weight.requires_grad, ) return linear From 986c590dfc0ca0bf21e02468c20926468bd140ee Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 13:04:12 +0800 Subject: [PATCH 36/45] update readme --- .../prototype/quantized_training/README.md | 35 +++++++++++-------- .../prototype/quantized_training/int8_mm.py | 3 +- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 456dbf30e1..c6a18e2f75 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -1,20 +1,32 @@ # Quantized training -This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from: +This folder contains experimental work on quantized training (QT), with a focus on INT8. We take inspirations from: + +- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] +- SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] - Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] - JetFire: [[paper](https://arxiv.org/abs/2403.12422)] [[code](https://github.com/thu-ml/Jetfire-INT8Training)] -Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. +The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. However, terminologies for INT8 training are generally not standardized yet. To be precise, we use these terms with the following meaning: -In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. +- **Quantized training**: model weights are quantized. This is a strict requirement. Does not matter what is the compute precision. Examples of this: Q-GaLore, JetFire. +- **INT8 mixed-precision training**: model weights are in original precision, while compute dtype for some or all ops is in INT8. We call it like this because it is similar to FP16/BF16 mixed-precision training. One difference is that in FP16/BF16 mixed-precision training, matmul will return FP16/BF16 outputs, while for INT8 mixed-precision training, the returned dtype is usually not INT8. Examples include Google AQT and SwitchBack. -There are 2 main benefits for training in this way: -1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting. -2. What you train is what you serve ([WYTIWYS](https://github.com/google/aqt?tab=readme-ov-file#features)). +There are 3 main benefits of using low-precision dtype for training (the extent depends on the actual strategies): -Currently we only support weight-only channel-wise INT8 symmetric quantization. +- **Memory**: reduce memory footprint by model weights, activations, gradients, and distributed communication bandwidth. +- **Speed**: speedup compute-bound ops with low-precision hardware instructions (e.g. INT8 Tensor Cores) and speedup memory-bound ops with quantized inputs/outputs. +- [What you train is what you serve](https://github.com/google/aqt?tab=readme-ov-file#features). + +[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training on single GPU for strategies implemented in this folder. -## INT8 weight only +## INT8 quantized training + +Typically, quantized weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. + +In precise terms, the probability of rounding up is `x - ⌊x⌋`. Note that when the value is exactly an integer value, the probability of rounding up is zero. + +Currently we only support weight-only channel-wise INT8 symmetric quantization. In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization `[-127, 127]`. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype. @@ -33,8 +45,6 @@ optim = _AdamW(model.parameters(), lr=3e-4) Only `torch.optim.Adam` and optimizers from `torchao.prototype.low_bit_optim` are known to work with quantized training in this folder. This is because we implement stochastic rounding logic within tensor subclass instead of the optimizer. We provide `torchao.prototype.low_bit_optim._AdamW` as an alternative to `torch.optim.AdamW` specifically for this purpose. -[`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training using this INT8 quantized training. - See [#644](https://github.com/pytorch/ao/pull/644) for some early results. TODO: investigate suboptimal memory saving when `torch.compile()` is used. Might be due to transposed weight. Benchamark for Llama2-1B, bs=4, seq_len=2048, activation checkpointing, 4070Ti SUPER. @@ -48,10 +58,7 @@ INT8 QT compile | 9.84 | 8700 ## INT8 mixed-precision -On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired by prior works: - -- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)] -- SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)] +On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. Usage diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 8a345edcc4..6dfac79b9c 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -122,7 +122,8 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi assert A.shape[1] == B.shape[0] assert A_scale_rowwise.squeeze().shape == (A.shape[0],) assert B_scale_colwise.squeeze().shape == (B.shape[1],) - # TODO: handle this inside triton kernel + # TODO: (low priority) investigate if handling strided scales inside triton kernel or + # simply calling .contiguous() like here (which hopefully is fused with prior ops) is faster. A_scale_rowwise = A_scale_rowwise.contiguous() B_scale_colwise = B_scale_colwise.contiguous() return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) From 35df4471d756e0d6169cb962682005c3b43b4322 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 15:41:50 +0800 Subject: [PATCH 37/45] some change --- .../int8_mixed_precision.py | 68 ++++++++++--------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index e56b634481..11ad3c417c 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -27,21 +27,7 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): _DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() -# adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor aten = torch.ops.aten -_ops_to_preserve_subclass = { - aten.detach.default, - aten.empty_like.default, - aten.new_zeros.default, - aten.slice.Tensor, - aten.copy_.default, - aten.view.default, - aten.as_strided.default, - aten._to_copy.default, - aten._pin_memory.default, - aten.split.Tensor, - aten.clone.default, -} class Int8MixedPrecisionTrainingLinearWeight(Tensor): @@ -51,11 +37,10 @@ def __new__(cls, data: Tensor, config: Int8MixedPrecisionTrainingConfig): return Tensor._make_wrapper_subclass( cls, data.shape, + data.stride(), + data.storage_offset(), dtype=data.dtype, device=data.device, - strides=data.stride(), - storage_offset=data.storage_offset(), - # pin_memory=data.is_pinned(), # this will fail compile on 2.4 ) @torch._dynamo.disable @@ -85,11 +70,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): act = args[0] weight: cls = args[1] bias = args[2] if len(args) > 2 else None - return _Int8MixedPrecisionLinear.apply(act, weight._data, bias, weight.config) + return _Int8MixedPrecisionTrainingLinear.apply(act, weight._data, bias, weight.config) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) + # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): config = None @@ -104,10 +90,28 @@ def unwrap(x: cls): args, kwargs = pytree.tree_map_only(cls, unwrap, (args, kwargs)) out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - else: + + if func in { + aten.copy_.default, + aten.add_.Tensor, + }: + return args[0] + elif func in { + aten.t.default, + aten.detach.default, + aten.empty_like.default, + aten.new_zeros.default, + aten.slice.Tensor, + aten.view.default, + aten.as_strided.default, + aten._to_copy.default, + aten._pin_memory.default, + aten.split.Tensor, + aten.clone.default, + }: return pytree.tree_map_only(Tensor, lambda x: cls(x, config), out) + else: + return out def fsdp_pre_all_gather(self, mesh): # TODO: pre-quantize weight here -> reduce comm bandwidth. @@ -138,7 +142,7 @@ def __init__(self, *args, config: Int8MixedPrecisionTrainingConfig, **kwargs) -> self.config = config def forward(self, input: Tensor) -> Tensor: - return _Int8MixedPrecisionLinear.apply(input, self.weight, self.bias, self.config) + return _Int8MixedPrecisionTrainingLinear.apply(input, self.weight, self.bias, self.config) def extra_repr(self): return f"{super().extra_repr()}, config={self.config}" @@ -157,8 +161,7 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: # INT8 matmul is the most performant when A is row-major and B is column-major. # thus, we transpose B before quantization. - # A and B (before quantization) might not be contiguous, especially in backward. - # it's also not guaranteed that A_i8 and B_t_i8 (after quantization) are contiguous, + # it's not guaranteed that A_i8 and B_t_i8 (after quantization) are contiguous, # thus we have to call .contiguous() on them. # hope that the .contiguous() calls will be fused into quantize op by torch.compile() # TODO: investigate if calling .contiguous() before quantization is better. @@ -173,14 +176,10 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: ) -class _Int8MixedPrecisionLinear(torch.autograd.Function): +class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionTrainingConfig): - ctx.config = config - ctx.save_for_backward(input, weight) - ctx.bias = bias is not None - - if ctx.config.output: + def forward(input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionTrainingConfig): + if config.output: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) out = _dynamic_int8_mm(input, weight.T) @@ -191,6 +190,13 @@ def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], config: out = out + bias if bias is not None else out return out + @staticmethod + def setup_context(ctx, inputs, output): + input, weight, bias, config = inputs + ctx.config = config + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors From 7164551869d98e72a71b4ed0ac8705eaa40eafbc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 17:29:31 +0800 Subject: [PATCH 38/45] fix big bug --- test/prototype/test_quantized_training.py | 4 +- .../prototype/quantized_training/__init__.py | 1 - .../int8_mixed_precision.py | 64 ++++++------------- .../prototype/quantized_training/int8_mm.py | 6 +- 4 files changed, 25 insertions(+), 50 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index da5a45ef24..9fad43e974 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -174,13 +174,15 @@ def test_int8_mixed_precision_training(self, compile): loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels) rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item()) - assert rel_error < 3e-2, (i, rel_error) + assert rel_error < 3e-3, (i, rel_error) loss_ref.backward() optim_ref.step() optim_ref.zero_grad() loss_int8mp.backward() + for p in model_int8mp.parameters(): + assert p.grad is not None optim_int8mp.step() optim_int8mp.zero_grad() diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index 9f697e1362..ccf2f5375d 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -5,7 +5,6 @@ ) from .int8_mixed_precision import ( Int8MixedPrecisionTrainingConfig, - Int8MixedPrecisionLinear, Int8MixedPrecisionTrainingLinearWeight, int8_mixed_precision_training, ) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 11ad3c417c..04aab02805 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -67,10 +67,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = dict() if func is torch.nn.functional.linear: - act = args[0] - weight: cls = args[1] - bias = args[2] if len(args) > 2 else None - return _Int8MixedPrecisionTrainingLinear.apply(act, weight._data, bias, weight.config) + return _Int8MixedPrecisionTrainingLinear.apply(*args) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -88,13 +85,13 @@ def unwrap(x: cls): assert x.config == config return x._data - args, kwargs = pytree.tree_map_only(cls, unwrap, (args, kwargs)) - out = func(*args, **kwargs) + out = func( + *pytree.tree_map_only(cls, unwrap, args), + **pytree.tree_map_only(cls, unwrap, kwargs), + ) - if func in { - aten.copy_.default, - aten.add_.Tensor, - }: + if func is aten.copy_.default: + # return original object return args[0] elif func in { aten.t.default, @@ -109,8 +106,10 @@ def unwrap(x: cls): aten.split.Tensor, aten.clone.default, }: + # return new wrapped object return pytree.tree_map_only(Tensor, lambda x: cls(x, config), out) else: + # return new unwrapped object return out def fsdp_pre_all_gather(self, mesh): @@ -135,54 +134,31 @@ def fsdp_post_all_gather( return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs -# alternative UX. to be deleted -class Int8MixedPrecisionLinear(nn.Linear): - def __init__(self, *args, config: Int8MixedPrecisionTrainingConfig, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.config = config - - def forward(self, input: Tensor) -> Tensor: - return _Int8MixedPrecisionTrainingLinear.apply(input, self.weight, self.bias, self.config) - - def extra_repr(self): - return f"{super().extra_repr()}, config={self.config}" - - @classmethod - def convert_linear(cls, module: nn.Module, config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): - if module.__class__ is nn.Linear: # exact match, don't swap nn.Linear subclasses - module.__class__ = cls - module.config = config - return - for child in module.children(): - cls.convert_linear(child, config) - - def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: # INT8 matmul is the most performant when A is row-major and B is column-major. # thus, we transpose B before quantization. - # it's not guaranteed that A_i8 and B_t_i8 (after quantization) are contiguous, + # it's not guaranteed that outputs after quantization are contiguous, # thus we have to call .contiguous() on them. # hope that the .contiguous() calls will be fused into quantize op by torch.compile() - # TODO: investigate if calling .contiguous() before quantization is better. # TODO: check if transpose+quantize are fused. A_i8, A_scale_rowwise = quantize_int8_rowwise(A) B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) return int8_mm_dequant( A_i8.contiguous(), B_t_i8.contiguous().T, - A_scale_rowwise, - B_scale_colwise, + A_scale_rowwise.contiguous(), + B_scale_colwise.contiguous(), ) class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function): @staticmethod - def forward(input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8MixedPrecisionTrainingConfig): - if config.output: + def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]): + if weight.config.output: batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) - out = _dynamic_int8_mm(input, weight.T) + out = _dynamic_int8_mm(input, weight._data.T) out = out.view(*batch_dims, weight.shape[0]) else: out = input @ weight.T @@ -192,15 +168,15 @@ def forward(input: Tensor, weight: Tensor, bias: Optional[Tensor], config: Int8M @staticmethod def setup_context(ctx, inputs, output): - input, weight, bias, config = inputs - ctx.config = config - ctx.save_for_backward(input, weight) + input, weight, bias = inputs + ctx.config = weight.config + ctx.save_for_backward(input, weight._data) ctx.bias = bias is not None @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors - grad_input = grad_weight = grad_bias = grad_config = None + grad_input = grad_weight = grad_bias = None batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight.shape[0]) @@ -223,7 +199,7 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[2] and ctx.bias: grad_bias = grad_output.sum(0) - return grad_input, grad_weight, grad_bias, grad_config + return grad_input, grad_weight, grad_bias def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 6dfac79b9c..b316e82208 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -122,10 +122,8 @@ def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwi assert A.shape[1] == B.shape[0] assert A_scale_rowwise.squeeze().shape == (A.shape[0],) assert B_scale_colwise.squeeze().shape == (B.shape[1],) - # TODO: (low priority) investigate if handling strided scales inside triton kernel or - # simply calling .contiguous() like here (which hopefully is fused with prior ops) is faster. - A_scale_rowwise = A_scale_rowwise.contiguous() - B_scale_colwise = B_scale_colwise.contiguous() + assert A_scale_rowwise.is_contiguous() + assert B_scale_colwise.is_contiguous() return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) From b14ab6dd5941e963a473f43bc093bad57e89f64a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 17:54:11 +0800 Subject: [PATCH 39/45] add docstring. update _get_linear_inserter --- .../prototype/quantized_training/README.md | 1 + torchao/prototype/quantized_training/int8.py | 15 ++---- .../int8_mixed_precision.py | 47 ++++++++++++------- torchao/quantization/quant_api.py | 5 +- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index c6a18e2f75..a0a659f634 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -93,6 +93,7 @@ You can configure which matmul to be applied with INT8 mixed-precision (shown ab Note: - When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT for INT8 dynamic activations + INT8 weight quantization (A8W8). - When we only apply INT8 mixed-precision to `output` and `grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling. +- Apply stochastic rounding to INT8 quantization may improve matmul accuracy. However, from our testing, this seems to be unnecessary, thus we don't implement it at the moment. Pre-train Llama2-1B on C4 realnewslike subset. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). Using 4090. INT8 mixed precision is not applied to LM head. diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 36f8dc03ab..82b3ee1d9e 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -1,10 +1,11 @@ from typing import Any, Optional, Tuple import torch -from torch import Tensor, nn +from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TorchAOBaseTensor +from torchao.quantization.quant_api import _get_linear_subclass_inserter aten = torch.ops.aten @@ -255,14 +256,4 @@ def _(func, types, args, kwargs): def int8_weight_only_quantized_training(): - # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` - # when we have this out of prototype (or there are stable trainable tensor subclasses), - # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. - def apply_int8_linear_weight(linear: nn.Linear): - linear.weight = nn.Parameter( - Int8QuantizedTrainingLinearWeight.from_float(linear.weight), - requires_grad=linear.weight.requires_grad, - ) - return linear - - return apply_int8_linear_weight + return _get_linear_subclass_inserter(Int8QuantizedTrainingLinearWeight.from_float, allow_requires_grad=True) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 04aab02805..5210ff754f 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -2,9 +2,11 @@ import torch import torch.utils._pytree as pytree -from torch import Tensor, nn +from torch import Tensor from torch.utils._triton import has_triton +from torchao.quantization.quant_api import _get_linear_subclass_inserter + from .int8 import quantize_int8_rowwise if has_triton(): @@ -31,6 +33,11 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): class Int8MixedPrecisionTrainingLinearWeight(Tensor): + """Linear weight for INT8 mixed-precision training. The weight is in original precision (e.g. FP32 or BF16). + During training, weight and activation are dynamically quantized and cast to INT8 to utilize INT8 Tensor Cores, + and then scaled back to original precision. This is also applied to backward pass. + """ + @staticmethod @torch._dynamo.disable def __new__(cls, data: Tensor, config: Int8MixedPrecisionTrainingConfig): @@ -135,13 +142,23 @@ def fsdp_post_all_gather( def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: - # INT8 matmul is the most performant when A is row-major and B is column-major. - # thus, we transpose B before quantization. + """Dynamically quantize A and B to perform INT8 matmul, then scale the results back to original precision. + To fuse scaling to matmul output, we use row-wise scaling for A and column-wise scaling for B. + + We transpose B before quantization for 2 reasons: + - INT8 matmul is the most performant when A is row-major and B is column-major. + - Row-wise scaling for B.T is column-wise scaling for B -> we only need to implement row-wise scaling. + + Note that inputs and outputs of `quantize_int8_rowwise()` are not guaranteed to be contiguous. We call + `.contiguous()` to outputs of the quantize op to make sure: + - Performant layout for INT8 matmul inputs (see above). + - Scales are contiguous (this is a limitation of our triton kernel). - # it's not guaranteed that outputs after quantization are contiguous, - # thus we have to call .contiguous() on them. - # hope that the .contiguous() calls will be fused into quantize op by torch.compile() - # TODO: check if transpose+quantize are fused. + We hope that the `.contiguous()` calls, as well as possible layout transpose before quantization, are + fused into quantize op by torch compiler. + + TODO: check if transpose+quantize are actually fused. + """ A_i8, A_scale_rowwise = quantize_int8_rowwise(A) B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) return int8_mm_dequant( @@ -203,14 +220,8 @@ def backward(ctx, grad_output): def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): - # TODO: right now `_get_linear_subclass_inserter()` will always set `requires_grad=False` - # when we have this out of prototype (or there are stable trainable tensor subclasses), - # update `_get_linear_subclass_inserter()` to allow `requires_grad=True`. - def apply_int8_linear_weight(linear: nn.Linear): - linear.weight = nn.Parameter( - Int8MixedPrecisionTrainingLinearWeight(linear.weight.detach(), config), - requires_grad=linear.weight.requires_grad, - ) - return linear - - return apply_int8_linear_weight + return _get_linear_subclass_inserter( + Int8MixedPrecisionTrainingLinearWeight, + config=config, + allow_requires_grad=True, + ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 89bccf1264..d562637342 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -295,12 +295,13 @@ def _quantization_type(weight: torch.Tensor): def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, **kwargs): +def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ def insert_subclass(lin): - lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=False) + requires_grad = allow_requires_grad and lin.weight.requires_grad + lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad) lin.extra_repr = types.MethodType(_linear_extra_repr, lin) return lin From dbbc90f51892e4cde588ce0bbe44ce46a60d7460 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 6 Sep 2024 22:20:49 +0800 Subject: [PATCH 40/45] add TorchAOBaseTensor back --- .../int8_mixed_precision.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 5210ff754f..059efec40c 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -6,6 +6,7 @@ from torch.utils._triton import has_triton from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor from .int8 import quantize_int8_rowwise @@ -32,7 +33,7 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): aten = torch.ops.aten -class Int8MixedPrecisionTrainingLinearWeight(Tensor): +class Int8MixedPrecisionTrainingLinearWeight(TorchAOBaseTensor): """Linear weight for INT8 mixed-precision training. The weight is in original precision (e.g. FP32 or BF16). During training, weight and activation are dynamically quantized and cast to INT8 to utilize INT8 Tensor Cores, and then scaled back to original precision. This is also applied to backward pass. @@ -68,17 +69,6 @@ def __repr__(self): def to_original(self): return self._data.clone() - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = dict() - - if func is torch.nn.functional.linear: - return _Int8MixedPrecisionTrainingLinear.apply(*args) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -141,6 +131,11 @@ def fsdp_post_all_gather( return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs +@Int8MixedPrecisionTrainingLinearWeight.implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + return _Int8MixedPrecisionTrainingLinear.apply(*args, **kwargs) + + def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: """Dynamically quantize A and B to perform INT8 matmul, then scale the results back to original precision. To fuse scaling to matmul output, we use row-wise scaling for A and column-wise scaling for B. From 8d918f18bac80c2f6f64b9ed55882c9c391d55e6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 7 Sep 2024 09:33:27 +0800 Subject: [PATCH 41/45] fix FSDP --- test/prototype/test_quantized_training.py | 109 +++++++++++++----- .../prototype/quantized_training/README.md | 28 ++++- .../int8_mixed_precision.py | 14 ++- 3 files changed, 118 insertions(+), 33 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 9fad43e974..0b1882f099 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -1,25 +1,30 @@ +import pytest + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 + +if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("Requires torch>=2.4", allow_module_level=True) + import copy -import pytest import torch +import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import ( - int8_weight_only_quantized_training, + Int8MixedPrecisionTrainingConfig, int8_mixed_precision_training, + int8_weight_only_quantized_training, quantize_int8_rowwise, ) from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Requires torch>=2.4", allow_module_level=True) - _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -187,34 +192,21 @@ def test_int8_mixed_precision_training(self, compile): optim_int8mp.zero_grad() +_FSDP_WORLD_SIZE = 2 + + class TestFSDP2(FSDPTest): @property def world_size(self) -> int: - return 2 + return _FSDP_WORLD_SIZE - @skip_if_lt_x_gpu(2) - def test_fsdp2(self): + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_fsdp2_correctness(self): # due to stochastic rounding, use a pretty large tolerance here - self.run_subtests( - dict(), - self._test_fsdp2, - quantize_fn=int8_weight_only_quantized_training(), - tolerance=0.05, - ) - - self.run_subtests( - dict(), - self._test_fsdp2, - quantize_fn=int8_mixed_precision_training(), - tolerance=1e-6, - ) + self._test_fsdp2(int8_weight_only_quantized_training(), tolerance=0.05) + self._test_fsdp2(int8_mixed_precision_training(), tolerance=1e-6) def _test_fsdp2(self, quantize_fn, tolerance): - import torch.distributed as dist - from torch.distributed._composable.fsdp import fully_shard - from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer - - _reset() batch_size = 3 vocab_size = 32 seq_len = 64 @@ -246,19 +238,76 @@ def _test_fsdp2(self, quantize_fn, tolerance): fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() + for param in fsdp_model.parameters(): + assert param.grad is not None fsdp_optim.step() base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) base_loss = base_model(inp).sum() base_loss.backward() for param in base_model.parameters(): - if param.grad is not None: - dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + assert param.grad is not None + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() assert rel_error < tolerance, (iter_idx, rel_error) + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_int8_mixed_precision_fsdp2_mixed_precision(self): + batch_size = 3 + vocab_size = 32 + seq_len = 64 + tolerance = 1e-6 + + # NOTE: if weight_tying=True and we also quantize LM head, INT8 mixed-precision will fail. + model_args = ModelArgs( + n_layers=2, + n_heads=2, + dim=128, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + base_model = Transformer(model_args).cuda() + mp_model = copy.deepcopy(base_model) + + quantize_(base_model.layers, int8_mixed_precision_training(), set_inductor_config=False) + for layer in base_model.layers: + fully_shard(layer) + fully_shard(base_model) + + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) + quantize_(mp_model.layers, int8_mixed_precision_training(config), set_inductor_config=False) + for layer in mp_model.layers: + fully_shard(layer) + fully_shard(mp_model) + + base_optim = torch.optim.AdamW(base_model.parameters(), lr=1e-2) + mp_optim = torch.optim.AdamW(mp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + mp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + mp_loss = mp_model(inp).sum() + mp_loss.backward() + for param in mp_model.parameters(): + assert param.grad is not None + mp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + assert param.grad is not None + base_optim.step() + + rel_error = (mp_loss - base_loss).abs() / base_loss.abs() + assert rel_error < tolerance, (iter_idx, rel_error) + instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index a0a659f634..1dde72598c 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -60,7 +60,7 @@ INT8 QT compile | 9.84 | 8700 On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. -Usage +### Basic usage ```python from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig @@ -104,6 +104,32 @@ INT8 mixed-precision | ~29k | 19.47 | 2.90 See [#748](https://github.com/pytorch/ao/pull/748) for more results. +### FSDP support + +Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 `MixedPrecisionPolicy(param_dtype=param_dtype)`, where `param_dtype` != model dtype. As a workaround, you will need to manually specify the FSDP2's `param_dtype` in `Int8MixedPrecisionTrainingConfig` + +```python +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig +from torchao.quantization import quantize_ + +model = ... # FP32 model + +# setup configs +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) +int8mp_config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) + +# exclude LM head +quantize_(model.layers, int8_mixed_precision_training(int8mp_config)) + +# shard the model w/ FSDP2 +for layer in model.layers: + fully_shard(layer, mp_policy=mp_policy) +fully_shard(model, mp_policy=mp_policy) + +# train model as usual +``` + ## Future ideas - Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 059efec40c..59d052d790 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -26,6 +26,10 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): grad_input: bool = True grad_weight: bool = True + # workaround for FSDP2 with `MixedPrecisionPolicy(param_dtype)` + # see `Int8MixedPrecisionTrainingLinearWeight.fsdp_pre_all_gather()` for more details. + fsdp_param_dtype: Optional[torch.dtype] = None + _DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() @@ -112,7 +116,13 @@ def unwrap(x: cls): def fsdp_pre_all_gather(self, mesh): # TODO: pre-quantize weight here -> reduce comm bandwidth. # we will need another tensor subclass to hold the quantized weight. - return (self._data,), (self.config,) + + # doing dtype casting to `param_dtype` in `fsdp_post_all_gather()` will give wrong results. + # as a workaround, we do it in `fsdp_pre_all_gather()` instead. since `param_dtype` is not + # exposed to `fsdp_pre_all_gather()`, we need to specify it in the config. + # this workaround can be removed once we implement INT8 communication. + data = self._data.to(dtype=self.config.fsdp_param_dtype) + return (data,), (self.config,) def fsdp_post_all_gather( self, @@ -128,7 +138,7 @@ def fsdp_post_all_gather( assert isinstance(out, Int8MixedPrecisionTrainingLinearWeight) assert out.config == config return - return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs + return Int8MixedPrecisionTrainingLinearWeight(data, config), all_gather_outputs @Int8MixedPrecisionTrainingLinearWeight.implements(torch.nn.functional.linear) From d67a9336cec8f125d9ce9645c5fd5ea7dc94d76c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 7 Sep 2024 17:00:44 +0800 Subject: [PATCH 42/45] update FSDP test. add autocast support --- test/prototype/test_quantized_training.py | 112 ++++++++---------- torchao/prototype/quantized_training/int8.py | 2 +- .../int8_mixed_precision.py | 3 + 3 files changed, 51 insertions(+), 66 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 0b1882f099..bcc92e07eb 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -149,8 +149,17 @@ def test_int8_weight_only_training(self, compile, device): optim_int8.zero_grad() @parametrize("compile", [False, True]) + @parametrize( + "config", + [ + Int8MixedPrecisionTrainingConfig(), + Int8MixedPrecisionTrainingConfig(output=False), + Int8MixedPrecisionTrainingConfig(grad_input=False), + Int8MixedPrecisionTrainingConfig(grad_weight=False), + ], + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_int8_mixed_precision_training(self, compile): + def test_int8_mixed_precision_training(self, compile, config): _reset() bsize = 4 embed_dim = 32 @@ -163,7 +172,7 @@ def test_int8_mixed_precision_training(self, compile): nn.Linear(embed_dim, embed_dim), ).to(device) model_int8mp = copy.deepcopy(model_ref) - quantize_(model_int8mp, int8_mixed_precision_training(), set_inductor_config=False) + quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) if compile: model_ref.compile() @@ -202,11 +211,36 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) def test_fsdp2_correctness(self): - # due to stochastic rounding, use a pretty large tolerance here - self._test_fsdp2(int8_weight_only_quantized_training(), tolerance=0.05) - self._test_fsdp2(int8_mixed_precision_training(), tolerance=1e-6) + test_args = [ + ( + int8_weight_only_quantized_training(), # quantize_fn for base model + int8_weight_only_quantized_training(), # quantize_fn for FSDP model + MixedPrecisionPolicy(), + 0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here + ), + ( + int8_mixed_precision_training(), + int8_mixed_precision_training(), + MixedPrecisionPolicy(), + 1e-6, + ), + ( + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + int8_mixed_precision_training(), + int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)), + MixedPrecisionPolicy(param_dtype=torch.bfloat16), + 1e-2, + ), + ] + self.run_subtests({"args": test_args}, self._run_subtest) + + def _run_subtest(self, args): + base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args - def _test_fsdp2(self, quantize_fn, tolerance): batch_size = 3 vocab_size = 32 seq_len = 64 @@ -222,18 +256,21 @@ def _test_fsdp2(self, quantize_fn, tolerance): ) torch.manual_seed(42) base_model = Transformer(model_args).cuda() - quantize_(base_model.layers, quantize_fn, set_inductor_config=False) fsdp_model = copy.deepcopy(base_model) + quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False) + quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False) + for layer in fsdp_model.layers: - fully_shard(layer) - fully_shard(fsdp_model) + fully_shard(layer, mp_policy=mp_policy) + fully_shard(fsdp_model, mp_policy=mp_policy) + # start testing base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) torch.manual_seed(42 + self.rank + 1) - for iter_idx in range(5): + for iter_idx in range(10): inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() @@ -253,61 +290,6 @@ def _test_fsdp2(self, quantize_fn, tolerance): rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() assert rel_error < tolerance, (iter_idx, rel_error) - @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) - def test_int8_mixed_precision_fsdp2_mixed_precision(self): - batch_size = 3 - vocab_size = 32 - seq_len = 64 - tolerance = 1e-6 - - # NOTE: if weight_tying=True and we also quantize LM head, INT8 mixed-precision will fail. - model_args = ModelArgs( - n_layers=2, - n_heads=2, - dim=128, - vocab_size=vocab_size, - max_seq_len=seq_len, - dropout_p=0, - ) - torch.manual_seed(42) - base_model = Transformer(model_args).cuda() - mp_model = copy.deepcopy(base_model) - - quantize_(base_model.layers, int8_mixed_precision_training(), set_inductor_config=False) - for layer in base_model.layers: - fully_shard(layer) - fully_shard(base_model) - - mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) - config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) - quantize_(mp_model.layers, int8_mixed_precision_training(config), set_inductor_config=False) - for layer in mp_model.layers: - fully_shard(layer) - fully_shard(mp_model) - - base_optim = torch.optim.AdamW(base_model.parameters(), lr=1e-2) - mp_optim = torch.optim.AdamW(mp_model.parameters(), lr=1e-2) - - torch.manual_seed(42 + self.rank + 1) - for iter_idx in range(5): - inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") - mp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) - mp_loss = mp_model(inp).sum() - mp_loss.backward() - for param in mp_model.parameters(): - assert param.grad is not None - mp_optim.step() - - base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) - base_loss = base_model(inp).sum() - base_loss.backward() - for param in base_model.parameters(): - assert param.grad is not None - base_optim.step() - - rel_error = (mp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < tolerance, (iter_idx, rel_error) - instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 82b3ee1d9e..828655f04c 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -111,7 +111,7 @@ def fsdp_post_all_gather( out: Optional[Tensor] = None, ): int_data, scale = all_gather_outputs - return Int8QuantizedTrainingLinearWeight(int_data, scale.to(param_dtype)), all_gather_outputs + return Int8QuantizedTrainingLinearWeight(int_data, scale), all_gather_outputs class _Int8WeightOnlyLinear(torch.autograd.Function): diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 59d052d790..960c1858f9 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -143,6 +143,9 @@ def fsdp_post_all_gather( @Int8MixedPrecisionTrainingLinearWeight.implements(torch.nn.functional.linear) def _(func, types, args, kwargs): + if torch.is_autocast_enabled("cuda"): + dtype = torch.get_autocast_gpu_dtype() + args = tuple(x.to(dtype) if x is not None else x for x in args) return _Int8MixedPrecisionTrainingLinear.apply(*args, **kwargs) From 6122aaae2702fa12a02896a41869c64868e28e6e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 9 Sep 2024 10:41:46 +0800 Subject: [PATCH 43/45] reduce iter --- test/prototype/test_quantized_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index bcc92e07eb..bffff16fc1 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -270,7 +270,7 @@ def _run_subtest(self, args): fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) torch.manual_seed(42 + self.rank + 1) - for iter_idx in range(10): + for iter_idx in range(5): inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() From 0d65b26eb7634fb013c2d2b3b93d713130d3e6d4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 9 Sep 2024 12:03:14 +0800 Subject: [PATCH 44/45] update int8_mm fallback --- .../prototype/quantized_training/int8_mixed_precision.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 960c1858f9..0abbdaee0a 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -15,10 +15,11 @@ else: + # This is less performant than the explicit hand-written Triton kernel, though things might + # change in the future. + # Multiplying B_scale first is faster than the other way round. def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: - A_scaled = A * A_scale_rowwise.view(-1, 1) - B_scaled = B * B_scale_colwise.view(1, -1) - return A_scaled @ B_scaled + return torch._int_mm(A, B) * B_scale_colwise * A_scale_rowwise.view(-1, 1) class Int8MixedPrecisionTrainingConfig(NamedTuple): From 6082d308ebea1274a318faeed9bc28e7bc6ea3fa Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 9 Sep 2024 14:35:44 +0800 Subject: [PATCH 45/45] put leading dims logic to _dynamic_int8_mm --- .../int8_mixed_precision.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 0abbdaee0a..0f96e348ba 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -168,27 +168,25 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: TODO: check if transpose+quantize are actually fused. """ - A_i8, A_scale_rowwise = quantize_int8_rowwise(A) + # A may have more than 2 dims, while B must be exactly 2-dim + A_i8, A_scale_rowwise = quantize_int8_rowwise(A.view(-1, A.shape[-1])) B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) - return int8_mm_dequant( + out = int8_mm_dequant( A_i8.contiguous(), B_t_i8.contiguous().T, A_scale_rowwise.contiguous(), B_scale_colwise.contiguous(), ) + return out.view(*A.shape[:-1], out.shape[-1]) class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function): @staticmethod def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]): if weight.config.output: - batch_dims = input.shape[:-1] - input = input.view(-1, weight.shape[1]) out = _dynamic_int8_mm(input, weight._data.T) - out = out.view(*batch_dims, weight.shape[0]) else: - out = input @ weight.T - + out = input @ weight._data.T out = out + bias if bias is not None else out return out @@ -204,18 +202,15 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - batch_dims = grad_output.shape[:-1] - grad_output = grad_output.view(-1, weight.shape[0]) - input = input.view(-1, weight.shape[1]) - if ctx.needs_input_grad[0]: if ctx.config.grad_input: grad_input = _dynamic_int8_mm(grad_output, weight) else: grad_input = grad_output @ weight - grad_input = grad_input.view(*batch_dims, weight.shape[1]) if ctx.needs_input_grad[1]: + grad_output = grad_output.view(-1, weight.shape[0]) + input = input.view(-1, weight.shape[1]) if ctx.config.grad_weight: # grad_weight = _dynamic_int8_mm(grad_output.T, input) grad_weight = _dynamic_int8_mm(input.T, grad_output).T # this is slightly faster