diff --git a/video/Wan2.2/generate.py b/video/Wan2.2/generate.py index aeb785a3f..a98c2676a 100644 --- a/video/Wan2.2/generate.py +++ b/video/Wan2.2/generate.py @@ -2,25 +2,22 @@ import argparse import logging import os +import random import sys import warnings from datetime import datetime -warnings.filterwarnings('ignore') - -import random -import mlx.core as mx -import mlx.nn as nn from PIL import Image -import numpy as np - -# Note: MLX doesn't have built-in distributed training support like PyTorch -# For distributed training, you would need to implement custom logic or use MPI import wan # Assuming wan has been converted to MLX from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.utils import save_video, str2bool +# Note: MLX doesn't have built-in distributed training support like PyTorch +# For distributed training, you would need to implement custom logic or use MPI + +warnings.filterwarnings('ignore') + EXAMPLE_PROMPT = { "t2v-A14B": { "prompt": @@ -220,25 +217,25 @@ def generate(args): rank = 0 world_size = 1 local_rank = 0 - + # Check for distributed execution environment variables # Note: Actual distributed implementation would require custom logic if "RANK" in os.environ: logging.warning("MLX doesn't have built-in distributed training. Running on single device.") - + _init_logging(rank) if args.offload_model is None: args.offload_model = False logging.info( f"offload_model is not specified, set to {args.offload_model}.") - + # MLX doesn't support FSDP or distributed training out of the box if args.t5_fsdp or args.dit_fsdp: logging.warning("FSDP is not supported in MLX. Ignoring FSDP flags.") args.t5_fsdp = False args.dit_fsdp = False - + if args.ulysses_size > 1: logging.warning("Sequence parallel is not supported in MLX single-device mode. Setting ulysses_size to 1.") args.ulysses_size = 1 @@ -369,7 +366,7 @@ def generate(args): args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix logging.info(f"Saving generated video to {args.save_file}") - + # Don't convert to numpy - keep as MLX array save_video( tensor=video[None], # Just add batch dimension, keep as MLX array diff --git a/video/Wan2.2/wan/__init__.py b/video/Wan2.2/wan/__init__.py index c7b180b86..614aab1c5 100644 --- a/video/Wan2.2/wan/__init__.py +++ b/video/Wan2.2/wan/__init__.py @@ -1,2 +1,2 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -from .text2video import WanT2V \ No newline at end of file +from .text2video import WanT2V diff --git a/video/Wan2.2/wan/configs/__init__.py b/video/Wan2.2/wan/configs/__init__.py index 875afe7e1..2720176e7 100644 --- a/video/Wan2.2/wan/configs/__init__.py +++ b/video/Wan2.2/wan/configs/__init__.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import copy import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' diff --git a/video/Wan2.2/wan/configs/shared_config.py b/video/Wan2.2/wan/configs/shared_config.py index c58ab04ff..bc806c573 100644 --- a/video/Wan2.2/wan/configs/shared_config.py +++ b/video/Wan2.2/wan/configs/shared_config.py @@ -1,5 +1,5 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import torch +import mlx.core as mx from easydict import EasyDict #------------------------ Wan shared config ------------------------# @@ -7,11 +7,11 @@ # t5 wan_shared_cfg.t5_model = 'umt5_xxl' -wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.t5_dtype = mx.bfloat16 wan_shared_cfg.text_len = 512 # transformer -wan_shared_cfg.param_dtype = torch.bfloat16 +wan_shared_cfg.param_dtype = mx.bfloat16 # inference wan_shared_cfg.num_train_timesteps = 1000 diff --git a/video/Wan2.2/wan/configs/wan_i2v_A14B.py b/video/Wan2.2/wan/configs/wan_i2v_A14B.py index f654cc6b2..dbb23196d 100644 --- a/video/Wan2.2/wan/configs/wan_i2v_A14B.py +++ b/video/Wan2.2/wan/configs/wan_i2v_A14B.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import torch from easydict import EasyDict from .shared_config import wan_shared_cfg @@ -9,6 +8,7 @@ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B') i2v_A14B.update(wan_shared_cfg) +# t5 i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' i2v_A14B.t5_tokenizer = 'google/umt5-xxl' diff --git a/video/Wan2.2/wan/model_converter.py b/video/Wan2.2/wan/model_converter.py new file mode 100644 index 000000000..15eca0e27 --- /dev/null +++ b/video/Wan2.2/wan/model_converter.py @@ -0,0 +1,83 @@ +import os + +import torch +from safetensors.torch import save_file + + +def convert_pickle_to_safetensors( + pickle_path: str, + safetensors_path: str, + load_method: str = "weights_only" # Changed default to weights_only +): + """Convert PyTorch pickle file to safetensors format.""" + + print(f"Loading PyTorch weights from: {pickle_path}") + + # Try multiple loading methods in order of preference + methods_to_try = ["weights_only", "state_dict", "full_model"] + methods_to_try.remove(load_method) + methods_to_try.insert(0, load_method) + + for method in methods_to_try: + try: + if method == "weights_only": + state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True) + + elif method == "state_dict": + checkpoint = torch.load(pickle_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, dict) and 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + elif method == "full_model": + model = torch.load(pickle_path, map_location='cpu') + if hasattr(model, 'state_dict'): + state_dict = model.state_dict() + else: + state_dict = model + + print(f"āœ… Successfully loaded with method: {method}") + break + + except Exception as e: + print(f"āŒ Method {method} failed: {e}") + continue + else: + raise RuntimeError(f"All loading methods failed for {pickle_path}") + + # Clean up the state dict + state_dict = clean_state_dict(state_dict) + + print(f"Found {len(state_dict)} parameters") + + # Save as safetensors + print(f"Saving to safetensors: {safetensors_path}") + os.makedirs(os.path.dirname(safetensors_path), exist_ok=True) + save_file(state_dict, safetensors_path) + + print("āœ… Conversion complete!") + return state_dict + + +def clean_state_dict(state_dict): + """ + Clean up state dict by removing unwanted prefixes or keys. + """ + cleaned = {} + + for key, value in state_dict.items(): + # Remove common prefixes that might interfere + clean_key = key + + if clean_key.startswith('module.'): + clean_key = clean_key[7:] + + if clean_key != key: + print(f"Cleaned key: {key} -> {clean_key}") + + cleaned[clean_key] = value + + return cleaned diff --git a/video/Wan2.2/wan/modules/model.py b/video/Wan2.2/wan/modules/model.py index 6cd3aa799..aa841b346 100644 --- a/video/Wan2.2/wan/modules/model.py +++ b/video/Wan2.2/wan/modules/model.py @@ -4,7 +4,6 @@ import mlx.core as mx import mlx.nn as nn -import numpy as np __all__ = ['WanModel'] @@ -52,17 +51,17 @@ def rope_apply(x, grid_sizes, freqs): # reshape x_i to complex representation x_i = x[i, :seq_len].reshape(seq_len, n, c, 2) - + # precompute frequency multipliers for each dimension freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2) freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2) - + freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2) freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2) - + freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2) freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2) - + # Concatenate frequency components freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1) freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2) @@ -72,18 +71,18 @@ def rope_apply(x, grid_sizes, freqs): x_imag = x_i[..., 1] freqs_real = freqs_i[..., 0] freqs_imag = freqs_i[..., 1] - + out_real = x_real * freqs_real - x_imag * freqs_imag out_imag = x_real * freqs_imag + x_imag * freqs_real - + x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1) - + # Handle remaining sequence if x.shape[1] > seq_len: x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0) output.append(x_i) - + return mx.stack(output) @@ -136,29 +135,29 @@ def mlx_attention( # Get shapes b, lq, n, d = q.shape _, lk, _, _ = k.shape - + # Scale queries if needed if q_scale is not None: q = q * q_scale - + # Compute attention scores q = q.transpose(0, 2, 1, 3) # [b, n, lq, d] k = k.transpose(0, 2, 1, 3) # [b, n, lk, d] v = v.transpose(0, 2, 1, 3) # [b, n, lk, d] - + # Compute attention scores scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk] - + # Apply softmax scale if provided if softmax_scale is not None: scores = scores * softmax_scale else: # Default scaling by sqrt(d) scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype)) - + # Create attention mask attn_mask = None - + # Apply window size masking if specified if window_size != (-1, -1): left_window, right_window = window_size @@ -168,7 +167,7 @@ def mlx_attention( end = min(lk, i + right_window + 1) window_mask[i, start:end] = 1 attn_mask = window_mask - + # Apply causal masking if needed if causal: causal_mask = mx.tril(mx.ones((lq, lk)), k=0) @@ -176,12 +175,12 @@ def mlx_attention( attn_mask = causal_mask else: attn_mask = mx.logical_and(attn_mask, causal_mask) - + # Apply attention mask if present if attn_mask is not None: attn_mask = attn_mask.astype(scores.dtype) scores = scores * attn_mask + (1 - attn_mask) * -1e4 - + # Apply attention mask if lengths are provided if q_lens is not None or k_lens is not None: if q_lens is not None: @@ -192,22 +191,22 @@ def mlx_attention( mask = mx.arange(lk)[None, :] < k_lens[:, None] mask = mask.astype(scores.dtype) scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4 - + # Apply softmax max_scores = mx.max(scores, axis=-1, keepdims=True) scores = scores - max_scores exp_scores = mx.exp(scores) sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True) attn = exp_scores / (sum_exp + 1e-6) - + # Apply dropout if needed if dropout_p > 0 and not deterministic: raise NotImplementedError("Dropout not implemented in MLX version") - + # Compute output out = mx.matmul(attn, v) # [b, n, lq, d] out = out.transpose(0, 2, 1, 3) # [b, lq, n, d] - + return out class WanSelfAttention(nn.Module): @@ -356,7 +355,7 @@ def __call__( y = self.ffn( self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2)) x = x + y * mx.squeeze(e[5], axis=2) - + return x @@ -541,13 +540,13 @@ def __call__( grid_sizes = mx.stack( [mx.array(u.shape[1:4], dtype=mx.int32) for u in x]) - + x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x] seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32) assert seq_lens.max() <= seq_len - + # Pad sequences x_padded = [] for u in x: @@ -637,7 +636,7 @@ def init_weights(self): std = math.sqrt(2.0 / (fan_in + fan_out)) self.patch_embedding.weight = mx.random.uniform( low=-std, high=std, shape=self.patch_embedding.weight.shape) - + # Initialize text embedding layers with normal distribution text_layers = list(self.text_embedding.layers) for i in [0, 2]: # First and third layers @@ -645,7 +644,7 @@ def init_weights(self): layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02 if hasattr(layer, 'bias') and layer.bias is not None: layer.bias = mx.zeros(layer.bias.shape) - + # Initialize time embedding layers time_layers = list(self.time_embedding.layers) for i in [0, 2]: # First and third layers @@ -653,8 +652,8 @@ def init_weights(self): layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02 if hasattr(layer, 'bias') and layer.bias is not None: layer.bias = mx.zeros(layer.bias.shape) - + # Initialize output head to zeros self.head.head.weight = mx.zeros(self.head.head.weight.shape) if hasattr(self.head.head, 'bias') and self.head.head.bias is not None: - self.head.head.bias = mx.zeros(self.head.head.bias.shape) \ No newline at end of file + self.head.head.bias = mx.zeros(self.head.head.bias.shape) diff --git a/video/Wan2.2/wan/modules/t5.py b/video/Wan2.2/wan/modules/t5.py index fc2c57deb..a31fc55a9 100644 --- a/video/Wan2.2/wan/modules/t5.py +++ b/video/Wan2.2/wan/modules/t5.py @@ -1,14 +1,12 @@ # MLX implementation for t5.py import logging import math -from typing import Optional, Tuple, List import mlx.core as mx import mlx.nn as nn -import numpy as np from mlx.utils import tree_unflatten -from .tokenizers import HuggingfaceTokenizer +from ..t5_model_io import sanitize __all__ = [ 'T5Model', @@ -90,11 +88,11 @@ def __call__(self, x, context=None, mask=None, pos_bias=None): # compute attention (T5 does not use scaling) attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2] - + # add position bias if provided if pos_bias is not None: attn = attn + pos_bias - + # apply mask if mask is not None: if mask.ndim == 2: @@ -110,14 +108,14 @@ def __call__(self, x, context=None, mask=None, pos_bias=None): # softmax and apply attention attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype) attn = self.dropout(attn) - + # apply attention to values x = mx.matmul(attn, v) # [B, N, L1, C] - + # transpose back and reshape x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C] x = x.reshape(b, l1, -1) - + # output projection x = self.o(x) x = self.dropout(x) @@ -237,17 +235,17 @@ def __call__(self, lq, lk): positions_q = mx.arange(lq)[:, None] positions_k = mx.arange(lk)[None, :] rel_pos = positions_k - positions_q - + # Apply bucketing rel_pos = self._relative_position_bucket(rel_pos) - + # Get embeddings rel_pos_embeds = self.embedding(rel_pos) - + # Reshape to [1, N, Lq, Lk] rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1)) rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0) - + return rel_pos_embeds def _relative_position_bucket(self, rel_pos): @@ -264,19 +262,19 @@ def _relative_position_bucket(self, rel_pos): # embeddings for small and large positions max_exact = num_buckets // 2 is_small = rel_pos < max_exact - + # For large positions, use log scale rel_pos_large = max_exact + ( mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) ).astype(mx.int32) - + rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1) - + # Combine small and large position buckets rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large) - + return rel_buckets @@ -305,7 +303,7 @@ def __init__(self, self.token_embedding = vocab else: self.token_embedding = nn.Embedding(vocab, dim) - + self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=True) if shared_pos else None self.dropout = nn.Dropout(dropout) @@ -352,7 +350,7 @@ def __init__(self, self.token_embedding = vocab else: self.token_embedding = nn.Embedding(vocab, dim) - + self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=False) if shared_pos else None self.dropout = nn.Dropout(dropout) @@ -425,10 +423,10 @@ def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): def init_mlx_weights(module, key): """Initialize weights for T5 model components to match PyTorch initialization""" - + def normal(key, shape, std=1.0): return mx.random.normal(key, shape) * std - + if isinstance(module, T5LayerNorm): module.weight = mx.ones_like(module.weight) elif isinstance(module, nn.Embedding): @@ -445,7 +443,7 @@ def normal(key, shape, std=1.0): std=module.dim_ffn**-0.5) elif isinstance(module, T5Attention): # Match PyTorch initialization - key1, key2, key3, key4 = random.split(key, 4) + key1, key2, key3, key4 = mx.random.split(key, 4) module.q.weight = normal(key1, module.q.weight.shape, std=(module.dim * module.dim_attn)**-0.5) module.k.weight = normal(key2, module.k.weight.shape, @@ -464,7 +462,7 @@ def normal(key, shape, std=1.0): fan_in = module.weight.shape[1] bound = 1.0 / math.sqrt(fan_in) module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound) - + return module @@ -493,7 +491,7 @@ def _t5(name, # init model model = model_cls(**kwargs) - + # Initialize weights properly key = mx.random.key(0) model = init_mlx_weights(model, key) @@ -529,6 +527,7 @@ def __init__( text_len, checkpoint_path=None, tokenizer_path=None, + dtype=mx.bfloat16 ): self.text_len = text_len self.checkpoint_path = checkpoint_path @@ -538,15 +537,21 @@ def __init__( model = umt5_xxl( encoder_only=True, return_tokenizer=False) - + if checkpoint_path: logging.info(f'loading {checkpoint_path}') # Load weights - assuming MLX format checkpoint weights = mx.load(checkpoint_path) + + for key, weight in weights.items(): + if weight.dtype != dtype and mx.issubdtype(weight.dtype, mx.floating): + weights[key] = weight.astype(dtype) + + weights = sanitize(weights) model.update(tree_unflatten(list(weights.items()))) - + self.model = model - + # init tokenizer from .tokenizers import HuggingfaceTokenizer self.tokenizer = HuggingfaceTokenizer( @@ -558,11 +563,11 @@ def __call__(self, texts): # Handle single string input if isinstance(texts, str): texts = [texts] - + # Tokenize texts tokenizer_output = self.tokenizer( texts, return_mask=True, add_special_tokens=True) - + # Handle different tokenizer output formats if isinstance(tokenizer_output, tuple): ids, mask = tokenizer_output @@ -570,47 +575,24 @@ def __call__(self, texts): # Assuming dict output with 'input_ids' and 'attention_mask' ids = tokenizer_output['input_ids'] mask = tokenizer_output['attention_mask'] - + # Convert to MLX arrays if not already if not isinstance(ids, mx.array): ids = mx.array(ids) if not isinstance(mask, mx.array): mask = mx.array(mask) - + # Get sequence lengths seq_lens = mx.sum(mask > 0, axis=1) - + # Run encoder context = self.model(ids, mask) - + # Return variable length outputs # Convert seq_lens to Python list for indexing if seq_lens.ndim == 0: # Single value seq_lens_list = [seq_lens.item()] else: seq_lens_list = seq_lens.tolist() - - return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))] - -# Utility function to convert PyTorch checkpoint to MLX -def convert_pytorch_checkpoint(pytorch_path, mlx_path): - """Convert PyTorch checkpoint to MLX format""" - import torch - - # Load PyTorch checkpoint - pytorch_state = torch.load(pytorch_path, map_location='cpu') - - # Convert to numpy then to MLX - mlx_state = {} - for key, value in pytorch_state.items(): - if isinstance(value, torch.Tensor): - # Handle the key mapping if needed - mlx_key = key - # Convert tensor to MLX array - mlx_state[mlx_key] = mx.array(value.numpy()) - - # Save MLX checkpoint - mx.save(mlx_path, mlx_state) - - return mlx_state \ No newline at end of file + return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))] diff --git a/video/Wan2.2/wan/modules/vae2_1.py b/video/Wan2.2/wan/modules/vae2_1.py index b021cc0d4..404d379ab 100644 --- a/video/Wan2.2/wan/modules/vae2_1.py +++ b/video/Wan2.2/wan/modules/vae2_1.py @@ -1,13 +1,12 @@ # MLX implementation of vae2_1.py import logging -from typing import Optional, List, Tuple import mlx.core as mx import mlx.nn as nn -import numpy as np - from mlx.utils import tree_unflatten +from ..vae_model_io import sanitize + __all__ = [ 'Wan2_1_VAE', ] @@ -35,12 +34,12 @@ def __call__(self, x, cache_x=None): if cache_x is not None and self._padding[4] > 0: x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis padding[4] -= cache_x.shape[1] - + # Pad in BTHWC format pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]), (padding[0], padding[1]), (0, 0)] x = mx.pad(x, pad_width) - + result = super().__call__(x) return result @@ -74,10 +73,10 @@ def __init__(self, scale_factor, mode='nearest-exact'): def __call__(self, x): scale_h, scale_w = self.scale_factor - + out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension - + return out class AsymmetricPad(nn.Module): @@ -124,17 +123,17 @@ def __init__(self, dim, mode): pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0))) conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0) self.resample = nn.Sequential(pad_layer, conv_layer) - + self.time_conv = CausalConv3d( dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - + else: self.resample = nn.Identity() def __call__(self, x, feat_cache=None, feat_idx=[0]): # The __call__ method logic remains unchanged from your original code b, t, h, w, c = x.shape - + if self.mode == 'upsample3d': if feat_cache is not None: idx = feat_idx[0] @@ -151,7 +150,7 @@ def __call__(self, x, feat_cache=None, feat_idx=[0]): cache_x = mx.concatenate([ mx.zeros_like(cache_x), cache_x ], axis=1) - + if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -162,10 +161,10 @@ def __call__(self, x, feat_cache=None, feat_idx=[0]): x = x.reshape(b, t, h, w, 2, c) x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2) x = x.reshape(b, t * 2, h, w, c) - + t = x.shape[1] x = x.reshape(b * t, h, w, c) - + x = self.resample(x) _, h_new, w_new, c_new = x.shape @@ -183,7 +182,7 @@ def __call__(self, x, feat_cache=None, feat_idx=[0]): mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) feat_cache[idx] = cache_x feat_idx[0] += 1 - + return x @@ -204,13 +203,13 @@ def __init__(self, in_dim, out_dim, dropout=0.0): nn.Dropout(dropout) if dropout > 0 else nn.Identity(), CausalConv3d(out_dim, out_dim, 3, padding=1) ) - + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() def __call__(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) - + for i, layer in enumerate(self.residual.layers): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -254,7 +253,7 @@ def __call__(self, x): qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c) qkv = qkv.reshape(b * t, h * w, 3 * c) q, k, v = mx.split(qkv, 3, axis=-1) - + # Reshape for attention q = q.reshape(b * t, h * w, c) k = k.reshape(b * t, h * w, c) @@ -535,10 +534,10 @@ def encode(self, x, scale): feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = mx.concatenate([out, out_], axis=1) - + z = self.conv1(out) mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension - + if isinstance(scale[0], mx.array): # Reshape scale for broadcasting in BTHWC format scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim) @@ -605,7 +604,7 @@ def clear_cache(self): self._enc_feat_map = [None] * self._enc_conv_num -def _video_vae(pretrained_path=None, z_dim=None, **kwargs): +def _video_vae(pretrained_path=None, z_dim=None, dtype=mx.float32, **kwargs): # params cfg = dict( dim=96, @@ -624,6 +623,12 @@ def _video_vae(pretrained_path=None, z_dim=None, **kwargs): if pretrained_path: logging.info(f'loading {pretrained_path}') weights = mx.load(pretrained_path) + + for key, weight in weights.items(): + if weight.dtype != dtype and mx.issubdtype(weight.dtype, mx.floating): + weights[key] = weight.astype(dtype) + + weights = sanitize(weights) model.update(tree_unflatten(list(weights.items()))) return model @@ -653,6 +658,7 @@ def __init__(self, self.model = _video_vae( pretrained_path=vae_pth, z_dim=z_dim, + dtype=dtype ) def encode(self, videos): @@ -665,16 +671,16 @@ def encode(self, videos): # Convert CTHW -> BTHWC x = mx.expand_dims(video, axis=0) # Add batch dimension x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC - + # Encode z = self.model.encode(x, self.scale)[0] # Get mu only - + # Convert back BTHWC -> CTHW and remove batch dimension z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW z = z.squeeze(0) # Remove batch dimension -> CTHW - + encoded.append(z.astype(mx.float32)) - + return encoded def decode(self, zs): @@ -687,17 +693,17 @@ def decode(self, zs): # Convert CTHW -> BTHWC x = mx.expand_dims(z, axis=0) # Add batch dimension x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC - + # Decode x = self.model.decode(x, self.scale) - + # Convert back BTHWC -> CTHW and remove batch dimension x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW x = x.squeeze(0) # Remove batch dimension -> CTHW - + # Clamp values x = mx.clip(x, -1, 1) - + decoded.append(x.astype(mx.float32)) - - return decoded \ No newline at end of file + + return decoded diff --git a/video/Wan2.2/wan/t5_model_io.py b/video/Wan2.2/wan/t5_model_io.py index 4e4bce7c3..380173c31 100644 --- a/video/Wan2.2/wan/t5_model_io.py +++ b/video/Wan2.2/wan/t5_model_io.py @@ -1,265 +1,50 @@ -import json -from typing import Optional, List, Tuple -import mlx.core as mx -from mlx.utils import tree_unflatten -from safetensors import safe_open -import torch - - -def check_safetensors_dtypes(safetensors_path: str): - """ - Check what dtypes are in the safetensors file. - Useful for debugging dtype issues. - """ - print(f"šŸ” Checking dtypes in: {safetensors_path}") - - dtype_counts = {} - - with safe_open(safetensors_path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - dtype_str = str(tensor.dtype) - - if dtype_str not in dtype_counts: - dtype_counts[dtype_str] = [] - dtype_counts[dtype_str].append(key) - - print("šŸ“Š Dtype summary:") - for dtype, keys in dtype_counts.items(): - print(f" {dtype}: {len(keys)} parameters") - if dtype == "torch.bfloat16": - print(f" āš ļø BFloat16 detected - will convert to float32") - print(f" Examples: {keys[:3]}") - - return dtype_counts - - -def convert_tensor_dtype(tensor: torch.Tensor) -> torch.Tensor: - """ - Convert tensor to MLX-compatible dtype. - """ - if tensor.dtype == torch.bfloat16: - # Convert BFloat16 to float32 - return tensor.float() - elif tensor.dtype == torch.float64: - # Convert float64 to float32 for efficiency - return tensor.float() - else: - # Keep other dtypes as-is - return tensor - - -def map_t5_encoder_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]: - """ - Map T5 encoder weights from PyTorch format to MLX format. - Following the pattern used in MLX Stable Diffusion. - - Args: - key: Parameter name from PyTorch model - value: Parameter tensor - - Returns: - List of (key, value) tuples for MLX model - """ - - # Handle the main structural difference: FFN gate layer - if ".ffn.gate.0.weight" in key: - # PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act - key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight") - return [(key, value)] - - elif ".ffn.gate.0.bias" in key: - # Handle bias if it exists - key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias") - return [(key, value)] - - elif ".ffn.gate.1" in key: - # Skip GELU activation parameters - MLX handles this separately - print(f"Skipping GELU parameter: {key}") - return [] - - # Handle any other potential FFN mappings - elif ".ffn.fc1.weight" in key: - return [(key, value)] - elif ".ffn.fc2.weight" in key: - return [(key, value)] - - # Handle attention layers (should be direct mapping) - elif ".attn.q.weight" in key: - return [(key, value)] - elif ".attn.k.weight" in key: - return [(key, value)] - elif ".attn.v.weight" in key: - return [(key, value)] - elif ".attn.o.weight" in key: - return [(key, value)] - - # Handle embeddings and norms (direct mapping) - elif "token_embedding.weight" in key: - return [(key, value)] - elif "pos_embedding.embedding.weight" in key: - return [(key, value)] - elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key: - return [(key, value)] - - # Default: direct mapping for any other parameters - else: - return [(key, value)] - - -def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]: - """Flatten nested list of parameter tuples""" - return [(k, v) for p in params for (k, v) in p] - - -def _load_safetensor_weights( - mapper_func, - model, - weight_file: str, - float16: bool = False -): - """ - Load safetensor weights using the mapping function. - Based on MLX SD pattern. - """ - dtype = mx.float16 if float16 else mx.float32 - - # Load weights from safetensors file - weights = {} - with safe_open(weight_file, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - - # Handle BFloat16 - convert to float32 first - if tensor.dtype == torch.bfloat16: - print(f"Converting BFloat16 to float32 for: {key}") - tensor = tensor.float() # Convert to float32 - - weights[key] = mx.array(tensor.numpy()).astype(dtype) - - # Apply mapping function - mapped_weights = _flatten([mapper_func(k, v) for k, v in weights.items()]) - - # Update model with mapped weights - model.update(tree_unflatten(mapped_weights)) - - return model - - -def load_t5_encoder_from_safetensors( - safetensors_path: str, - model, # Your MLX T5Encoder instance - float16: bool = False -): - """ - Load T5 encoder weights from safetensors file into MLX model. - - Args: - safetensors_path: Path to the safetensors file - model: Your MLX T5Encoder model instance - float16: Whether to use float16 precision - - Returns: - Model with loaded weights - """ - print(f"Loading T5 encoder weights from: {safetensors_path}") - - # Load and map weights - model = _load_safetensor_weights( - map_t5_encoder_weights, - model, - safetensors_path, - float16 - ) - - print("T5 encoder weights loaded successfully!") - return model - - -def debug_weight_mapping(safetensors_path: str, float16: bool = False): - """ - Debug function to see how weights are being mapped. - Useful for troubleshooting conversion issues. - """ - dtype = mx.float16 if float16 else mx.float32 - - print("=== T5 Weight Mapping Debug ===") - - with safe_open(safetensors_path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - - # Handle BFloat16 - original_dtype = tensor.dtype - if tensor.dtype == torch.bfloat16: - print(f"Converting BFloat16 to float32 for: {key}") - tensor = tensor.float() - - value = mx.array(tensor.numpy()).astype(dtype) - - # Apply mapping - mapped = map_t5_encoder_weights(key, value) - - if len(mapped) == 0: - print(f"SKIPPED: {key} ({original_dtype}) -> (no mapping)") - elif len(mapped) == 1: - new_key, new_value = mapped[0] - if new_key == key: - print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]") - else: - print(f"MAPPED: {key} ({original_dtype}) -> {new_key} [{tensor.shape}]") - else: - print(f"SPLIT: {key} ({original_dtype}) -> {len(mapped)} parameters") - for new_key, new_value in mapped: - print(f" -> {new_key} [{new_value.shape}]") - - -def convert_safetensors_to_mlx_weights( - safetensors_path: str, - output_path: str, - float16: bool = False -): - """ - Convert safetensors file to MLX weights file (.npz format). - - Args: - safetensors_path: Input safetensors file - output_path: Output MLX weights file (.npz) - float16: Whether to use float16 precision - """ - dtype = mx.float16 if float16 else mx.float32 - - print(f"Converting safetensors to MLX format...") - print(f"Input: {safetensors_path}") - print(f"Output: {output_path}") - print(f"Target dtype: {dtype}") - - # Load and convert weights - weights = {} - bfloat16_count = 0 - - with safe_open(safetensors_path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - - # Handle BFloat16 - # if tensor.dtype == torch.bfloat16: - # bfloat16_count += 1 - # tensor = tensor.float() # Convert to float32 first - - value = mx.array(tensor.numpy())#.astype(dtype) - - # Apply mapping - mapped = map_t5_encoder_weights(key, value) - - for new_key, new_value in mapped: - weights[new_key] = new_value - - if bfloat16_count > 0: - print(f"āš ļø Converted {bfloat16_count} BFloat16 tensors to float32") - - # Save as MLX format - print(f"Saving {len(weights)} parameters to: {output_path}") - mx.save_safetensors(output_path, weights) - - return weights \ No newline at end of file +def sanitize(weights): + """Following the pattern used in MLX Stable Diffusion.""" + new_weights = {} + + for key, weight in weights.items(): + # Handle the main structural difference: FFN gate layer + if ".ffn.gate.0.weight" in key: + # PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act + key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight") + + elif ".ffn.gate.0.bias" in key: + # Handle bias if it exists + key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias") + + elif ".ffn.gate.1" in key: + # Skip GELU activation parameters - MLX handles this separately + print(f"Skipping GELU parameter: {key}") + continue + + # Handle any other potential FFN mappings + elif ".ffn.fc1.weight" in key: + pass + elif ".ffn.fc2.weight" in key: + pass + + # Handle attention layers (should be direct mapping) + elif ".attn.q.weight" in key: + pass + elif ".attn.k.weight" in key: + pass + elif ".attn.v.weight" in key: + pass + elif ".attn.o.weight" in key: + pass + + # Handle embeddings and norms (direct mapping) + elif "token_embedding.weight" in key: + pass + elif "pos_embedding.embedding.weight" in key: + pass + elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key: + pass + + # Default: direct mapping for any other parameters + else: + pass + + new_weights[key] = weight + + return new_weights diff --git a/video/Wan2.2/wan/t5_torch_to_sf.py b/video/Wan2.2/wan/t5_torch_to_sf.py deleted file mode 100644 index 5b70e6ef6..000000000 --- a/video/Wan2.2/wan/t5_torch_to_sf.py +++ /dev/null @@ -1,233 +0,0 @@ -import os -import torch -from safetensors.torch import save_file -from pathlib import Path -import json - -from wan.modules.t5 import T5Model - - -def convert_pickle_to_safetensors( - pickle_path: str, - safetensors_path: str, - model_class=None, - model_kwargs=None, - load_method: str = "weights_only" # Changed default to weights_only -): - """Convert PyTorch pickle file to safetensors format.""" - - print(f"Loading PyTorch weights from: {pickle_path}") - - # Try multiple loading methods in order of preference - methods_to_try = [load_method, "weights_only", "state_dict", "full_model"] - - for method in methods_to_try: - try: - if method == "weights_only": - state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True) - - elif method == "state_dict": - checkpoint = torch.load(pickle_path, map_location='cpu') - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif isinstance(checkpoint, dict) and 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - state_dict = checkpoint - - elif method == "full_model": - model = torch.load(pickle_path, map_location='cpu') - if hasattr(model, 'state_dict'): - state_dict = model.state_dict() - else: - state_dict = model - - print(f"āœ… Successfully loaded with method: {method}") - break - - except Exception as e: - print(f"āŒ Method {method} failed: {e}") - continue - else: - raise RuntimeError(f"All loading methods failed for {pickle_path}") - - # Clean up the state dict - state_dict = clean_state_dict(state_dict) - - print(f"Found {len(state_dict)} parameters") - - # Convert BF16 to FP32 if needed - for key, tensor in state_dict.items(): - if tensor.dtype == torch.bfloat16: - state_dict[key] = tensor.to(torch.float32) - print(f"Converted {key} from bfloat16 to float32") - - # Save as safetensors - print(f"Saving to safetensors: {safetensors_path}") - os.makedirs(os.path.dirname(safetensors_path), exist_ok=True) - save_file(state_dict, safetensors_path) - - print("āœ… T5 conversion complete!") - return state_dict - - -def clean_state_dict(state_dict): - """ - Clean up state dict by removing unwanted prefixes or keys. - """ - cleaned = {} - - for key, value in state_dict.items(): - # Remove common prefixes that might interfere - clean_key = key - - if clean_key.startswith('module.'): - clean_key = clean_key[7:] - - if clean_key.startswith('model.'): - clean_key = clean_key[6:] - - cleaned[clean_key] = value - - if clean_key != key: - print(f"Cleaned key: {key} -> {clean_key}") - - return cleaned - - -def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs): - """ - Load pickle weights into your specific PyTorch T5 model implementation. - - Args: - pickle_path: Path to pickle file - model_class: Your T5Encoder class - **model_kwargs: Arguments for your model constructor - """ - - print("Method 1: Loading into your PyTorch T5 model") - - # Initialize your model - model = model_class(**model_kwargs) - - # Load checkpoint - checkpoint = torch.load(pickle_path, map_location='cpu') - - # Handle different checkpoint formats - if isinstance(checkpoint, dict): - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - # Assume the dict IS the state dict - state_dict = checkpoint - else: - # Assume it's a model object - state_dict = checkpoint.state_dict() - - # Clean and load - state_dict = clean_state_dict(state_dict) - model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys - - return model, state_dict - - -def explore_pickle_file(pickle_path: str): - """ - Explore the contents of a pickle file to understand its structure. - """ - print(f"šŸ” Exploring pickle file: {pickle_path}") - - try: - # Try loading with weights_only first (safer) - print("\n--- Trying weights_only=True ---") - try: - data = torch.load(pickle_path, map_location='cpu', weights_only=True) - print(f"āœ… Loaded with weights_only=True") - print(f"Type: {type(data)}") - - if isinstance(data, dict): - print(f"Dictionary with {len(data)} keys:") - for i, key in enumerate(data.keys()): - print(f" {key}: {type(data[key])}") - if hasattr(data[key], 'shape'): - print(f" Shape: {data[key].shape}") - if i >= 9: # Show first 10 keys - break - - except Exception as e: - print(f"āŒ weights_only=True failed: {e}") - - # Try regular loading - print("\n--- Trying regular loading ---") - data = torch.load(pickle_path, map_location='cpu') - print(f"āœ… Loaded successfully") - print(f"Type: {type(data)}") - - if hasattr(data, 'state_dict'): - print("šŸ“‹ Found state_dict method") - state_dict = data.state_dict() - print(f"State dict has {len(state_dict)} parameters") - - elif isinstance(data, dict): - print(f"šŸ“‹ Dictionary with keys: {list(data.keys())}") - - # Check for common checkpoint keys - if 'state_dict' in data: - print("Found 'state_dict' key") - print(f"state_dict has {len(data['state_dict'])} parameters") - elif 'model' in data: - print("Found 'model' key") - print(f"model has {len(data['model'])} parameters") - - except Exception as e: - print(f"āŒ Failed to load: {e}") - - -def full_conversion_pipeline( - pickle_path: str, - safetensors_path: str, - torch_model_class=None, - model_kwargs=None -): - """ - Complete pipeline: pickle -> safetensors -> ready for MLX conversion - """ - - print("šŸš€ Starting full conversion pipeline") - print("="*50) - - # Step 1: Explore the pickle file - print("Step 1: Exploring pickle file structure") - explore_pickle_file(pickle_path) - - # Step 2: Convert to safetensors - print(f"\nStep 2: Converting to safetensors") - - # Try different loading methods - for method in ["weights_only", "state_dict", "full_model"]: - try: - print(f"\nTrying load method: {method}") - state_dict = convert_pickle_to_safetensors( - pickle_path, - safetensors_path, - model_class=torch_model_class, - model_kwargs=model_kwargs, - load_method=method - ) - print(f"āœ… Success with method: {method}") - break - - except Exception as e: - print(f"āŒ Method {method} failed: {e}") - continue - else: - print("āŒ All methods failed!") - return None - - print(f"\nšŸŽ‰ Conversion complete!") - print(f"Safetensors file saved to: {safetensors_path}") - print(f"Ready for MLX conversion using the previous script!") - - return state_dict \ No newline at end of file diff --git a/video/Wan2.2/wan/text2video.py b/video/Wan2.2/wan/text2video.py index 8bd067a9f..08eb8aaaa 100644 --- a/video/Wan2.2/wan/text2video.py +++ b/video/Wan2.2/wan/text2video.py @@ -1,15 +1,12 @@ # MLX implementation of text2video.py import gc -import glob import logging import math import os import random import sys -from contextlib import contextmanager -from functools import partial -from typing import Optional, Tuple, List, Dict, Any, Union +from typing import Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -24,8 +21,8 @@ retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .wan_model_io import load_wan_2_2_from_safetensors -from .wan_model_io import convert_wan_2_2_safetensors_to_mlx, convert_multiple_wan_2_2_safetensors_to_mlx, load_wan_2_2_from_safetensors class WanT2V: def __init__( @@ -53,94 +50,58 @@ def __init__( self.patch_size = config.patch_size self.num_train_timesteps = config.num_train_timesteps self.boundary = config.boundary - # Convert PyTorch dtype to MLX dtype - if str(config.param_dtype) == 'torch.bfloat16': - self.param_dtype = mx.bfloat16 - elif str(config.param_dtype) == 'torch.float16': - self.param_dtype = mx.float16 - elif str(config.param_dtype) == 'torch.float32': - self.param_dtype = mx.float32 - else: - self.param_dtype = mx.float32 # default - + self.param_dtype = config.param_dtype + self.t5_dtype = config.t5_dtype + # Initialize T5 text encoder print(f"checkpoint_dir is: {checkpoint_dir}") t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint) - mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors') - if not os.path.exists(mlx_t5_path): + if not os.path.exists(t5_checkpoint_path): # Check if it's a .pth file that needs conversion pth_path = t5_checkpoint_path.replace('.safetensors', '.pth') if os.path.exists(pth_path): logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}") - from .t5_torch_to_sf import convert_pickle_to_safetensors - convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only") - # Convert torch safetensors to MLX safetensors - from .t5_model_io import convert_safetensors_to_mlx_weights - convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16)) + from .model_converter import convert_pickle_to_safetensors + convert_pickle_to_safetensors(pth_path, t5_checkpoint_path) else: raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}") - t5_checkpoint_path = mlx_t5_path # Use the MLX version logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}") self.text_encoder = T5EncoderModel( text_len=config.text_len, checkpoint_path=t5_checkpoint_path, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer)) - + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + dtype=self.t5_dtype + ) + # Initialize VAE - with automatic conversion vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint) if not os.path.exists(vae_path): # Check for PyTorch VAE file to convert - pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth') + pth_vae_path = vae_path.replace('.safetensors', '.pth') if not os.path.exists(pth_vae_path): # Try alternative naming pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth') - + if os.path.exists(pth_vae_path): - logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}") - from .vae_model_io import convert_pytorch_to_mlx - convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16)) + logging.info(f"Converting VAE PyTorch model to safetensors: {pth_vae_path}") + from .model_converter import convert_pickle_to_safetensors + convert_pickle_to_safetensors(pth_vae_path, vae_path) else: raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}") logging.info("Loading VAE...") + # self.vae = Wan2_1_VAE(vae_pth=vae_path, dtype=self.param_dtype) self.vae = Wan2_1_VAE(vae_pth=vae_path) - + # Load low and high noise models logging.info(f"Creating WanModel from {checkpoint_dir}") - - # Helper function to load model with automatic conversion - def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype): - """Load model with automatic PyTorch to MLX conversion if needed.""" - - # Look for existing MLX files - mlx_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model_mlx.safetensors") - mlx_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_mlx_model*.safetensors") - mlx_files = glob.glob(mlx_pattern) - - # If no MLX files, convert PyTorch files - if not os.path.exists(mlx_single) and not mlx_files: - pytorch_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model.safetensors") - pytorch_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model-*.safetensors") - pytorch_files = glob.glob(pytorch_pattern) - - if os.path.exists(pytorch_single): - logging.info(f"Converting PyTorch model to MLX: {pytorch_single}") - convert_wan_2_2_safetensors_to_mlx( - pytorch_single, - mlx_single, - float16=(param_dtype == mx.float16) - ) - elif pytorch_files: - logging.info(f"Converting {len(pytorch_files)} PyTorch files to MLX") - convert_multiple_wan_2_2_safetensors_to_mlx( - os.path.join(checkpoint_dir, subfolder), - float16=(param_dtype == mx.float16) - ) - mlx_files = glob.glob(mlx_pattern) # Update file list - else: - raise FileNotFoundError(f"No model files found in {os.path.join(checkpoint_dir, subfolder)}") - + + # Helper function to load model + def load_model(checkpoint_dir, subfolder, config, param_dtype): + # Look for existing safetensors single file + safetensors_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model.safetensors") + # Create model model = WanModel( model_type='t2v', @@ -159,25 +120,25 @@ def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype): cross_attn_norm=getattr(config, 'cross_attn_norm', True), eps=getattr(config, 'eps', 1e-6) ) - + # Load weights - if os.path.exists(mlx_single): - logging.info(f"Loading single MLX file: {mlx_single}") - model = load_wan_2_2_from_safetensors(mlx_single, model, float16=(param_dtype == mx.float16)) + if os.path.exists(safetensors_single): + logging.info(f"Loading single safetensors file: {safetensors_single}") + model = load_wan_2_2_from_safetensors(safetensors_single, model, dtype=param_dtype) else: - logging.info(f"Loading multiple MLX files from: {os.path.join(checkpoint_dir, subfolder)}") + logging.info(f"Loading multiple safetensors files from: {os.path.join(checkpoint_dir, subfolder)}") model = load_wan_2_2_from_safetensors( - os.path.join(checkpoint_dir, subfolder), - model, - float16=(param_dtype == mx.float16) + os.path.join(checkpoint_dir, subfolder), + model, + dtype=param_dtype ) - + return model # Load both models logging.info(f"Creating WanModel from {checkpoint_dir}") logging.info("Loading low noise model") - self.low_noise_model = load_model_with_conversion( + self.low_noise_model = load_model( checkpoint_dir, config.low_noise_checkpoint, self.config, @@ -186,14 +147,14 @@ def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype): self.low_noise_model = self._configure_model(self.low_noise_model, convert_model_dtype) logging.info("Loading high noise model") - self.high_noise_model = load_model_with_conversion( + self.high_noise_model = load_model( checkpoint_dir, config.high_noise_checkpoint, self.config, self.param_dtype ) self.high_noise_model = self._configure_model(self.high_noise_model, convert_model_dtype) - + self.sp_size = 1 # No sequence parallel in single device self.sample_neg_prompt = config.sample_neg_prompt @@ -212,12 +173,12 @@ def _configure_model(self, model: nn.Module, convert_model_dtype: bool) -> nn.Mo The configured model. """ model.eval() - + if convert_model_dtype: # In MLX, we would need to manually convert parameters # This would be implemented in the actual model class pass - + return model def _prepare_model_for_timestep(self, t, boundary, offload_model): @@ -230,7 +191,7 @@ def _prepare_model_for_timestep(self, t, boundary, offload_model): else: required_model_name = 'low_noise_model' offload_model_name = 'high_noise_model' - + # MLX doesn't need the CPU offloading logic, just return the right model return getattr(self, required_model_name) @@ -279,7 +240,7 @@ def generate( # Preprocess guide_scale = (guide_scale, guide_scale) if isinstance( guide_scale, float) else guide_scale - + F = frame_num target_shape = ( self.vae.model.z_dim, @@ -287,24 +248,24 @@ def generate( size[1] // self.vae_stride[1], size[0] // self.vae_stride[2] ) - + seq_len = math.ceil( (target_shape[2] * target_shape[3]) / (self.patch_size[1] * self.patch_size[2]) * target_shape[1] / self.sp_size ) * self.sp_size - + if n_prompt == "": n_prompt = self.sample_neg_prompt - + # Set random seed seed = seed if seed >= 0 else random.randint(0, sys.maxsize) mx.random.seed(seed) - + # Encode text prompts context = self.text_encoder([input_prompt]) context_null = self.text_encoder([n_prompt]) - + # Generate initial noise noise = [ mx.random.normal( @@ -312,10 +273,10 @@ def generate( dtype=mx.float32 ) ] - + # Set boundary boundary = self.boundary * self.num_train_timesteps - + # Initialize scheduler if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( @@ -340,26 +301,26 @@ def generate( ) else: raise NotImplementedError("Unsupported solver.") - + # Sample videos latents = noise - + arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} mx.eval(latents) - + # Denoising loop for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = mx.array([t]) - + # Select model based on timestep model = self._prepare_model_for_timestep( t, boundary, offload_model ) sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0] - + # Model predictions noise_pred_cond = model( latent_model_input, t=timestep, **arg_c @@ -370,13 +331,13 @@ def generate( latent_model_input, t=timestep, **arg_null )[0] mx.eval(noise_pred_uncond) # Force evaluation - + # Classifier-free guidance noise_pred = noise_pred_uncond + sample_guide_scale * ( noise_pred_cond - noise_pred_uncond ) mx.eval(noise_pred) # Force evaluation - + # Scheduler step temp_x0 = sample_scheduler.step( mx.expand_dims(noise_pred, axis=0), @@ -387,15 +348,15 @@ def generate( latents = [mx.squeeze(temp_x0, axis=0)] mx.eval(latents) - + # Decode final latents x0 = latents videos = self.vae.decode(x0) - + # Cleanup del noise, latents del sample_scheduler if offload_model: gc.collect() - - return videos[0] \ No newline at end of file + + return videos[0] diff --git a/video/Wan2.2/wan/utils/fm_solvers.py b/video/Wan2.2/wan/utils/fm_solvers.py index 9fe67a3c6..e1be4ced5 100644 --- a/video/Wan2.2/wan/utils/fm_solvers.py +++ b/video/Wan2.2/wan/utils/fm_solvers.py @@ -2,11 +2,10 @@ from typing import List, Optional, Tuple, Union import mlx.core as mx -import numpy as np def get_sampling_sigmas(sampling_steps, shift): - sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = mx.linspace(1, 0, sampling_steps + 1)[:sampling_steps] sigma = (shift * sigma / (1 + (shift - 1) * sigma)) return sigma @@ -48,9 +47,9 @@ class FlowDPMSolverMultistepScheduler: MLX implementation of FlowDPMSolverMultistepScheduler. A fast dedicated high-order solver for diffusion ODEs. """ - + order = 1 - + def __init__( self, num_train_timesteps: int = 1000, @@ -89,52 +88,52 @@ def __init__( 'variance_type': variance_type, 'invert_sigmas': invert_sigmas, } - + # Validate algorithm type if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.config['algorithm_type'] = "dpmsolver++" else: raise NotImplementedError(f"{algorithm_type} is not implemented") - + # Validate solver type if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.config['solver_type'] = "midpoint" else: raise NotImplementedError(f"{solver_type} is not implemented") - + # Initialize scheduling self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + alphas = mx.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1] sigmas = 1.0 - alphas sigmas = mx.array(sigmas, dtype=mx.float32) - + if not use_dynamic_shifting: sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - + self.sigmas = sigmas self.timesteps = sigmas * num_train_timesteps - + self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None - + self.sigma_min = float(self.sigmas[-1]) self.sigma_max = float(self.sigmas[0]) - + @property def step_index(self): return self._step_index - + @property def begin_index(self): return self._begin_index - + def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index - + def set_timesteps( self, num_inference_steps: Union[int, None] = None, @@ -148,17 +147,17 @@ def set_timesteps( raise ValueError( "you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" ) - + if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] - + sigmas = mx.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1)[:-1] + if self.config['use_dynamic_shifting']: sigmas = self.time_shift(mu, 1.0, sigmas) else: if shift is None: shift = self.config['shift'] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - + if self.config['final_sigmas_type'] == "sigma_min": sigma_last = self.sigma_min elif self.config['final_sigmas_type'] == "zero": @@ -167,31 +166,31 @@ def set_timesteps( raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}" ) - + timesteps = sigmas * self.config['num_train_timesteps'] - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - + sigmas = mx.concatenate([sigmas, mx.array([sigma_last])]).astype(mx.float32) + self.sigmas = mx.array(sigmas) self.timesteps = mx.array(timesteps, dtype=mx.int64) - + self.num_inference_steps = len(timesteps) - + self.model_outputs = [None] * self.config['solver_order'] self.lower_order_nums = 0 - + self._step_index = None self._begin_index = None - + def _threshold_sample(self, sample: mx.array) -> mx.array: """Dynamic thresholding method.""" dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape - + # Flatten sample for quantile calculation - sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - + sample_flat = sample.reshape(batch_size, channels * mx.prod(remaining_dims)) + abs_sample = mx.abs(sample_flat) - + # Compute quantile s = mx.quantile( abs_sample, @@ -200,22 +199,22 @@ def _threshold_sample(self, sample: mx.array) -> mx.array: keepdims=True ) s = mx.clip(s, 1, self.config['sample_max_value']) - + # Threshold and normalize sample_flat = mx.clip(sample_flat, -s, s) / s - + sample = sample_flat.reshape(batch_size, channels, *remaining_dims) return sample.astype(dtype) - + def _sigma_to_t(self, sigma): return sigma * self.config['num_train_timesteps'] - + def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma - + def time_shift(self, mu: float, sigma: float, t: mx.array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) - + def convert_model_output( self, model_output: mx.array, @@ -233,12 +232,12 @@ def convert_model_output( f"prediction_type given as {self.config['prediction_type']} must be " f"'flow_prediction' for the FlowDPMSolverMultistepScheduler." ) - + if self.config['thresholding']: x0_pred = self._threshold_sample(x0_pred) - + return x0_pred - + # DPM-Solver needs to solve an integral of the noise prediction model elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]: if self.config['prediction_type'] == "flow_prediction": @@ -249,15 +248,15 @@ def convert_model_output( f"prediction_type given as {self.config['prediction_type']} must be " f"'flow_prediction' for the FlowDPMSolverMultistepScheduler." ) - + if self.config['thresholding']: sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output x0_pred = self._threshold_sample(x0_pred) epsilon = model_output + x0_pred - + return epsilon - + def dpm_solver_first_order_update( self, model_output: mx.array, @@ -269,12 +268,12 @@ def dpm_solver_first_order_update( sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - + lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s = mx.log(alpha_s) - mx.log(sigma_s) - + h = lambda_t - lambda_s - + if self.config['algorithm_type'] == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output elif self.config['algorithm_type'] == "dpmsolver": @@ -293,9 +292,9 @@ def dpm_solver_first_order_update( 2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output + sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise ) - + return x_t - + def multistep_dpm_solver_second_order_update( self, model_output_list: List[mx.array], @@ -309,21 +308,21 @@ def multistep_dpm_solver_second_order_update( self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) - + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - + lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1) - + m0, m1 = model_output_list[-1], model_output_list[-2] - + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - + if self.config['algorithm_type'] == "dpmsolver++": if self.config['solver_type'] == "midpoint": x_t = ( @@ -382,9 +381,9 @@ def multistep_dpm_solver_second_order_update( 2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise ) - + return x_t - + def multistep_dpm_solver_third_order_update( self, model_output_list: List[mx.array], @@ -398,26 +397,26 @@ def multistep_dpm_solver_third_order_update( self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) - + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - + lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1) lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2) - + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - + if self.config['algorithm_type'] == "dpmsolver++": x_t = ( (sigma_t / sigma_s0) * sample - @@ -432,25 +431,27 @@ def multistep_dpm_solver_third_order_update( (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) - + return x_t - + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps - - indices = mx.where(schedule_timesteps == timestep)[0] + + arr = schedule_timesteps.tolist() + indices = [i for i, v in enumerate(arr) if v == timestep] + pos = 1 if len(indices) > 1 else 0 - - return int(indices[pos]) - + + return indices[pos] + def _init_step_index(self, timestep): """Initialize the step_index counter for the scheduler.""" if self.begin_index is None: self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index - + def step( self, model_output: mx.array, @@ -465,10 +466,10 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - + if self.step_index is None: self._init_step_index(timestep) - + # Improve numerical stability for small number of steps lower_order_final = ( (self.step_index == len(self.timesteps) - 1) and @@ -481,15 +482,15 @@ def step( self.config['lower_order_final'] and len(self.timesteps) < 15 ) - + model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config['solver_order'] - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - + # Upcast to avoid precision issues sample = sample.astype(mx.float32) - + # Generate noise if needed for SDE variants if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: noise = mx.random.normal(model_output.shape, dtype=mx.float32) @@ -497,7 +498,7 @@ def step( noise = variance_noise.astype(mx.float32) else: noise = None - + if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( model_output, sample=sample, noise=noise @@ -510,25 +511,25 @@ def step( prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, sample=sample ) - + if self.lower_order_nums < self.config['solver_order']: self.lower_order_nums += 1 - + # Cast sample back to expected dtype prev_sample = prev_sample.astype(model_output.dtype) - + # Increase step index self._step_index += 1 - + if not return_dict: return (prev_sample,) - + return SchedulerOutput(prev_sample=prev_sample) - + def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array: """Scale model input - no scaling needed for this scheduler.""" return sample - + def add_noise( self, original_samples: mx.array, @@ -538,7 +539,7 @@ def add_noise( """Add noise to original samples.""" sigmas = self.sigmas.astype(original_samples.dtype) schedule_timesteps = self.timesteps - + # Get step indices if self.begin_index is None: step_indices = [ @@ -549,14 +550,14 @@ def add_noise( step_indices = [self.step_index] * timesteps.shape[0] else: step_indices = [self.begin_index] * timesteps.shape[0] - + sigma = sigmas[step_indices] while len(sigma.shape) < len(original_samples.shape): sigma = mx.expand_dims(sigma, -1) - + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - + def __len__(self): - return self.config['num_train_timesteps'] \ No newline at end of file + return self.config['num_train_timesteps'] diff --git a/video/Wan2.2/wan/utils/fm_solvers_unipc.py b/video/Wan2.2/wan/utils/fm_solvers_unipc.py index 39d59c829..74e9f7243 100644 --- a/video/Wan2.2/wan/utils/fm_solvers_unipc.py +++ b/video/Wan2.2/wan/utils/fm_solvers_unipc.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Union import mlx.core as mx -import numpy as np class SchedulerOutput: @@ -16,9 +15,9 @@ class FlowUniPCMultistepScheduler: MLX implementation of UniPCMultistepScheduler. A training-free framework designed for the fast sampling of diffusion models. """ - + order = 1 - + def __init__( self, num_train_timesteps: int = 1000, @@ -57,7 +56,7 @@ def __init__( 'steps_offset': steps_offset, 'final_sigmas_type': final_sigmas_type, } - + # Validate solver type if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: @@ -66,20 +65,20 @@ def __init__( raise NotImplementedError( f"{solver_type} is not implemented for {self.__class__}" ) - + self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + alphas = mx.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1] sigmas = 1.0 - alphas sigmas = mx.array(sigmas, dtype=mx.float32) - + if not use_dynamic_shifting: sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - + self.sigmas = sigmas self.timesteps = sigmas * num_train_timesteps - + self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 @@ -88,24 +87,24 @@ def __init__( self.last_sample = None self._step_index = None self._begin_index = None - + self.sigma_min = float(self.sigmas[-1]) self.sigma_max = float(self.sigmas[0]) - + @property def step_index(self): """The index counter for current timestep.""" return self._step_index - + @property def begin_index(self): """The index for the first timestep.""" return self._begin_index - + def set_begin_index(self, begin_index: int = 0): """Sets the begin index for the scheduler.""" self._begin_index = begin_index - + def set_timesteps( self, num_inference_steps: Union[int, None] = None, @@ -119,17 +118,17 @@ def set_timesteps( raise ValueError( "you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" ) - + if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] - + sigmas = mx.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1)[:-1] + if self.config['use_dynamic_shifting']: sigmas = self.time_shift(mu, 1.0, sigmas) else: if shift is None: shift = self.config['shift'] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - + if self.config['final_sigmas_type'] == "sigma_min": sigma_last = self.sigma_min elif self.config['final_sigmas_type'] == "zero": @@ -138,35 +137,35 @@ def set_timesteps( raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}" ) - + timesteps = sigmas * self.config['num_train_timesteps'] - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - + sigmas = mx.concatenate([sigmas, mx.array([sigma_last])]).astype(mx.float32) + self.sigmas = mx.array(sigmas) self.timesteps = mx.array(timesteps, dtype=mx.int64) - + self.num_inference_steps = len(timesteps) - + self.model_outputs = [None] * self.config['solver_order'] self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) - + # add an index counter for schedulers self._step_index = None self._begin_index = None - + def _threshold_sample(self, sample: mx.array) -> mx.array: """Dynamic thresholding method.""" dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape - + # Flatten sample for quantile calculation - sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - + sample_flat = sample.reshape(batch_size, channels * mx.prod(remaining_dims)) + abs_sample = mx.abs(sample_flat) - + # Compute quantile s = mx.quantile( abs_sample, @@ -175,22 +174,22 @@ def _threshold_sample(self, sample: mx.array) -> mx.array: keepdims=True ) s = mx.clip(s, 1, self.config['sample_max_value']) - + # Threshold and normalize sample_flat = mx.clip(sample_flat, -s, s) / s - + sample = sample_flat.reshape(batch_size, channels, *remaining_dims) return sample.astype(dtype) - + def _sigma_to_t(self, sigma): return sigma * self.config['num_train_timesteps'] - + def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma - + def time_shift(self, mu: float, sigma: float, t: mx.array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) - + def convert_model_output( self, model_output: mx.array, @@ -200,7 +199,7 @@ def convert_model_output( """Convert the model output to the corresponding type the UniPC algorithm needs.""" sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - + if self.predict_x0: if self.config['prediction_type'] == "flow_prediction": sigma_t = self.sigmas[self.step_index] @@ -210,10 +209,10 @@ def convert_model_output( f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' " f"for the UniPCMultistepScheduler." ) - + if self.config['thresholding']: x0_pred = self._threshold_sample(x0_pred) - + return x0_pred else: if self.config['prediction_type'] == "flow_prediction": @@ -224,15 +223,15 @@ def convert_model_output( f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' " f"for the UniPCMultistepScheduler." ) - + if self.config['thresholding']: sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output x0_pred = self._threshold_sample(x0_pred) epsilon = model_output + x0_pred - + return epsilon - + def multistep_uni_p_bh_update( self, model_output: mx.array, @@ -242,24 +241,24 @@ def multistep_uni_p_bh_update( ) -> mx.array: """One step for the UniP (B(h) version).""" model_output_list = self.model_outputs - + s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample - + if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t - + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - + lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) - + h = lambda_t - lambda_s0 - + rks = [] D1s = [] for i in range(1, order): @@ -270,35 +269,35 @@ def multistep_uni_p_bh_update( rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) - + rks.append(1.0) rks = mx.array(rks) - + R = [] b = [] - + hh = -h if self.predict_x0 else h h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 - + factorial_i = 1 - + if self.config['solver_type'] == "bh1": B_h = hh elif self.config['solver_type'] == "bh2": B_h = mx.exp(hh) - 1 else: raise NotImplementedError() - + for i in range(1, order + 1): R.append(mx.power(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i - + R = mx.stack(R) b = mx.array(b) - + if len(D1s) > 0: D1s = mx.stack(D1s, axis=1) # (B, K) # for order 2, we use a simplified version @@ -308,7 +307,7 @@ def multistep_uni_p_bh_update( rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype) else: D1s = None - + if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: @@ -323,10 +322,10 @@ def multistep_uni_p_bh_update( else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res - + x_t = x_t.astype(x.dtype) return x_t - + def multistep_uni_c_bh_update( self, this_model_output: mx.array, @@ -337,21 +336,21 @@ def multistep_uni_c_bh_update( ) -> mx.array: """One step for the UniC (B(h) version).""" model_output_list = self.model_outputs - + m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output - + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - + lambda_t = mx.log(alpha_t) - mx.log(sigma_t) lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0) - + h = lambda_t - lambda_s0 - + rks = [] D1s = [] for i in range(1, order): @@ -362,46 +361,46 @@ def multistep_uni_c_bh_update( rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) - + rks.append(1.0) rks = mx.array(rks) - + R = [] b = [] - + hh = -h if self.predict_x0 else h h_phi_1 = mx.exp(hh) - 1 h_phi_k = h_phi_1 / hh - 1 - + factorial_i = 1 - + if self.config['solver_type'] == "bh1": B_h = hh elif self.config['solver_type'] == "bh2": B_h = mx.exp(hh) - 1 else: raise NotImplementedError() - + for i in range(1, order + 1): R.append(mx.power(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i - + R = mx.stack(R) b = mx.array(b) - + if len(D1s) > 0: D1s = mx.stack(D1s, axis=1) else: D1s = None - + # for order 1, we use a simplified version if order == 1: rhos_c = mx.array([0.5], dtype=x.dtype) else: rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype) - + if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: @@ -418,27 +417,27 @@ def multistep_uni_c_bh_update( corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) - + x_t = x_t.astype(x.dtype) return x_t - + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps condition = schedule_timesteps == timestep indices = mx.argmax(condition.astype(mx.int32)) - + # Convert scalar to int and return return int(indices) - + def _init_step_index(self, timestep): """Initialize the step_index counter for the scheduler.""" if self.begin_index is None: self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index - + def step( self, model_output: mx.array, @@ -452,16 +451,16 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - + if self.step_index is None: self._init_step_index(timestep) - + use_corrector = ( self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) - + model_output_convert = self.convert_model_output( model_output, sample=sample ) @@ -472,14 +471,14 @@ def step( this_sample=sample, order=self.this_order, ) - + for i in range(self.config['solver_order'] - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] - + self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep - + if self.config['lower_order_final']: this_order = min( self.config['solver_order'], @@ -487,32 +486,32 @@ def step( ) else: this_order = self.config['solver_order'] - + self.this_order = min(this_order, self.lower_order_nums + 1) assert self.this_order > 0 - + self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, sample=sample, order=self.this_order, ) - + if self.lower_order_nums < self.config['solver_order']: self.lower_order_nums += 1 - + # Increase step index self._step_index += 1 - + if not return_dict: return (prev_sample,) - + return SchedulerOutput(prev_sample=prev_sample) - + def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array: """Scale model input - no scaling needed for this scheduler.""" return sample - + def add_noise( self, original_samples: mx.array, @@ -522,7 +521,7 @@ def add_noise( """Add noise to original samples.""" sigmas = self.sigmas.astype(original_samples.dtype) schedule_timesteps = self.timesteps - + # Get step indices if self.begin_index is None: step_indices = [ @@ -533,14 +532,14 @@ def add_noise( step_indices = [self.step_index] * timesteps.shape[0] else: step_indices = [self.begin_index] * timesteps.shape[0] - + sigma = sigmas[step_indices] while len(sigma.shape) < len(original_samples.shape): sigma = mx.expand_dims(sigma, -1) - + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - + def __len__(self): - return self.config['num_train_timesteps'] \ No newline at end of file + return self.config['num_train_timesteps'] diff --git a/video/Wan2.2/wan/utils/qwen_vl_utils.py b/video/Wan2.2/wan/utils/qwen_vl_utils.py index bf0e83286..b4c17bd2d 100644 --- a/video/Wan2.2/wan/utils/qwen_vl_utils.py +++ b/video/Wan2.2/wan/utils/qwen_vl_utils.py @@ -15,8 +15,8 @@ import requests import torch import torchvision -from packaging import version from PIL import Image +from packaging import version from torchvision import io, transforms from torchvision.transforms import InterpolationMode @@ -343,8 +343,7 @@ def extract_vision_info( def process_vision_info( conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | - None]: +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: vision_infos = extract_vision_info(conversations) ## Read images or videos image_inputs = [] diff --git a/video/Wan2.2/wan/utils/utils.py b/video/Wan2.2/wan/utils/utils.py index 987245e91..3dc4e6d38 100644 --- a/video/Wan2.2/wan/utils/utils.py +++ b/video/Wan2.2/wan/utils/utils.py @@ -26,35 +26,35 @@ def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)): """MLX equivalent of torchvision.utils.make_grid""" # tensor shape: (batch, channels, height, width) batch_size, channels, height, width = tensor.shape - + # Calculate grid dimensions ncol = nrow nrow_actual = (batch_size + ncol - 1) // ncol - + # Create grid grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding grid_width = width * ncol + (ncol - 1) * 2 - + # Initialize grid with zeros grid = mx.zeros((channels, grid_height, grid_width)) - + # Fill grid for idx in range(batch_size): row = idx // ncol col = idx % ncol - + y_start = row * (height + 2) y_end = y_start + height x_start = col * (width + 2) x_end = x_start + width - + img = tensor[idx] if normalize: # Normalize to [0, 1] img = (img - value_range[0]) / (value_range[1] - value_range[0]) - + grid[:, y_start:y_end, x_start:x_end] = img - + return grid @@ -73,7 +73,7 @@ def save_video(tensor, try: # preprocess tensor = mx.clip(tensor, value_range[0], value_range[1]) - + # tensor shape: (batch, channels, frames, height, width) # Process each frame frames = [] @@ -81,11 +81,11 @@ def save_video(tensor, frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width) grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range) frames.append(grid) - + # Stack frames and convert to (frames, height, width, channels) tensor = mx.stack(frames, axis=0) # (frames, channels, height, width) tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels) - + # Convert to uint8 tensor = (tensor * 255).astype(mx.uint8) tensor_np = np.array(tensor) @@ -112,14 +112,14 @@ def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)): try: # Clip values tensor = mx.clip(tensor, value_range[0], value_range[1]) - + # Make grid grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range) - + # Convert to (height, width, channels) and uint8 grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels) grid = (grid * 255).astype(mx.uint8) - + # Save using imageio imageio.imwrite(save_file, np.array(grid)) return save_file @@ -157,13 +157,13 @@ def str2bool(v): def masks_like(tensor, zero=False, generator=None, p=0.2): """ Generate masks similar to input tensors. - + Args: tensor: List of MLX arrays zero: Whether to apply zero masking generator: Random generator (for MLX, we use mx.random.seed instead) p: Probability for random masking - + Returns: Tuple of two lists of masks """ @@ -197,14 +197,14 @@ def masks_like(tensor, zero=False, generator=None, p=0.2): def best_output_size(w, h, dw, dh, expected_area): """ Calculate the best output size given constraints. - + Args: w: Width h: Height dw: Width divisor dh: Height divisor expected_area: Target area - + Returns: Tuple of (output_width, output_height) """ @@ -230,4 +230,4 @@ def best_output_size(w, h, dw, dh, expected_area): ratio2 / ratio): return ow1, oh1 else: - return ow2, oh2 \ No newline at end of file + return ow2, oh2 diff --git a/video/Wan2.2/wan/vae_model_io.py b/video/Wan2.2/wan/vae_model_io.py index 1c77c0fe0..cb579408a 100644 --- a/video/Wan2.2/wan/vae_model_io.py +++ b/video/Wan2.2/wan/vae_model_io.py @@ -1,69 +1,42 @@ -import torch import mlx.core as mx -import numpy as np -from typing import Dict, Tuple -from safetensors import safe_open - -def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = False): - """ - Convert PyTorch VAE weights to MLX format with correct mapping. - """ - print(f"Converting {pytorch_path} -> {output_path}") - dtype = mx.float16 if float16 else mx.float32 - - # Load PyTorch weights - if pytorch_path.endswith('.safetensors'): - weights = {} - with safe_open(pytorch_path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - if tensor.dtype == torch.bfloat16: - tensor = tensor.float() - weights[key] = tensor.numpy() - else: - checkpoint = torch.load(pytorch_path, map_location='cpu') - weights = {} - state_dict = checkpoint if isinstance(checkpoint, dict) and 'state_dict' not in checkpoint else checkpoint.get('state_dict', checkpoint) - for key, tensor in state_dict.items(): - if tensor.dtype == torch.bfloat16: - tensor = tensor.float() - weights[key] = tensor.numpy() - + + +def sanitize(weights): # Convert weights - mlx_weights = {} - + new_weights = {} + for key, value in weights.items(): # Skip these if any(skip in key for skip in ["num_batches_tracked", "running_mean", "running_var"]): continue - + # Convert weight formats - if value.ndim == 5 and "weight" in key: # Conv3d weights + elif value.ndim == 5 and "weight" in key: # Conv3d weights # PyTorch: (out_channels, in_channels, D, H, W) # MLX: (out_channels, D, H, W, in_channels) - value = np.transpose(value, (0, 2, 3, 4, 1)) + value = mx.transpose(value, (0, 2, 3, 4, 1)) elif value.ndim == 4 and "weight" in key: # Conv2d weights # PyTorch: (out_channels, in_channels, H, W) # MLX Conv2d expects: (out_channels, H, W, in_channels) - value = np.transpose(value, (0, 2, 3, 1)) + value = mx.transpose(value, (0, 2, 3, 1)) elif value.ndim == 1 and "bias" in key: # Conv biases # Keep as is - MLX uses same format pass - + # Map the key new_key = key - + # Map residual block internals within Sequential # PyTorch: encoder.downsamples.0.residual.0.gamma # MLX: encoder.downsamples.layers.0.residual.layers.0.gamma import re - + # Add .layers to Sequential modules new_key = re.sub(r'\.downsamples\.(\d+)', r'.downsamples.layers.\1', new_key) new_key = re.sub(r'\.upsamples\.(\d+)', r'.upsamples.layers.\1', new_key) new_key = re.sub(r'\.middle\.(\d+)', r'.middle.layers.\1', new_key) new_key = re.sub(r'\.head\.(\d+)', r'.head.layers.\1', new_key) - + # Map residual Sequential internals if ".residual." in new_key: match = re.search(r'\.residual\.(\d+)\.', new_key) @@ -85,7 +58,7 @@ def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = new_key = re.sub(r'\.residual\.5\.', '.residual.layers.5.', new_key) elif idx == 6: # Second Conv3d new_key = re.sub(r'\.residual\.6\.', '.residual.layers.6.', new_key) - + # Map resample internals if ".resample." in new_key: # In both Encoder and Decoder Resample blocks, the Conv2d is at index 1 @@ -93,15 +66,15 @@ def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = # We just need to map PyTorch's .1 to MLX's .layers.1 if ".resample.1." in new_key: new_key = new_key.replace(".resample.1.", ".resample.layers.1.") - + # The layers at index 0 (ZeroPad2d, Upsample) have no weights, so we can # safely skip any keys associated with them. if ".resample.0." in key: continue - + # Map head internals (already using Sequential in MLX) # Just need to handle the layers index - + # Handle shortcut layers if ".shortcut." in new_key and "Identity" not in key: # Shortcut Conv3d layers - keep as is @@ -109,52 +82,35 @@ def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = elif "Identity" in key: # Skip Identity modules continue - + # Handle time_conv in Resample if "time_conv" in new_key: # Keep as is - already correctly named pass - + # Handle attention layers if "to_qkv" in new_key or "proj" in new_key: # Keep as is - already correctly named pass - - # In the conversion script + + # In the conversion script if "gamma" in new_key: # Squeeze gamma from (C, 1, 1) or (C, 1, 1, 1) to just (C,) - value = np.squeeze(value) # This removes all dimensions of size 1 + value = mx.squeeze(value) # This removes all dimensions of size 1 # Result will always be 1D array of shape (C,) - - # Add to MLX weights - mlx_weights[new_key] = mx.array(value).astype(dtype) - + + # Add to new weights + new_weights[new_key] = value + # Verify critical layers are present critical_prefixes = [ "encoder.conv1", "decoder.conv1", "conv1", "conv2", "encoder.head.layers.2", "decoder.head.layers.2" # Updated for Sequential ] - + for prefix in critical_prefixes: - found = any(k.startswith(prefix) for k in mlx_weights.keys()) + found = any(k.startswith(prefix) for k in new_weights.keys()) if not found: print(f"WARNING: No weights found for {prefix}") - - print(f"Converted {len(mlx_weights)} parameters") - - # Print a few example keys for verification - print("\nExample converted keys:") - for i, key in enumerate(sorted(mlx_weights.keys())[:10]): - print(f" {key}") - - # Save - if output_path.endswith('.safetensors'): - mx.save_safetensors(output_path, mlx_weights) - else: - mx.savez(output_path, **mlx_weights) - - print(f"\nSaved to {output_path}") - print("\nAll converted keys:") - for key in sorted(mlx_weights.keys()): - print(f" {key}: {mlx_weights[key].shape}") - return mlx_weights \ No newline at end of file + + return new_weights diff --git a/video/Wan2.2/wan/wan_model_io.py b/video/Wan2.2/wan/wan_model_io.py index 62be6b31e..8f78e65c7 100644 --- a/video/Wan2.2/wan/wan_model_io.py +++ b/video/Wan2.2/wan/wan_model_io.py @@ -1,84 +1,88 @@ # wan_model_io.py -from typing import List, Tuple, Set, Dict +import glob import os +from typing import Tuple, Set, Dict + import mlx.core as mx from mlx.utils import tree_unflatten, tree_flatten -from safetensors import safe_open -import torch -import numpy as np -import glob -def map_wan_2_2_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]: - """Map PyTorch WAN 2.2 weights to MLX format.""" - +def sanitize(weights): + """Sanitize PyTorch WAN 2.2 weights to adapt MLX format.""" + # Only add .layers to Sequential WITHIN components, not to blocks themselves # blocks.N stays as blocks.N (not blocks.layers.N) - + # Handle Sequential layers - PyTorch uses .0, .1, .2, MLX uses .layers.0, .layers.1, .layers.2 # Only for components INSIDE blocks and top-level modules - if ".ffn." in key and not ".layers." in key: - # Replace .ffn.0 with .ffn.layers.0, etc. - key = key.replace(".ffn.0.", ".ffn.layers.0.") - key = key.replace(".ffn.1.", ".ffn.layers.1.") - key = key.replace(".ffn.2.", ".ffn.layers.2.") - - if "text_embedding." in key and not ".layers." in key: - for i in range(10): - key = key.replace(f"text_embedding.{i}.", f"text_embedding.layers.{i}.") - - if "time_embedding." in key and not ".layers." in key: - for i in range(10): - key = key.replace(f"time_embedding.{i}.", f"time_embedding.layers.{i}.") - - if "time_projection." in key and not ".layers." in key: - for i in range(10): - key = key.replace(f"time_projection.{i}.", f"time_projection.layers.{i}.") - - # Handle conv transpose for patch_embedding - if "patch_embedding.weight" in key: - # PyTorch Conv3d: (out_channels, in_channels, D, H, W) - # MLX Conv3d: (out_channels, D, H, W, in_channels) - value = mx.transpose(value, (0, 2, 3, 4, 1)) - - return [(key, value)] + + new_weights = {} + + for key, weight in weights.items(): + if ".ffn." in key and not ".layers." in key: + # Replace .ffn.0 with .ffn.layers.0, etc. + key = key.replace(".ffn.0.", ".ffn.layers.0.") + key = key.replace(".ffn.1.", ".ffn.layers.1.") + key = key.replace(".ffn.2.", ".ffn.layers.2.") + + elif "text_embedding." in key and not ".layers." in key: + for i in range(10): + key = key.replace(f"text_embedding.{i}.", f"text_embedding.layers.{i}.") + + elif "time_embedding." in key and not ".layers." in key: + for i in range(10): + key = key.replace(f"time_embedding.{i}.", f"time_embedding.layers.{i}.") + + elif "time_projection." in key and not ".layers." in key: + for i in range(10): + key = key.replace(f"time_projection.{i}.", f"time_projection.layers.{i}.") + + # Handle conv transpose for patch_embedding + elif "patch_embedding.weight" in key: + # PyTorch Conv3d: (out_channels, in_channels, D, H, W) + # MLX Conv3d: (out_channels, D, H, W, in_channels) + weight = mx.transpose(weight, (0, 2, 3, 4, 1)) + + new_weights[key] = weight + + return new_weights def check_parameter_mismatch(model, weights: Dict[str, mx.array]) -> Tuple[Set[str], Set[str]]: """ Check for parameter mismatches between model and weights. - + Returns: (model_only, weights_only): Sets of parameter names that exist only in model or weights """ # Get all parameter names from model model_params = dict(tree_flatten(model.parameters())) model_keys = set(model_params.keys()) - + # Remove computed buffers that aren't loaded from weights computed_buffers = {'freqs'} # Add any other computed buffers here model_keys = model_keys - computed_buffers - + # Get all parameter names from weights weight_keys = set(weights.keys()) - + # Find differences model_only = model_keys - weight_keys weights_only = weight_keys - model_keys - + return model_only, weights_only def load_wan_2_2_from_safetensors( - safetensors_path: str, + safetensors_path: str, model, - float16: bool = False, + dtype=mx.float32, check_mismatch: bool = True ): """ Load WAN 2.2 Model weights from safetensors file(s) into MLX model. - + Args: safetensors_path: Path to safetensors file or directory model: MLX model instance @@ -86,209 +90,77 @@ def load_wan_2_2_from_safetensors( check_mismatch: Whether to check for parameter mismatches """ if os.path.isdir(safetensors_path): - # Multiple files (14B model) - only diffusion_mlx_model files - pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors") + # Multiple files (14B model) - only diffusion_pytorch_model files + pattern = os.path.join(safetensors_path, "diffusion_pytorch_model*.safetensors") safetensor_files = sorted(glob.glob(pattern)) - print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files") - + print(f"Found {len(safetensor_files)} diffusion_pytorch_model safetensors files") + # Load all files and merge weights all_weights = {} for file_path in safetensor_files: print(f"Loading: {file_path}") weights = mx.load(file_path) all_weights.update(weights) - + + for key, weight in all_weights.items(): + if weight.dtype != dtype and mx.issubdtype(weight.dtype, mx.floating): + all_weights[key] = weight.astype(dtype) + + all_weights = sanitize(all_weights) + if check_mismatch: model_only, weights_only = check_parameter_mismatch(model, all_weights) - + if model_only: print(f"\nāš ļø WARNING: {len(model_only)} parameters in model but NOT in weights:") for param in sorted(model_only)[:10]: # Show first 10 print(f" - {param}") if len(model_only) > 10: print(f" ... and {len(model_only) - 10} more") - + if weights_only: print(f"\nāš ļø WARNING: {len(weights_only)} parameters in weights but NOT in model:") for param in sorted(weights_only)[:10]: # Show first 10 print(f" - {param}") if len(weights_only) > 10: print(f" ... and {len(weights_only) - 10} more") - + if not model_only and not weights_only: print("\nāœ… Perfect match: All parameters align between model and weights!") - + model.update(tree_unflatten(list(all_weights.items()))) else: # Single file print(f"Loading single file: {safetensors_path}") weights = mx.load(safetensors_path) - + + for key, weight in weights.items(): + if weight.dtype != dtype and mx.issubdtype(weight.dtype, mx.floating): + weights[key] = weight.astype(dtype) + + weights = sanitize(weights) + if check_mismatch: model_only, weights_only = check_parameter_mismatch(model, weights) - + if model_only: print(f"\nāš ļø WARNING: {len(model_only)} parameters in model but NOT in weights:") for param in sorted(model_only)[:10]: # Show first 10 print(f" - {param}") if len(model_only) > 10: print(f" ... and {len(model_only) - 10} more") - + if weights_only: print(f"\nāš ļø WARNING: {len(weights_only)} parameters in weights but NOT in model:") for param in sorted(weights_only)[:10]: # Show first 10 print(f" - {param}") if len(weights_only) > 10: print(f" ... and {len(weights_only) - 10} more") - + if not model_only and not weights_only: print("\nāœ… Perfect match: All parameters align between model and weights!") - + model.update(tree_unflatten(list(weights.items()))) - + print("\nWAN 2.2 Model weights loaded successfully!") return model - - -def convert_wan_2_2_safetensors_to_mlx( - safetensors_path: str, - output_path: str, - float16: bool = False, - model=None # Optional: provide model instance to check parameter alignment -): - """ - Convert WAN 2.2 PyTorch safetensors file to MLX weights file. - - Args: - safetensors_path: Input safetensors file - output_path: Output MLX weights file (.safetensors) - float16: Whether to use float16 precision - model: Optional MLX model instance to check parameter alignment - """ - dtype = mx.float16 if float16 else mx.float32 - - print(f"Converting WAN 2.2 safetensors to MLX format...") - print(f"Input: {safetensors_path}") - print(f"Output: {output_path}") - print(f"Target dtype: {dtype}") - - # Load and convert weights - weights = {} - bfloat16_count = 0 - - with safe_open(safetensors_path, framework="pt", device="cpu") as f: - keys = list(f.keys()) - print(f"Processing {len(keys)} parameters...") - - for key in keys: - tensor = f.get_tensor(key) - - # Handle BFloat16 - if tensor.dtype == torch.bfloat16: - bfloat16_count += 1 - tensor = tensor.float() # Convert to float32 first - - value = mx.array(tensor.numpy()).astype(dtype) - - # Apply mapping - mapped = map_wan_2_2_weights(key, value) - - for new_key, new_value in mapped: - weights[new_key] = new_value - - if bfloat16_count > 0: - print(f"āš ļø Converted {bfloat16_count} BFloat16 tensors to {dtype}") - - # Check parameter alignment if model provided - if model is not None: - print("\nChecking parameter alignment with model...") - model_only, weights_only = check_parameter_mismatch(model, weights) - - if model_only: - print(f"\nāš ļø WARNING: {len(model_only)} parameters in model but NOT in converted weights:") - for param in sorted(model_only)[:10]: - print(f" - {param}") - if len(model_only) > 10: - print(f" ... and {len(model_only) - 10} more") - - if weights_only: - print(f"\nāš ļø WARNING: {len(weights_only)} parameters in converted weights but NOT in model:") - for param in sorted(weights_only)[:10]: - print(f" - {param}") - if len(weights_only) > 10: - print(f" ... and {len(weights_only) - 10} more") - - if not model_only and not weights_only: - print("\nāœ… Perfect match: All parameters align between model and converted weights!") - - # Save as MLX format - print(f"\nSaving {len(weights)} parameters to: {output_path}") - mx.save_safetensors(output_path, weights) - - # Print a few example keys for verification - print("\nExample converted keys:") - for i, key in enumerate(sorted(weights.keys())[:10]): - print(f" {key}: {weights[key].shape}") - - return weights - - -def convert_multiple_wan_2_2_safetensors_to_mlx( - checkpoint_dir: str, - float16: bool = False -): - """Convert multiple WAN 2.2 PyTorch safetensors files to MLX format.""" - # Find all PyTorch model files - pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors") - pytorch_files = sorted(glob.glob(pytorch_pattern)) - - if not pytorch_files: - raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}") - - print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...") - - for i, pytorch_file in enumerate(pytorch_files, 1): - # Extract the suffix (e.g., "00001-of-00006") - basename = os.path.basename(pytorch_file) - suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "") - - # Create MLX filename - mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors") - - print(f"\nConverting {i}/{len(pytorch_files)}: {basename}") - convert_wan_2_2_safetensors_to_mlx(pytorch_file, mlx_file, float16) - - print("\nAll files converted successfully!") - - -def debug_wan_2_2_weight_mapping(safetensors_path: str, float16: bool = False): - """ - Debug function to see how WAN 2.2 weights are being mapped. - """ - dtype = mx.float16 if float16 else mx.float32 - - print("=== WAN 2.2 Weight Mapping Debug ===") - - with safe_open(safetensors_path, framework="pt", device="cpu") as f: - # Check first 30 keys to see the mapping - for i, key in enumerate(f.keys()): - if i >= 30: - break - - tensor = f.get_tensor(key) - - # Handle BFloat16 - original_dtype = tensor.dtype - if tensor.dtype == torch.bfloat16: - tensor = tensor.float() - - value = mx.array(tensor.numpy()).astype(dtype) - - # Apply mapping - mapped = map_wan_2_2_weights(key, value) - - new_key, new_value = mapped[0] - if new_key == key: - print(f"UNCHANGED: {key} [{tensor.shape}]") - else: - print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]") \ No newline at end of file