diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6858d842cbb..7766442f7133 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -972,15 +972,32 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "pyramid": + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": if num_frames % 2 == 0: # num_frames = 4 => [1, 2, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) + mid = num_frames // 2 + weights = list(range(1, mid + 1)) weights = weights + weights[::-1] else: # num_frames = 5 => [1, 2, 3, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + [num_frames // 2 + 1] + weights[::-1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) else: raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index e91551c70953..fa37e1f9e393 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -691,7 +691,6 @@ def forward( emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(sample_num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 73c9c70c4a11..89cdb76741f7 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -116,7 +116,7 @@ def __init__( self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks @@ -2178,7 +2178,6 @@ def forward( emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index a1f0374e318a..cb6f50f43c4f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -432,7 +432,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -470,8 +469,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -557,11 +556,15 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, @@ -701,9 +704,10 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -716,22 +720,39 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -783,6 +804,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 6e8b0e3e5fe3..5357d6d5b8d9 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -505,8 +505,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -699,6 +699,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -858,9 +862,10 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -883,22 +888,39 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -990,6 +1012,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1002,7 +1027,6 @@ def __call__( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index a46682347519..e531c91c168f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1143,6 +1143,8 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index e9e0d518c806..8b037cdc34fb 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -878,6 +878,8 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + # 4. Prepare IP-Adapter embeddings if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 70a4201ca05c..1ebe2b9b60dd 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -246,7 +246,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, @@ -299,7 +298,7 @@ def encode_prompt( else: scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -582,8 +581,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -628,23 +627,20 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device): def prepare_latents( self, - video, - height, - width, - num_channels_latents, - batch_size, - timestep, - dtype, - device, - generator, - latents=None, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, decode_chunk_size: int = 16, - ): - if latents is None: - num_frames = video.shape[1] - else: - num_frames = latents.shape[2] - + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] shape = ( batch_size, num_channels_latents, @@ -708,8 +704,13 @@ def prepare_latents( if shape != latents.shape: # [B, C, F, H, W] raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + latents = latents.to(device, dtype=dtype) + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + return latents @property @@ -735,6 +736,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -743,6 +748,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + enforce_inference_steps: bool = False, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, @@ -874,9 +880,10 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -884,51 +891,29 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + dtype = self.dtype - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_videos_per_prompt, - self.do_classifier_free_guidance, + # 3. Prepare timesteps + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - - # 5. Prepare latent variables + # 4. Prepare latent variables if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) # Move the number of frames before the number of channels. video = video.permute(0, 2, 1, 3, 4) - video = video.to(device=device, dtype=prompt_embeds.dtype) + video = video.to(device=device, dtype=dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, @@ -937,17 +922,67 @@ def __call__( num_channels_latents=num_channels_latents, batch_size=batch_size * num_videos_per_prompt, timestep=latent_timestep, - dtype=prompt_embeds.dtype, + dtype=dtype, device=device, generator=generator, latents=latents, decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Add image embeds for IP-Adapter + # 8. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None @@ -967,9 +1002,12 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - # 8. Denoising loop + # 9. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1005,14 +1043,14 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 9. Post-processing + # 10. Post-processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) - # 10. Offload all models + # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 1ee3b6d0a985..f2763f1c33cc 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -22,6 +22,7 @@ DownBlockMotion, UpBlockMotion, ) +from ..pipelines.pipeline_utils import DiffusionPipeline from ..utils import logging from ..utils.torch_utils import randn_tensor @@ -98,6 +99,142 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do free_noise_transfomer_block.state_dict(), strict=True ) + def _check_inputs_free_noise( + self, + prompt, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + num_frames, + ) -> None: + if not isinstance(prompt, (str, dict)): + raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, dict)): + raise ValueError( + f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" + ) + + if prompt_embeds is not None or negative_prompt_embeds is not None: + raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") + + frame_indices = [isinstance(x, int) for x in prompt.keys()] + frame_prompts = [isinstance(x, str) for x in prompt.values()] + min_frame = min(list(prompt.keys())) + max_frame = max(list(prompt.keys())) + + if not all(frame_indices): + raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") + if not all(frame_prompts): + raise ValueError("Expected str values in `prompt` dict for FreeNoise.") + if min_frame != 0: + raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") + if max_frame >= num_frames: + raise ValueError( + f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." + ) + + def _encode_prompt_free_noise( + self, + prompt: Union[str, Dict[int, str]], + num_frames: int, + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, Dict[int, str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.Tensor: + if negative_prompt is None: + negative_prompt = "" + + # Ensure that we have a dictionary of prompts + if isinstance(prompt, str): + prompt = {0: prompt} + if isinstance(negative_prompt, str): + negative_prompt = {0: negative_prompt} + + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) + + # Sort the prompts based on frame indices + prompt = dict(sorted(prompt.items())) + negative_prompt = dict(sorted(negative_prompt.items())) + + # Ensure that we have a prompt for the last frame index + prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] + negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] + + frame_indices = list(prompt.keys()) + frame_prompts = list(prompt.values()) + frame_negative_indices = list(negative_prompt.keys()) + frame_negative_prompts = list(negative_prompt.values()) + + # Generate and interpolate positive prompts + prompt_embeds, _ = self.encode_prompt( + prompt=frame_prompts, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + shape = (num_frames, *prompt_embeds.shape[1:]) + prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) + + for i in range(len(frame_indices) - 1): + start_frame = frame_indices[i] + end_frame = frame_indices[i + 1] + start_tensor = prompt_embeds[i].unsqueeze(0) + end_tensor = prompt_embeds[i + 1].unsqueeze(0) + + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( + start_frame, end_frame, start_tensor, end_tensor + ) + + # Generate and interpolate negative prompts + negative_prompt_embeds = None + negative_prompt_interpolation_embeds = None + + if do_classifier_free_guidance: + _, negative_prompt_embeds = self.encode_prompt( + prompt=[""] * len(frame_negative_prompts), + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=True, + negative_prompt=frame_negative_prompts, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) + + for i in range(len(frame_negative_indices) - 1): + start_frame = frame_negative_indices[i] + end_frame = frame_negative_indices[i + 1] + start_tensor = negative_prompt_embeds[i].unsqueeze(0) + end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) + + negative_prompt_interpolation_embeds[ + start_frame : end_frame + 1 + ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + prompt_embeds = prompt_interpolation_embeds + negative_prompt_embeds = negative_prompt_interpolation_embeds + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds, negative_prompt_embeds + def _prepare_latents_free_noise( self, batch_size: int, @@ -172,12 +309,29 @@ def _prepare_latents_free_noise( latents = latents[:, :, :num_frames] return latents + def _lerp( + self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor + ) -> torch.Tensor: + num_indices = end_index - start_index + 1 + interpolated_tensors = [] + + for i in range(num_indices): + alpha = i / (num_indices - 1) + interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor + interpolated_tensors.append(interpolated_tensor) + + interpolated_tensors = torch.cat(interpolated_tensors) + return interpolated_tensors + def enable_free_noise( self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", + prompt_interpolation_callback: Optional[ + Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] + ] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -195,13 +349,27 @@ def enable_free_noise( weighting_scheme (`str`, defaults to `pyramid`): Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting schemes are supported currently: + - "flat" + Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1]. - "pyramid" - Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + - "delayed_reverse_sawtooth" + Performs weighted averaging with low weights for earlier frames and high-to-low weights for + later frames: [0.01, 0.01, 3, 2, 1]. noise_type (`str`, defaults to "shuffle_context"): - TODO + Must be one of ["shuffle_context", "repeat_context", "random"]. + - "shuffle_context" + Shuffles a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. This is usually the best setting for most generation scenarious. However, there + might be visible repetition noticeable in the kinds of motion/animation generated. + - "repeated_context" + Repeats a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. + - "random" + The final latents are random without any repetition. """ - allowed_weighting_scheme = ["pyramid"] + allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: @@ -219,6 +387,7 @@ def enable_free_noise( self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type + self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp if hasattr(self.unet.mid_block, "motion_modules"): blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] @@ -229,6 +398,7 @@ def enable_free_noise( self._enable_free_noise_in_block(block) def disable_free_noise(self) -> None: + r"""Disable the FreeNoise sampling mechanism.""" self._free_noise_context_length = None if hasattr(self.unet.mid_block, "motion_modules"): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index 73c53b365848..1e81fa3a158c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -734,6 +734,8 @@ def __call__( elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, @@ -805,7 +807,9 @@ def __call__( with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f0e8cfb03def..b7dfcd39edce 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -824,6 +824,8 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index 53833d6a075b..ee05f0d93824 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -51,7 +51,7 @@ def dummy_input(self): noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device) return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 1354ac9ff1a8..618a5cff9912 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,29 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_frames"] = 16 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device) + inputs["num_frames"] = 16 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 3035fc1e3c61..c0ad223c6ce8 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -476,6 +476,27 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + def test_vae_slicing(self, video_count=2): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index cd33bf0891a5..c49790e0f262 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -491,3 +491,28 @@ def test_free_noise(self): 1e-4, "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that pipeline works when prompt indices are within num_frames bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0]