Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__

## LTXConditionPipeline

[[autodoc]] LTXConditionPipeline
- all
- __call__

## LTXPipelineOutput

[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
23 changes: 15 additions & 8 deletions scripts/convert_ltx_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
"decoder.timestep_scale_multiplier": remove_keys_,
}


Expand Down Expand Up @@ -268,6 +269,9 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
"timestep_scale_multiplier": 1000.0,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config
Expand Down Expand Up @@ -346,14 +350,17 @@ def get_args():
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
if args.version == "0.9.5":
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)

pipe = LTXPipeline(
scheduler=scheduler,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Text2ImgPipeline",
Expand Down Expand Up @@ -857,6 +858,7 @@
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Text2ImgPipeline,
Expand Down
22 changes: 17 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,14 @@ def __init__(
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
timestep_scale_multiplier: float = 1.0,
) -> None:
super().__init__()

self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.out_channels = out_channels * patch_size**2
self.timestep_scale_multiplier = timestep_scale_multiplier

block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
Expand Down Expand Up @@ -981,9 +983,7 @@ def __init__(
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)

Expand All @@ -992,7 +992,7 @@ def __init__(
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)

if self.timestep_scale_multiplier is not None:
if temb is not None:
temb = temb * self.timestep_scale_multiplier

if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down Expand Up @@ -1105,6 +1105,9 @@ def __init__(
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
timestep_scale_multiplier: float = 1.0,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1135,15 +1138,24 @@ def __init__(
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
timestep_scale_multiplier=timestep_scale_multiplier,
)

latents_mean = torch.zeros((latent_channels,), requires_grad=False)
latents_std = torch.ones((latent_channels,), requires_grad=False)
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = (
patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)

# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
Expand Down
99 changes: 66 additions & 33 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,47 +115,77 @@ def __init__(
self.theta = theta
self._causal_rope_fix = _causal_rope_fix

def forward(
def _prepare_video_coords(
self,
hidden_states: torch.Tensor,
batch_size: int,
num_frames: int,
height: int,
width: int,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)

rope_interpolation_scale: Tuple[torch.Tensor, float, float],
frame_rate: float,
device: torch.device,
) -> torch.Tensor:
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)

if rope_interpolation_scale is not None:
if isinstance(rope_interpolation_scale, tuple):
# This will be deprecated in v0.34.0
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
if isinstance(rope_interpolation_scale, tuple):
# This will be deprecated in v0.34.0
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
else:
if not self._causal_rope_fix:
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
else:
if not self._causal_rope_fix:
grid[:, 0:1] = (
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
)
else:
grid[:, 0:1] = (
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
* self.patch_size_t
/ self.base_num_frames
)
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
grid[:, 0:1] = (
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
* self.patch_size_t
/ self.base_num_frames
)
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width

grid = grid.flatten(2, 4).transpose(1, 2)

return grid

def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)

if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
frame_rate=frame_rate,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)

start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
Expand Down Expand Up @@ -387,11 +417,12 @@ def forward(
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
frame_rate: int,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
Expand All @@ -414,7 +445,9 @@ def forward(
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
deprecate("rope_interpolation_scale", "0.34.0", msg)

image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
image_rotary_emb = self.rope(
hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords
)

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
Expand Down Expand Up @@ -610,7 +610,7 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
from .lumina2 import Lumina2Text2ImgPipeline
from .marigold import (
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/ltx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -34,6 +35,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline

else:
Expand Down
Loading