Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions video/Wan2.2/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,22 @@
import argparse
import logging
import os
import random
import sys
import warnings
from datetime import datetime

warnings.filterwarnings('ignore')

import random
import mlx.core as mx
import mlx.nn as nn
from PIL import Image
import numpy as np

# Note: MLX doesn't have built-in distributed training support like PyTorch
# For distributed training, you would need to implement custom logic or use MPI

import wan # Assuming wan has been converted to MLX
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.utils import save_video, str2bool

# Note: MLX doesn't have built-in distributed training support like PyTorch
# For distributed training, you would need to implement custom logic or use MPI

warnings.filterwarnings('ignore')

EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
Expand Down Expand Up @@ -220,25 +217,25 @@ def generate(args):
rank = 0
world_size = 1
local_rank = 0

# Check for distributed execution environment variables
# Note: Actual distributed implementation would require custom logic
if "RANK" in os.environ:
logging.warning("MLX doesn't have built-in distributed training. Running on single device.")

_init_logging(rank)

if args.offload_model is None:
args.offload_model = False
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")

# MLX doesn't support FSDP or distributed training out of the box
if args.t5_fsdp or args.dit_fsdp:
logging.warning("FSDP is not supported in MLX. Ignoring FSDP flags.")
args.t5_fsdp = False
args.dit_fsdp = False

if args.ulysses_size > 1:
logging.warning("Sequence parallel is not supported in MLX single-device mode. Setting ulysses_size to 1.")
args.ulysses_size = 1
Expand Down Expand Up @@ -369,7 +366,7 @@ def generate(args):
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix

logging.info(f"Saving generated video to {args.save_file}")

# Don't convert to numpy - keep as MLX array
save_video(
tensor=video[None], # Just add batch dimension, keep as MLX array
Expand Down
2 changes: 1 addition & 1 deletion video/Wan2.2/wan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .text2video import WanT2V
from .text2video import WanT2V
1 change: 0 additions & 1 deletion video/Wan2.2/wan/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
Expand Down
6 changes: 3 additions & 3 deletions video/Wan2.2/wan/configs/shared_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import mlx.core as mx
from easydict import EasyDict

#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()

# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.t5_dtype = mx.bfloat16
wan_shared_cfg.text_len = 512

# transformer
wan_shared_cfg.param_dtype = torch.bfloat16
wan_shared_cfg.param_dtype = mx.bfloat16

# inference
wan_shared_cfg.num_train_timesteps = 1000
Expand Down
2 changes: 1 addition & 1 deletion video/Wan2.2/wan/configs/wan_i2v_A14B.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict

from .shared_config import wan_shared_cfg
Expand All @@ -9,6 +8,7 @@
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
i2v_A14B.update(wan_shared_cfg)

# t5
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'

Expand Down
83 changes: 83 additions & 0 deletions video/Wan2.2/wan/model_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os

import torch
from safetensors.torch import save_file


def convert_pickle_to_safetensors(
pickle_path: str,
safetensors_path: str,
load_method: str = "weights_only" # Changed default to weights_only
):
"""Convert PyTorch pickle file to safetensors format."""

print(f"Loading PyTorch weights from: {pickle_path}")

# Try multiple loading methods in order of preference
methods_to_try = ["weights_only", "state_dict", "full_model"]
methods_to_try.remove(load_method)
methods_to_try.insert(0, load_method)

for method in methods_to_try:
try:
if method == "weights_only":
state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True)

elif method == "state_dict":
checkpoint = torch.load(pickle_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

elif method == "full_model":
model = torch.load(pickle_path, map_location='cpu')
if hasattr(model, 'state_dict'):
state_dict = model.state_dict()
else:
state_dict = model

print(f"✅ Successfully loaded with method: {method}")
break

except Exception as e:
print(f"❌ Method {method} failed: {e}")
continue
else:
raise RuntimeError(f"All loading methods failed for {pickle_path}")

# Clean up the state dict
state_dict = clean_state_dict(state_dict)

print(f"Found {len(state_dict)} parameters")

# Save as safetensors
print(f"Saving to safetensors: {safetensors_path}")
os.makedirs(os.path.dirname(safetensors_path), exist_ok=True)
save_file(state_dict, safetensors_path)

print("✅ Conversion complete!")
return state_dict


def clean_state_dict(state_dict):
"""
Clean up state dict by removing unwanted prefixes or keys.
"""
cleaned = {}

for key, value in state_dict.items():
# Remove common prefixes that might interfere
clean_key = key

if clean_key.startswith('module.'):
clean_key = clean_key[7:]

if clean_key != key:
print(f"Cleaned key: {key} -> {clean_key}")

cleaned[clean_key] = value

return cleaned
57 changes: 28 additions & 29 deletions video/Wan2.2/wan/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import mlx.core as mx
import mlx.nn as nn
import numpy as np

__all__ = ['WanModel']

Expand Down Expand Up @@ -52,17 +51,17 @@ def rope_apply(x, grid_sizes, freqs):

# reshape x_i to complex representation
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)

# precompute frequency multipliers for each dimension
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)

freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)

freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)

# Concatenate frequency components
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
Expand All @@ -72,18 +71,18 @@ def rope_apply(x, grid_sizes, freqs):
x_imag = x_i[..., 1]
freqs_real = freqs_i[..., 0]
freqs_imag = freqs_i[..., 1]

out_real = x_real * freqs_real - x_imag * freqs_imag
out_imag = x_real * freqs_imag + x_imag * freqs_real

x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)

# Handle remaining sequence
if x.shape[1] > seq_len:
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)

output.append(x_i)

return mx.stack(output)


Expand Down Expand Up @@ -136,29 +135,29 @@ def mlx_attention(
# Get shapes
b, lq, n, d = q.shape
_, lk, _, _ = k.shape

# Scale queries if needed
if q_scale is not None:
q = q * q_scale

# Compute attention scores
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]

# Compute attention scores
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]

# Apply softmax scale if provided
if softmax_scale is not None:
scores = scores * softmax_scale
else:
# Default scaling by sqrt(d)
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))

# Create attention mask
attn_mask = None

# Apply window size masking if specified
if window_size != (-1, -1):
left_window, right_window = window_size
Expand All @@ -168,20 +167,20 @@ def mlx_attention(
end = min(lk, i + right_window + 1)
window_mask[i, start:end] = 1
attn_mask = window_mask

# Apply causal masking if needed
if causal:
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
if attn_mask is None:
attn_mask = causal_mask
else:
attn_mask = mx.logical_and(attn_mask, causal_mask)

# Apply attention mask if present
if attn_mask is not None:
attn_mask = attn_mask.astype(scores.dtype)
scores = scores * attn_mask + (1 - attn_mask) * -1e4

# Apply attention mask if lengths are provided
if q_lens is not None or k_lens is not None:
if q_lens is not None:
Expand All @@ -192,22 +191,22 @@ def mlx_attention(
mask = mx.arange(lk)[None, :] < k_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4

# Apply softmax
max_scores = mx.max(scores, axis=-1, keepdims=True)
scores = scores - max_scores
exp_scores = mx.exp(scores)
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
attn = exp_scores / (sum_exp + 1e-6)

# Apply dropout if needed
if dropout_p > 0 and not deterministic:
raise NotImplementedError("Dropout not implemented in MLX version")

# Compute output
out = mx.matmul(attn, v) # [b, n, lq, d]
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]

return out

class WanSelfAttention(nn.Module):
Expand Down Expand Up @@ -356,7 +355,7 @@ def __call__(
y = self.ffn(
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
x = x + y * mx.squeeze(e[5], axis=2)

return x


Expand Down Expand Up @@ -541,13 +540,13 @@ def __call__(

grid_sizes = mx.stack(
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])


x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]

seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
assert seq_lens.max() <= seq_len

# Pad sequences
x_padded = []
for u in x:
Expand Down Expand Up @@ -637,24 +636,24 @@ def init_weights(self):
std = math.sqrt(2.0 / (fan_in + fan_out))
self.patch_embedding.weight = mx.random.uniform(
low=-std, high=std, shape=self.patch_embedding.weight.shape)

# Initialize text embedding layers with normal distribution
text_layers = list(self.text_embedding.layers)
for i in [0, 2]: # First and third layers
layer = text_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)

# Initialize time embedding layers
time_layers = list(self.time_embedding.layers)
for i in [0, 2]: # First and third layers
layer = time_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)

# Initialize output head to zeros
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
self.head.head.bias = mx.zeros(self.head.head.bias.shape)
self.head.head.bias = mx.zeros(self.head.head.bias.shape)
Loading