diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index c1a7010d919a..583ffc0f9030 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -677,6 +677,13 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -686,13 +693,6 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - noise_pred = self.transformer( hidden_states=latents, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)