From c95947af44a3eaeef6a9b4a0a8a6b78bfdd58727 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 31 May 2025 13:51:32 +0800 Subject: [PATCH 1/2] handle error when default dtype is BF16 --- test/test_low_bit_optim.py | 20 ++++++++++++++++++++ torchao/optim/subclass_4bit.py | 6 +++--- torchao/optim/subclass_8bit.py | 6 +++--- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index 08fdfa569f..a80d4bbfc0 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -166,6 +166,26 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + @parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"]) + @parametrize("device", _DEVICES) + def test_optim_default_dtype_bf16(self, optim_name, device): + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + + try: + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) + model.to(device=device) + optimizer = getattr(optim, optim_name)(model.parameters()) + + x = torch.randn(4, 32, device=device) + loss = model(x).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + finally: + torch.set_default_dtype(old_dtype) + # aten.slice is required for dcp.load() when world size changes i.e. re-sharding # however, it's cumbersome to test it directly, since we would need to run distributed # test 2 times with different world size, and persist checkpoint across the 2 runs. diff --git a/torchao/optim/subclass_4bit.py b/torchao/optim/subclass_4bit.py index 209d0b8cad..bc5fd33414 100644 --- a/torchao/optim/subclass_4bit.py +++ b/torchao/optim/subclass_4bit.py @@ -69,6 +69,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, sha assert codes.dtype is torch.uint8 assert codes.ndim == 1 # flattened buffer assert scale.ndim == 1 + assert qmap.dtype is torch.float32 self.codes = codes self.scale = scale self.qmap = qmap @@ -101,9 +102,8 @@ def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None): codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device) scale = torch.zeros(n_elems // block_size, device=device) - qmap = torch.tensor( - get_qmap_signed() if signed else get_qmap_unsigned(), device=device - ) + qmap_list = get_qmap_signed() if signed else get_qmap_unsigned() + qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device) return cls(codes, scale, qmap, signed, shape) def __repr__(self): diff --git a/torchao/optim/subclass_8bit.py b/torchao/optim/subclass_8bit.py index 58a51734d7..d3f7634526 100644 --- a/torchao/optim/subclass_8bit.py +++ b/torchao/optim/subclass_8bit.py @@ -62,6 +62,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): """ assert codes.dtype is torch.uint8 assert scale.ndim == 1 + assert qmap.dtype is torch.float32 self.codes = codes self.scale = scale self.qmap = qmap @@ -89,9 +90,8 @@ def dequantize(self, output_dtype=None): def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None): codes = torch.zeros(shape, dtype=torch.uint8, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) - qmap = torch.tensor( - get_qmap_signed() if signed else get_qmap_unsigned(), device=device - ) + qmap_list = get_qmap_signed() if signed else get_qmap_unsigned() + qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device) return cls(codes, scale, qmap, signed) def __repr__(self): From e658764000c99430f8761674136bd0c65c483c49 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Jun 2025 10:08:24 +0800 Subject: [PATCH 2/2] skip FP8 optim on unsupported GPUs --- test/test_low_bit_optim.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index a80d4bbfc0..692a0d9e6c 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -37,7 +37,6 @@ from torchao.optim.subclass_fp8 import OptimStateFp8 from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7, get_available_devices, @@ -128,8 +127,6 @@ class TestOptim(TestCase): @skip_if_rocm("ROCm enablement in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": - if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("FP8 CUDA requires PyTorch >= 2.4") if torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 CUDA requires compute capability >= 8.9") @@ -169,6 +166,10 @@ def test_optim_smoke(self, optim_name, dtype, device): @parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"]) @parametrize("device", _DEVICES) def test_optim_default_dtype_bf16(self, optim_name, device): + if optim_name.endswith("Fp8") and device == "cuda": + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + old_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.bfloat16) @@ -198,8 +199,6 @@ def test_subclass_slice(self, subclass, shape, device): if subclass == OptimStateFp8: if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5") - if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("FP8 CUDA requires PyTorch >= 2.4") if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 CUDA requires compute capability >= 8.9")