From d0a81ae604c567ad2119cd578a60a21561779958 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 Aug 2024 16:21:29 +0200 Subject: [PATCH 01/16] update --- .../animatediff/pipeline_animatediff.py | 10 ++- .../pipeline_animatediff_video2video.py | 64 +++++++++++++------ 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index a1f0374e318a..e407b06837c5 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -557,11 +557,15 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, @@ -701,6 +705,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -783,6 +788,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 70a4201ca05c..38c0d5098447 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -628,23 +628,20 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device): def prepare_latents( self, - video, - height, - width, - num_channels_latents, - batch_size, - timestep, - dtype, - device, - generator, - latents=None, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, decode_chunk_size: int = 16, - ): - if latents is None: - num_frames = video.shape[1] - else: - num_frames = latents.shape[2] - + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] shape = ( batch_size, num_channels_latents, @@ -708,8 +705,13 @@ def prepare_latents( if shape != latents.shape: # [B, C, F, H, W] raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + latents = latents.to(device, dtype=dtype) + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + return latents @property @@ -735,6 +737,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -743,6 +749,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + enforce_inference_steps: bool = False, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, @@ -874,6 +881,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -917,11 +925,20 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + num_inference_steps += 1 + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) # 5. Prepare latent variables if latents is None: @@ -942,6 +959,7 @@ def __call__( generator=generator, latents=latents, decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -970,6 +988,10 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + + print("here:", t, self.scheduler.sigmas[-num_inference_steps:]) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From d55903d0b244f0a00d6c055ad4a257da82d0984b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 Aug 2024 17:20:05 +0200 Subject: [PATCH 02/16] implement prompt interpolation --- .../models/unets/unet_motion_model.py | 1 - .../animatediff/pipeline_animatediff.py | 56 ++++--- .../pipeline_animatediff_video2video.py | 110 +++++++------ src/diffusers/pipelines/free_noise_utils.py | 147 +++++++++++++++++- 4 files changed, 244 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 73c9c70c4a11..7ac61821edd6 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2178,7 +2178,6 @@ def forward( emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e407b06837c5..1f5ff3dd731e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -432,7 +432,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -470,8 +469,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -708,7 +707,7 @@ def __call__( self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -721,22 +720,39 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 38c0d5098447..6fcb2e30206c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -246,7 +246,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, @@ -299,7 +298,7 @@ def encode_prompt( else: scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -582,8 +581,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -884,7 +883,7 @@ def __call__( self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -892,39 +891,9 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + dtype = self.dtype - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_videos_per_prompt, - self.do_classifier_free_guidance, - ) - - # 4. Prepare timesteps + # 3. Prepare timesteps if not enforce_inference_steps: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas @@ -936,16 +905,15 @@ def __call__( timesteps, denoising_inference_steps = retrieve_timesteps( self.scheduler, denoising_inference_steps, device, timesteps, sigmas ) - num_inference_steps += 1 timesteps = timesteps[-num_inference_steps:] latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - # 5. Prepare latent variables + # 4. Prepare latent variables if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) # Move the number of frames before the number of channels. video = video.permute(0, 2, 1, 3, 4) - video = video.to(device=device, dtype=prompt_embeds.dtype) + video = video.to(device=device, dtype=dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, @@ -954,7 +922,7 @@ def __call__( num_channels_latents=num_channels_latents, batch_size=batch_size * num_videos_per_prompt, timestep=latent_timestep, - dtype=prompt_embeds.dtype, + dtype=dtype, device=device, generator=generator, latents=latents, @@ -962,10 +930,59 @@ def __call__( add_noise=enforce_inference_steps, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Add image embeds for IP-Adapter + # 8. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None @@ -985,13 +1002,12 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - # 8. Denoising loop + # 9. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - print("here:", t, self.scheduler.sigmas[-num_inference_steps:]) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1027,14 +1043,14 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 9. Post-processing + # 10. Post-processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) - # 10. Offload all models + # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f8128abb9b58..17c52381d920 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -22,6 +22,7 @@ DownBlockMotion, UpBlockMotion, ) +from ..pipelines.pipeline_utils import DiffusionPipeline from ..utils import logging from ..utils.torch_utils import randn_tensor @@ -97,7 +98,135 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - + + def _check_inputs_free_noise( + self, + prompt, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + num_frames, + ) -> None: + if not isinstance(prompt, (str, dict)): + raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, dict)): + raise ValueError(f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}") + + if prompt_embeds is not None or negative_prompt_embeds is not None: + raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") + + frame_indices = [isinstance(x, int) for x in prompt.keys()] + frame_prompts = [isinstance(x, str) for x in prompt.values()] + min_frame = min(list(prompt.keys())) + max_frame = max(list(prompt.keys())) + + if not all(frame_indices): + raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") + if not all(frame_prompts): + raise ValueError("Expected str values in `prompt` dict for FreeNoise.") + if min_frame != 0: + raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") + if max_frame >= num_frames: + raise ValueError(f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing.") + + def _encode_prompt_free_noise( + self, + prompt: Union[str, Dict[int, str]], + num_frames: int, + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, Dict[int, str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.Tensor: + if negative_prompt is None: + negative_prompt = "" + + # Ensure that we have a dictionary of prompts + if isinstance(prompt, str): + prompt = {0: prompt} + if isinstance(negative_prompt, str): + negative_prompt = {0: negative_prompt} + + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) + + # Sort the prompts based on frame indices + prompt = dict(sorted(prompt.items())) + negative_prompt = dict(sorted(negative_prompt.items())) + + # Ensure that we have a prompt for the last frame index + prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] + negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] + + frame_indices = list(prompt.keys()) + frame_prompts = list(prompt.values()) + frame_negative_indices = list(negative_prompt.keys()) + frame_negative_prompts = list(negative_prompt.values()) + + # Generate and interpolate positive prompts + prompt_embeds, _ = self.encode_prompt( + prompt=frame_prompts, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + shape = (num_frames, *prompt_embeds.shape[1:]) + prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) + + for i in range(len(frame_indices) - 1): + start_frame = frame_indices[i] + end_frame = frame_indices[i + 1] + start_tensor = prompt_embeds[i].unsqueeze(0) + end_tensor = prompt_embeds[i + 1].unsqueeze(0) + + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + # Generate and interpolate negative prompts + negative_prompt_embeds = None + negative_prompt_interpolation_embeds = None + + if do_classifier_free_guidance: + _, negative_prompt_embeds = self.encode_prompt( + prompt=[""] * len(frame_negative_prompts), + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=True, + negative_prompt=frame_negative_prompts, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) + + for i in range(len(frame_negative_indices) - 1): + start_frame = frame_negative_indices[i] + end_frame = frame_negative_indices[i + 1] + start_tensor = negative_prompt_embeds[i].unsqueeze(0) + end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) + + negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + prompt_embeds = prompt_interpolation_embeds + negative_prompt_embeds = negative_prompt_interpolation_embeds + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds, negative_prompt_embeds + def _prepare_latents_free_noise( self, batch_size: int, @@ -171,6 +300,18 @@ def _prepare_latents_free_noise( latents = latents[:, :, :num_frames] return latents + + def _lerp(self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor: + num_indices = end_index - start_index + 1 + interpolated_tensors = [] + + for i in range(num_indices): + alpha = i / (num_indices - 1) + interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor + interpolated_tensors.append(interpolated_tensor) + + interpolated_tensors = torch.cat(interpolated_tensors) + return interpolated_tensors def enable_free_noise( self, @@ -178,6 +319,7 @@ def enable_free_noise( context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", + prompt_interpolation_callback: Optional[Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -219,6 +361,7 @@ def enable_free_noise( self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type + self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: From a86eabe0bd27dfa5ba57938e57790961f2c2cf6a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 Aug 2024 17:20:32 +0200 Subject: [PATCH 03/16] make style --- .../animatediff/pipeline_animatediff.py | 2 +- .../pipeline_animatediff_video2video.py | 2 +- src/diffusers/pipelines/free_noise_utils.py | 56 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 1f5ff3dd731e..cb6f50f43c4f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -745,7 +745,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) - + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 6fcb2e30206c..1ebe2b9b60dd 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -960,7 +960,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) - + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 17c52381d920..f710206d7730 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -98,7 +98,7 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - + def _check_inputs_free_noise( self, prompt, @@ -109,11 +109,13 @@ def _check_inputs_free_noise( ) -> None: if not isinstance(prompt, (str, dict)): raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") - + if negative_prompt is not None: if not isinstance(negative_prompt, (str, dict)): - raise ValueError(f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}") - + raise ValueError( + f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" + ) + if prompt_embeds is not None or negative_prompt_embeds is not None: raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") @@ -129,7 +131,9 @@ def _check_inputs_free_noise( if min_frame != 0: raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") if max_frame >= num_frames: - raise ValueError(f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing.") + raise ValueError( + f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." + ) def _encode_prompt_free_noise( self, @@ -146,23 +150,23 @@ def _encode_prompt_free_noise( ) -> torch.Tensor: if negative_prompt is None: negative_prompt = "" - + # Ensure that we have a dictionary of prompts if isinstance(prompt, str): prompt = {0: prompt} if isinstance(negative_prompt, str): negative_prompt = {0: negative_prompt} - + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) - + # Sort the prompts based on frame indices prompt = dict(sorted(prompt.items())) negative_prompt = dict(sorted(negative_prompt.items())) - + # Ensure that we have a prompt for the last frame index prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] - + frame_indices = list(prompt.keys()) frame_prompts = list(prompt.values()) frame_negative_indices = list(negative_prompt.keys()) @@ -180,7 +184,7 @@ def _encode_prompt_free_noise( lora_scale=lora_scale, clip_skip=clip_skip, ) - + shape = (num_frames, *prompt_embeds.shape[1:]) prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) @@ -190,7 +194,9 @@ def _encode_prompt_free_noise( start_tensor = prompt_embeds[i].unsqueeze(0) end_tensor = prompt_embeds[i + 1].unsqueeze(0) - prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( + start_frame, end_frame, start_tensor, end_tensor + ) # Generate and interpolate negative prompts negative_prompt_embeds = None @@ -208,7 +214,7 @@ def _encode_prompt_free_noise( lora_scale=lora_scale, clip_skip=clip_skip, ) - + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) for i in range(len(frame_negative_indices) - 1): @@ -217,16 +223,18 @@ def _encode_prompt_free_noise( start_tensor = negative_prompt_embeds[i].unsqueeze(0) end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) - negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) - + negative_prompt_interpolation_embeds[ + start_frame : end_frame + 1 + ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + prompt_embeds = prompt_interpolation_embeds negative_prompt_embeds = negative_prompt_interpolation_embeds - + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds, negative_prompt_embeds - + def _prepare_latents_free_noise( self, batch_size: int, @@ -300,16 +308,18 @@ def _prepare_latents_free_noise( latents = latents[:, :, :num_frames] return latents - - def _lerp(self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor: + + def _lerp( + self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor + ) -> torch.Tensor: num_indices = end_index - start_index + 1 interpolated_tensors = [] - + for i in range(num_indices): alpha = i / (num_indices - 1) interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor interpolated_tensors.append(interpolated_tensor) - + interpolated_tensors = torch.cat(interpolated_tensors) return interpolated_tensors @@ -319,7 +329,9 @@ def enable_free_noise( context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", - prompt_interpolation_callback: Optional[Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None, + prompt_interpolation_callback: Optional[ + Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] + ] = None, ) -> None: r""" Enable long video generation using FreeNoise. From 94438e1439994900dcf69f39b7cd327acf61c20f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 02:05:32 +0200 Subject: [PATCH 04/16] resnet memory optimizations --- src/diffusers/models/attention.py | 16 ++- src/diffusers/models/attention_processor.py | 2 + .../models/unets/unet_motion_model.py | 102 ++++++++++++++++-- src/diffusers/pipelines/free_noise_utils.py | 7 ++ 4 files changed, 118 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6858d842cbb..edccbc990fae 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -43,6 +43,12 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output +def _experimental_split_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int +) -> torch.Tensor: + return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim) + + @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -525,7 +531,10 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + ff_output = _experimental_split_feed_forward( + self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim + ) else: ff_output = self.ff(norm_hidden_states) @@ -1095,7 +1104,10 @@ def forward( norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + ff_output = _experimental_split_feed_forward( + self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim + ) else: ff_output = self.ff(norm_hidden_states) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ab1606b345..e3ebf1077dc2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2221,6 +2221,8 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj + # TODO: figure out a better way to do this + # hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0) hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7ac61821edd6..5d0bc7b810bb 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,6 +49,18 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def _chunked_resnet_forward( + resnet: ResnetBlock2D, hidden_states: torch.Tensor, temb: torch.Tensor, chunk_size: int, chunk_dim: int +) -> torch.Tensor: + return torch.cat( + [ + resnet(hs_split, t_split) + for hs_split, t_split in zip(hidden_states.split(chunk_size, chunk_dim), temb.split(chunk_size, chunk_dim)) + ], + dim=chunk_dim, + ) + + @dataclass class UNetMotionOutput(BaseOutput): """ @@ -116,7 +128,7 @@ def __init__( self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks @@ -306,6 +318,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -344,7 +362,12 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -493,6 +516,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -540,7 +569,12 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -695,6 +729,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -766,7 +806,12 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -866,6 +911,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -929,7 +980,12 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -1065,6 +1121,12 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size = chunk_size + self._chunk_dim = dim def forward( self, @@ -1080,7 +1142,12 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](hidden_states, temb) + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + self.resnets[0], hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: @@ -1125,11 +1192,18 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(hidden_states, temb) + + if self._chunk_size is not None: + hidden_states = _chunked_resnet_forward( + resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + ) + else: + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1970,6 +2044,20 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) + def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 1 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_resnet"): + logger.debug(f"Enabling chunked resnet inference in: {name}") + module.set_chunk_resnet(chunk_size, dim) + + def disable_resnet_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_resnet"): + logger.debug(f"Disabling chunked resnet inference in: {name}") + module.set_chunk_resnet(None) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f710206d7730..ff5af5af01da 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -332,6 +332,8 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, + _chunk_size_resnet: Optional[int] = None, + _chunk_size_feed_forward: Optional[int] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -379,6 +381,11 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) + if _chunk_size_resnet is not None: + self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) + if _chunk_size_feed_forward is not None: + self.unet.enable_forward_chunking(_chunk_size_feed_forward, dim=0) + def disable_free_noise(self) -> None: self._free_noise_context_length = None From 74e3ab088cb300bc6e3a053753ff4b6eb3a95781 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 06:14:44 +0200 Subject: [PATCH 05/16] more memory optimizations; todo: refactor --- src/diffusers/models/attention.py | 23 +- .../models/unets/unet_motion_model.py | 309 ++++++++++++++---- src/diffusers/pipelines/free_noise_utils.py | 9 + 3 files changed, 271 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index edccbc990fae..5005ad118894 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1096,19 +1096,38 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.where( - num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, ).to(dtype) + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ).to(dtype) + # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: + # norm_hidden_states = torch.cat([ + # self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) + # ], dim=self._chunk_dim) + # ff_output = torch.cat([ + # self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) + # ], dim=self._chunk_dim) + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) ff_output = _experimental_split_feed_forward( self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim ) else: + norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5d0bc7b810bb..e01f5dbfe628 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -61,6 +61,29 @@ def _chunked_resnet_forward( ) +def _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + chunk_size: int, + chunk_dim: int, +) -> torch.Tensor: + return torch.cat( + [ + attn( + hs_split, ehs_split, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict=False + )[0] + for hs_split, ehs_split in zip( + hidden_states.split(chunk_size, chunk_dim), encoder_hidden_states.split(chunk_size, chunk_dim) + ) + ], + dim=chunk_dim, + ) + + @dataclass class UNetMotionOutput(BaseOutput): """ @@ -152,6 +175,12 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, in_channels) + self._chunk_size_motion_module = None + self._chunk_dim_motion_module = 0 + + def set_chunk_motion_module(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_motion_module = chunk_size + self._chunk_dim_motion_module = dim def forward( self, @@ -203,13 +232,37 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + if encoder_hidden_states is None: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split in hidden_states.split(self._chunk_size_motion_module) + ], + dim=self._chunk_dim_motion_module, + ) + else: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=ehs_split, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split, ehs_split in zip( + hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + encoder_hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + ) + ], + dim=self._chunk_dim_motion_module, + ) # 3. Output hidden_states = self.proj_out(hidden_states) @@ -318,12 +371,12 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_dim_resnet = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim def forward( self, @@ -362,9 +415,9 @@ def custom_forward(*inputs): ) else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -375,7 +428,16 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + downsampler(hs_split) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -516,12 +578,18 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -569,21 +637,34 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -597,7 +678,16 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + downsampler(hs_split) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -729,12 +819,18 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -806,21 +902,34 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -828,7 +937,16 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + upsampler(hs_split, upsample_size) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -911,12 +1029,12 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_dim_resnet = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim def forward( self, @@ -980,9 +1098,9 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -991,7 +1109,16 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if self._chunk_size_resnet is not None: + hidden_states = torch.cat( + [ + upsampler(hs_split, upsample_size) + for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) + ], + dim=self._chunk_dim_resnet, + ) + else: + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1121,12 +1248,18 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False - self._chunk_size = None - self._chunk_dim = 0 + self._chunk_size_resnet = None + self._chunk_size_attn = None + self._chunk_dim_resnet = 0 + self._chunk_dim_attn = 0 def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size = chunk_size - self._chunk_dim = dim + self._chunk_size_resnet = chunk_size + self._chunk_dim_resnet = dim + + def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + self._chunk_size_attn = chunk_size + self._chunk_dim_attn = dim def forward( self, @@ -1142,9 +1275,9 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - self.resnets[0], hidden_states, temb, self._chunk_size, self._chunk_dim + self.resnets[0], hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = self.resnets[0](hidden_states, temb) @@ -1184,23 +1317,35 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + if self._chunk_size_attn is not None: + hidden_states = _chunked_attn_forward( + attn, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + self._chunk_size_resnet, + self._chunk_dim_resnet, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - if self._chunk_size is not None: + if self._chunk_size_resnet is not None: hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size, self._chunk_dim + resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet ) else: hidden_states = resnet(hidden_states, temb) @@ -2045,7 +2190,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int fn_recursive_feed_forward(module, None, 0) def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 1 + chunk_size = chunk_size or 16 for name, module in self.named_modules(): if hasattr(module, "set_chunk_resnet"): @@ -2058,6 +2203,34 @@ def disable_resnet_chunking(self) -> None: logger.debug(f"Disabling chunked resnet inference in: {name}") module.set_chunk_resnet(None) + def enable_attn_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 16 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_attn"): + logger.debug(f"Enabling chunked attn inference in: {name}") + module.set_chunk_attn(chunk_size, dim) + + def disable_attn_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_attn"): + logger.debug(f"Disabling chunked attn inference in: {name}") + module.set_chunk_attn(None) + + def enable_motion_module_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + chunk_size = chunk_size or 256 + + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_motion_module"): + logger.debug(f"Enabling chunked motion module inference in: {name}") + module.set_chunk_motion_module(chunk_size, dim) + + def disable_motion_module_chunking(self) -> None: + for name, module in self.named_modules(): + if hasattr(module, "set_chunk_motion_module"): + logger.debug(f"Disabling chunked motion module inference in: {name}") + module.set_chunk_motion_module(None) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index ff5af5af01da..e76ae2358b8d 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -70,6 +70,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -98,6 +101,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) def _check_inputs_free_noise( self, @@ -332,6 +338,7 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, + _chunk_size_attn: Optional[int] = None, _chunk_size_resnet: Optional[int] = None, _chunk_size_feed_forward: Optional[int] = None, ) -> None: @@ -381,6 +388,8 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) + if _chunk_size_attn is not None: + self.unet.enable_attn_chunking(_chunk_size_attn, dim=0) if _chunk_size_resnet is not None: self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) if _chunk_size_feed_forward is not None: From ec91064966ac944c7ae0cb0cd2d89345e33c796e Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 18:42:15 +0200 Subject: [PATCH 06/16] update --- src/diffusers/models/attention.py | 25 +++++-- .../models/unets/unet_motion_model.py | 69 +++++++++++-------- src/diffusers/pipelines/free_noise_utils.py | 2 +- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5005ad118894..83d7ae6c448c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -981,15 +981,32 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "pyramid": + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": if num_frames % 2 == 0: # num_frames = 4 => [1, 2, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) + mid = num_frames // 2 + weights = list(range(1, mid + 1)) weights = weights + weights[::-1] else: # num_frames = 5 => [1, 2, 3, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + [num_frames // 2 + 1] + weights[::-1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) else: raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index e01f5dbfe628..26a29e5e4802 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -232,36 +232,47 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if encoder_hidden_states is None: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=None, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split in hidden_states.split(self._chunk_size_motion_module) - ], - dim=self._chunk_dim_motion_module, - ) + if self._chunk_size_motion_module is not None: + if encoder_hidden_states is None: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split in hidden_states.split(self._chunk_size_motion_module) + ], + dim=self._chunk_dim_motion_module, + ) + else: + hidden_states = torch.cat( + [ + block( + hs_split, + encoder_hidden_states=ehs_split, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + for hs_split, ehs_split in zip( + hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), + encoder_hidden_states.split( + self._chunk_size_motion_module, self._chunk_dim_motion_module + ), + ) + ], + dim=self._chunk_dim_motion_module, + ) else: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=ehs_split, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split, ehs_split in zip( - hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - encoder_hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - ) - ], - dim=self._chunk_dim_motion_module, + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, ) # 3. Output diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index e76ae2358b8d..cbf323be15bc 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -364,7 +364,7 @@ def enable_free_noise( TODO """ - allowed_weighting_scheme = ["pyramid"] + allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: From 65686818ab13e4fa0b17163139566dd2e390a59b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 18 Aug 2024 23:54:55 +0200 Subject: [PATCH 07/16] update animatediff controlnet with latest changes --- .../pipeline_animatediff_controlnet.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 6e8b0e3e5fe3..5357d6d5b8d9 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -505,8 +505,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -699,6 +699,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -858,9 +862,10 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if prompt is not None and isinstance(prompt, (str, dict)): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -883,22 +888,39 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -990,6 +1012,9 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1002,7 +1027,6 @@ def __call__( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] From 761c44d116665836ef238991ab01f4cb0acf86db Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 21 Aug 2024 11:47:31 +0200 Subject: [PATCH 08/16] refactor chunked inference changes --- .../models/unets/unet_motion_model.py | 385 +++--------------- src/diffusers/pipelines/free_noise_utils.py | 119 +++++- 2 files changed, 154 insertions(+), 350 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 26a29e5e4802..6125feba5899 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,41 +49,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _chunked_resnet_forward( - resnet: ResnetBlock2D, hidden_states: torch.Tensor, temb: torch.Tensor, chunk_size: int, chunk_dim: int -) -> torch.Tensor: - return torch.cat( - [ - resnet(hs_split, t_split) - for hs_split, t_split in zip(hidden_states.split(chunk_size, chunk_dim), temb.split(chunk_size, chunk_dim)) - ], - dim=chunk_dim, - ) - - -def _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - chunk_size: int, - chunk_dim: int, -) -> torch.Tensor: - return torch.cat( - [ - attn( - hs_split, ehs_split, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict=False - )[0] - for hs_split, ehs_split in zip( - hidden_states.split(chunk_size, chunk_dim), encoder_hidden_states.split(chunk_size, chunk_dim) - ) - ], - dim=chunk_dim, - ) - - @dataclass class UNetMotionOutput(BaseOutput): """ @@ -175,12 +140,6 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, in_channels) - self._chunk_size_motion_module = None - self._chunk_dim_motion_module = 0 - - def set_chunk_motion_module(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_motion_module = chunk_size - self._chunk_dim_motion_module = dim def forward( self, @@ -228,55 +187,20 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(input=hidden_states) # 2. Blocks for block in self.transformer_blocks: - if self._chunk_size_motion_module is not None: - if encoder_hidden_states is None: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=None, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split in hidden_states.split(self._chunk_size_motion_module) - ], - dim=self._chunk_dim_motion_module, - ) - else: - hidden_states = torch.cat( - [ - block( - hs_split, - encoder_hidden_states=ehs_split, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - for hs_split, ehs_split in zip( - hidden_states.split(self._chunk_size_motion_module, self._chunk_dim_motion_module), - encoder_hidden_states.split( - self._chunk_size_motion_module, self._chunk_dim_motion_module - ), - ) - ], - dim=self._chunk_dim_motion_module, - ) - else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) # 3. Output - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(input=hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -382,12 +306,6 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_dim_resnet = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim def forward( self, @@ -426,12 +344,7 @@ def custom_forward(*inputs): ) else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -439,16 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - downsampler(hs_split) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -589,18 +493,6 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -639,42 +531,17 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, @@ -689,16 +556,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - downsampler(hs_split) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -830,18 +688,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -904,42 +750,17 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, @@ -948,16 +769,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - upsampler(hs_split, upsample_size) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1040,12 +852,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._chunk_size_resnet = None - self._chunk_dim_resnet = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim def forward( self, @@ -1109,27 +915,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - if self._chunk_size_resnet is not None: - hidden_states = torch.cat( - [ - upsampler(hs_split, upsample_size) - for hs_split in hidden_states.split(self._chunk_size_resnet, self._chunk_dim_resnet) - ], - dim=self._chunk_dim_resnet, - ) - else: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1259,18 +1051,6 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False - self._chunk_size_resnet = None - self._chunk_size_attn = None - self._chunk_dim_resnet = 0 - self._chunk_dim_attn = 0 - - def set_chunk_resnet(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_resnet = chunk_size - self._chunk_dim_resnet = dim - - def set_chunk_attn(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - self._chunk_size_attn = chunk_size - self._chunk_dim_attn = dim def forward( self, @@ -1286,15 +1066,19 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - self.resnets[0], hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1307,14 +1091,6 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1328,38 +1104,11 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - if self._chunk_size_attn is not None: - hidden_states = _chunked_attn_forward( - attn, - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - self._chunk_size_resnet, - self._chunk_dim_resnet, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - - if self._chunk_size_resnet is not None: - hidden_states = _chunked_resnet_forward( - resnet, hidden_states, temb, self._chunk_size_resnet, self._chunk_dim_resnet - ) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states @@ -2200,48 +1949,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - def enable_resnet_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 16 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_resnet"): - logger.debug(f"Enabling chunked resnet inference in: {name}") - module.set_chunk_resnet(chunk_size, dim) - - def disable_resnet_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_resnet"): - logger.debug(f"Disabling chunked resnet inference in: {name}") - module.set_chunk_resnet(None) - - def enable_attn_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 16 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_attn"): - logger.debug(f"Enabling chunked attn inference in: {name}") - module.set_chunk_attn(chunk_size, dim) - - def disable_attn_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_attn"): - logger.debug(f"Disabling chunked attn inference in: {name}") - module.set_chunk_attn(None) - - def enable_motion_module_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - chunk_size = chunk_size or 256 - - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_motion_module"): - logger.debug(f"Enabling chunked motion module inference in: {name}") - module.set_chunk_motion_module(chunk_size, dim) - - def disable_motion_module_chunking(self) -> None: - for name, module in self.named_modules(): - if hasattr(module, "set_chunk_motion_module"): - logger.debug(f"Disabling chunked motion module inference in: {name}") - module.set_chunk_motion_module(None) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index cbf323be15bc..817360b3eaea 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -30,6 +34,59 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class ChunkedInferenceModule(nn.Module): + def __init__( + self, + module: nn.Module, + chunk_size: int = 1, + chunk_dim: int = 0, + input_kwargs_to_chunk: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.chunk_size = chunk_size + self.chunk_dim = chunk_dim + self.input_kwargs_to_chunk = set(input_kwargs_to_chunk) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method of `ChunkedInferenceModule`. + + All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be + chunked that are specified in `inputs_to_chunk` when initializing the module. + """ + chunked_inputs = {} + + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]): + continue + chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim) + kwargs.pop(key) + + results = [] + for chunked_input in zip(*chunked_inputs.values()): + inputs = dict(zip(chunked_inputs.keys(), chunked_input)) + inputs.update(kwargs) + + for key, input in inputs.items(): + if torch.is_tensor(input): + print(key, input.shape) + else: + print(key) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.chunk_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -338,9 +395,6 @@ def enable_free_noise( prompt_interpolation_callback: Optional[ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] ] = None, - _chunk_size_attn: Optional[int] = None, - _chunk_size_resnet: Optional[int] = None, - _chunk_size_feed_forward: Optional[int] = None, ) -> None: r""" Enable long video generation using FreeNoise. @@ -388,13 +442,6 @@ def enable_free_noise( for block in blocks: self._enable_free_noise_in_block(block) - if _chunk_size_attn is not None: - self.unet.enable_attn_chunking(_chunk_size_attn, dim=0) - if _chunk_size_resnet is not None: - self.unet.enable_resnet_chunking(_chunk_size_resnet, dim=0) - if _chunk_size_feed_forward is not None: - self.unet.enable_forward_chunking(_chunk_size_feed_forward, dim=0) - def disable_free_noise(self) -> None: self._free_noise_context_length = None @@ -402,6 +449,56 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) + def _enable_chunked_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = ChunkedInferenceModule( + motion_module.transformer_blocks[i], + spatial_chunk_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"]) + + def _enable_chunked_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_chunk_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = ChunkedInferenceModule( + attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"]) + + def _enable_chunked_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"]) + + def enable_free_noise_chunked_inference( + self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16 + ) -> None: + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size) + if getattr(block, "attentions", None) is not None: + self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size) + if getattr(block, "resnets", None) is not None: + self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None From 6830fb08052039c549708f4a8ec77d3d4c29b6d7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 21 Aug 2024 11:52:05 +0200 Subject: [PATCH 09/16] remove print statements --- src/diffusers/pipelines/free_noise_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 817360b3eaea..1c1147a4949e 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -68,12 +68,6 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: inputs = dict(zip(chunked_inputs.keys(), chunked_input)) inputs.update(kwargs) - for key, input in inputs.items(): - if torch.is_tensor(input): - print(key, input.shape) - else: - print(key) - intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) From c887d7700b5e45f1939f9913d3b3d18449f7a46c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 24 Aug 2024 01:54:35 +0200 Subject: [PATCH 10/16] undo memory optimization changes --- src/diffusers/models/attention.py | 28 +---- src/diffusers/models/attention_processor.py | 2 - .../models/unets/unet_motion_model.py | 101 ++++++++++------- src/diffusers/pipelines/free_noise_utils.py | 103 +----------------- 4 files changed, 64 insertions(+), 170 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 83d7ae6c448c..efeb553c1947 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -43,12 +43,6 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output -def _experimental_split_feed_forward( - ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int -) -> torch.Tensor: - return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim) - - @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -531,10 +525,7 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - ff_output = _experimental_split_feed_forward( - self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) @@ -1124,27 +1115,12 @@ def forward( dim=1, ).to(dtype) - # hidden_states = torch.where( - # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values - # ).to(dtype) - # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: - # norm_hidden_states = torch.cat([ - # self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) - # ], dim=self._chunk_dim) - # ff_output = torch.cat([ - # self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim) - # ], dim=self._chunk_dim) - - # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - ff_output = _experimental_split_feed_forward( - self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: - norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 17fbdc526a6d..9f9bc5a46e10 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2372,8 +2372,6 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - # TODO: figure out a better way to do this - # hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0) hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 6125feba5899..89cdb76741f7 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -187,12 +187,12 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(input=hidden_states) + hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( - hidden_states=hidden_states, + hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, @@ -200,7 +200,7 @@ def forward( ) # 3. Output - hidden_states = self.proj_out(input=hidden_states) + hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -344,7 +344,7 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -352,7 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states=hidden_states) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -531,18 +531,25 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) - - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -556,7 +563,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states=hidden_states) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -750,18 +757,25 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) - - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -769,7 +783,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -915,13 +929,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1066,19 +1080,10 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1091,6 +1096,14 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1104,11 +1117,19 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 1c1147a4949e..fe4fdd5d0f3e 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Union import torch -import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock -from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( - AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -34,53 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class ChunkedInferenceModule(nn.Module): - def __init__( - self, - module: nn.Module, - chunk_size: int = 1, - chunk_dim: int = 0, - input_kwargs_to_chunk: List[str] = ["hidden_states"], - ) -> None: - super().__init__() - - self.module = module - self.chunk_size = chunk_size - self.chunk_dim = chunk_dim - self.input_kwargs_to_chunk = set(input_kwargs_to_chunk) - - def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method of `ChunkedInferenceModule`. - - All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be - chunked that are specified in `inputs_to_chunk` when initializing the module. - """ - chunked_inputs = {} - - for key in list(kwargs.keys()): - if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]): - continue - chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim) - kwargs.pop(key) - - results = [] - for chunked_input in zip(*chunked_inputs.values()): - inputs = dict(zip(chunked_inputs.keys(), chunked_input)) - inputs.update(kwargs) - - intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) - results.append(intermediate_tensor_or_tensor_tuple) - - if isinstance(results[0], torch.Tensor): - return torch.cat(results, dim=self.chunk_dim) - elif isinstance(results[0], tuple): - return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)]) - else: - raise ValueError( - "In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." - ) - - class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -443,56 +392,6 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) - def _enable_chunked_inference_motion_modules_( - self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int - ) -> None: - for motion_module in motion_modules: - motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"]) - - for i in range(len(motion_module.transformer_blocks)): - motion_module.transformer_blocks[i] = ChunkedInferenceModule( - motion_module.transformer_blocks[i], - spatial_chunk_size, - 0, - ["hidden_states", "encoder_hidden_states"], - ) - - motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"]) - - def _enable_chunked_inference_attentions_( - self, attentions: List[Transformer2DModel], temporal_chunk_size: int - ) -> None: - for i in range(len(attentions)): - attentions[i] = ChunkedInferenceModule( - attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"] - ) - - def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None: - for i in range(len(resnets)): - resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"]) - - def _enable_chunked_inference_samplers_( - self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int - ) -> None: - for i in range(len(samplers)): - samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"]) - - def enable_free_noise_chunked_inference( - self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16 - ) -> None: - blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] - for block in blocks: - if getattr(block, "motion_modules", None) is not None: - self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size) - if getattr(block, "attentions", None) is not None: - self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size) - if getattr(block, "resnets", None) is not None: - self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size) - if getattr(block, "downsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size) - if getattr(block, "upsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size) - @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None From 88be756284969eb309f7ec58d54c1d7518aa9610 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 24 Aug 2024 02:04:24 +0200 Subject: [PATCH 11/16] update docstrings --- src/diffusers/pipelines/free_noise_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index fe4fdd5d0f3e..21247edb8867 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -355,10 +355,24 @@ def enable_free_noise( weighting_scheme (`str`, defaults to `pyramid`): Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting schemes are supported currently: + - "flat" + Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1]. - "pyramid" - Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + - "delayed_reverse_sawtooth" + Performs weighted averaging with low weights for earlier frames and high-to-low weights for + later frames: [0.01, 0.01, 3, 2, 1]. noise_type (`str`, defaults to "shuffle_context"): - TODO + Must be one of ["shuffle_context", "repeat_context", "random"]. + - "shuffle_context" + Shuffles a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. This is usually the best setting for most generation scenarious. However, there + might be visible repetition noticeable in the kinds of motion/animation generated. + - "repeated_context" + Repeats a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. + - "random" + The final latents are random without any repetition. """ allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] @@ -386,6 +400,7 @@ def enable_free_noise( self._enable_free_noise_in_block(block) def disable_free_noise(self) -> None: + r"""Disable the FreeNoise sampling mechanism.""" self._free_noise_context_length = None blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] From 91ded8323b9fcbd0776b2ca578a54bf626d7f6e8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 25 Aug 2024 14:24:22 +0200 Subject: [PATCH 12/16] fix tests --- src/diffusers/models/controlnet_sparsectrl.py | 1 - .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 2 ++ .../animatediff/pipeline_animatediff_sparsectrl.py | 2 ++ src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py | 6 +++++- tests/models/unets/test_models_unet_motion.py | 2 +- 5 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index e91551c70953..fa37e1f9e393 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -691,7 +691,6 @@ def forward( emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(sample_num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index a46682347519..e531c91c168f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1143,6 +1143,8 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index e9e0d518c806..8b037cdc34fb 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -878,6 +878,8 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + # 4. Prepare IP-Adapter embeddings if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index 73c53b365848..1e81fa3a158c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -734,6 +734,8 @@ def __call__( elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, @@ -805,7 +807,9 @@ def __call__( with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index 53833d6a075b..ee05f0d93824 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -51,7 +51,7 @@ def dummy_input(self): noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device) return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} From d1ed0d7bf2269e8a251b0e9ee8b1f22896cce510 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 25 Aug 2024 22:29:02 +0200 Subject: [PATCH 13/16] fix pia tests --- src/diffusers/pipelines/pia/pipeline_pia.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f0e8cfb03def..b7dfcd39edce 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -824,6 +824,8 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, From 81ce0e091e354c38ace44b382df1378a01cfc98b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 27 Aug 2024 13:06:18 +0200 Subject: [PATCH 14/16] apply suggestions from review --- src/diffusers/models/attention.py | 11 ++--------- src/diffusers/pipelines/free_noise_utils.py | 6 ------ 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index efeb553c1947..7766442f7133 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,15 +1104,8 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.cat( - [ - torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) - for accumulated_split, num_times_split in zip( - accumulated_values.split(self.context_length, dim=1), - num_times_accumulated.split(self.context_length, dim=1), - ) - ], - dim=1, + hidden_states = torch.where( + num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values ).to(dtype) # 3. Feed-forward diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 21247edb8867..25346134ca52 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -70,9 +70,6 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) - motion_module.transformer_blocks[i].set_chunk_feed_forward( - basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim - ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -101,9 +98,6 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - motion_module.transformer_blocks[i].set_chunk_feed_forward( - free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim - ) def _check_inputs_free_noise( self, From 1dc97180b9c5b530bd21e53e0dec06c784f31dac Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 28 Aug 2024 11:02:34 +0200 Subject: [PATCH 15/16] add tests --- .../pipelines/animatediff/test_animatediff.py | 23 +++++++++++++++++ .../test_animatediff_controlnet.py | 21 ++++++++++++++++ .../test_animatediff_video2video.py | 25 +++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 1354ac9ff1a8..7b67681c2d69 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,29 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_frames"] = 16 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device) + inputs["num_frames"] = 16 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 3035fc1e3c61..5992f731ca6d 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -476,6 +476,27 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] + def test_vae_slicing(self, video_count=2): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index cd33bf0891a5..6a47fe2ba184 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -491,3 +491,28 @@ def test_free_noise(self): 1e-4, "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + + def test_free_noise_multi_prompt(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + context_length = 8 + context_stride = 4 + pipe.enable_free_noise(context_length, context_stride) + + # Make sure that + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + pipe(**inputs).frames[0] + + with self.assertRaises(ValueError): + # Ensure that prompt indices are within bounds + inputs = self.get_dummy_inputs(torch_device, num_frames=16) + inputs["num_inference_steps"] = 2 + inputs["strength"] = 0.5 + inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} + pipe(**inputs).frames[0] From de88f16c2468330121cf15e01d0f24a135afd2f6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 28 Aug 2024 11:04:23 +0200 Subject: [PATCH 16/16] update comment --- tests/pipelines/animatediff/test_animatediff.py | 2 +- tests/pipelines/animatediff/test_animatediff_controlnet.py | 2 +- tests/pipelines/animatediff/test_animatediff_video2video.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 7b67681c2d69..618a5cff9912 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -470,7 +470,7 @@ def test_free_noise_multi_prompt(self): context_stride = 4 pipe.enable_free_noise(context_length, context_stride) - # Make sure that + # Make sure that pipeline works when prompt indices are within num_frames bounds inputs = self.get_dummy_inputs(torch_device) inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} inputs["num_frames"] = 16 diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 5992f731ca6d..c0ad223c6ce8 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -486,7 +486,7 @@ def test_free_noise_multi_prompt(self): context_stride = 4 pipe.enable_free_noise(context_length, context_stride) - # Make sure that + # Make sure that pipeline works when prompt indices are within num_frames bounds inputs = self.get_dummy_inputs(torch_device, num_frames=16) inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} pipe(**inputs).frames[0] diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 6a47fe2ba184..c49790e0f262 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -502,7 +502,7 @@ def test_free_noise_multi_prompt(self): context_stride = 4 pipe.enable_free_noise(context_length, context_stride) - # Make sure that + # Make sure that pipeline works when prompt indices are within num_frames bounds inputs = self.get_dummy_inputs(torch_device, num_frames=16) inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"} inputs["num_inference_steps"] = 2