@@ -266,6 +266,7 @@ def __init__(
266266 self .tokenizer = tokenizer
267267 self .text_preprocessor = TextPreprocessor ()
268268 self .default_sample_size = default_sample_size
269+ self ._guidance_scale = 1.0
269270
270271 self .register_modules (
271272 transformer = transformer ,
@@ -277,10 +278,15 @@ def __init__(
277278
278279 self .register_to_config (default_sample_size = self .default_sample_size )
279280
280- self .image_processor = PixArtImageProcessor (vae_scale_factor = self .vae_scale_factor )
281+ if vae is not None :
282+ self .image_processor = PixArtImageProcessor (vae_scale_factor = self .vae_scale_factor )
283+ else :
284+ self .image_processor = None
281285
282286 @property
283287 def vae_scale_factor (self ):
288+ if self .vae is None :
289+ return 8
284290 if hasattr (self .vae , "spatial_compression_ratio" ):
285291 return self .vae .spatial_compression_ratio
286292 else : # Flux VAE
@@ -291,6 +297,10 @@ def do_classifier_free_guidance(self):
291297 """Check if classifier-free guidance is enabled based on guidance scale."""
292298 return self ._guidance_scale > 1.0
293299
300+ @property
301+ def guidance_scale (self ):
302+ return self ._guidance_scale
303+
294304 def prepare_latents (
295305 self ,
296306 batch_size : int ,
@@ -318,28 +328,58 @@ def prepare_latents(
318328 def encode_prompt (
319329 self ,
320330 prompt : Union [str , List [str ]],
321- device : torch .device ,
331+ device : Optional [ torch .device ] = None ,
322332 do_classifier_free_guidance : bool = True ,
323333 negative_prompt : str = "" ,
334+ num_images_per_prompt : int = 1 ,
324335 prompt_embeds : Optional [torch .FloatTensor ] = None ,
325336 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
326337 prompt_attention_mask : Optional [torch .BoolTensor ] = None ,
327338 negative_prompt_attention_mask : Optional [torch .BoolTensor ] = None ,
328339 ):
329340 """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
330- if prompt_embeds is not None :
331- # Use precomputed embeddings
332- return (
333- prompt_embeds ,
334- prompt_attention_mask ,
335- negative_prompt_embeds if do_classifier_free_guidance else None ,
336- negative_prompt_attention_mask if do_classifier_free_guidance else None ,
341+ if device is None :
342+ device = self ._execution_device
343+
344+ if prompt_embeds is None :
345+ if isinstance (prompt , str ):
346+ prompt = [prompt ]
347+ # Encode the prompts
348+ text_embeddings , cross_attn_mask , uncond_text_embeddings , uncond_cross_attn_mask = (
349+ self ._encode_prompt_standard (prompt , device , do_classifier_free_guidance , negative_prompt )
337350 )
338-
339- if isinstance (prompt , str ):
340- prompt = [prompt ]
341-
342- return self ._encode_prompt_standard (prompt , device , do_classifier_free_guidance , negative_prompt )
351+ prompt_embeds = text_embeddings
352+ prompt_attention_mask = cross_attn_mask
353+ negative_prompt_embeds = uncond_text_embeddings
354+ negative_prompt_attention_mask = uncond_cross_attn_mask
355+
356+ # Duplicate embeddings for each generation per prompt
357+ if num_images_per_prompt > 1 :
358+ # Repeat prompt embeddings
359+ bs_embed , seq_len , _ = prompt_embeds .shape
360+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
361+ prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
362+
363+ if prompt_attention_mask is not None :
364+ prompt_attention_mask = prompt_attention_mask .view (bs_embed , - 1 )
365+ prompt_attention_mask = prompt_attention_mask .repeat (num_images_per_prompt , 1 )
366+
367+ # Repeat negative embeddings if using CFG
368+ if do_classifier_free_guidance and negative_prompt_embeds is not None :
369+ bs_embed , seq_len , _ = negative_prompt_embeds .shape
370+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
371+ negative_prompt_embeds = negative_prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
372+
373+ if negative_prompt_attention_mask is not None :
374+ negative_prompt_attention_mask = negative_prompt_attention_mask .view (bs_embed , - 1 )
375+ negative_prompt_attention_mask = negative_prompt_attention_mask .repeat (num_images_per_prompt , 1 )
376+
377+ return (
378+ prompt_embeds ,
379+ prompt_attention_mask ,
380+ negative_prompt_embeds if do_classifier_free_guidance else None ,
381+ negative_prompt_attention_mask if do_classifier_free_guidance else None ,
382+ )
343383
344384 def _tokenize_prompts (self , prompts : List [str ], device : torch .device ):
345385 """Tokenize and clean prompts."""
@@ -549,6 +589,11 @@ def __call__(
549589 width = width or self .default_sample_size
550590
551591 if use_resolution_binning :
592+ if self .image_processor is None :
593+ raise ValueError (
594+ "Resolution binning requires a VAE with image_processor, but VAE is not available. "
595+ "Set use_resolution_binning=False or provide a VAE."
596+ )
552597 if self .default_sample_size <= 256 :
553598 aspect_ratio_bin = ASPECT_RATIO_256_BIN
554599 else :
@@ -570,13 +615,20 @@ def __call__(
570615 negative_prompt_embeds ,
571616 )
572617
618+ if self .vae is None and output_type not in ["latent" , "pt" ]:
619+ raise ValueError (
620+ f"VAE is required for output_type='{ output_type } ' but it is not available. "
621+ "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
622+ )
623+
573624 if prompt is not None and isinstance (prompt , str ):
574625 batch_size = 1
575626 elif prompt is not None and isinstance (prompt , list ):
576627 batch_size = len (prompt )
577628 else :
578629 batch_size = prompt_embeds .shape [0 ]
579630
631+ # Use execution device (handles offloading scenarios including group offloading)
580632 device = self ._execution_device
581633
582634 self ._guidance_scale = guidance_scale
@@ -587,11 +639,15 @@ def __call__(
587639 device ,
588640 do_classifier_free_guidance = self .do_classifier_free_guidance ,
589641 negative_prompt = negative_prompt ,
642+ num_images_per_prompt = num_images_per_prompt ,
590643 prompt_embeds = prompt_embeds ,
591644 negative_prompt_embeds = negative_prompt_embeds ,
592645 prompt_attention_mask = prompt_attention_mask ,
593646 negative_prompt_attention_mask = negative_prompt_attention_mask ,
594647 )
648+ # Expose standard names for callbacks parity
649+ prompt_embeds = text_embeddings
650+ negative_prompt_embeds = uncond_text_embeddings
595651
596652 # 3. Prepare timesteps
597653 if timesteps is not None :
@@ -602,8 +658,14 @@ def __call__(
602658 self .scheduler .set_timesteps (num_inference_steps , device = device )
603659 timesteps = self .scheduler .timesteps
604660
661+ self .num_timesteps = len (timesteps )
662+
605663 # 4. Prepare latent variables
606- num_channels_latents = self .vae .config .latent_channels
664+ if self .vae is not None :
665+ num_channels_latents = self .vae .config .latent_channels
666+ else :
667+ # When vae is None, get latent channels from transformer
668+ num_channels_latents = self .transformer .config .in_channels
607669 latents = self .prepare_latents (
608670 batch_size * num_images_per_prompt ,
609671 num_channels_latents ,
@@ -675,7 +737,7 @@ def __call__(
675737 progress_bar .update ()
676738
677739 # 8. Post-processing
678- if output_type == "latent" :
740+ if output_type == "latent" or ( output_type == "pt" and self . vae is None ) :
679741 image = latents
680742 else :
681743 # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
0 commit comments