diff --git a/torchao/optim/adam.py b/torchao/optim/adam.py index 8beaffb627..a2941ef09e 100644 --- a/torchao/optim/adam.py +++ b/torchao/optim/adam.py @@ -179,10 +179,11 @@ def single_param_adam( p_f32 = p.float() grad_f32 = grad.float() - if IS_ADAMW: - p_f32 = p_f32 - lr * weight_decay * p_f32 - else: - grad_f32 = grad_f32 + weight_decay * p_f32 + if weight_decay != 0: + if IS_ADAMW: + p_f32.mul_(1 - lr * weight_decay) + else: + grad_f32 = grad_f32.add(p_f32, alpha=weight_decay) bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step @@ -201,7 +202,8 @@ def single_param_adam( else: denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()) + eps - p_f32 = p_f32 - lr * (exp_avg_f32 / bias_correction1) / denom + step_size = lr / bias_correction1 + p_f32.addcdiv_(exp_avg_f32, denom, value=-step_size.to(p_f32.device)) if BF16_STOCHASTIC_ROUND: p.copy_(_fp32_to_bf16_sr(p_f32))