From e28b855cc58a86ef3be0449c5e83aafe614e6bcb Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 21 Oct 2024 07:30:14 +0000 Subject: [PATCH 1/5] add BF16 sr for optimizer --- test/prototype/test_low_bit_optim.py | 57 +++++++- torchao/prototype/low_bit_optim/adam.py | 122 +++++++++++++++--- .../prototype/low_bit_optim/quant_utils.py | 25 ++++ 3 files changed, 183 insertions(+), 21 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 847ea066bc..21caca80ba 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -14,8 +14,12 @@ from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim -from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 +from torchao.prototype.low_bit_optim.quant_utils import ( + quantize_8bit_with_qmap, + quantize_4bit_with_qmap, + _fp32_to_bf16_sr, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 try: import bitsandbytes as bnb @@ -74,6 +78,22 @@ def test_quantize_4bit_with_qmap_compile(self, device): torch.testing.assert_close(actual, expected) + @parametrize("device", _DEVICES) + @parametrize("compile", [False, True]) + def test_bf16_stochastic_round(self, device, compile): + x = torch.rand(32, device=device) * 100 + x_rep = x.view(-1, 1).repeat(1, 100_000) + + if compile: + x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep) + else: + x_rep_bf16 = _fp32_to_bf16_sr(x_rep) + + assert x_rep_bf16.dtype is torch.bfloat16 + + # must cast BF16 tensor back to FP32 so that .mean() is accurate + torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5) + class TestOptim(TestCase): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") @@ -249,13 +269,44 @@ def test_optim_cpu_offload_save_load(self): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + def test_optim_bf16_stochastic_round_correctness(self): + device = "cuda" + torch.manual_seed(2024) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model2 = copy.deepcopy(model1).bfloat16() + + # small LR so that weight update is small + # when bf16_stochastic_round=False, the test will fail after 1 iteration + optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5, fused=True) + optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) + + # overfit on this sample + x = torch.randn(4, 32, device=device) + + for idx in range(5): + # mixed-precision training + with torch.autocast(device, dtype=torch.bfloat16): + loss1 = model1(x) + loss1 = loss1.sum() # under autocast context, bf16.sum() will return fp32 + loss1.backward() + optim1.step() + optim1.zero_grad() + + # full BF16 training with stochastic round weight update + loss2 = model2(x.bfloat16()).sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") + class TestFSDP2(FSDPTest): @property def world_size(self) -> int: return 2 - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default") + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.") @skip_if_lt_x_gpu(2) def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 980b60f9a1..19e0640334 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -8,10 +8,13 @@ from .subclass_8bit import OptimState8bit from .subclass_4bit import OptimState4bit from .subclass_fp8 import OptimStateFp8 +from .quant_utils import _fp32_to_bf16_sr class _AdamBase(Optimizer): - def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, is_adamw) -> None: + def __init__( + self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw + ) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -23,6 +26,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super().__init__(params, defaults) self.block_size = block_size + self.bf16_stochastic_round = bf16_stochastic_round self.is_adamw = is_adamw def __setstate__(self, state): @@ -102,6 +106,7 @@ def step(self, closure=None): group["weight_decay"], group["eps"], self.is_adamw, + self.bf16_stochastic_round and p.dtype is torch.bfloat16, ) return loss @@ -121,14 +126,17 @@ def single_param_adam( beta2: float, weight_decay: float, eps: float, - is_adamw: bool, + IS_ADAMW: bool, + BF16_STOCHASTIC_ROUND: bool, ): # compute in FP32 for accurate calculations p_f32 = p.float() grad_f32 = grad.float() - if not is_adamw: - grad_f32 = grad_f32.add(p_f32, alpha=weight_decay) + if IS_ADAMW: + p_f32 = p_f32 - lr * weight_decay * p_f32 + else: + grad_f32 = grad_f32 + weight_decay * p_f32 bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step @@ -143,16 +151,16 @@ def single_param_adam( if max_exp_avg_sq is not None: max_exp_avg_sq_f32 = torch.maximum(max_exp_avg_sq.float(), exp_avg_sq_f32) max_exp_avg_sq.copy_(max_exp_avg_sq_f32) - denom = (max_exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps) + denom = (max_exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()) + eps else: - denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps) + denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()) + eps - step_size = lr / bias_correction1 - if is_adamw: - # merge weight decay and param update in a single .add_() to make this work with quantized param - p.add_(-lr * weight_decay * p_f32 - step_size * exp_avg_f32 / denom) + p_f32 = p_f32 - lr * (exp_avg_f32 / bias_correction1) / denom + + if BF16_STOCHASTIC_ROUND: + p.copy_(_fp32_to_bf16_sr(p_f32)) else: - p.addcdiv_(exp_avg_f32, denom, value=-step_size) + p.copy_(p_f32) class Adam8bit(_AdamBase): @@ -166,8 +174,19 @@ def __init__( amsgrad=False, *, block_size=256, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=False, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -185,8 +204,19 @@ def __init__( amsgrad=False, *, block_size=128, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=False, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -204,8 +234,19 @@ def __init__( amsgrad=False, *, block_size=256, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=False, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -223,8 +264,19 @@ def __init__( amsgrad=False, *, block_size=256, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=True, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -242,8 +294,19 @@ def __init__( amsgrad=False, *, block_size=128, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=True, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -261,8 +324,19 @@ def __init__( amsgrad=False, *, block_size=256, + bf16_stochastic_round=False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=block_size, + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=True, + ) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -278,7 +352,19 @@ def __init__( eps=1e-8, weight_decay=1e-2, amsgrad=False, + *, + bf16_stochastic_round=False, ) -> None: """AdamW optimizer that supports quantized training (parameter is quantized). This optimizer should only be used with torchao's quantized training.""" - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf"), is_adamw=True) + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + block_size=float("inf"), + bf16_stochastic_round=bf16_stochastic_round, + is_adamw=True, + ) diff --git a/torchao/prototype/low_bit_optim/quant_utils.py b/torchao/prototype/low_bit_optim/quant_utils.py index 0dc262ed40..556a2f290c 100644 --- a/torchao/prototype/low_bit_optim/quant_utils.py +++ b/torchao/prototype/low_bit_optim/quant_utils.py @@ -110,3 +110,28 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor): # torch.compile() cannot use uint8 as index out = qmap[codes.int()].view(scale.shape[0], -1) * scale.view(-1, 1) return out.view(codes.shape) + + +def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: + # For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16 + # - Round towards zero: [a31, ..., a16, 0, ..., 0] + # - Round away from zero: [a31, ..., a16+1, 0, ..., 0] + # (since the value can be negative, we use round towards/away from zero instead of round up/down) + # + # For stochastic rounding, we round away from zero with the probability of + # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16 + # + # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16 + rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32) + x_f32_bits = x_f32.view(torch.int32) + x_fraction = x_f32_bits & 0xFFFF # lower 16 bits + x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits + + x_f32_bits = torch.where( + rand_16bit < x_fraction, # this is True with the probability of p_fraction + x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer + x_bf16_towards_zero, + ) + # alternative, slightly faster + # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000 + return x_f32_bits.view(torch.float32).bfloat16() From 218214c7f9461934af5479d0ceaaf286e94e876b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 21 Oct 2024 09:49:35 +0000 Subject: [PATCH 2/5] update doc and benchmark scripts --- benchmarks/benchmark_low_bit_adam.py | 15 +++++++++++-- .../quantized_training/pretrain_llama2.py | 9 +++++++- torchao/prototype/low_bit_optim/README.md | 22 +++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index d9f03a88bf..bd31193892 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -90,6 +90,7 @@ def get_parser(): parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys()) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=0) + parser.add_argument("--optim_kwargs", type=json.loads, default=dict()) parser.add_argument("--cosine_lr_scheduler", action="store_true") parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]) @@ -206,7 +207,12 @@ def evaluate_model(model, args): train_batch_size=args.batch_size, optimizer=dict( type="Adam", - params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False), + params=dict( + lr=args.lr, + weight_decay=args.weight_decay, + fp32_optimizer_states=False, + **args.optim_kwargs, + ), ), bf16=dict(enabled=args.full_bf16), zero_optimization=dict( @@ -225,7 +231,12 @@ def evaluate_model(model, args): elif args.optim_cpu_offload == "ao_offload_grads": optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True) - optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optim = optim_cls( + model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + **args.optim_kwargs, + ) lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 0085f24264..eed90fe9f6 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -11,6 +11,7 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import argparse +import json import time from functools import partial from pathlib import Path @@ -108,6 +109,7 @@ def get_tinystories(): parser.add_argument("--optim", default="AdamW") parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--optim_kwargs", type=json.loads, default=dict()) parser.add_argument("--project", default="quantized_training") parser.add_argument("--run_name") @@ -171,7 +173,12 @@ def insert_rmsnorm(module: torch.nn.Module): # only use optimizers from torchao.prototype.low_bit_optim to support quantized training if args.optim == "AdamW": args.optim = "_AdamW" - optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optim = getattr(low_bit_optim, args.optim)( + model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + **args.optim_kwargs, + ) data = get_tinystories().cuda() args.torch_version = torch.__version__ diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index ece7687801..abcf69233a 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -5,6 +5,7 @@ This folder implements: - 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861 - 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507 - FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental) +- Stochastic rounding for BF16 weight (https://arxiv.org/abs/2010.06192, experimental) The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms. @@ -56,6 +57,27 @@ ao 4-bit | 33.2 | 2900 | 42.27 NOTE: lpmm's 4-bit AdamW does not support BF16 weights. +## Stochastic rounding for BF16 weight + +BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights. + +Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform **stochastic rounding** when casting FP32->BF16. +- In stochastic rounding, we will round up with the probability of `(x - round_down(x)) / (round_up(x) - round_down(x))`, and round down otherwise. +- It follows that successive weight update with stochastic rounding will correctly approximate high-precision weight update. +- Since BF16 is simply a truncation of FP32, there is an efficient implementation for FP32->BF16 stochastic rounding (the same is not true for FP32->FP16). +- More detailed discussion can be found at https://arxiv.org/abs/2010.06192. [llm.c](https://github.com/karpathy/llm.c/blob/master/llmc/adamw.cuh#L43) also implements this approach. + +```python +# a clone of torch.optim.AdamW with extra features +from torchao.prototype.low_bit_optim import _AdamW + +model = ... +model_bf16 = model.bfloat16() +optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True) +``` + +All of our low-bit optimizers mentioned above also support `bf16_stochastic_round` flag. Note that this flag only applies to BF16 weight. + ## Optimizer CPU offload This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload. From f4220da1a4b23ea63ab17c080837731ce96d020a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 21 Oct 2024 19:27:15 +0800 Subject: [PATCH 3/5] fix device --- test/prototype/test_low_bit_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 21caca80ba..1bdd4e6b8d 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -270,7 +270,7 @@ def test_optim_cpu_offload_save_load(self): torch.testing.assert_close(p2, p1) def test_optim_bf16_stochastic_round_correctness(self): - device = "cuda" + device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) model2 = copy.deepcopy(model1).bfloat16() From 796f285106f0fe8bc70d7ae74c0a56b42b614243 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 21 Oct 2024 19:40:20 +0800 Subject: [PATCH 4/5] remove fused=True since CPU does not support --- test/prototype/test_low_bit_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 1bdd4e6b8d..8db22ad86e 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -277,7 +277,7 @@ def test_optim_bf16_stochastic_round_correctness(self): # small LR so that weight update is small # when bf16_stochastic_round=False, the test will fail after 1 iteration - optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5, fused=True) + optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) # overfit on this sample From 935d19868fa7358194fdbcb4f6d5ec1ab5a6adcd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 22 Oct 2024 01:11:11 +0000 Subject: [PATCH 5/5] use permalink for llm.c ref --- torchao/prototype/low_bit_optim/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index abcf69233a..bd66262609 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -65,7 +65,7 @@ Note that our optimizer step calculations are always done in FP32 to ensure accu - In stochastic rounding, we will round up with the probability of `(x - round_down(x)) / (round_up(x) - round_down(x))`, and round down otherwise. - It follows that successive weight update with stochastic rounding will correctly approximate high-precision weight update. - Since BF16 is simply a truncation of FP32, there is an efficient implementation for FP32->BF16 stochastic rounding (the same is not true for FP32->FP16). -- More detailed discussion can be found at https://arxiv.org/abs/2010.06192. [llm.c](https://github.com/karpathy/llm.c/blob/master/llmc/adamw.cuh#L43) also implements this approach. +- More detailed discussion can be found at https://arxiv.org/abs/2010.06192. [llm.c](https://github.com/karpathy/llm.c/blob/7ecd8906afe6ed7a2b2cdb731c042f26d525b820/llmc/adamw.cuh#L43) also implements this approach. ```python # a clone of torch.optim.AdamW with extra features