-
Couldn't load subscription status.
- Fork 6.5k
Description
The newer models like Mochi-1 run the text encoder and VAE decoding in FP32 while keeping the denoising process in torch.bfloat16 autocast.
Currently, it's not possible for our pipelines to run the different models involved as we set a global torch_dtype while initializing the pipeline.
We have some pipelines like SDXL where the VAE has a config attribute called force_upcast and it's handled within the pipeline implementation like so:
diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Lines 1264 to 1275 in cfdeebd
| if not output_type == "latent": | |
| # make sure the VAE is in float32 mode, as it overflows in float16 | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.upcast_vae() | |
| latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
| elif latents.dtype != self.vae.dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| self.vae = self.vae.to(latents.dtype) | |
Another way to achieve this could be to decouple the major computation stages of the pipeline and users can choose whatever supported torch_dtype they want. Here is an example.
But this an involved process and is a power-user thing, IMO. What if we could allow the users to pass a torch_dtype map like so:
{"unet": torch.bfloat16, "vae": torch.float32, "text_encoder": torch.float32}This along with @a-r-r-o-w's idea of an upcast marker could really benefit the pipelines that are not resilient to precision changes.