diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index 5b48569f3505..ac12bdb5578d 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -20,11 +20,29 @@ The abstract from the paper is: *Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* +PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers. + +- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor` +- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor` +- Partial identifier as a RegEx: `down_blocks.2`, or `attn1` +- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]` + + + +Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results. + + + ## AnimateDiffPAGPipeline [[autodoc]] AnimateDiffPAGPipeline - all - __call__ +## HunyuanDiTPAGPipeline +[[autodoc]] HunyuanDiTPAGPipeline + - all + - __call__ + ## StableDiffusionPAGPipeline [[autodoc]] StableDiffusionPAGPipeline - all @@ -59,4 +77,4 @@ The abstract from the paper is: ## PixArtSigmaPAGPipeline [[autodoc]] PixArtSigmaPAGPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d58bbdac1867..ce437a5e6b73 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -252,6 +252,7 @@ "CycleDiffusionPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", + "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", @@ -675,6 +676,7 @@ CycleDiffusionPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, + HunyuanDiTPAGPipeline, HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 784eaaa62c55..80b462ef6a4f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2147,6 +2147,253 @@ def __call__( return hidden_states +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -3468,4 +3715,6 @@ def __init__(self): CustomDiffusionAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, + PAGCFGHunyuanAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, ] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 10f6c4a92054..46e93ea7769b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -145,6 +145,7 @@ _import_structure["pag"].extend( [ "AnimateDiffPAGPipeline", + "HunyuanDiTPAGPipeline", "StableDiffusionPAGPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionXLPAGPipeline", @@ -532,6 +533,7 @@ from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, + HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 854cfaa47b7a..d45d616dedf0 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -50,6 +50,7 @@ from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pag import ( + HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, @@ -85,6 +86,7 @@ ("stable-diffusion-3", StableDiffusion3Pipeline), ("if", IFPipeline), ("hunyuan", HunyuanDiTPipeline), + ("hunyuan-pag", HunyuanDiTPAGPipeline), ("kandinsky", KandinskyCombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky3", Kandinsky3Pipeline), diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index b80064eb5e9a..5bdbaaac0829 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] @@ -41,6 +42,7 @@ else: from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7c9bb2d098d2..728f730c9904 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import Dict, List, Tuple, Union + import torch +import torch.nn as nn from ...models.attention_processor import ( + Attention, + AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) @@ -25,140 +31,56 @@ class PAGMixin: - r"""Mixin class for PAG.""" - - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: - "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` - can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be - in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}" - """ - - layer_splits = layer.split(".") - - if len(layer_splits) > 3: - raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") - - if len(layer_splits) >= 1: - if layer_splits[0] not in ["down", "mid", "up"]: - raise ValueError( - f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" - ) - - if len(layer_splits) >= 2: - if not layer_splits[1].startswith("block_"): - raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") - - if len(layer_splits) == 3: - layer_2 = layer_splits[2] - if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"): - raise ValueError( - f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'" - ) + r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() + pag_attn_processors = self._pag_attn_processors + if pag_attn_processors is None: + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() + model: nn.Module = self.transformer - def is_self_attn(module_name): + def is_self_attn(module: nn.Module) -> bool: r""" Check if the module is self-attention module based on its name. """ - return "attn1" in module_name and "to" not in name + return isinstance(module, Attention) and not module.is_cross_attention - def get_block_type(module_name): - r""" - Get the block type from the module name. Can be "down", "mid", "up". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down" - return module_name.split(".")[0].split("_")[0] + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name - def get_block_index(module_name): - r""" - Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" - module_name_splits = module_name.split(".") - block_index = module_name_splits[1] - if "attentions" in block_index or "motion_modules" in block_index: - return "block_0" - else: - return f"block_{block_index}" - - def get_attn_index(module_name): - r""" - Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0", - "motion_modules_1", ... - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - module_name_split = module_name.split(".") - mid_name = module_name_split[1] - down_name = module_name_split[2] - if "attentions" in down_name: - return f"attentions_{module_name_split[3]}" - if "attentions" in mid_name: - return f"attentions_{module_name_split[2]}" - if "motion_modules" in down_name: - return f"motion_modules_{module_name_split[3]}" - if "motion_modules" in mid_name: - return f"motion_modules_{module_name_split[2]}" - - for pag_layer_input in pag_applied_layers: + for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] - pag_layer_input_splits = pag_layer_input.split(".") - - if len(pag_layer_input_splits) == 1: - # when the layer input only contains block_type. e.g. "mid", "down", "up" - block_type = pag_layer_input_splits[0] - for name, module in self.unet.named_modules(): - if is_self_attn(name) and get_block_type(name) == block_type: - target_modules.append(module) - - elif len(pag_layer_input_splits) == 2: - # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - ): - target_modules.append(module) - - elif len(pag_layer_input_splits) == 3: - # when the layer input contains block_type, block_index and attention_index. - # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - attn_index = pag_layer_input_splits[2] - - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - and get_attn_index(name) == attn_index - ): - target_modules.append(module) + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc @@ -221,239 +143,95 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free cond = torch.cat([uncond, cond], dim=0) return cond - def set_pag_applied_layers(self, pag_applied_layers): + def set_pag_applied_layers( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - """ + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors @property - def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ + def pag_scale(self) -> float: + r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property - def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ + def pag_adaptive_scale(self) -> float: + r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property - def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ + def do_pag_adaptive_scaling(self) -> bool: + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ + def do_perturbed_attention_guidance(self) -> bool: + r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def pag_attn_processors(self): + def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model with the key as the name of the layer. """ - processors = {} - for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): - processors[name] = proc - return processors + if self._pag_attn_processors is None: + return {} + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} -class PixArtPAGMixin: - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. - """ - - # Check if the layer index is valid (should be int or str of int) - if isinstance(layer, int): - return # Valid layer index - - if isinstance(layer, str): - if layer.isdigit(): - return # Valid layer index - - # If it is not a valid layer index, raise a ValueError - raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.") - - def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() - else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() - - def is_self_attn(module_name): - r""" - Check if the module is self-attention module based on its name. - """ - return ( - "attn1" in module_name and len(module_name.split(".")) == 3 - ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ... - - def get_block_index(module_name): - r""" - Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # transformer_blocks.23.attn -> "23" - return module_name.split(".")[1] - - for pag_layer_input in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the transformer model - target_modules = [] - - block_index = str(pag_layer_input) - - for name, module in self.transformer.named_modules(): - if is_self_attn(name) and get_block_index(name) == block_index: - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") - - for module in target_modules: - module.processor = pag_attn_proc - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers - def set_pag_applied_layers(self, pag_applied_layers): - r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) - - self.pag_applied_layers = pag_applied_layers - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale - def _get_pag_scale(self, t): - r""" - Get the scale factor for the perturbed attention guidance at timestep `t`. - """ - - if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) - if signal_scale < 0: - signal_scale = 0 - return signal_scale - else: - return self.pag_scale - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): - r""" - Apply perturbed attention guidance to the noise prediction. - - Args: - noise_pred (torch.Tensor): The noise prediction tensor. - do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. - guidance_scale (float): The scale factor for the guidance term. - t (int): The current time step. - - Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. - """ - pag_scale = self._get_pag_scale(t) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) + processors = {} + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer else: - noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - return noise_pred - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance - def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): - """ - Prepares the perturbed attention guidance for the PAG model. - - Args: - cond (torch.Tensor): The conditional input tensor. - uncond (torch.Tensor): The unconditional input tensor. - do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. - - Returns: - torch.Tensor: The prepared perturbed attention guidance tensor. - """ + raise ValueError("No denoiser module found.") - cond = torch.cat([cond] * 2, dim=0) - - if do_classifier_free_guidance: - cond = torch.cat([uncond, cond], dim=0) - return cond - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale - def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ - return self._pag_scale - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale - def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ - return self._pag_adaptive_scale - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling - def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance - def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ - return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer - def pag_attn_processors(self): - r""" - Returns: - `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model - with the key as the name of the layer. - """ - - processors = {} - for name, proc in self.transformer.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + for name, proc in denoiser_module.attn_processors.items(): + if proc.__class__ in valid_attn_processors: processors[name] = proc + return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100644 index 000000000000..63126cc5aae9 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,953 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[14], + ... ).to("cuda") + + >>> # prompt = "an astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + requires_safety_checker: bool = True, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, + pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 1188ffe52ed7..8e5e6cbaf5ad 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -40,7 +40,7 @@ ASPECT_RATIO_1024_BIN, ) from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN -from .pag_utils import PixArtPAGMixin +from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -61,7 +61,7 @@ >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", ... torch_dtype=torch.float16, - ... pag_applied_layers=[14], + ... pag_applied_layers=["blocks.14"], ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") @@ -132,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): r""" [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation using PixArt-Sigma. @@ -164,7 +164,7 @@ def __init__( vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, - pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block + pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index e37506a60c61..b3b103742061 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -129,7 +129,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, - pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"] + pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] ): super().__init__() if isinstance(unet, UNet2DConditionModel): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3e9a33503906..63151cb86727 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -332,6 +332,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanDiTPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanDiTPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 8f637b991056..2cd80921e932 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -429,7 +429,10 @@ def test_pag_applied_layers(self): pipe.set_progress_bar_config(disable=None) # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers - all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k] + # Note that for motion modules in AnimateDiff, both attn1 and attn2 are self-attention + all_self_attn_layers = [ + k for k in pipe.unet.attn_processors.keys() if "attn1" in k or ("motion_modules" in k and "attn2" in k) + ] original_attn_procs = pipe.unet.attn_processors pag_layers = [ "down", @@ -439,12 +442,13 @@ def test_pag_applied_layers(self): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) - # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e. + # pag_applied_layers = ["mid"], or ["mid_block.0"] should apply to all self-attention layers in mid_block, i.e. # mid_block.motion_modules.0.transformer_blocks.0.attn1.processor # mid_block.attentions.0.transformer_blocks.0.attn1.processor all_self_attn_mid_layers = [ - "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn2.processor", ] pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid"] @@ -452,17 +456,17 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"] + pag_layers = ["mid_block.(attentions|motion_modules)"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -474,19 +478,19 @@ def test_pag_applied_layers(self): pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["down"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 6 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert (len(pipe.pag_attn_processors)) == 4 + assert (len(pipe.pag_attn_processors)) == 6 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 2 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.motion_modules_2"] + pag_layers = ["motion_modules.42"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py new file mode 100644 index 000000000000..db0e257760ed --- /dev/null +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, BertModel, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + HunyuanDiT2DModel, + HunyuanDiTPAGPipeline, + HunyuanDiTPipeline, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanDiTPAGPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanDiT2DModel( + sample_size=16, + num_layers=2, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + in_channels=4, + cross_attention_dim=32, + cross_attention_dim_t5=32, + pooled_projection_dim=16, + hidden_size=24, + activation_fn="gelu-approximate", + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDPMScheduler() + text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel") + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "use_resolution_binning": False, + "pag_scale": 0.0, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 16, 16, 3)) + expected_slice = np.array( + [0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_sequential_cpu_offload_forward_pass(self): + # TODO(YiYi) need to fix later + pass + + def test_sequential_offload_forward_pass_twice(self): + # TODO(YiYi) need to fix later + pass + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-3, + ) + + def test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) + + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = pipe.encode_prompt( + prompt, + device=torch_device, + dtype=torch.float32, + text_encoder_index=1, + ) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + + def test_feed_forward_chunking(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_no_chunking = image[0, -3:, -3:, -1] + + pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_chunking = image[0, -3:, -3:, -1] + + max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max() + self.assertLess(max_diff, 1e-4) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image = pipe(**inputs)[0] + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_fused = pipe(**inputs)[0] + image_slice_fused = image_fused[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_disabled = pipe(**inputs)[0] + image_slice_disabled = image_disabled[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + pipe_sd = HunyuanDiTPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + components = self.get_dummy_components() + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 3.0 + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_pag_applied_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] + original_attn_procs = pipe.transformer.attn_processors + pag_layers = ["blocks.0", "blocks.1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) + + # blocks.0 + block_0_self_attn = ["blocks.0.attn1.processor"] + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0.attn1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.(0|1)"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert (len(pipe.pag_attn_processors)) == 2 + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0", r"blocks\.1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 2 diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py index be86afe45be0..70b528dede56 100644 --- a/tests/pipelines/pag/test_pag_pixart_sigma.py +++ b/tests/pipelines/pag/test_pag_pixart_sigma.py @@ -127,7 +127,7 @@ def test_pag_disable_enable(self): out = pipe(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 - components["pag_applied_layers"] = [1] + components["pag_applied_layers"] = ["blocks.1"] pipe_pag = self.pipeline_class(**components) pipe_pag = pipe_pag.to(device) pipe_pag.set_progress_bar_config(disable=None) @@ -158,7 +158,7 @@ def test_pag_applied_layers(self): # "attn1" should apply to all self-attention layers. all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] - pag_layers = [0, 1] + pag_layers = ["blocks.0", "blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) @@ -228,7 +228,7 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -282,7 +282,7 @@ def test_save_load_local(self, expected_max_difference=1e-4): pipe.save_pretrained(tmpdir, safe_serialization=False) with CaptureLogger(logger) as cap_logger: - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index a0930245b375..e9adb3ac447e 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -213,18 +213,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -239,17 +239,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 1 diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py index 5ec3dc5555f1..589573385677 100644 --- a/tests/pipelines/pag/test_pag_sdxl.py +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -225,18 +225,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -251,17 +251,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2