Skip to content

Commit 4aeccfe

Browse files
committed
add pipeline test + corresponding fixes
1 parent 027dbd5 commit 4aeccfe

File tree

4 files changed

+353
-17
lines changed

4 files changed

+353
-17
lines changed

src/diffusers/models/transformers/transformer_photon.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
7474
Tensor of the same shape as `xq` with rotary embeddings applied.
7575
"""
7676
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
77+
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
78+
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
7779
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
7880
return xq_out.reshape(*xq.shape).type_as(xq)
7981

@@ -409,7 +411,8 @@ def forward(
409411

410412
device = img_q.device
411413
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
412-
joint_mask = torch.cat([attention_mask.to(torch.bool), ones_img], dim=-1)
414+
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
415+
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
413416
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1)
414417

415418
kv_packed = torch.cat([k, v], dim=-1)

src/diffusers/pipelines/photon/pipeline_photon.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(
266266
self.tokenizer = tokenizer
267267
self.text_preprocessor = TextPreprocessor()
268268
self.default_sample_size = default_sample_size
269+
self._guidance_scale = 1.0
269270

270271
self.register_modules(
271272
transformer=transformer,
@@ -277,10 +278,15 @@ def __init__(
277278

278279
self.register_to_config(default_sample_size=self.default_sample_size)
279280

280-
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
281+
if vae is not None:
282+
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
283+
else:
284+
self.image_processor = None
281285

282286
@property
283287
def vae_scale_factor(self):
288+
if self.vae is None:
289+
return 8
284290
if hasattr(self.vae, "spatial_compression_ratio"):
285291
return self.vae.spatial_compression_ratio
286292
else: # Flux VAE
@@ -291,6 +297,10 @@ def do_classifier_free_guidance(self):
291297
"""Check if classifier-free guidance is enabled based on guidance scale."""
292298
return self._guidance_scale > 1.0
293299

300+
@property
301+
def guidance_scale(self):
302+
return self._guidance_scale
303+
294304
def prepare_latents(
295305
self,
296306
batch_size: int,
@@ -318,28 +328,58 @@ def prepare_latents(
318328
def encode_prompt(
319329
self,
320330
prompt: Union[str, List[str]],
321-
device: torch.device,
331+
device: Optional[torch.device] = None,
322332
do_classifier_free_guidance: bool = True,
323333
negative_prompt: str = "",
334+
num_images_per_prompt: int = 1,
324335
prompt_embeds: Optional[torch.FloatTensor] = None,
325336
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
326337
prompt_attention_mask: Optional[torch.BoolTensor] = None,
327338
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
328339
):
329340
"""Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
330-
if prompt_embeds is not None:
331-
# Use precomputed embeddings
332-
return (
333-
prompt_embeds,
334-
prompt_attention_mask,
335-
negative_prompt_embeds if do_classifier_free_guidance else None,
336-
negative_prompt_attention_mask if do_classifier_free_guidance else None,
341+
if device is None:
342+
device = self._execution_device
343+
344+
if prompt_embeds is None:
345+
if isinstance(prompt, str):
346+
prompt = [prompt]
347+
# Encode the prompts
348+
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = (
349+
self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
337350
)
338-
339-
if isinstance(prompt, str):
340-
prompt = [prompt]
341-
342-
return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
351+
prompt_embeds = text_embeddings
352+
prompt_attention_mask = cross_attn_mask
353+
negative_prompt_embeds = uncond_text_embeddings
354+
negative_prompt_attention_mask = uncond_cross_attn_mask
355+
356+
# Duplicate embeddings for each generation per prompt
357+
if num_images_per_prompt > 1:
358+
# Repeat prompt embeddings
359+
bs_embed, seq_len, _ = prompt_embeds.shape
360+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
361+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
362+
363+
if prompt_attention_mask is not None:
364+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
365+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
366+
367+
# Repeat negative embeddings if using CFG
368+
if do_classifier_free_guidance and negative_prompt_embeds is not None:
369+
bs_embed, seq_len, _ = negative_prompt_embeds.shape
370+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
371+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
372+
373+
if negative_prompt_attention_mask is not None:
374+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
375+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
376+
377+
return (
378+
prompt_embeds,
379+
prompt_attention_mask,
380+
negative_prompt_embeds if do_classifier_free_guidance else None,
381+
negative_prompt_attention_mask if do_classifier_free_guidance else None,
382+
)
343383

344384
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
345385
"""Tokenize and clean prompts."""
@@ -549,6 +589,11 @@ def __call__(
549589
width = width or self.default_sample_size
550590

551591
if use_resolution_binning:
592+
if self.image_processor is None:
593+
raise ValueError(
594+
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
595+
"Set use_resolution_binning=False or provide a VAE."
596+
)
552597
if self.default_sample_size <= 256:
553598
aspect_ratio_bin = ASPECT_RATIO_256_BIN
554599
else:
@@ -570,13 +615,20 @@ def __call__(
570615
negative_prompt_embeds,
571616
)
572617

618+
if self.vae is None and output_type not in ["latent", "pt"]:
619+
raise ValueError(
620+
f"VAE is required for output_type='{output_type}' but it is not available. "
621+
"Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
622+
)
623+
573624
if prompt is not None and isinstance(prompt, str):
574625
batch_size = 1
575626
elif prompt is not None and isinstance(prompt, list):
576627
batch_size = len(prompt)
577628
else:
578629
batch_size = prompt_embeds.shape[0]
579630

631+
# Use execution device (handles offloading scenarios including group offloading)
580632
device = self._execution_device
581633

582634
self._guidance_scale = guidance_scale
@@ -587,11 +639,15 @@ def __call__(
587639
device,
588640
do_classifier_free_guidance=self.do_classifier_free_guidance,
589641
negative_prompt=negative_prompt,
642+
num_images_per_prompt=num_images_per_prompt,
590643
prompt_embeds=prompt_embeds,
591644
negative_prompt_embeds=negative_prompt_embeds,
592645
prompt_attention_mask=prompt_attention_mask,
593646
negative_prompt_attention_mask=negative_prompt_attention_mask,
594647
)
648+
# Expose standard names for callbacks parity
649+
prompt_embeds = text_embeddings
650+
negative_prompt_embeds = uncond_text_embeddings
595651

596652
# 3. Prepare timesteps
597653
if timesteps is not None:
@@ -602,8 +658,14 @@ def __call__(
602658
self.scheduler.set_timesteps(num_inference_steps, device=device)
603659
timesteps = self.scheduler.timesteps
604660

661+
self.num_timesteps = len(timesteps)
662+
605663
# 4. Prepare latent variables
606-
num_channels_latents = self.vae.config.latent_channels
664+
if self.vae is not None:
665+
num_channels_latents = self.vae.config.latent_channels
666+
else:
667+
# When vae is None, get latent channels from transformer
668+
num_channels_latents = self.transformer.config.in_channels
607669
latents = self.prepare_latents(
608670
batch_size * num_images_per_prompt,
609671
num_channels_latents,
@@ -675,7 +737,7 @@ def __call__(
675737
progress_bar.update()
676738

677739
# 8. Post-processing
678-
if output_type == "latent":
740+
if output_type == "latent" or (output_type == "pt" and self.vae is None):
679741
image = latents
680742
else:
681743
# Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)

tests/pipelines/photon/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Dummy encoder for testing
2+
import torch
3+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
4+
5+
6+
class DummyT5GemmaConfig(PretrainedConfig):
7+
"""Dummy config for testing that mimics T5GemmaEncoder config."""
8+
9+
model_type = "dummy_t5gemma"
10+
11+
def __init__(self, hidden_size=8, vocab_size=256000, **kwargs):
12+
super().__init__(**kwargs)
13+
self.hidden_size = hidden_size
14+
self.vocab_size = vocab_size
15+
16+
17+
class DummyT5GemmaEncoder(PreTrainedModel):
18+
"""Dummy T5GemmaEncoder for testing that supports serialization."""
19+
20+
config_class = DummyT5GemmaConfig
21+
22+
def __init__(self, config):
23+
super().__init__(config)
24+
self.config = config
25+
self.embed = torch.nn.Embedding(config.vocab_size, config.hidden_size)
26+
27+
def forward(self, input_ids, attention_mask=None, output_hidden_states=False, **kwargs):
28+
hidden_states = self.embed(input_ids)
29+
return {"last_hidden_state": hidden_states}
30+
31+
32+
# Register the dummy model with transformers AutoConfig and AutoModel
33+
AutoConfig.register("dummy_t5gemma", DummyT5GemmaConfig)
34+
AutoModel.register(DummyT5GemmaConfig, DummyT5GemmaEncoder)

0 commit comments

Comments
 (0)