diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b879f7d3536d..7e58f8bf7ec6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -144,6 +144,7 @@ "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", + "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "HeunDiscreteScheduler", @@ -526,6 +527,7 @@ DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + EDMEulerScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a3683535e74e..d8476459b2c9 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -52,6 +52,7 @@ _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] + _import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"] _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] @@ -144,6 +145,7 @@ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler + from .scheduling_edm_euler import EDMEulerScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py new file mode 100644 index 000000000000..e62a486cc214 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -0,0 +1,381 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EDMEulerSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class EDMEulerScheduler(SchedulerMixin, ConfigMixin): + """ + Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1]. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.002): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable + range is [0, 10]. + sigma_max (`float`, *optional*, defaults to 80.0): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable + range is [0.2, 80.0]. + sigma_data (`float`, *optional*, defaults to 0.5): + The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + rho (`float`, *optional*, defaults to 7.0): + The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + num_train_timesteps: int = 1000, + prediction_type: str = "epsilon", + rho: float = 7.0, + ): + # setable values + self.num_inference_steps = None + + ramp = torch.linspace(0, 1, num_train_timesteps) + sigmas = self._compute_sigmas(ramp) + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.is_scale_input_called = False + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + c_noise = 0.25 * torch.log(sigma) + + return c_noise + + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + ramp = np.linspace(0, 1, self.num_inference_steps) + sigmas = self._compute_sigmas(ramp) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EDMEulerSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EDMEulerScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 9c6c068b918c..5dbdb82884bc 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -45,6 +45,7 @@ class KarrasDiffusionSchedulers(Enum): DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 @dataclass diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index a4f5436038ea..40515b1674ff 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -855,6 +855,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class EDMEulerScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index ff3c64408e89..04721b4a8cc1 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -209,6 +209,7 @@ def test_karras_schedulers_shape(self): "KDPM2DiscreteScheduler", "KDPM2AncestralDiscreteScheduler", "DPMSolverSDEScheduler", + "EDMEulerScheduler", ] components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/schedulers/test_scheduler_edm_euler.py b/tests/schedulers/test_scheduler_edm_euler.py new file mode 100644 index 000000000000..9d2adea6ca60 --- /dev/null +++ b/tests/schedulers/test_scheduler_edm_euler.py @@ -0,0 +1,206 @@ +import inspect +import tempfile +import unittest +from typing import Dict, List, Tuple + +import torch + +from diffusers import EDMEulerScheduler + +from .test_schedulers import SchedulerCommonTest + + +class EDMEulerSchedulerTest(SchedulerCommonTest): + scheduler_classes = (EDMEulerScheduler,) + forward_default_kwargs = (("num_inference_steps", 10),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 256, + "sigma_min": 0.002, + "sigma_max": 80.0, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_full_loop_no_noise(self, num_inference_steps=10, seed=0): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + scaled_sample = scheduler.scale_model_input(sample, t) + + model_output = model(scaled_sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 34.1855) < 1e-3 + assert abs(result_mean.item() - 0.044) < 1e-3 + + def test_full_loop_device(self, num_inference_steps=10, seed=0): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + scaled_sample = scheduler.scale_model_input(sample, t) + + model_output = model(scaled_sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 34.1855) < 1e-3 + assert abs(result_mean.item() - 0.044) < 1e-3 + + # Override test_from_save_pretrined to use EDMEulerScheduler-specific logic + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + + scheduler.set_timesteps(num_inference_steps) + new_scheduler.set_timesteps(num_inference_steps) + timestep = scheduler.timesteps[0] + + sample = self.dummy_sample + + scaled_sample = scheduler.scale_model_input(sample, timestep) + residual = 0.1 * scaled_sample + + new_scaled_sample = new_scheduler.scale_model_input(sample, timestep) + new_residual = 0.1 * new_scaled_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.manual_seed(0) + output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.manual_seed(0) + new_output = new_scheduler.step(new_residual, timestep, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + # Override test_from_save_pretrined to use EDMEulerScheduler-specific logic + def test_step_shape(self): + num_inference_steps = 10 + + scheduler_config = self.get_scheduler_config() + scheduler = self.scheduler_classes[0](**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + timestep_0 = scheduler.timesteps[0] + timestep_1 = scheduler.timesteps[1] + + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, timestep_0) + residual = 0.1 * scaled_sample + + output_0 = scheduler.step(residual, timestep_0, sample).prev_sample + output_1 = scheduler.step(residual, timestep_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + # Override test_from_save_pretrined to use EDMEulerScheduler-specific logic + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", 50) + + timestep = 0 + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + timestep = scheduler.timesteps[0] + + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, timestep) + residual = 0.1 * scaled_sample + + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.manual_seed(0) + outputs_dict = scheduler.step(residual, timestep, sample, **kwargs) + + scheduler.set_timesteps(num_inference_steps) + + scaled_sample = scheduler.scale_model_input(sample, timestep) + residual = 0.1 * scaled_sample + + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.manual_seed(0) + outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple, outputs_dict) + + @unittest.skip(reason="EDMEulerScheduler does not support beta schedules.") + def test_trained_betas(self): + pass diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a7c159ffbd06..9982db771dea 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -30,6 +30,7 @@ DDIMScheduler, DEISMultistepScheduler, DiffusionPipeline, + EDMEulerScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -385,6 +386,9 @@ def check_over_configs(self, time_step=0, **config): scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) time_step = scaled_sigma_max + if scheduler_class == EDMEulerScheduler: + time_step = scheduler.timesteps[-1] + if scheduler_class == VQDiffusionScheduler: num_vec_classes = scheduler_config["num_vec_classes"] sample = self.dummy_sample(num_vec_classes) @@ -693,6 +697,8 @@ def test_scheduler_public_api(self): # Get valid timestep based on sigma_max, which should always be in timestep schedule. scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + elif scheduler_class == EDMEulerScheduler: + scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1]) else: scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) @@ -710,6 +716,8 @@ def test_add_noise_device(self): # Get valid timestep based on sigma_max, which should always be in timestep schedule. scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max) scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max) + if scheduler_class == EDMEulerScheduler: + scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1]) else: scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape)