From 0a49adb563626607c3b44fb7ed2e32cea74d3e12 Mon Sep 17 00:00:00 2001 From: swimmincatt35 Date: Mon, 3 Mar 2025 12:39:48 -0500 Subject: [PATCH] (CH) eval mode appended --- training/phema.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/training/phema.py b/training/phema.py index e4ac23e..218af00 100644 --- a/training/phema.py +++ b/training/phema.py @@ -93,6 +93,9 @@ def __init__(self, net, stds=[0.010, 0.050, 0.100]): self.net = net self.stds = stds self.emas = [copy.deepcopy(net) for _std in stds] + # Added eval() mode, prevent EDM2 to self-normalize and change weights at sampling/evaluation. + for ema in self.emas: + ema.eval() @torch.no_grad() def reset(self): @@ -133,6 +136,8 @@ def __init__(self, net, ema_beta=0.9999, halflife_Mimg=None, rampup_ratio=None): self.halflife_Mimg = halflife_Mimg self.rampup_ratio = rampup_ratio self.ema = copy.deepcopy(net) + # Added eval() mode, prevent EDM2 to self-normalize and change weights at sampling/evaluation. + self.ema.eval() @torch.no_grad() def reset(self):