Skip to content

Commit a31e15d

Browse files
authored
Add BF16 stochastic rounding option for optimizers (#1124)
* add BF16 sr for optimizer * update doc and benchmark scripts * fix device * remove fused=True since CPU does not support * use permalink for llm.c ref
1 parent 45e37b2 commit a31e15d

File tree

6 files changed

+226
-24
lines changed

6 files changed

+226
-24
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def get_parser():
9090
parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys())
9191
parser.add_argument("--lr", type=float, default=1e-4)
9292
parser.add_argument("--weight_decay", type=float, default=0)
93+
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
9394
parser.add_argument("--cosine_lr_scheduler", action="store_true")
9495
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
9596

@@ -206,7 +207,12 @@ def evaluate_model(model, args):
206207
train_batch_size=args.batch_size,
207208
optimizer=dict(
208209
type="Adam",
209-
params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False),
210+
params=dict(
211+
lr=args.lr,
212+
weight_decay=args.weight_decay,
213+
fp32_optimizer_states=False,
214+
**args.optim_kwargs,
215+
),
210216
),
211217
bf16=dict(enabled=args.full_bf16),
212218
zero_optimization=dict(
@@ -225,7 +231,12 @@ def evaluate_model(model, args):
225231
elif args.optim_cpu_offload == "ao_offload_grads":
226232
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
227233

228-
optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
234+
optim = optim_cls(
235+
model.parameters(),
236+
lr=args.lr,
237+
weight_decay=args.weight_decay,
238+
**args.optim_kwargs,
239+
)
229240

230241
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
231242
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
1212

1313
import argparse
14+
import json
1415
import time
1516
from functools import partial
1617
from pathlib import Path
@@ -108,6 +109,7 @@ def get_tinystories():
108109
parser.add_argument("--optim", default="AdamW")
109110
parser.add_argument("--lr", type=float, default=3e-4)
110111
parser.add_argument("--weight_decay", type=float, default=1e-2)
112+
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
111113

112114
parser.add_argument("--project", default="quantized_training")
113115
parser.add_argument("--run_name")
@@ -171,7 +173,12 @@ def insert_rmsnorm(module: torch.nn.Module):
171173
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
172174
if args.optim == "AdamW":
173175
args.optim = "_AdamW"
174-
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
176+
optim = getattr(low_bit_optim, args.optim)(
177+
model.parameters(),
178+
lr=args.lr,
179+
weight_decay=args.weight_decay,
180+
**args.optim_kwargs,
181+
)
175182

176183
data = get_tinystories().cuda()
177184
args.torch_version = torch.__version__

test/prototype/test_low_bit_optim.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1515
from torch.testing._internal.common_fsdp import FSDPTest
1616
from torchao.prototype import low_bit_optim
17-
from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap
18-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5
17+
from torchao.prototype.low_bit_optim.quant_utils import (
18+
quantize_8bit_with_qmap,
19+
quantize_4bit_with_qmap,
20+
_fp32_to_bf16_sr,
21+
)
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
1923

2024
try:
2125
import bitsandbytes as bnb
@@ -74,6 +78,22 @@ def test_quantize_4bit_with_qmap_compile(self, device):
7478

7579
torch.testing.assert_close(actual, expected)
7680

81+
@parametrize("device", _DEVICES)
82+
@parametrize("compile", [False, True])
83+
def test_bf16_stochastic_round(self, device, compile):
84+
x = torch.rand(32, device=device) * 100
85+
x_rep = x.view(-1, 1).repeat(1, 100_000)
86+
87+
if compile:
88+
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
89+
else:
90+
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)
91+
92+
assert x_rep_bf16.dtype is torch.bfloat16
93+
94+
# must cast BF16 tensor back to FP32 so that .mean() is accurate
95+
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)
96+
7797

7898
class TestOptim(TestCase):
7999
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@@ -249,13 +269,44 @@ def test_optim_cpu_offload_save_load(self):
249269
for p1, p2 in zip(model1.parameters(), model2.parameters()):
250270
torch.testing.assert_close(p2, p1)
251271

272+
def test_optim_bf16_stochastic_round_correctness(self):
273+
device = "cuda" if torch.cuda.is_available() else "cpu"
274+
torch.manual_seed(2024)
275+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
276+
model2 = copy.deepcopy(model1).bfloat16()
277+
278+
# small LR so that weight update is small
279+
# when bf16_stochastic_round=False, the test will fail after 1 iteration
280+
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
281+
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
282+
283+
# overfit on this sample
284+
x = torch.randn(4, 32, device=device)
285+
286+
for idx in range(5):
287+
# mixed-precision training
288+
with torch.autocast(device, dtype=torch.bfloat16):
289+
loss1 = model1(x)
290+
loss1 = loss1.sum() # under autocast context, bf16.sum() will return fp32
291+
loss1.backward()
292+
optim1.step()
293+
optim1.zero_grad()
294+
295+
# full BF16 training with stochastic round weight update
296+
loss2 = model2(x.bfloat16()).sum()
297+
loss2.backward()
298+
optim2.step()
299+
optim2.zero_grad()
300+
301+
torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
302+
252303

253304
class TestFSDP2(FSDPTest):
254305
@property
255306
def world_size(self) -> int:
256307
return 2
257308

258-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default")
309+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
259310
@skip_if_lt_x_gpu(2)
260311
def test_fsdp2(self):
261312
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]

torchao/prototype/low_bit_optim/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ This folder implements:
55
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
66
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
77
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)
8+
- Stochastic rounding for BF16 weight (https://arxiv.org/abs/2010.06192, experimental)
89

910
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms.
1011

@@ -56,6 +57,27 @@ ao 4-bit | 33.2 | 2900 | 42.27
5657

5758
NOTE: lpmm's 4-bit AdamW does not support BF16 weights.
5859

60+
## Stochastic rounding for BF16 weight
61+
62+
BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights.
63+
64+
Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform **stochastic rounding** when casting FP32->BF16.
65+
- In stochastic rounding, we will round up with the probability of `(x - round_down(x)) / (round_up(x) - round_down(x))`, and round down otherwise.
66+
- It follows that successive weight update with stochastic rounding will correctly approximate high-precision weight update.
67+
- Since BF16 is simply a truncation of FP32, there is an efficient implementation for FP32->BF16 stochastic rounding (the same is not true for FP32->FP16).
68+
- More detailed discussion can be found at https://arxiv.org/abs/2010.06192. [llm.c](https://github.com/karpathy/llm.c/blob/7ecd8906afe6ed7a2b2cdb731c042f26d525b820/llmc/adamw.cuh#L43) also implements this approach.
69+
70+
```python
71+
# a clone of torch.optim.AdamW with extra features
72+
from torchao.prototype.low_bit_optim import _AdamW
73+
74+
model = ...
75+
model_bf16 = model.bfloat16()
76+
optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True)
77+
```
78+
79+
All of our low-bit optimizers mentioned above also support `bf16_stochastic_round` flag. Note that this flag only applies to BF16 weight.
80+
5981
## Optimizer CPU offload
6082

6183
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.

0 commit comments

Comments
 (0)