From 85b24cfcb74f3a5a0b370577adbd86c07fc9b52d Mon Sep 17 00:00:00 2001 From: nguyen599 Date: Tue, 30 Sep 2025 13:43:12 +0700 Subject: [PATCH 1/2] Fix FSDP offload+Adamw8bits compatible Signed-off-by: nguyen599 --- torchao/optim/adam.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchao/optim/adam.py b/torchao/optim/adam.py index 8beaffb627..12ebfe9357 100644 --- a/torchao/optim/adam.py +++ b/torchao/optim/adam.py @@ -179,10 +179,11 @@ def single_param_adam( p_f32 = p.float() grad_f32 = grad.float() - if IS_ADAMW: - p_f32 = p_f32 - lr * weight_decay * p_f32 - else: - grad_f32 = grad_f32 + weight_decay * p_f32 + if weight_decay != 0: + if IS_ADAMW: + p_f32.mul_(1 - lr * weight_decay) + else: + grad_f32 = grad_f32.add(p_f32, alpha=weight_decay) bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step From 593631242ee6ca445c04373cefb5062b5a84a2f4 Mon Sep 17 00:00:00 2001 From: nguyen599 Date: Tue, 30 Sep 2025 15:05:45 +0700 Subject: [PATCH 2/2] fix Signed-off-by: nguyen599 --- torchao/optim/adam.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/optim/adam.py b/torchao/optim/adam.py index 12ebfe9357..a2941ef09e 100644 --- a/torchao/optim/adam.py +++ b/torchao/optim/adam.py @@ -202,7 +202,8 @@ def single_param_adam( else: denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()) + eps - p_f32 = p_f32 - lr * (exp_avg_f32 / bias_correction1) / denom + step_size = lr / bias_correction1 + p_f32.addcdiv_(exp_avg_f32, denom, value=-step_size.to(p_f32.device)) if BF16_STOCHASTIC_ROUND: p.copy_(_fp32_to_bf16_sr(p_f32))