|
14 | 14 | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
15 | 15 | from torch.testing._internal.common_fsdp import FSDPTest |
16 | 16 | from torchao.prototype import low_bit_optim |
17 | | -from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap |
18 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 |
| 17 | +from torchao.prototype.low_bit_optim.quant_utils import ( |
| 18 | + quantize_8bit_with_qmap, |
| 19 | + quantize_4bit_with_qmap, |
| 20 | + _fp32_to_bf16_sr, |
| 21 | +) |
| 22 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 |
19 | 23 |
|
20 | 24 | try: |
21 | 25 | import bitsandbytes as bnb |
@@ -74,6 +78,22 @@ def test_quantize_4bit_with_qmap_compile(self, device): |
74 | 78 |
|
75 | 79 | torch.testing.assert_close(actual, expected) |
76 | 80 |
|
| 81 | + @parametrize("device", _DEVICES) |
| 82 | + @parametrize("compile", [False, True]) |
| 83 | + def test_bf16_stochastic_round(self, device, compile): |
| 84 | + x = torch.rand(32, device=device) * 100 |
| 85 | + x_rep = x.view(-1, 1).repeat(1, 100_000) |
| 86 | + |
| 87 | + if compile: |
| 88 | + x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep) |
| 89 | + else: |
| 90 | + x_rep_bf16 = _fp32_to_bf16_sr(x_rep) |
| 91 | + |
| 92 | + assert x_rep_bf16.dtype is torch.bfloat16 |
| 93 | + |
| 94 | + # must cast BF16 tensor back to FP32 so that .mean() is accurate |
| 95 | + torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5) |
| 96 | + |
77 | 97 |
|
78 | 98 | class TestOptim(TestCase): |
79 | 99 | @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") |
@@ -249,13 +269,44 @@ def test_optim_cpu_offload_save_load(self): |
249 | 269 | for p1, p2 in zip(model1.parameters(), model2.parameters()): |
250 | 270 | torch.testing.assert_close(p2, p1) |
251 | 271 |
|
| 272 | + def test_optim_bf16_stochastic_round_correctness(self): |
| 273 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 274 | + torch.manual_seed(2024) |
| 275 | + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) |
| 276 | + model2 = copy.deepcopy(model1).bfloat16() |
| 277 | + |
| 278 | + # small LR so that weight update is small |
| 279 | + # when bf16_stochastic_round=False, the test will fail after 1 iteration |
| 280 | + optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) |
| 281 | + optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) |
| 282 | + |
| 283 | + # overfit on this sample |
| 284 | + x = torch.randn(4, 32, device=device) |
| 285 | + |
| 286 | + for idx in range(5): |
| 287 | + # mixed-precision training |
| 288 | + with torch.autocast(device, dtype=torch.bfloat16): |
| 289 | + loss1 = model1(x) |
| 290 | + loss1 = loss1.sum() # under autocast context, bf16.sum() will return fp32 |
| 291 | + loss1.backward() |
| 292 | + optim1.step() |
| 293 | + optim1.zero_grad() |
| 294 | + |
| 295 | + # full BF16 training with stochastic round weight update |
| 296 | + loss2 = model2(x.bfloat16()).sum() |
| 297 | + loss2.backward() |
| 298 | + optim2.step() |
| 299 | + optim2.zero_grad() |
| 300 | + |
| 301 | + torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") |
| 302 | + |
252 | 303 |
|
253 | 304 | class TestFSDP2(FSDPTest): |
254 | 305 | @property |
255 | 306 | def world_size(self) -> int: |
256 | 307 | return 2 |
257 | 308 |
|
258 | | - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default") |
| 309 | + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.") |
259 | 310 | @skip_if_lt_x_gpu(2) |
260 | 311 | def test_fsdp2(self): |
261 | 312 | optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] |
|
0 commit comments