Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,11 +715,6 @@ def __init__(
) -> None:
super().__init__()

# Store normalization parameters as tensors
self.mean = torch.tensor(latents_mean)
self.std = torch.tensor(latents_std)
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]

self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
Expand Down Expand Up @@ -751,7 +746,6 @@ def _count_conv3d(model):
self._enc_feat_map = [None] * self._enc_conv_num

def _encode(self, x: torch.Tensor) -> torch.Tensor:
scale = self.scale.type_as(x)
self.clear_cache()
## cache
t = x.shape[2]
Expand All @@ -770,8 +764,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:

enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
Expand All @@ -798,10 +790,8 @@ def encode(
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
# z: [b,c,t,h,w]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)

iter_ = z.shape[2]
x = self.post_quant_conv(z)
Expand Down Expand Up @@ -835,8 +825,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
scale = self.scale.type_as(z)
decoded = self._decode(z, scale).sample
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)

Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,15 @@ def __call__(

if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,17 @@ def prepare_latents(
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)

latent_condition = (latent_condition - latents_mean) * latents_std

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
Expand Down Expand Up @@ -654,6 +665,15 @@ def __call__(

if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
Expand Down
Loading