diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index 4e2baa4..985ac87 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -1,36 +1,33 @@ from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig -from mlx_video.models.wan import WanModel, WanModelConfig # Audio VAE components from mlx_video.models.ltx_2.audio_vae import ( AudioDecoder, AudioEncoder, + AudioLatentShape, + AudioPatchifier, + PerChannelStatistics, Vocoder, decode_audio, - AudioPatchifier, - AudioLatentShape, - PerChannelStatistics, ) # Conditioning -from mlx_video.models.ltx_2.conditioning import ( - VideoConditionByLatentIndex, -) +from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex # Utilities from mlx_video.models.ltx_2.utils import ( convert_audio_encoder, get_model_path, - load_safetensors, load_config, + load_safetensors, save_weights, ) +from mlx_video.models.wan import WanModel, WanModelConfig __all__ = [ # Models "LTXModel", "LTXModelConfig", - # Audio VAE "AudioDecoder", "AudioEncoder", diff --git a/mlx_video/lora/__init__.py b/mlx_video/lora/__init__.py index 4c0d81b..c4398e3 100644 --- a/mlx_video/lora/__init__.py +++ b/mlx_video/lora/__init__.py @@ -6,10 +6,7 @@ from mlx_video.lora.apply import ( apply_loras_to_model, apply_loras_to_weights, ) -from mlx_video.lora.loader import ( - load_lora_weights, - load_multiple_loras, -) +from mlx_video.lora.loader import load_lora_weights, load_multiple_loras from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights __all__ = [ diff --git a/mlx_video/lora/apply.py b/mlx_video/lora/apply.py index 97b694e..dadb62d 100644 --- a/mlx_video/lora/apply.py +++ b/mlx_video/lora/apply.py @@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str: candidates = [lora_key] for prefix in prefixes_to_strip: if lora_key.startswith(prefix): - candidates.append(lora_key[len(prefix):]) + candidates.append(lora_key[len(prefix) :]) for candidate in candidates: # Try as-is @@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str: transformed = transformed.replace(".ffn.0.", ".ffn.fc1.") transformed = transformed.replace(".ffn.2.", ".ffn.fc2.") if transformed.endswith(".ffn.0"): - transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1" + transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1" if transformed.endswith(".ffn.2"): - transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2" + transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2" # Text embedding: text_embedding.0 → text_embedding_0 transformed = transformed.replace("text_embedding.0.", "text_embedding_0.") transformed = transformed.replace("text_embedding.2.", "text_embedding_1.") if transformed.endswith("text_embedding.0"): - transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0" + transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0" if transformed.endswith("text_embedding.2"): - transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1" + transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1" # Time embedding: time_embedding.0 → time_embedding_0 transformed = transformed.replace("time_embedding.0.", "time_embedding_0.") transformed = transformed.replace("time_embedding.2.", "time_embedding_1.") if transformed.endswith("time_embedding.0"): - transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0" + transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0" if transformed.endswith("time_embedding.2"): - transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1" + transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1" # Time projection: time_projection.1 → time_projection transformed = transformed.replace("time_projection.1.", "time_projection.") if transformed.endswith("time_projection.1"): - transformed = transformed[:-len("time_projection.1")] + "time_projection" + transformed = transformed[: -len("time_projection.1")] + "time_projection" # Patch embedding: patch_embedding → patch_embedding_proj - if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed: + if ( + "patch_embedding" in transformed + and "patch_embedding_proj" not in transformed + ): transformed = transformed.replace("patch_embedding", "patch_embedding_proj") if f"{transformed}.weight" in model_keys or transformed in model_keys: @@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str: # Return best attempt with prefix stripped for prefix in prefixes_to_strip: if lora_key.startswith(prefix): - return lora_key[len(prefix):] + return lora_key[len(prefix) :] return lora_key @@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str: for prefix in prefixes_to_strip: if lora_key.startswith(prefix): - normalized = lora_key[len(prefix):] + normalized = lora_key[len(prefix) :] if f"{normalized}.weight" in model_keys or normalized in model_keys: return normalized transformed = normalized if transformed.endswith(".to_out.0"): - transformed = transformed[:-len(".to_out.0")] + ".to_out" + transformed = transformed[: -len(".to_out.0")] + ".to_out" transformed = transformed.replace(".to_out.0.", ".to_out.") transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.") transformed = transformed.replace(".ff.net.2", ".ff.proj_out") - transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") - transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in") + transformed = transformed.replace( + ".audio_ff.net.0.proj.", ".audio_ff.proj_in." + ) + transformed = transformed.replace( + ".audio_ff.net.0.proj", ".audio_ff.proj_in" + ) transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out") @@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str: # Try transformations on the original key transformed = lora_key if transformed.endswith(".to_out.0"): - transformed = transformed[:-len(".to_out.0")] + ".to_out" + transformed = transformed[: -len(".to_out.0")] + ".to_out" transformed = transformed.replace(".to_out.0.", ".to_out.") transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") @@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str: for prefix in prefixes_to_strip: if lora_key.startswith(prefix): - return lora_key[len(prefix):] + return lora_key[len(prefix) :] return lora_key @@ -226,7 +233,9 @@ def apply_loras_to_weights( skipped_count += 1 skipped_modules.append(module_name) if verbose and skipped_count <= 5: - print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND") + print( + f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND" + ) similar = [ k for k in list(model_keys)[:1000] @@ -251,13 +260,21 @@ def apply_loras_to_weights( if is_quantized: scales = modified_weights[scales_key] biases = modified_weights[biases_key] - group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits) + group_size = (original_weight.shape[-1] * 32) // ( + scales.shape[-1] * quantization_bits + ) dequantized = mx.dequantize( - original_weight, scales, biases, group_size=group_size, bits=quantization_bits + original_weight, + scales, + biases, + group_size=group_size, + bits=quantization_bits, ) modified = apply_lora_to_linear(dequantized, loras) # Re-quantize with same parameters - new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits) + new_w, new_scales, new_biases = mx.quantize( + modified, group_size=group_size, bits=quantization_bits + ) modified_weights[weight_key] = new_w modified_weights[scales_key] = new_scales modified_weights[biases_key] = new_biases @@ -346,9 +363,15 @@ def apply_loras_to_model( parent = model try: for part in parts[:-1]: - parent = getattr(parent, part) if not part.isdigit() else parent[int(part)] + parent = ( + getattr(parent, part) if not part.isdigit() else parent[int(part)] + ) leaf_name = parts[-1] - target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)] + target = ( + getattr(parent, leaf_name) + if not leaf_name.isdigit() + else parent[int(leaf_name)] + ) except (AttributeError, IndexError, TypeError): skipped.append(lora_key) if verbose: @@ -358,8 +381,11 @@ def apply_loras_to_model( if isinstance(target, nn.QuantizedLinear): # Dequantize → merge LoRA → replace with bf16 Linear weight = mx.dequantize( - target.weight, target.scales, target.biases, - group_size=target.group_size, bits=target.bits, + target.weight, + target.scales, + target.biases, + group_size=target.group_size, + bits=target.bits, ) merged = apply_lora_to_linear(weight, loras) new_linear = nn.Linear(merged.shape[1], merged.shape[0]) @@ -379,7 +405,9 @@ def apply_loras_to_model( else: skipped.append(lora_key) if verbose: - print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear") + print( + f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear" + ) continue if applied_count > 0: diff --git a/mlx_video/lora/loader.py b/mlx_video/lora/loader.py index adf11b1..2a44aca 100644 --- a/mlx_video/lora/loader.py +++ b/mlx_video/lora/loader.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List import mlx.core as mx diff --git a/mlx_video/models/__init__.py b/mlx_video/models/__init__.py index e591cba..4c49754 100644 --- a/mlx_video/models/__init__.py +++ b/mlx_video/models/__init__.py @@ -1,3 +1,2 @@ - from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig from mlx_video.models.wan import WanModel, WanModelConfig diff --git a/mlx_video/models/ltx_2/__init__.py b/mlx_video/models/ltx_2/__init__.py index 7e58251..f382326 100644 --- a/mlx_video/models/ltx_2/__init__.py +++ b/mlx_video/models/ltx_2/__init__.py @@ -1,8 +1,7 @@ - +from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio from mlx_video.models.ltx_2.config import ( LTXModelConfig, - TransformerConfig, LTXModelType, + TransformerConfig, ) from mlx_video.models.ltx_2.ltx import LTXModel, X0Model -from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio diff --git a/mlx_video/models/ltx_2/adaln.py b/mlx_video/models/ltx_2/adaln.py index fee57c1..6d61129 100644 --- a/mlx_video/models/ltx_2/adaln.py +++ b/mlx_video/models/ltx_2/adaln.py @@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding class AdaLayerNormSingle(nn.Module): - def __init__( self, embedding_dim: int, @@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module): ) self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + self.linear = nn.Linear( + embedding_dim, embedding_coefficient * embedding_dim, bias=True + ) def __call__( self, @@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): use_additional_conditions: bool = False, timestep_proj_dim: int = 256, ): - + super().__init__() self.embedding_dim = embedding_dim self.size_emb_dim = size_emb_dim self.use_additional_conditions = use_additional_conditions - self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim) + self.time_proj = Timesteps( + timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.timestep_embedder = TimestepEmbedding( + timestep_proj_dim, embedding_dim, out_dim=embedding_dim + ) if use_additional_conditions and size_emb_dim > 0: self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim) @@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): # Add additional conditions if enabled if self.use_additional_conditions and self.size_emb_dim > 0: if resolution is not None and aspect_ratio is not None: - additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype) + additional_embeds = self.additional_embedder( + resolution, aspect_ratio, hidden_dtype + ) timesteps_emb = timesteps_emb + additional_embeds return timesteps_emb diff --git a/mlx_video/models/ltx_2/audio_vae/__init__.py b/mlx_video/models/ltx_2/audio_vae/__init__.py index 79a1679..59509e0 100644 --- a/mlx_video/models/ltx_2/audio_vae/__init__.py +++ b/mlx_video/models/ltx_2/audio_vae/__init__.py @@ -1,10 +1,10 @@ """Audio VAE module for LTX-2 audio generation.""" -from .attention import AttentionType, AttnBlock, make_attn -from .audio_vae import AudioDecoder, AudioEncoder, decode_audio -from .audio_processor import load_audio, ensure_stereo, waveform_to_mel -from .causal_conv_2d import CausalConv2d, make_conv2d from ..config import CausalityAxis +from .attention import AttentionType, AttnBlock, make_attn +from .audio_processor import ensure_stereo, load_audio, waveform_to_mel +from .audio_vae import AudioDecoder, AudioEncoder, decode_audio +from .causal_conv_2d import CausalConv2d, make_conv2d from .downsample import Downsample, build_downsampling_path from .normalization import NormType, PixelNorm, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics diff --git a/mlx_video/models/ltx_2/audio_vae/attention.py b/mlx_video/models/ltx_2/audio_vae/attention.py index 38c5744..a4f868f 100644 --- a/mlx_video/models/ltx_2/audio_vae/attention.py +++ b/mlx_video/models/ltx_2/audio_vae/attention.py @@ -32,7 +32,9 @@ class AttnBlock(nn.Module): self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def __call__(self, x: mx.array) -> mx.array: """ @@ -103,6 +105,8 @@ def make_attn( elif attn_type == AttentionType.NONE: return Identity() elif attn_type == AttentionType.LINEAR: - raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + raise NotImplementedError( + f"Attention type {attn_type.value} is not supported yet." + ) else: raise ValueError(f"Unknown attention type: {attn_type}") diff --git a/mlx_video/models/ltx_2/audio_vae/audio_processor.py b/mlx_video/models/ltx_2/audio_vae/audio_processor.py index ed5ff7a..915575f 100644 --- a/mlx_video/models/ltx_2/audio_vae/audio_processor.py +++ b/mlx_video/models/ltx_2/audio_vae/audio_processor.py @@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog using librosa for macOS/MLX compatibility. """ -from pathlib import Path -import numpy as np import mlx.core as mx +import numpy as np def load_audio( @@ -99,14 +98,16 @@ def waveform_to_mel( for ch in range(channels): # Magnitude spectrogram (power=1.0) - S = np.abs(librosa.stft( - waveform[ch], - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - center=True, - pad_mode="reflect", - )) + S = np.abs( + librosa.stft( + waveform[ch], + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=True, + pad_mode="reflect", + ) + ) # Mel filterbank with slaney normalization mel_basis = librosa.filters.mel( diff --git a/mlx_video/models/ltx_2/audio_vae/audio_vae.py b/mlx_video/models/ltx_2/audio_vae/audio_vae.py index e9954ed..415c222 100644 --- a/mlx_video/models/ltx_2/audio_vae/audio_vae.py +++ b/mlx_video/models/ltx_2/audio_vae/audio_vae.py @@ -1,15 +1,15 @@ """Audio VAE encoder and decoder for LTX-2.""" -from typing import Dict from pathlib import Path +from typing import Dict import mlx.core as mx import mlx.nn as nn from mlx_vlm.models.base import check_array_shape -from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig + +from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d -from ..config import CausalityAxis from .downsample import build_downsampling_path from .normalization import NormType, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics @@ -39,7 +39,9 @@ def build_mid_block( causality_axis=causality_axis, ) mid["attn_1"] = ( - make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None + make_attn(channels, attn_type=attn_type, norm_type=norm_type) + if add_attention + else None ) mid["block_2"] = ResnetBlock( in_channels=channels, @@ -93,7 +95,10 @@ class AudioEncoder(nn.Module): self.attn_type = config.attn_type self.conv_in = make_conv2d( - config.in_channels, self.ch, kernel_size=3, stride=1, + config.in_channels, + self.ch, + kernel_size=3, + stride=1, causality_axis=self.causality_axis, ) @@ -125,7 +130,10 @@ class AudioEncoder(nn.Module): self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) out_channels = 2 * config.z_channels if config.double_z else config.z_channels self.conv_out = make_conv2d( - block_in, out_channels, kernel_size=3, stride=1, + block_in, + out_channels, + kernel_size=3, + stride=1, causality_axis=self.causality_axis, ) @@ -160,7 +168,11 @@ class AudioEncoder(nn.Module): continue if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: - value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) + value = ( + value + if check_array_shape(value) + else mx.transpose(value, (0, 2, 3, 1)) + ) sanitized[new_key] = value return sanitized @@ -168,11 +180,14 @@ class AudioEncoder(nn.Module): @classmethod def from_pretrained(cls, model_path: Path) -> "AudioEncoder": """Load audio encoder from pretrained weights.""" - from mlx_video.models.ltx_2.config import AudioEncoderModelConfig import json + from mlx_video.models.ltx_2.config import AudioEncoderModelConfig + model_path = Path(model_path) - config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) + config = AudioEncoderModelConfig.from_dict( + json.load(open(model_path / "config.json")) + ) encoder = cls(config) weights = mx.load(str(model_path / "model.safetensors")) encoder.load_weights(list(weights.items()), strict=True) @@ -265,7 +280,6 @@ class AudioDecoder(nn.Module): """ super().__init__() - # Per-channel statistics for denormalizing latents # Uses ch (base channel count) to match the patchified latent dimension # Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16) @@ -305,7 +319,11 @@ class AudioDecoder(nn.Module): self.z_shape = (1, config.z_channels, base_resolution, base_resolution) self.conv_in = make_conv2d( - config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + config.z_channels, + base_block_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, ) self.mid = build_mid_block( @@ -334,9 +352,15 @@ class AudioDecoder(nn.Module): initial_block_channels=base_block_channels, ) - self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.norm_out = build_normalization_layer( + final_block_channels, normtype=self.norm_type + ) self.conv_out = make_conv2d( - final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + final_block_channels, + config.out_ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, ) def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: @@ -371,7 +395,11 @@ class AudioDecoder(nn.Module): # PyTorch: (out_channels, in_channels, H, W) # MLX: (out_channels, H, W, in_channels) if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: - value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) + value = ( + value + if check_array_shape(value) + else mx.transpose(value, (0, 2, 3, 1)) + ) sanitized[new_key] = value @@ -380,17 +408,19 @@ class AudioDecoder(nn.Module): @classmethod def from_pretrained(cls, model_path: Path) -> "AudioDecoder": """Load audio VAE decoder from pretrained model.""" - from mlx_video.models.ltx_2.config import AudioDecoderModelConfig import json - config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) + from mlx_video.models.ltx_2.config import AudioDecoderModelConfig + + config = AudioDecoderModelConfig.from_dict( + json.load(open(model_path / "config.json")) + ) decoder = cls(config) weights = mx.load(str(model_path / "model.safetensors")) # weights = decoder.sanitize(weights) decoder.load_weights(list(weights.items()), strict=True) return decoder - def __call__(self, sample: mx.array) -> mx.array: """ Decode latent features back to audio spectrograms. @@ -414,7 +444,9 @@ class AudioDecoder(nn.Module): return self._adjust_output_shape(h, target_shape) - def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]: + def _denormalize_latents( + self, sample: mx.array + ) -> tuple[mx.array, AudioLatentShape]: """Denormalize latents using per-channel statistics.""" # sample shape: (B, H, W, C) in MLX format latent_shape = AudioLatentShape( @@ -436,7 +468,9 @@ class AudioDecoder(nn.Module): batch=latent_shape.batch, channels=self.out_ch, frames=target_frames, - mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + mel_bins=( + self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins + ), ) return sample, target_shape @@ -462,7 +496,10 @@ class AudioDecoder(nn.Module): # Step 1: Crop first to avoid exceeding target dimensions decoded_output = decoded_output[ - :, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels + :, + : min(current_time, target_time), + : min(current_freq, target_freq), + :target_channels, ] # Step 2: Calculate padding needed for time and frequency dimensions @@ -514,7 +551,9 @@ class AudioDecoder(nn.Module): return mx.tanh(h) if self.tanh_out else h -def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array: +def decode_audio( + latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder" +) -> mx.array: """ Decode an audio latent representation using the provided audio decoder and vocoder. Args: diff --git a/mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py b/mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py index b303268..4cc8233 100644 --- a/mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py +++ b/mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py @@ -53,8 +53,16 @@ class CausalConv2d(nn.Module): # For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width) if self.causality_axis == CausalityAxis.NONE: # Non-causal: symmetric padding - self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2) - elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY): + self.padding = ( + pad_h // 2, + pad_h - pad_h // 2, + pad_w // 2, + pad_w - pad_w // 2, + ) + elif self.causality_axis in ( + CausalityAxis.WIDTH, + CausalityAxis.WIDTH_COMPATIBILITY, + ): # Causal on width: pad left (before width axis) self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0) elif self.causality_axis == CausalityAxis.HEIGHT: @@ -90,7 +98,10 @@ class CausalConv2d(nn.Module): if any(p > 0 for p in self.padding): # MLX pad expects: [(before_0, after_0), (before_1, after_1), ...] # For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C - x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)]) + x = mx.pad( + x, + [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)], + ) return self.conv(x) @@ -124,7 +135,14 @@ def make_conv2d( if causality_axis is not None: # For causal convolution, padding is handled internally by CausalConv2d return CausalConv2d( - in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis + in_channels, + out_channels, + kernel_size, + stride, + dilation, + groups, + bias, + causality_axis, ) else: # For non-causal convolution, use symmetric padding if not specified diff --git a/mlx_video/models/ltx_2/audio_vae/downsample.py b/mlx_video/models/ltx_2/audio_vae/downsample.py index 8831668..f80c18e 100644 --- a/mlx_video/models/ltx_2/audio_vae/downsample.py +++ b/mlx_video/models/ltx_2/audio_vae/downsample.py @@ -5,8 +5,8 @@ from typing import Set, Tuple import mlx.core as mx import mlx.nn as nn -from .attention import AttentionType, make_attn from ..config import CausalityAxis +from .attention import AttentionType, make_attn from .normalization import NormType from .resnet import ResnetBlock @@ -34,7 +34,9 @@ class Downsample(nn.Module): if self.with_conv: # Do time downsampling here # no asymmetric padding in MLX conv, must do it ourselves - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def __call__(self, x: mx.array) -> mx.array: """ @@ -116,10 +118,14 @@ def build_downsampling_path( ) block_in = block_out if curr_res in attn_resolutions: - stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) + stage["attn"][i_block] = make_attn( + block_in, attn_type=attn_type, norm_type=norm_type + ) if i_level != num_resolutions - 1: - stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + stage["downsample"] = Downsample( + block_in, resamp_with_conv, causality_axis=causality_axis + ) curr_res = curr_res // 2 down_modules[i_level] = stage diff --git a/mlx_video/models/ltx_2/audio_vae/normalization.py b/mlx_video/models/ltx_2/audio_vae/normalization.py index 361c6b4..8376a21 100644 --- a/mlx_video/models/ltx_2/audio_vae/normalization.py +++ b/mlx_video/models/ltx_2/audio_vae/normalization.py @@ -51,7 +51,9 @@ def build_normalization_layer( A normalization layer """ if normtype == NormType.GROUP: - return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True) + return nn.GroupNorm( + num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True + ) if normtype == NormType.PIXEL: # For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1) # PyTorch uses dim=1 for channels-first format (B, C, H, W) diff --git a/mlx_video/models/ltx_2/audio_vae/resnet.py b/mlx_video/models/ltx_2/audio_vae/resnet.py index ca20f67..c1b2ee5 100644 --- a/mlx_video/models/ltx_2/audio_vae/resnet.py +++ b/mlx_video/models/ltx_2/audio_vae/resnet.py @@ -1,12 +1,12 @@ """ResNet blocks for audio VAE and vocoder.""" -from typing import List, Tuple +from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .causal_conv_2d import make_conv2d from ..config import CausalityAxis +from .causal_conv_2d import make_conv2d from .normalization import NormType, build_normalization_layer LRELU_SLOPE = 0.1 @@ -125,7 +125,11 @@ class ResnetBlock(nn.Module): self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) self.conv1 = make_conv2d( - in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + in_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, ) if temb_channels > 0: @@ -134,17 +138,29 @@ class ResnetBlock(nn.Module): self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) self.dropout_rate = dropout self.conv2 = make_conv2d( - out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + out_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + in_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, ) else: self.nin_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + in_channels, + out_channels, + kernel_size=1, + stride=1, + causality_axis=causality_axis, ) def __call__( @@ -168,7 +184,9 @@ class ResnetBlock(nn.Module): if temb is not None and self.temb_channels > 0: # temb: (B, temb_channels) -> (B, out_channels) # Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting - h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1) + h = h + mx.expand_dims( + mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1 + ) h = self.norm2(h) h = nn.silu(h) diff --git a/mlx_video/models/ltx_2/audio_vae/upsample.py b/mlx_video/models/ltx_2/audio_vae/upsample.py index 734ccab..d443049 100644 --- a/mlx_video/models/ltx_2/audio_vae/upsample.py +++ b/mlx_video/models/ltx_2/audio_vae/upsample.py @@ -5,9 +5,9 @@ from typing import Set, Tuple import mlx.core as mx import mlx.nn as nn +from ..config import CausalityAxis from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d -from ..config import CausalityAxis from .normalization import NormType from .resnet import ResnetBlock @@ -42,7 +42,11 @@ class Upsample(nn.Module): self.causality_axis = causality_axis if self.with_conv: self.conv = make_conv2d( - in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + in_channels, + in_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, ) def __call__(self, x: mx.array) -> mx.array: @@ -124,10 +128,14 @@ def build_upsampling_path( ) block_in = block_out if curr_res in attn_resolutions: - stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) + stage["attn"][i_block] = make_attn( + block_in, attn_type=attn_type, norm_type=norm_type + ) if level != 0: - stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + stage["upsample"] = Upsample( + block_in, resamp_with_conv, causality_axis=causality_axis + ) curr_res *= 2 up_modules[level] = stage diff --git a/mlx_video/models/ltx_2/audio_vae/vocoder.py b/mlx_video/models/ltx_2/audio_vae/vocoder.py index 71b548c..70e2722 100644 --- a/mlx_video/models/ltx_2/audio_vae/vocoder.py +++ b/mlx_video/models/ltx_2/audio_vae/vocoder.py @@ -7,8 +7,8 @@ Supports: """ import math -from typing import List, Tuple from pathlib import Path +from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -32,7 +32,9 @@ class Snake(nn.Module): def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: super().__init__() self.alpha_logscale = alpha_logscale - self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + self.alpha = ( + mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + ) def __call__(self, x: mx.array) -> mx.array: # x: (N, L, C) in MLX format @@ -48,8 +50,12 @@ class SnakeBeta(nn.Module): def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: super().__init__() self.alpha_logscale = alpha_logscale - self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) - self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + self.alpha = ( + mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + ) + self.beta = ( + mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + ) def __call__(self, x: mx.array) -> mx.array: alpha = self.alpha @@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array: ) -def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array: +def kaiser_sinc_filter1d( + cutoff: float, half_width: float, kernel_size: int +) -> mx.array: """Compute a Kaiser-windowed sinc filter.""" even = kernel_size % 2 == 0 half_size = kernel_size // 2 @@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> # Kaiser window - compute using scipy-compatible formula import numpy as np + window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32)) if even: @@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]: """Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler).""" import numpy as np + rolloff = 0.99 lowpass_filter_width = 6 width = math.ceil(lowpass_filter_width / rolloff) @@ -187,10 +197,16 @@ class UpSample1d(nn.Module): self.kernel_size = filt.shape[2] self.filter = filt else: - self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) self.pad = self.kernel_size // ratio - 1 - self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 - self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + self.pad_left = ( + self.pad * self.stride + (self.kernel_size - self.stride) // 2 + ) + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) self.filter = kaiser_sinc_filter1d( cutoff=0.5 / ratio, half_width=0.6 / ratio, @@ -215,10 +231,12 @@ class UpSample1d(nn.Module): filt = self.filter.astype(x.dtype) # (1, 1, K) filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1) - x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1) + x = self.ratio * mx.conv_transpose1d( + x, filt, stride=self.stride + ) # (N*C, L', 1) # Trim padding - x = x[:, self.pad_left:-self.pad_right, :] + x = x[:, self.pad_left : -self.pad_right, :] x = x.reshape(n, c, -1) # (N, C, L') x = mx.transpose(x, (0, 2, 1)) # (N, L', C) @@ -285,16 +303,24 @@ class AMPBlock1(nn.Module): self.convs1 = { i: nn.Conv1d( - channels, channels, kernel_size, stride=1, - dilation=d, padding=get_padding(kernel_size, d), + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), ) for i, d in enumerate(dilation) } self.convs2 = { i: nn.Conv1d( - channels, channels, kernel_size, stride=1, - dilation=1, padding=get_padding(kernel_size, 1), + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), ) for i in range(len(dilation)) } @@ -348,7 +374,9 @@ class STFTFn(nn.Module): y = mx.concatenate([first, y], axis=1) # forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX - basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1) + basis = mx.transpose( + self.forward_basis.astype(y.dtype), (0, 2, 1) + ) # (514, K, 1) # Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514) spec = mx.conv1d(y, basis, stride=self.hop_length) @@ -358,8 +386,10 @@ class STFTFn(nn.Module): real = spec[..., :n_freqs] imag = spec[..., n_freqs:] - magnitude = mx.sqrt(real ** 2 + imag ** 2) - phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype) + magnitude = mx.sqrt(real**2 + imag**2) + phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype( + real.dtype + ) # Output: (B, T_frames, n_freqs) in MLX channels-last return magnitude, phase @@ -368,7 +398,9 @@ class STFTFn(nn.Module): class MelSTFT(nn.Module): """Causal log-mel spectrogram from precomputed STFT bases.""" - def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None: + def __init__( + self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int + ) -> None: super().__init__() self.stft_fn = STFTFn(filter_length, hop_length, win_length) n_freqs = filter_length // 2 + 1 @@ -385,7 +417,9 @@ class MelSTFT(nn.Module): """ magnitude, phase = self.stft_fn(y) # magnitude: (B, T_frames, n_freqs) - mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels) + mel = ( + magnitude @ self.mel_basis.astype(magnitude.dtype).T + ) # (B, T_frames, n_mels) log_mel = mx.log(mx.clip(mel, 1e-5, None)) # Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format return mx.transpose(log_mel, (0, 2, 1)) @@ -415,8 +449,11 @@ class Vocoder(nn.Module): in_channels = 128 if config.stereo else 64 self.conv_pre = nn.Conv1d( - in_channels, config.upsample_initial_channel, - kernel_size=7, stride=1, padding=3, + in_channels, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, ) # Upsampling layers @@ -424,11 +461,13 @@ class Vocoder(nn.Module): for i, (stride, kernel_size) in enumerate( zip(config.upsample_rates, config.upsample_kernel_sizes) ): - in_ch = config.upsample_initial_channel // (2 ** i) + in_ch = config.upsample_initial_channel // (2**i) out_ch = config.upsample_initial_channel // (2 ** (i + 1)) self.ups[i] = nn.ConvTranspose1d( - in_ch, out_ch, - kernel_size=kernel_size, stride=stride, + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, padding=(kernel_size - stride) // 2, ) @@ -442,7 +481,9 @@ class Vocoder(nn.Module): config.resblock_kernel_sizes, config.resblock_dilation_sizes ): self.resblocks[block_idx] = AMPBlock1( - ch, kernel_size, tuple(dilations), + ch, + kernel_size, + tuple(dilations), activation=config.activation, ) block_idx += 1 @@ -455,10 +496,14 @@ class Vocoder(nn.Module): for kernel_size, dilations in zip( config.resblock_kernel_sizes, config.resblock_dilation_sizes ): - self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) + self.resblocks[block_idx] = resblock_class( + ch, kernel_size, tuple(dilations) + ) block_idx += 1 - final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates)) + final_channels = config.upsample_initial_channel // ( + 2 ** len(config.upsample_rates) + ) # Post-activation if self.is_amp: @@ -468,8 +513,11 @@ class Vocoder(nn.Module): # Final conv out_channels = 2 if config.stereo else 1 self.conv_post = nn.Conv1d( - final_channels, out_channels, - kernel_size=7, stride=1, padding=3, + final_channels, + out_channels, + kernel_size=7, + stride=1, + padding=3, bias=config.use_bias_at_final, ) @@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module): """ x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate _, _, length_low_rate = x.shape - output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate + output_length = ( + length_low_rate * self.output_sampling_rate // self.input_sampling_rate + ) # Pad to hop_length multiple remainder = length_low_rate % self.hop_length @@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE: model.load_weights(list(weights.items()), strict=False) return model - - diff --git a/mlx_video/models/ltx_2/conditioning/__init__.py b/mlx_video/models/ltx_2/conditioning/__init__.py index 3f8516e..08d7c97 100644 --- a/mlx_video/models/ltx_2/conditioning/__init__.py +++ b/mlx_video/models/ltx_2/conditioning/__init__.py @@ -1,3 +1,6 @@ """Conditioning modules for LTX-2 video generation.""" -from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning +from mlx_video.models.ltx_2.conditioning.latent import ( + VideoConditionByLatentIndex, + apply_conditioning, +) diff --git a/mlx_video/models/ltx_2/conditioning/latent.py b/mlx_video/models/ltx_2/conditioning/latent.py index acf3d99..4f101b2 100644 --- a/mlx_video/models/ltx_2/conditioning/latent.py +++ b/mlx_video/models/ltx_2/conditioning/latent.py @@ -5,7 +5,7 @@ the video generation process at specific frame positions. """ from dataclasses import dataclass -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple import mlx.core as mx @@ -22,6 +22,7 @@ class VideoConditionByLatentIndex: frame_idx: Frame index to condition (0 = first frame) strength: Denoising strength (1.0 = full denoise, 0.0 = keep original) """ + latent: mx.array frame_idx: int = 0 strength: float = 1.0 @@ -41,6 +42,7 @@ class LatentState: denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where 1.0 = full denoise, 0.0 = keep clean """ + latent: mx.array clean_latent: mx.array denoise_mask: mx.array @@ -130,15 +132,15 @@ def apply_conditioning( if frame_idx <= i < end_idx: # Use conditioning latent cond_idx = i - frame_idx - latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) - clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) + latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1]) + clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1]) # Set mask: 1.0 - strength means less denoising for conditioned frames mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype)) else: # Keep original - latent_list.append(state.latent[:, :, i:i+1]) - clean_list.append(state.clean_latent[:, :, i:i+1]) - mask_list.append(state.denoise_mask[:, :, i:i+1]) + latent_list.append(state.latent[:, :, i : i + 1]) + clean_list.append(state.clean_latent[:, :, i : i + 1]) + mask_list.append(state.denoise_mask[:, :, i : i + 1]) state.latent = mx.concatenate(latent_list, axis=2) state.clean_latent = mx.concatenate(clean_list, axis=2) diff --git a/mlx_video/models/ltx_2/config.py b/mlx_video/models/ltx_2/config.py index 4692d45..3d96ab7 100644 --- a/mlx_video/models/ltx_2/config.py +++ b/mlx_video/models/ltx_2/config.py @@ -1,4 +1,3 @@ - import inspect from dataclasses import dataclass, field from enum import Enum @@ -22,9 +21,11 @@ class LTXRopeType(Enum): SPLIT = "split" TWO_D = "2d" + class AttentionType(Enum): DEFAULT = "default" + @dataclass class BaseModelConfig: @@ -46,7 +47,7 @@ class BaseModelConfig: if v is not None: if isinstance(v, Enum): result[k] = v.value - elif hasattr(v, 'to_dict'): + elif hasattr(v, "to_dict"): result[k] = v.to_dict() else: result[k] = v @@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig): out_channels: int = 128 latent_channels: int = 128 patch_size: int = 4 - encoder_blocks: List[tuple] = field(default_factory=lambda: [ - ("res_x", {"num_layers": 4}), - ("compress_space_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_time_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ]) - decoder_blocks: List[tuple] = field(default_factory=lambda: [ - ("res_x", {"num_layers": 5, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 5, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 5, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 5, "inject_noise": False}), - ]) + encoder_blocks: List[tuple] = field( + default_factory=lambda: [ + ("res_x", {"num_layers": 4}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ] + ) + decoder_blocks: List[tuple] = field( + default_factory=lambda: [ + ("res_x", {"num_layers": 5, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 5, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 5, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 5, "inject_noise": False}), + ] + ) @dataclass @@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig): audio_in_channels: int = 128 audio_out_channels: int = 128 audio_cross_attention_dim: int = 2048 - audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video) + audio_caption_channels: int = ( + 3840 # Input dim for audio text embeddings (same as video) + ) # Positional embedding config positional_embedding_theta: float = 10000.0 @@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig): ) - class CausalityAxis(Enum): """Enum for specifying the causality axis in causal convolutions.""" @@ -237,21 +243,22 @@ class AudioDecoderModelConfig(BaseModelConfig): def __post_init__(self): """Convert string enum values to proper enum types.""" # Import here to avoid circular imports - from .audio_vae.normalization import NormType from .audio_vae.attention import AttentionType - + from .audio_vae.normalization import NormType + # Convert causality_axis string to enum if isinstance(self.causality_axis, str): self.causality_axis = CausalityAxis(self.causality_axis) - + # Convert norm_type string to enum if isinstance(self.norm_type, str): self.norm_type = NormType(self.norm_type) - + # Convert attn_type string to enum if isinstance(self.attn_type, str): self.attn_type = AttentionType(self.attn_type) + @dataclass class AudioEncoderModelConfig(BaseModelConfig): ch: int = 128 @@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig): def __post_init__(self): """Convert string enum values to proper enum types.""" - from .audio_vae.normalization import NormType from .audio_vae.attention import AttentionType + from .audio_vae.normalization import NormType if isinstance(self.causality_axis, str): self.causality_axis = CausalityAxis(self.causality_axis) @@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig): dropout: float = 0.0 timestep_conditioning: bool = False + @dataclass class VideoEncoderModelConfig(BaseModelConfig): convolution_dimensions: int = 3 @@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig): norm_layer: Enum = None latent_log_var: Enum = None encoder_spatial_padding_mode: Enum = None - encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}), - ("compress_space_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_time_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}) - ]) + encoder_blocks: List[tuple] = field( + default_factory=lambda: [ + ("res_x", {"num_layers": 4}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ] + ) def __post_init__(self): + from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType - from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType if self.norm_layer is None: self.norm_layer = NormLayerType.PIXEL_NORM @@ -371,10 +382,12 @@ class VideoEncoderModelConfig(BaseModelConfig): if isinstance(self.latent_log_var, str): self.latent_log_var = LogVarianceType(self.latent_log_var) if isinstance(self.encoder_spatial_padding_mode, str): - self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode) + self.encoder_spatial_padding_mode = PaddingModeType( + self.encoder_spatial_padding_mode + ) def to_dict(self) -> dict[str, Any]: result = super().to_dict() if self.encoder_blocks is not None: result["encoder_blocks"] = [list(block) for block in self.encoder_blocks] - return result \ No newline at end of file + return result diff --git a/mlx_video/models/ltx_2/convert.py b/mlx_video/models/ltx_2/convert.py index bc6d239..ffc2a65 100644 --- a/mlx_video/models/ltx_2/convert.py +++ b/mlx_video/models/ltx_2/convert.py @@ -49,7 +49,6 @@ from typing import Dict import mlx.core as mx - # ─── Key prefix routing ────────────────────────────────────────────────────── TRANSFORMER_PREFIX = "model.diffusion_model." @@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: if "audio_embeddings_connector" in key or "video_embeddings_connector" in key: continue - new_key = key[len(TRANSFORMER_PREFIX):] + new_key = key[len(TRANSFORMER_PREFIX) :] new_key = new_key.replace(".to_out.0.", ".to_out.") new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") @@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: else: continue elif key.startswith(VAE_DECODER_PREFIX): - new_key = key[len(VAE_DECODER_PREFIX):] + new_key = key[len(VAE_DECODER_PREFIX) :] else: continue @@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: if value.dtype != mx.float32: value = value.astype(mx.float32) elif key.startswith(VAE_ENCODER_PREFIX): - new_key = key[len(VAE_ENCODER_PREFIX):] + new_key = key[len(VAE_ENCODER_PREFIX) :] else: continue @@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: new_key = None if key.startswith(AUDIO_DECODER_PREFIX): - new_key = key[len(AUDIO_DECODER_PREFIX):] + new_key = key[len(AUDIO_DECODER_PREFIX) :] elif key.startswith(AUDIO_STATS_PREFIX): if "mean-of-means" in key: new_key = "per_channel_statistics.mean_of_means" @@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: new_key = None if key.startswith(AUDIO_ENCODER_PREFIX): - new_key = key[len(AUDIO_ENCODER_PREFIX):] + new_key = key[len(AUDIO_ENCODER_PREFIX) :] elif key.startswith(AUDIO_STATS_PREFIX): if "mean-of-means" in key: new_key = "per_channel_statistics.mean_of_means" @@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: if not key.startswith(VOCODER_PREFIX): continue - new_key = key[len(VOCODER_PREFIX):] + new_key = key[len(VOCODER_PREFIX) :] # Handle Conv1d/ConvTranspose1d weight shape conversion if "weight" in new_key and value.ndim == 3: @@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array # aggregate_embed weights (text_embedding_projection.*) for key, value in weights.items(): if key.startswith(TEXT_PROJ_PREFIX): - new_key = key[len(TEXT_PROJ_PREFIX):] + new_key = key[len(TEXT_PROJ_PREFIX) :] extracted[new_key] = value # video_embeddings_connector for key, value in weights.items(): if key.startswith(VIDEO_CONNECTOR_PREFIX): - suffix = key[len(VIDEO_CONNECTOR_PREFIX):] + suffix = key[len(VIDEO_CONNECTOR_PREFIX) :] new_key = "video_embeddings_connector." + sanitize_connector_key(suffix) extracted[new_key] = value # audio_embeddings_connector for key, value in weights.items(): if key.startswith(AUDIO_CONNECTOR_PREFIX): - suffix = key[len(AUDIO_CONNECTOR_PREFIX):] + suffix = key[len(AUDIO_CONNECTOR_PREFIX) :] new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix) extracted[new_key] = value @@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path): # ─── Source resolution ───────────────────────────────────────────────────────── # Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc. -MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?Pdistilled|dev)\.safetensors$") +MONOLITHIC_PATTERN = re.compile( + r"^ltx-[\d.]+-\d+b-(?Pdistilled|dev)\.safetensors$" +) # Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors, # ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc. -UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$") +UPSCALER_PATTERN = re.compile( + r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$" +) def resolve_source(source: str, variant: str) -> Path: @@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict: def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict: """Infer VAE decoder config from weights.""" # Check for timestep conditioning keys - has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights) + has_timestep = any( + "last_time_embedder" in k or "last_scale_shift_table" in k for k in weights + ) # Count channel multipliers from up_blocks max_block = -1 @@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): config = infer_transformer_config(transformer_weights) save_config(config, output_path / "transformer") t_params = sum(v.size for v in transformer_weights.values()) - print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards") + print( + f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards" + ) # 2. VAE Decoder print(" [2/7] VAE Decoder...") @@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): ] else: upscaler_files = [ - f.name for f in source_dir.iterdir() + f.name + for f in source_dir.iterdir() if f.is_file() and UPSCALER_PATTERN.match(f.name) ] @@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(f"\nDone! Converted {all_converted}/{total_keys} keys") if all_converted < total_keys: known_prefixes = ( - TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX, - VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX, - AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX, - VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX, + TRANSFORMER_PREFIX, + VAE_DECODER_PREFIX, + VAE_ENCODER_PREFIX, + VAE_STATS_PREFIX, + AUDIO_DECODER_PREFIX, + AUDIO_ENCODER_PREFIX, + AUDIO_STATS_PREFIX, + VOCODER_PREFIX, + TEXT_PROJ_PREFIX, + VIDEO_CONNECTOR_PREFIX, + AUDIO_CONNECTOR_PREFIX, ) - skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)] + skipped = [ + k for k in all_weights if not any(k.startswith(p) for p in known_prefixes) + ] if skipped: print(f" Skipped {len(skipped)} keys:") for k in sorted(skipped)[:20]: diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index 08f0840..6c3fc72 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -14,30 +14,46 @@ import mlx.core as mx import numpy as np from PIL import Image from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn from rich.panel import Panel +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, +) # Rich console for styled output console = Console() +from mlx_video.models.ltx_2.conditioning import ( + VideoConditionByLatentIndex, + apply_conditioning, +) +from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask from mlx_video.models.ltx_2.ltx import LTXModel from mlx_video.models.ltx_2.transformer import Modality - -from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path -from mlx_video.models.ltx_2.video_vae.decoder import VideoDecoder -from mlx_video.models.ltx_2.video_vae import VideoEncoder -from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig from mlx_video.models.ltx_2.upsampler import load_upsampler, upsample_latents -from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex, apply_conditioning -from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask +from mlx_video.models.ltx_2.video_vae import VideoEncoder +from mlx_video.models.ltx_2.video_vae.decoder import VideoDecoder +from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig +from mlx_video.utils import ( + get_model_path, + load_image, + prepare_image_for_encoding, +) class PipelineType(Enum): """Pipeline type selector.""" - DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG - DEV = "dev" # Single-stage, dynamic sigmas, CFG - DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res) + + DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG + DEV = "dev" # Single-stage, dynamic sigmas, CFG + DEV_TWO_STAGE = ( + "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res) + ) DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages @@ -56,7 +72,9 @@ AUDIO_HOP_LENGTH = 160 AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying AUDIO_MEL_BINS = 16 -AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 +AUDIO_LATENTS_PER_SECOND = ( + AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR +) # 25 # Default negative prompt for CFG (dev pipeline) # Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py @@ -157,7 +175,7 @@ def load_and_merge_lora( new_key = key for old, new in _LORA_KEY_REPLACEMENTS: if new_key.endswith(old): - new_key = new_key[:-len(old)] + new + new_key = new_key[: -len(old)] + new else: new_key = new_key.replace(old + ".", new + ".") sanitized_pairs[new_key] = pair @@ -197,7 +215,9 @@ def load_and_merge_lora( delta = (lora_b * strength) @ lora_a base_weight = flat_weights.pop(weight_key) - merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype) + merged_weight = (base_weight.astype(mx.float32) + delta).astype( + base_weight.dtype + ) batch.append((weight_key, merged_weight)) del base_weight merged_count += 1 @@ -259,8 +279,12 @@ def apg_delta( # Optionally clamp guidance norm for stability if norm_threshold > 0: - guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8) - scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm) + guidance_norm = mx.sqrt( + mx.sum(guidance**2, axis=(-1, -2, -3), keepdims=True) + 1e-8 + ) + scale_factor = mx.minimum( + mx.ones_like(guidance_norm), norm_threshold / guidance_norm + ) guidance = guidance * scale_factor # Project guidance onto cond direction @@ -270,7 +294,7 @@ def apg_delta( # Projection coefficient: (guidance · cond) / (cond · cond) dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True) - squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8 + squared_norm = mx.sum(cond_flat**2, axis=1, keepdims=True) + 1e-8 proj_coeff = dot_product / squared_norm # Reshape back and compute parallel/orthogonal components @@ -320,7 +344,7 @@ def ltx2_scheduler( # Apply shift transformation power = 1 - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): sigmas = np.where( sigmas != 0, math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), @@ -371,10 +395,12 @@ def create_position_grid( h_coords = np.arange(0, height, patch_size_h) w_coords = np.arange(0, width, patch_size_w) - t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') + t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij") patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) - patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1) + patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape( + 3, 1, 1, 1 + ) patch_ends = patch_starts + patch_size_delta latent_coords = np.stack([patch_starts, patch_ends], axis=-1) @@ -382,14 +408,14 @@ def create_position_grid( latent_coords = latent_coords.reshape(3, num_patches, 2) latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) - scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1) + scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape( + 1, 3, 1, 1 + ) pixel_coords = (latent_coords * scale_factors).astype(np.float32) if causal_fix: pixel_coords[:, 0, :, :] = np.clip( - pixel_coords[:, 0, :, :] + 1 - temporal_scale, - a_min=0, - a_max=None + pixel_coords[:, 0, :, :] + 1 - temporal_scale, a_min=0, a_max=None ) # Divide temporal coords by fps @@ -413,6 +439,7 @@ def create_audio_position_grid( is_causal: bool = True, ) -> mx.array: """Create temporal position grid for audio RoPE.""" + def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) mel_frame = latent_frame * downsample_factor @@ -443,6 +470,7 @@ def compute_audio_frames(num_video_frames: int, fps: float) -> int: # Distilled Pipeline Denoising (no CFG, fixed sigmas) # ============================================================================= + def denoise_distilled( latents: mx.array, positions: mx.array, @@ -488,7 +516,9 @@ def denoise_distilled( b, c, f, h, w = latents.shape num_tokens = f * h * w # Cast to model dtype for transformer input - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) + latents_flat = mx.transpose( + mx.reshape(latents, (b, c, -1)), (0, 2, 1) + ).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -515,8 +545,16 @@ def denoise_distilled( audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) # A2V: frozen audio uses timesteps=0 (tells model audio is clean) - a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) - a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) + a_ts = ( + mx.zeros((ab, at), dtype=dtype) + if audio_frozen + else mx.full((ab, at), sigma, dtype=dtype) + ) + a_sig = ( + mx.zeros((ab,), dtype=dtype) + if audio_frozen + else mx.full((ab,), sigma, dtype=dtype) + ) audio_modality = Modality( latent=audio_flat, timesteps=a_ts, @@ -527,7 +565,9 @@ def denoise_distilled( sigma=a_sig, ) - velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) + velocity, audio_velocity = transformer( + video=video_modality, audio=audio_modality + ) mx.eval(velocity) if audio_velocity is not None: mx.eval(audio_velocity) @@ -544,10 +584,14 @@ def denoise_distilled( ab, ac, at, af = audio_latents.shape audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) - audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32) + audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype( + mx.float32 + ) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) + denoised = apply_denoise_mask( + denoised, state.clean_latent.astype(mx.float32), state.denoise_mask + ) mx.eval(denoised) if audio_denoised is not None: @@ -558,7 +602,10 @@ def denoise_distilled( sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 if enable_audio and audio_denoised is not None and not audio_frozen: - audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32 + audio_latents = ( + audio_denoised + + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32 + ) else: latents = denoised if enable_audio and audio_denoised is not None and not audio_frozen: @@ -577,6 +624,7 @@ def denoise_distilled( # Dev Pipeline Denoising (with CFG, dynamic sigmas) # ============================================================================= + def denoise_dev( latents: mx.array, positions: mx.array, @@ -647,7 +695,8 @@ def denoise_dev( disable=not verbose, ) as progress: passes = ["CFG"] if use_cfg else [] - if use_stg: passes.append("STG") + if use_stg: + passes.append("STG") label = "+".join(passes) if passes else "uncond" task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps) @@ -658,7 +707,9 @@ def denoise_dev( b, c, f, h, w = latents.shape num_tokens = f * h * w # Cast to model dtype for transformer input - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) + latents_flat = mx.transpose( + mx.reshape(latents, (b, c, -1)), (0, 2, 1) + ).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -689,7 +740,9 @@ def denoise_dev( # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) - x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32) + x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype( + mx.float32 + ) # Start with positive prediction x0_guided_f32 = x0_pos_f32 @@ -709,29 +762,39 @@ def denoise_dev( velocity_neg, _ = transformer(video=video_modality_neg, audio=None) # Convert negative velocity to x0 using per-token timesteps - x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32) + x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype( + mx.float32 + ) # Apply guidance to x0 predictions # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 if use_apg: # APG: decompose into parallel/orthogonal components for stability x0_guided_f32 = x0_pos_f32 + apg_delta( - x0_pos_f32, x0_neg_f32, cfg_scale, - eta=apg_eta, norm_threshold=apg_norm_threshold + x0_pos_f32, + x0_neg_f32, + cfg_scale, + eta=apg_eta, + norm_threshold=apg_norm_threshold, ) else: # Standard CFG - x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) + x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * ( + x0_pos_f32 - x0_neg_f32 + ) # STG pass: skip self-attention at specified blocks if use_stg: velocity_ptb, _ = transformer( - video=video_modality_pos, audio=None, + video=video_modality_pos, + audio=None, stg_video_blocks=stg_blocks, ) mx.eval(velocity_ptb) - x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32) + x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype( + mx.float32 + ) x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32) # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) @@ -743,12 +806,16 @@ def denoise_dev( x0_guided_f32 = x0_guided_f32 * v_factor # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) - denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) + denoised = mx.reshape( + mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w) + ) sigma_f32 = mx.array(sigma, dtype=mx.float32) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) + denoised = apply_denoise_mask( + denoised, state.clean_latent.astype(mx.float32), state.denoise_mask + ) # Euler step in float32 (latents stay in float32) if sigma_next > 0: @@ -853,8 +920,10 @@ def denoise_dev_av( disable=not verbose, ) as progress: passes = ["CFG"] if use_cfg else [] - if use_stg: passes.append("STG") - if use_modality: passes.append("Mod") + if use_stg: + passes.append("STG") + if use_modality: + passes.append("Mod") label = "+".join(passes) if passes else "uncond" task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=num_steps) @@ -865,7 +934,9 @@ def denoise_dev_av( # Flatten video latents (cast to model dtype for transformer input) b, c, f, h, w = video_latents.shape num_video_tokens = f * h * w - video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)).astype(dtype) + video_flat = mx.transpose( + mx.reshape(video_latents, (b, c, -1)), (0, 2, 1) + ).astype(dtype) # Flatten audio latents (cast to model dtype for transformer input) ab, ac, at, af = audio_latents.shape @@ -874,7 +945,9 @@ def denoise_dev_av( # Compute timesteps if video_state is not None: - denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.reshape( + video_state.denoise_mask, (b, 1, f, 1, 1) + ) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat @@ -882,35 +955,67 @@ def denoise_dev_av( video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) # A2V: frozen audio uses timesteps=0 (tells model audio is clean) - audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) + audio_timesteps = ( + mx.zeros((ab, at), dtype=dtype) + if audio_frozen + else mx.full((ab, at), sigma, dtype=dtype) + ) # Positive conditioning pass sigma_array = mx.full((b,), sigma, dtype=dtype) - audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) + audio_sigma_array = ( + mx.zeros((ab,), dtype=dtype) + if audio_frozen + else mx.full((ab,), sigma, dtype=dtype) + ) video_modality_pos = Modality( - latent=video_flat, timesteps=video_timesteps, positions=video_positions, - context=video_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, sigma=sigma_array, + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + sigma=sigma_array, ) audio_modality_pos = Modality( - latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, - context=audio_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + sigma=audio_sigma_array, + ) + video_vel_pos, audio_vel_pos = transformer( + video=video_modality_pos, audio=audio_modality_pos ) - video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) mx.eval(video_vel_pos, audio_vel_pos) # Convert velocity to denoised (x0) using per-token timesteps # This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity # For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant) # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity - video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) - audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af)) - video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) - audio_timesteps_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1) + video_flat_f32 = mx.transpose( + mx.reshape(video_latents, (b, c, -1)), (0, 2, 1) + ) + audio_flat_f32 = mx.reshape( + mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af) + ) + video_timesteps_f32 = mx.expand_dims( + video_timesteps.astype(mx.float32), axis=-1 + ) + audio_timesteps_f32 = mx.expand_dims( + audio_timesteps.astype(mx.float32), axis=-1 + ) - video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) - audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) + video_x0_pos_f32 = ( + video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) + ) + audio_x0_pos_f32 = ( + audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) + ) # Start with positive prediction video_x0_guided_f32 = video_x0_pos_f32 @@ -919,57 +1024,105 @@ def denoise_dev_av( # Pass 2: CFG (negative conditioning) if use_cfg: video_modality_neg = Modality( - latent=video_flat, timesteps=video_timesteps, positions=video_positions, - context=video_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, sigma=sigma_array, + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + sigma=sigma_array, ) audio_modality_neg = Modality( - latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, - context=audio_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + sigma=audio_sigma_array, + ) + video_vel_neg, audio_vel_neg = transformer( + video=video_modality_neg, audio=audio_modality_neg ) - video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) mx.eval(video_vel_neg, audio_vel_neg) - video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) - audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) + video_x0_neg_f32 = ( + video_flat_f32 + - video_timesteps_f32 * video_vel_neg.astype(mx.float32) + ) + audio_x0_neg_f32 = ( + audio_flat_f32 + - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) + ) if use_apg: video_x0_guided_f32 = video_x0_pos_f32 + apg_delta( - video_x0_pos_f32, video_x0_neg_f32, cfg_scale, - eta=apg_eta, norm_threshold=apg_norm_threshold + video_x0_pos_f32, + video_x0_neg_f32, + cfg_scale, + eta=apg_eta, + norm_threshold=apg_norm_threshold, ) else: - video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) - audio_x0_guided_f32 = audio_x0_pos_f32 + (audio_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) + video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * ( + video_x0_pos_f32 - video_x0_neg_f32 + ) + audio_x0_guided_f32 = audio_x0_pos_f32 + (audio_cfg_scale - 1.0) * ( + audio_x0_pos_f32 - audio_x0_neg_f32 + ) # Pass 3: STG (self-attention perturbation at specified blocks) if use_stg: video_vel_ptb, audio_vel_ptb = transformer( - video=video_modality_pos, audio=audio_modality_pos, - stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks, + video=video_modality_pos, + audio=audio_modality_pos, + stg_video_blocks=stg_video_blocks, + stg_audio_blocks=stg_audio_blocks, ) mx.eval(video_vel_ptb, audio_vel_ptb) - video_x0_ptb_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_ptb.astype(mx.float32) - audio_x0_ptb_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_ptb.astype(mx.float32) + video_x0_ptb_f32 = ( + video_flat_f32 + - video_timesteps_f32 * video_vel_ptb.astype(mx.float32) + ) + audio_x0_ptb_f32 = ( + audio_flat_f32 + - audio_timesteps_f32 * audio_vel_ptb.astype(mx.float32) + ) - video_x0_guided_f32 = video_x0_guided_f32 + stg_scale * (video_x0_pos_f32 - video_x0_ptb_f32) - audio_x0_guided_f32 = audio_x0_guided_f32 + stg_scale * (audio_x0_pos_f32 - audio_x0_ptb_f32) + video_x0_guided_f32 = video_x0_guided_f32 + stg_scale * ( + video_x0_pos_f32 - video_x0_ptb_f32 + ) + audio_x0_guided_f32 = audio_x0_guided_f32 + stg_scale * ( + audio_x0_pos_f32 - audio_x0_ptb_f32 + ) # Pass 4: Modality isolation (skip all cross-modal attention) if use_modality: video_vel_iso, audio_vel_iso = transformer( - video=video_modality_pos, audio=audio_modality_pos, + video=video_modality_pos, + audio=audio_modality_pos, skip_cross_modal=True, ) mx.eval(video_vel_iso, audio_vel_iso) - video_x0_iso_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_iso.astype(mx.float32) - audio_x0_iso_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_iso.astype(mx.float32) + video_x0_iso_f32 = ( + video_flat_f32 + - video_timesteps_f32 * video_vel_iso.astype(mx.float32) + ) + audio_x0_iso_f32 = ( + audio_flat_f32 + - audio_timesteps_f32 * audio_vel_iso.astype(mx.float32) + ) - video_x0_guided_f32 = video_x0_guided_f32 + (modality_scale - 1.0) * (video_x0_pos_f32 - video_x0_iso_f32) - audio_x0_guided_f32 = audio_x0_guided_f32 + (modality_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_iso_f32) + video_x0_guided_f32 = video_x0_guided_f32 + (modality_scale - 1.0) * ( + video_x0_pos_f32 - video_x0_iso_f32 + ) + audio_x0_guided_f32 = audio_x0_guided_f32 + (modality_scale - 1.0) * ( + audio_x0_pos_f32 - audio_x0_iso_f32 + ) # Apply CFG rescale (std-ratio rescaling to reduce over-saturation) if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality): @@ -981,7 +1134,9 @@ def denoise_dev_av( audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) - video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) + video_denoised_f32 = mx.reshape( + mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w) + ) audio_denoised_f32 = mx.reshape(audio_x0_guided_f32, (ab, at, ac, af)) audio_denoised_f32 = mx.transpose(audio_denoised_f32, (0, 2, 1, 3)) @@ -992,7 +1147,9 @@ def denoise_dev_av( if video_state is not None: clean_f32 = video_state.clean_latent.astype(mx.float32) mask_f32 = video_state.denoise_mask.astype(mx.float32) - video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * (1.0 - mask_f32) + video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * ( + 1.0 - mask_f32 + ) mx.eval(video_denoised_f32, audio_denoised_f32) @@ -1005,7 +1162,9 @@ def denoise_dev_av( video_latents = video_latents + video_velocity_f32 * dt_f32 if not audio_frozen: - audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32 + audio_velocity_f32 = ( + audio_latents - audio_denoised_f32 + ) / sigma_f32 audio_latents = audio_latents + audio_velocity_f32 * dt_f32 else: video_latents = video_denoised_f32 @@ -1056,7 +1215,11 @@ def denoise_res2s_av( bongmath_max_iter: Max bong iterations per step. """ from mlx_video.models.ltx_2.rope import precompute_freqs_cis - from mlx_video.models.ltx_2.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise + from mlx_video.models.ltx_2.samplers import ( + get_new_noise, + get_res2s_coefficients, + sde_noise_step, + ) if audio_cfg_rescale is None: audio_cfg_rescale = cfg_rescale @@ -1117,7 +1280,9 @@ def denoise_res2s_av( """Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format.""" b, c, f, h, w = v_latents.shape num_video_tokens = f * h * w - video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype) + video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype( + dtype + ) ab, ac, at, af = a_latents.shape audio_flat = mx.transpose(a_latents, (0, 2, 1, 3)) @@ -1131,28 +1296,50 @@ def denoise_res2s_av( video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) - audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) + audio_timesteps = ( + mx.zeros((ab, at), dtype=dtype) + if audio_frozen + else mx.full((ab, at), sigma, dtype=dtype) + ) sigma_array = mx.full((b,), sigma, dtype=dtype) - audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) + audio_sigma_array = ( + mx.zeros((ab,), dtype=dtype) + if audio_frozen + else mx.full((ab,), sigma, dtype=dtype) + ) # Pass 1: Positive conditioning video_modality_pos = Modality( - latent=video_flat, timesteps=video_timesteps, positions=video_positions, - context=video_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, sigma=sigma_array, + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + sigma=sigma_array, ) audio_modality_pos = Modality( - latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, - context=audio_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + sigma=audio_sigma_array, + ) + video_vel_pos, audio_vel_pos = transformer( + video=video_modality_pos, audio=audio_modality_pos ) - video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) mx.eval(video_vel_pos, audio_vel_pos) # Convert velocity to x0 video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)) - audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af)) + audio_flat_f32 = mx.reshape( + mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af) + ) video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1) @@ -1165,51 +1352,90 @@ def denoise_res2s_av( # Pass 2: CFG if use_cfg: video_modality_neg = Modality( - latent=video_flat, timesteps=video_timesteps, positions=video_positions, - context=video_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, sigma=sigma_array, + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + sigma=sigma_array, ) audio_modality_neg = Modality( - latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, - context=audio_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + sigma=audio_sigma_array, + ) + video_vel_neg, audio_vel_neg = transformer( + video=video_modality_neg, audio=audio_modality_neg ) - video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) mx.eval(video_vel_neg, audio_vel_neg) - video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32) - audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32) + video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype( + mx.float32 + ) + audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype( + mx.float32 + ) - video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg) - audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg) + video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * ( + video_x0_pos - video_x0_neg + ) + audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * ( + audio_x0_pos - audio_x0_neg + ) # Pass 3: STG if use_stg: video_vel_ptb, audio_vel_ptb = transformer( - video=video_modality_pos, audio=audio_modality_pos, - stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks, + video=video_modality_pos, + audio=audio_modality_pos, + stg_video_blocks=stg_video_blocks, + stg_audio_blocks=stg_audio_blocks, ) mx.eval(video_vel_ptb, audio_vel_ptb) - video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32) - audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32) + video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype( + mx.float32 + ) + audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype( + mx.float32 + ) - video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb) - audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb) + video_x0_guided = video_x0_guided + stg_scale * ( + video_x0_pos - video_x0_ptb + ) + audio_x0_guided = audio_x0_guided + stg_scale * ( + audio_x0_pos - audio_x0_ptb + ) # Pass 4: Modality isolation if use_modality: video_vel_iso, audio_vel_iso = transformer( - video=video_modality_pos, audio=audio_modality_pos, + video=video_modality_pos, + audio=audio_modality_pos, skip_cross_modal=True, ) mx.eval(video_vel_iso, audio_vel_iso) - video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32) - audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32) + video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype( + mx.float32 + ) + audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype( + mx.float32 + ) - video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso) - audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso) + video_x0_guided = video_x0_guided + (modality_scale - 1.0) * ( + video_x0_pos - video_x0_iso + ) + audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * ( + audio_x0_pos - audio_x0_iso + ) # Rescale (separate factors for video and audio) if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality): @@ -1222,7 +1448,9 @@ def denoise_res2s_av( audio_x0_guided = audio_x0_guided * a_factor # Reshape to spatial - video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w)) + video_denoised = mx.reshape( + mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w) + ) audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af)) audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3)) @@ -1246,11 +1474,16 @@ def denoise_res2s_av( disable=not verbose, ) as progress: passes = ["res2s"] - if use_cfg: passes.append("CFG") - if use_stg: passes.append("STG") - if use_modality: passes.append("Mod") + if use_cfg: + passes.append("CFG") + if use_stg: + passes.append("STG") + if use_modality: + passes.append("Mod") label = "+".join(passes) - task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps) + task = progress.add_task( + f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps + ) for step_idx in range(n_full_steps): sigma = sigmas_list[step_idx] @@ -1289,10 +1522,14 @@ def denoise_res2s_av( substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3) substep_noise_v = get_new_noise(video_latents.shape, key1) - x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v) + x_mid_video = sde_noise_step( + x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v + ) if not audio_frozen: substep_noise_a = get_new_noise(audio_latents.shape, key2) - x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a) + x_mid_audio = sde_noise_step( + x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a + ) mx.eval(x_mid_video, x_mid_audio) # ============================================================ @@ -1314,7 +1551,9 @@ def denoise_res2s_av( # Stage 2: Evaluate denoiser at midpoint sigma # ============================================================ denoised_video_2, denoised_audio_2 = _eval_guided_denoise( - x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma + x_mid_video.astype(mx.float32), + x_mid_audio.astype(mx.float32), + sub_sigma, ) # ============================================================ @@ -1326,14 +1565,20 @@ def denoise_res2s_av( # SDE noise injection at step level step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3) step_noise_v = get_new_noise(video_latents.shape, key1) - x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v) + x_next_video = sde_noise_step( + x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v + ) video_latents = x_next_video.astype(mx.float32) if not audio_frozen: eps_2_audio = denoised_audio_2 - x_anchor_audio - x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) + x_next_audio = x_anchor_audio + h * ( + b1 * eps_1_audio + b2 * eps_2_audio + ) step_noise_a = get_new_noise(audio_latents.shape, key2) - x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a) + x_next_audio = sde_noise_step( + x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a + ) audio_latents = x_next_audio.astype(mx.float32) mx.eval(video_latents, audio_latents) @@ -1356,6 +1601,7 @@ def denoise_res2s_av( # Audio Loading and Processing # ============================================================================= + def load_audio_decoder(model_path: Path, pipeline: PipelineType): """Load audio VAE decoder.""" from mlx_video.models.ltx_2.audio_vae import AudioDecoder @@ -1385,7 +1631,7 @@ def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RA audio = np.clip(audio, -1.0, 1.0) audio_int16 = (audio * 32767).astype(np.int16) - with wave.open(str(path), 'wb') as wf: + with wave.open(str(path), "wb") as wf: wf.setnchannels(2 if audio_int16.ndim == 2 else 1) wf.setsampwidth(2) wf.setframerate(sample_rate) @@ -1397,13 +1643,18 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): import subprocess cmd = [ - "ffmpeg", "-y", - "-i", str(video_path), - "-i", str(audio_path), - "-c:v", "copy", - "-c:a", "aac", + "ffmpeg", + "-y", + "-i", + str(video_path), + "-i", + str(audio_path), + "-c:v", + "copy", + "-c:a", + "aac", "-shortest", - str(output_path) + str(output_path), ] try: @@ -1421,6 +1672,7 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): # Unified Generate Function # ============================================================================= + def generate_video( model_repo: str, text_encoder_repo: str, @@ -1504,20 +1756,28 @@ def generate_video( start_time = time.time() # Validate dimensions - is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ) + is_two_stage = pipeline in ( + PipelineType.DISTILLED, + PipelineType.DEV_TWO_STAGE, + PipelineType.DEV_TWO_STAGE_HQ, + ) divisor = 64 if is_two_stage else 32 assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" if num_frames % 8 != 1: adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 - console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using: {adjusted_num_frames}[/]") + console.print( + f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using: {adjusted_num_frames}[/]" + ) num_frames = adjusted_num_frames is_i2v = image is not None is_a2v = audio_file is not None if is_a2v and audio: - raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.") + raise ValueError( + "Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one." + ) # A2V implicitly enables audio path through the transformer if is_a2v: audio = True @@ -1538,25 +1798,37 @@ def generate_video( console.print(Panel(header, expand=False)) console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") - if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): + if pipeline in ( + PipelineType.DEV, + PipelineType.DEV_TWO_STAGE, + PipelineType.DEV_TWO_STAGE_HQ, + ): audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else "" stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else "" mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else "" - console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}{audio_cfg_info}, Rescale: {cfg_rescale}{stg_info}{mod_info}[/]") + console.print( + f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}{audio_cfg_info}, Rescale: {cfg_rescale}{stg_info}{mod_info}[/]" + ) if is_i2v: - console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") + console.print( + f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]" + ) # Always compute audio frames - PyTorch distilled pipeline unconditionally # generates audio alongside video (model was trained with joint audio-video). # The --audio flag only controls whether audio is decoded and saved to output. audio_frames = compute_audio_frames(num_frames, fps) if audio: - console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]") + console.print( + f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]" + ) # Get model path model_path = get_model_path(model_repo) - text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) + text_encoder_path = ( + model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) + ) # Resolve spatial upscaler path for two-stage pipelines upscaler_path = None @@ -1564,7 +1836,11 @@ def generate_video( if is_two_stage: if spatial_upscaler is not None: # User-specified upscaler file - upscaler_path = model_path / spatial_upscaler if not Path(spatial_upscaler).is_absolute() else Path(spatial_upscaler) + upscaler_path = ( + model_path / spatial_upscaler + if not Path(spatial_upscaler).is_absolute() + else Path(spatial_upscaler) + ) if not upscaler_path.exists(): # Try as a filename within model_path upscaler_path = model_path / spatial_upscaler @@ -1575,7 +1851,9 @@ def generate_video( upscaler_scale = 2.0 else: # Auto-detect: prefer x2 upscaler - upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + upscaler_files = sorted( + model_path.glob("*spatial-upscaler-x2*.safetensors") + ) if upscaler_files: upscaler_path = upscaler_files[0] upscaler_scale = 2.0 @@ -1595,6 +1873,7 @@ def generate_video( # Read transformer config to detect model version import json + transformer_config_path = model_path / "transformer" / "config.json" has_prompt_adaln = False if transformer_config_path.exists(): @@ -1604,6 +1883,7 @@ def generate_video( # Load text encoder with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"): from mlx_video.models.ltx_2.text_encoder import LTX2TextEncoder + text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) mx.eval(text_encoder.parameters()) @@ -1612,23 +1892,46 @@ def generate_video( # Optionally enhance the prompt if enhance_prompt: console.print("[bold magenta]✨ Enhancing prompt[/]") - prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) - console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") + prompt = text_encoder.enhance_t2v( + prompt, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, + verbose=verbose, + ) + console.print( + f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]" + ) # Encode prompts - always get audio embeddings since the model was trained # with joint audio-video processing (PyTorch unconditionally generates audio) - if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): + if pipeline in ( + PipelineType.DEV, + PipelineType.DEV_TWO_STAGE, + PipelineType.DEV_TWO_STAGE_HQ, + ): # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG - video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) - video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) + video_embeddings_pos, audio_embeddings_pos = text_encoder( + prompt, return_audio_embeddings=True + ) + video_embeddings_neg, audio_embeddings_neg = text_encoder( + negative_prompt, return_audio_embeddings=True + ) model_dtype = video_embeddings_pos.dtype - mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) + mx.eval( + video_embeddings_pos, + video_embeddings_neg, + audio_embeddings_pos, + audio_embeddings_neg, + ) # For dev-two-stage, stage 2 uses single positive embedding (no CFG) if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): text_embeddings = video_embeddings_pos else: # Distilled pipeline - single embedding - text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) + text_embeddings, audio_embeddings = text_encoder( + prompt, return_audio_embeddings=True + ) mx.eval(text_embeddings, audio_embeddings) model_dtype = text_embeddings.dtype @@ -1638,7 +1941,9 @@ def generate_video( # Load transformer transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..." with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): - transformer = LTXModel.from_pretrained(model_path=model_path / "transformer", strict=True) + transformer = LTXModel.from_pretrained( + model_path=model_path / "transformer", strict=True + ) console.print("[green]✓[/] Transformer loaded") @@ -1649,7 +1954,9 @@ def generate_video( stg_blocks = [28] else: stg_blocks = [29] - console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]") + console.print( + f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]" + ) # ========================================================================== # A2V: Encode input audio to frozen latents @@ -1658,11 +1965,17 @@ def generate_video( a2v_waveform = None a2v_sr = None if is_a2v: - from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel - from mlx_video.models.ltx_2.utils import convert_audio_encoder from mlx_video.models.ltx_2.audio_vae import AudioEncoder + from mlx_video.models.ltx_2.audio_vae.audio_processor import ( + ensure_stereo, + load_audio, + waveform_to_mel, + ) + from mlx_video.models.ltx_2.utils import convert_audio_encoder - with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"): + with console.status( + "[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots" + ): video_duration = num_frames / fps # Load audio @@ -1677,10 +1990,18 @@ def generate_video( a2v_sr = sr # Compute mel-spectrogram - mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64) + mel = waveform_to_mel( + waveform, + sample_rate=sr, + n_fft=1024, + hop_length=AUDIO_HOP_LENGTH, + n_mels=64, + ) # Convert audio encoder weights if needed, then load - encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2") + encoder_dir = convert_audio_encoder( + model_path, source_repo="Lightricks/LTX-2" + ) audio_encoder = AudioEncoder.from_pretrained(encoder_dir) mx.eval(audio_encoder.parameters()) @@ -1698,14 +2019,19 @@ def generate_video( a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :] elif t_encoded < audio_frames: pad_size = audio_frames - t_encoded - padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype) + padding = mx.zeros( + (1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), + dtype=model_dtype, + ) a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2) mx.eval(a2v_audio_latents) del audio_encoder mx.clear_cache() - console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})") + console.print( + f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})" + ) # ========================================================================== # Pipeline-specific generation logic @@ -1720,18 +2046,30 @@ def generate_video( stage1_image_latent = None stage2_image_latent = None if is_i2v: - with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + with console.status( + "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + ): + vae_encoder = VideoEncoder.from_pretrained( + model_path / "vae" / "encoder" + ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype) + input_image = load_image( + image, height=s1_h, width=s1_w, dtype=model_dtype + ) + stage1_image_tensor = prepare_image_for_encoding( + input_image, s1_h, s1_w, dtype=model_dtype + ) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype) + input_image = load_image( + image, height=s2_h, width=s2_w, dtype=model_dtype + ) + stage2_image_tensor = prepare_image_for_encoding( + input_image, s2_h, s2_w, dtype=model_dtype + ) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) @@ -1740,7 +2078,9 @@ def generate_video( console.print("[green]✓[/] VAE encoder loaded and image encoded") # Stage 1 - console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {stage1_w*32}x{stage1_h*32} (8 steps)") + console.print( + f"\n[bold yellow]⚡ Stage 1:[/] Generating at {stage1_w*32}x{stage1_h*32} (8 steps)" + ) mx.random.seed(seed) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) @@ -1748,7 +2088,13 @@ def generate_video( # Init audio latents/positions: use encoded A2V latents or random audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) + audio_latents = ( + a2v_audio_latents + if is_a2v + else mx.random.normal( + (1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS) + ).astype(model_dtype) + ) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning @@ -1760,40 +2106,63 @@ def generate_video( clean_latent=mx.zeros(latent_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage1_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state1 = apply_conditioning(state1, [conditioning]) noise = mx.random.normal(latent_shape, dtype=model_dtype) noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) scaled_mask = state1.denoise_mask * noise_scale state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state1.clean_latent, denoise_mask=state1.denoise_mask, ) latents = state1.latent mx.eval(latents) else: - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) + latents = mx.random.normal( + (1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype + ) mx.eval(latents) latents, audio_latents = denoise_distilled( - latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, - verbose=verbose, state=state1, - audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + latents, + positions, + text_embeddings, + transformer, + STAGE_1_SIGMAS, + verbose=verbose, + state=state1, + audio_latents=audio_latents, + audio_positions=audio_positions, + audio_embeddings=audio_embeddings, audio_frozen=is_a2v, ) # Upsample latents - with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"): + with console.status( + f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots" + ): if upscaler_path is None or not upscaler_path.exists(): raise FileNotFoundError(f"No spatial upscaler found in {model_path}") upsampler, upscaler_scale = load_upsampler(str(upscaler_path)) mx.eval(upsampler.parameters()) - vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + vae_decoder = VideoDecoder.from_pretrained( + str(model_path / "vae" / "decoder") + ) - latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + latents = upsample_latents( + latents, + upsampler, + vae_decoder.per_channel_statistics.mean, + vae_decoder.per_channel_statistics.std, + ) mx.eval(latents) del upsampler @@ -1801,7 +2170,9 @@ def generate_video( console.print("[green]✓[/] Latents upsampled") # Stage 2 - console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {stage2_w*32}x{stage2_h*32} (3 steps)") + console.print( + f"\n[bold yellow]⚡ Stage 2:[/] Refining at {stage2_w*32}x{stage2_h*32} (3 steps)" + ) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -1812,14 +2183,19 @@ def generate_video( clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage2_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state2 = apply_conditioning(state2, [conditioning]) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = state2.denoise_mask * noise_scale state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state2.clean_latent, denoise_mask=state2.denoise_mask, ) @@ -1836,14 +2212,22 @@ def generate_video( if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + audio_latents = audio_noise * audio_noise_scale + audio_latents * ( + mx.array(1.0, dtype=model_dtype) - audio_noise_scale + ) mx.eval(audio_latents) # Joint video + audio refinement (no CFG, positive embeddings only) latents, audio_latents = denoise_distilled( - latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, - verbose=verbose, state=state2, - audio_latents=audio_latents, audio_positions=audio_positions, + latents, + positions, + text_embeddings, + transformer, + STAGE_2_SIGMAS, + verbose=verbose, + state=state2, + audio_latents=audio_latents, + audio_positions=audio_positions, audio_embeddings=audio_embeddings, audio_frozen=is_a2v, ) @@ -1856,11 +2240,19 @@ def generate_video( # Load VAE encoder for I2V image_latent = None if is_i2v: - with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + with console.status( + "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + ): + vae_encoder = VideoEncoder.from_pretrained( + model_path / "vae" / "encoder" + ) - input_image = load_image(image, height=height, width=width, dtype=model_dtype) - image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + input_image = load_image( + image, height=height, width=width, dtype=model_dtype + ) + image_tensor = prepare_image_for_encoding( + input_image, height, width, dtype=model_dtype + ) image_latent = vae_encoder(image_tensor) mx.eval(image_latent) @@ -1871,9 +2263,13 @@ def generate_video( # Generate sigma schedule with token-count-dependent shifting sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) - console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") + console.print( + f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]" + ) - console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})") + console.print( + f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})" + ) mx.random.seed(seed) video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) @@ -1881,7 +2277,14 @@ def generate_video( # Always init audio latents/positions - PyTorch unconditionally generates audio audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = ( + a2v_audio_latents + if is_a2v + else mx.random.normal( + (1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), + dtype=model_dtype, + ) + ) mx.eval(audio_positions, audio_latents) # Initialize latents with optional I2V conditioning @@ -1893,14 +2296,17 @@ def generate_video( clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=image_latent, frame_idx=image_frame_idx, strength=image_strength + ) video_state = apply_conditioning(video_state, [conditioning]) noise = mx.random.normal(video_latent_shape, dtype=model_dtype) noise_scale = sigmas[0] scaled_mask = video_state.denoise_mask * noise_scale video_state = LatentState( - latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=video_state.clean_latent, denoise_mask=video_state.denoise_mask, ) @@ -1912,16 +2318,28 @@ def generate_video( # Always use A/V denoising - PyTorch always processes audio+video jointly latents, audio_latents = denoise_dev_av( - latents, audio_latents, - video_positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, + latents, + audio_latents, + video_positions, + audio_positions, + video_embeddings_pos, + video_embeddings_neg, + audio_embeddings_pos, + audio_embeddings_neg, + transformer, + sigmas, + cfg_scale=cfg_scale, audio_cfg_scale=audio_cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_video_blocks=stg_blocks, - stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + cfg_rescale=cfg_rescale, + verbose=verbose, + video_state=video_state, + use_apg=use_apg, + apg_eta=apg_eta, + apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, + stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, + modality_scale=modality_scale, audio_frozen=is_a2v, ) @@ -1940,18 +2358,30 @@ def generate_video( stage1_image_latent = None stage2_image_latent = None if is_i2v: - with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + with console.status( + "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + ): + vae_encoder = VideoEncoder.from_pretrained( + model_path / "vae" / "encoder" + ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype) + input_image = load_image( + image, height=s1_h, width=s1_w, dtype=model_dtype + ) + stage1_image_tensor = prepare_image_for_encoding( + input_image, s1_h, s1_w, dtype=model_dtype + ) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype) + input_image = load_image( + image, height=s2_h, width=s2_w, dtype=model_dtype + ) + stage2_image_tensor = prepare_image_for_encoding( + input_image, s2_h, s2_w, dtype=model_dtype + ) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) @@ -1962,9 +2392,13 @@ def generate_video( # Stage 1: Dev denoising at reduced resolution with CFG sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) - console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") + console.print( + f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]" + ) - console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {stage1_w*32}x{stage1_h*32} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})") + console.print( + f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {stage1_w*32}x{stage1_h*32} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})" + ) mx.random.seed(seed) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) @@ -1972,7 +2406,14 @@ def generate_video( # Always init audio latents/positions - PyTorch unconditionally generates audio audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = ( + a2v_audio_latents + if is_a2v + else mx.random.normal( + (1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), + dtype=model_dtype, + ) + ) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning for stage 1 @@ -1984,14 +2425,19 @@ def generate_video( clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage1_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state1 = apply_conditioning(state1, [conditioning]) noise = mx.random.normal(stage1_shape, dtype=model_dtype) noise_scale = sigmas[0] scaled_mask = state1.denoise_mask * noise_scale state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state1.clean_latent, denoise_mask=state1.denoise_mask, ) @@ -2003,31 +2449,52 @@ def generate_video( # Stage 1: Always use joint AV denoising (matches PyTorch) latents, audio_latents = denoise_dev_av( - latents, audio_latents, - positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, + latents, + audio_latents, + positions, + audio_positions, + video_embeddings_pos, + video_embeddings_neg, + audio_embeddings_pos, + audio_embeddings_neg, + transformer, + sigmas, + cfg_scale=cfg_scale, audio_cfg_scale=audio_cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_video_blocks=stg_blocks, - stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + cfg_rescale=cfg_rescale, + verbose=verbose, + video_state=state1, + use_apg=use_apg, + apg_eta=apg_eta, + apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, + stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, + modality_scale=modality_scale, audio_frozen=is_a2v, ) mx.eval(audio_latents) # Upsample latents - with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"): + with console.status( + f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots" + ): if upscaler_path is None or not upscaler_path.exists(): raise FileNotFoundError(f"No spatial upscaler found in {model_path}") upsampler, upscaler_scale = load_upsampler(str(upscaler_path)) mx.eval(upsampler.parameters()) - vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + vae_decoder = VideoDecoder.from_pretrained( + str(model_path / "vae" / "decoder") + ) - latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + latents = upsample_latents( + latents, + upsampler, + vae_decoder.per_channel_statistics.mean, + vae_decoder.per_channel_statistics.std, + ) mx.eval(latents) del upsampler @@ -2042,16 +2509,22 @@ def generate_video( lora_path = str(lora_files[0]) console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]") else: - console.print("[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]") + console.print( + "[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]" + ) if lora_path is not None: - with console.status("[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots"): + with console.status( + "[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots" + ): load_and_merge_lora(transformer, lora_path, strength=lora_strength) # Stage 2: Distilled refinement at full resolution (no CFG) # Matches PyTorch: re-noise audio at sigma=0.909375, then jointly refine # both video and audio through the distilled schedule using the LoRA-merged model. - console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)") + console.print( + f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)" + ) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -2062,14 +2535,19 @@ def generate_video( clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage2_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state2 = apply_conditioning(state2, [conditioning]) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = state2.denoise_mask * noise_scale state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state2.clean_latent, denoise_mask=state2.denoise_mask, ) @@ -2086,14 +2564,22 @@ def generate_video( if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + audio_latents = audio_noise * audio_noise_scale + audio_latents * ( + mx.array(1.0, dtype=model_dtype) - audio_noise_scale + ) mx.eval(audio_latents) # Joint video + audio refinement (no CFG, positive embeddings only) latents, audio_latents = denoise_distilled( - latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, - verbose=verbose, state=state2, - audio_latents=audio_latents, audio_positions=audio_positions, + latents, + positions, + text_embeddings, + transformer, + STAGE_2_SIGMAS, + verbose=verbose, + state=state2, + audio_latents=audio_latents, + audio_positions=audio_positions, audio_embeddings=audio_embeddings_pos, audio_frozen=is_a2v, ) @@ -2107,28 +2593,50 @@ def generate_video( # ====================================================================== # HQ defaults: STG disabled, lower rescale, fewer steps (PyTorch LTX_2_3_HQ_PARAMS) - hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25 - hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5 - hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45 - hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15 - hq_stg_scale = stg_scale if stg_scale != 1.0 else 0.0 # Override default 1.0 → 0.0 + hq_lora_strength_s1 = ( + lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25 + ) + hq_lora_strength_s2 = ( + lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5 + ) + hq_cfg_rescale = ( + cfg_rescale if cfg_rescale != 0.7 else 0.45 + ) # Override default 0.7 → 0.45 + hq_steps = ( + num_inference_steps if num_inference_steps != 30 else 15 + ) # Override default 30 → 15 + hq_stg_scale = ( + stg_scale if stg_scale != 1.0 else 0.0 + ) # Override default 1.0 → 0.0 # Load VAE encoder for I2V stage1_image_latent = None stage2_image_latent = None if is_i2v: - with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + with console.status( + "[blue]Loading VAE encoder and encoding image...[/]", spinner="dots" + ): + vae_encoder = VideoEncoder.from_pretrained( + model_path / "vae" / "encoder" + ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype) + input_image = load_image( + image, height=s1_h, width=s1_w, dtype=model_dtype + ) + stage1_image_tensor = prepare_image_for_encoding( + input_image, s1_h, s1_w, dtype=model_dtype + ) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype) + input_image = load_image( + image, height=s2_h, width=s2_w, dtype=model_dtype + ) + stage2_image_tensor = prepare_image_for_encoding( + input_image, s2_h, s2_w, dtype=model_dtype + ) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) @@ -2143,27 +2651,45 @@ def generate_video( lora_path = str(lora_files[0]) console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]") else: - console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]") + console.print( + "[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]" + ) if lora_path is not None: - with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"): - load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1) + with console.status( + f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", + spinner="dots", + ): + load_and_merge_lora( + transformer, lora_path, strength=hq_lora_strength_s1 + ) # Stage 1: res_2s denoising at reduced resolution with CFG # HQ passes actual token count to scheduler (unlike regular dev-two-stage) num_tokens = latent_frames * stage1_h * stage1_w sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens) mx.eval(sigmas) - console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]") + console.print( + f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]" + ) - console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {stage1_w*32}x{stage1_h*32} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})") + console.print( + f"\n[bold yellow]Stage 1:[/] res_2s at {stage1_w*32}x{stage1_h*32} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})" + ) mx.random.seed(seed) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = ( + a2v_audio_latents + if is_a2v + else mx.random.normal( + (1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), + dtype=model_dtype, + ) + ) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning for stage 1 @@ -2175,14 +2701,19 @@ def generate_video( clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage1_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state1 = apply_conditioning(state1, [conditioning]) noise = mx.random.normal(stage1_shape, dtype=model_dtype) noise_scale = sigmas[0] scaled_mask = state1.denoise_mask * noise_scale state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state1.clean_latent, denoise_mask=state1.denoise_mask, ) @@ -2194,16 +2725,26 @@ def generate_video( # Stage 1: res_2s with CFG (STG disabled for HQ by default) latents, audio_latents = denoise_res2s_av( - latents, audio_latents, - positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, + latents, + audio_latents, + positions, + audio_positions, + video_embeddings_pos, + video_embeddings_neg, + audio_embeddings_pos, + audio_embeddings_neg, + transformer, + sigmas, + cfg_scale=cfg_scale, audio_cfg_scale=audio_cfg_scale, - cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0, - verbose=verbose, video_state=state1, - stg_scale=hq_stg_scale, stg_video_blocks=stg_blocks, - stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + cfg_rescale=hq_cfg_rescale, + audio_cfg_rescale=1.0, + verbose=verbose, + video_state=state1, + stg_scale=hq_stg_scale, + stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, + modality_scale=modality_scale, noise_seed=seed, audio_frozen=is_a2v, ) @@ -2211,15 +2752,24 @@ def generate_video( mx.eval(audio_latents) # Upsample latents - with console.status(f"[magenta]Upsampling latents {upscaler_scale}x...[/]", spinner="dots"): + with console.status( + f"[magenta]Upsampling latents {upscaler_scale}x...[/]", spinner="dots" + ): if upscaler_path is None or not upscaler_path.exists(): raise FileNotFoundError(f"No spatial upscaler found in {model_path}") upsampler, upscaler_scale = load_upsampler(str(upscaler_path)) mx.eval(upsampler.parameters()) - vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + vae_decoder = VideoDecoder.from_pretrained( + str(model_path / "vae" / "decoder") + ) - latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + latents = upsample_latents( + latents, + upsampler, + vae_decoder.per_channel_statistics.mean, + vae_decoder.per_channel_statistics.std, + ) mx.eval(latents) del upsampler @@ -2230,11 +2780,18 @@ def generate_video( if lora_path is not None: additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1 if additional_strength > 0: - with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"): - load_and_merge_lora(transformer, lora_path, strength=additional_strength) + with console.status( + f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", + spinner="dots", + ): + load_and_merge_lora( + transformer, lora_path, strength=additional_strength + ) # Stage 2: res_2s refinement at full resolution (no CFG) - console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {stage2_w*32}x{stage2_h*32} (3 steps, no CFG)") + console.print( + f"\n[bold yellow]Stage 2:[/] res_2s refining at {stage2_w*32}x{stage2_h*32} (3 steps, no CFG)" + ) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -2245,14 +2802,19 @@ def generate_video( clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + conditioning = VideoConditionByLatentIndex( + latent=stage2_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) state2 = apply_conditioning(state2, [conditioning]) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = state2.denoise_mask * noise_scale state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + latent=noise * scaled_mask + + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state2.clean_latent, denoise_mask=state2.denoise_mask, ) @@ -2269,19 +2831,29 @@ def generate_video( if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + audio_latents = audio_noise * audio_noise_scale + audio_latents * ( + mx.array(1.0, dtype=model_dtype) - audio_noise_scale + ) mx.eval(audio_latents) # Stage 2: res_2s with no CFG (positive embeddings only) stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32) latents, audio_latents = denoise_res2s_av( - latents, audio_latents, - positions, audio_positions, - video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2) - audio_embeddings_pos, audio_embeddings_pos, - transformer, stage2_sigmas, cfg_scale=1.0, # no CFG + latents, + audio_latents, + positions, + audio_positions, + video_embeddings_pos, + video_embeddings_pos, # both pos (no neg for stage 2) + audio_embeddings_pos, + audio_embeddings_pos, + transformer, + stage2_sigmas, + cfg_scale=1.0, # no CFG audio_cfg_scale=1.0, - cfg_rescale=0.0, verbose=verbose, video_state=state2, + cfg_rescale=0.0, + verbose=verbose, + video_state=state2, noise_seed=seed + 1, audio_frozen=is_a2v, ) @@ -2323,7 +2895,8 @@ def generate_video( if stream and tiling_config is not None: import cv2 - fourcc = cv2.VideoWriter_fourcc(*'avc1') + + fourcc = cv2.VideoWriter_fourcc(*"avc1") video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) stream_progress = Progress( SpinnerColumn(), @@ -2333,7 +2906,9 @@ def generate_video( console=console, ) stream_progress.start() - stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames) + stream_task = stream_progress.add_task( + "[cyan]Streaming frames[/]", total=num_frames + ) def on_frames_ready(frames: mx.array, _start_idx: int): frames = mx.squeeze(frames, axis=0) @@ -2345,14 +2920,31 @@ def generate_video( for frame in frames_np: video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) stream_progress.advance(stream_task) + else: on_frames_ready = None if tiling_config is not None: - spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" - temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" - console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]") - video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready) + spatial_info = ( + f"{tiling_config.spatial_config.tile_size_in_pixels}px" + if tiling_config.spatial_config + else "none" + ) + temporal_info = ( + f"{tiling_config.temporal_config.tile_size_in_frames}f" + if tiling_config.temporal_config + else "none" + ) + console.print( + f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]" + ) + video = vae_decoder.decode_tiled( + latents, + tiling_config=tiling_config, + tiling_mode=tiling, + debug=verbose, + on_frames_ready=on_frames_ready, + ) else: console.print("[dim] Tiling: disabled[/]") video = vae_decoder(latents) @@ -2378,15 +2970,16 @@ def generate_video( video_np = np.array(video) if audio: - temp_video_path = output_path.with_suffix('.temp.mp4') + temp_video_path = output_path.with_suffix(".temp.mp4") save_path = temp_video_path else: save_path = output_path try: import cv2 + h, w = video_np.shape[1], video_np.shape[2] - fourcc = cv2.VideoWriter_fourcc(*'avc1') + fourcc = cv2.VideoWriter_fourcc(*"avc1") out = cv2.VideoWriter(str(save_path), fourcc, fps, (w, h)) for frame in video_np: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) @@ -2415,7 +3008,9 @@ def generate_video( mel_spectrogram = audio_decoder(audio_latents) mx.eval(mel_spectrogram) - console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]") + console.print( + f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]" + ) audio_waveform = vocoder(mel_spectrogram) mx.eval(audio_waveform) @@ -2425,18 +3020,24 @@ def generate_video( audio_np = audio_np[0] # Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE) - vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE) + vocoder_sample_rate = getattr( + vocoder, "output_sampling_rate", AUDIO_SAMPLE_RATE + ) del audio_decoder, vocoder mx.clear_cache() console.print("[green]✓[/] Audio decoded") - audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') + audio_path = ( + Path(output_audio_path) + if output_audio_path + else output_path.with_suffix(".wav") + ) save_audio(audio_np, audio_path, vocoder_sample_rate) console.print(f"[green]✅ Saved audio to[/] {audio_path}") with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"): - temp_video_path = output_path.with_suffix('.temp.mp4') + temp_video_path = output_path.with_suffix(".temp.mp4") success = mux_video_audio(temp_video_path, audio_path, output_path) if success: console.print(f"[green]✅ Saved video with audio to[/] {output_path}") @@ -2458,11 +3059,13 @@ def generate_video( elapsed = time.time() - start_time minutes, seconds = divmod(elapsed, 60) time_str = f"{int(minutes)}m {seconds:.1f}s" if minutes >= 1 else f"{seconds:.1f}s" - console.print(Panel( - f"[bold green]🎉 Done![/] Generated in {time_str} ({elapsed/num_frames:.2f}s/frame)\n" - f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB", - expand=False - )) + console.print( + Panel( + f"[bold green]🎉 Done![/] Generated in {time_str} ({elapsed/num_frames:.2f}s/frame)\n" + f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB", + expand=False, + ) + ) if audio: return video_np, audio_np @@ -2493,55 +3096,216 @@ Examples: # With Audio (works with both pipelines) python -m mlx_video.generate --prompt "Ocean waves crashing" --audio python -m mlx_video.generate --prompt "A jazz band playing" --audio --pipeline dev - """ + """, ) - parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") - parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"], - help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)") - parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, - help="Negative prompt for CFG (dev pipeline only)") - parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") - parser.add_argument("--width", "-W", type=int, default=512, help="Output video width") - parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") - parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)") - parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale for video (dev pipeline only, default 3.0)") - parser.add_argument("--audio-cfg-scale", type=float, default=7.0, help="CFG guidance scale for audio (default 7.0, PyTorch default)") - parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)") + parser.add_argument( + "--prompt", + "-p", + type=str, + required=True, + help="Text description of the video to generate", + ) + parser.add_argument( + "--pipeline", + type=str, + default="distilled", + choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"], + help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=DEFAULT_NEGATIVE_PROMPT, + help="Negative prompt for CFG (dev pipeline only)", + ) + parser.add_argument( + "--height", "-H", type=int, default=512, help="Output video height" + ) + parser.add_argument( + "--width", "-W", type=int, default=512, help="Output video width" + ) + parser.add_argument( + "--num-frames", "-n", type=int, default=33, help="Number of frames" + ) + parser.add_argument( + "--steps", + type=int, + default=30, + help="Number of inference steps (dev pipeline only, default 30)", + ) + parser.add_argument( + "--cfg-scale", + type=float, + default=3.0, + help="CFG guidance scale for video (dev pipeline only, default 3.0)", + ) + parser.add_argument( + "--audio-cfg-scale", + type=float, + default=7.0, + help="CFG guidance scale for audio (default 7.0, PyTorch default)", + ) + parser.add_argument( + "--cfg-rescale", + type=float, + default=0.7, + help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)", + ) parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") parser.add_argument("--fps", type=int, default=24, help="Frames per second") - parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path") - parser.add_argument("--save-frames", action="store_true", help="Save individual frames as images") - parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2", help="Model repository") - parser.add_argument("--text-encoder-repo", type=str, default=None, help="Text encoder repository") + parser.add_argument( + "--output-path", "-o", type=str, default="output.mp4", help="Output video path" + ) + parser.add_argument( + "--save-frames", action="store_true", help="Save individual frames as images" + ) + parser.add_argument( + "--model-repo", type=str, default="Lightricks/LTX-2", help="Model repository" + ) + parser.add_argument( + "--text-encoder-repo", type=str, default=None, help="Text encoder repository" + ) parser.add_argument("--verbose", action="store_true", help="Verbose output") - parser.add_argument("--enhance-prompt", action="store_true", help="Enhance the prompt using Gemma") - parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for prompt enhancement") - parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for prompt enhancement") - parser.add_argument("--image", "-i", type=str, default=None, help="Path to conditioning image for I2V") - parser.add_argument("--image-strength", type=float, default=1.0, help="Conditioning strength for I2V") - parser.add_argument("--image-frame-idx", type=int, default=0, help="Frame index to condition for I2V") - parser.add_argument("--tiling", type=str, default="auto", - choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], - help="Tiling mode for VAE decoding") - parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") - parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") - parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning") - parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)") - parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") - parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") - parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") - parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") - parser.add_argument("--stg-scale", type=float, default=1.0, help="STG (Spatiotemporal Guidance) scale (default 1.0, 0.0 = disabled)") - parser.add_argument("--stg-blocks", type=int, nargs="+", default=None, help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)") - parser.add_argument("--modality-scale", type=float, default=3.0, help="Cross-modal guidance scale (default 3.0, 1.0 = disabled)") - parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") - parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") - parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)") - parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)") - parser.add_argument("--spatial-upscaler", type=str, default=None, - help="Spatial upscaler filename (e.g. ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors). " - "Auto-detects x2 by default. Use this to select x1.5 or a specific version.") + parser.add_argument( + "--enhance-prompt", action="store_true", help="Enhance the prompt using Gemma" + ) + parser.add_argument( + "--max-tokens", type=int, default=512, help="Max tokens for prompt enhancement" + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for prompt enhancement", + ) + parser.add_argument( + "--image", + "-i", + type=str, + default=None, + help="Path to conditioning image for I2V", + ) + parser.add_argument( + "--image-strength", + type=float, + default=1.0, + help="Conditioning strength for I2V", + ) + parser.add_argument( + "--image-frame-idx", + type=int, + default=0, + help="Frame index to condition for I2V", + ) + parser.add_argument( + "--tiling", + type=str, + default="auto", + choices=[ + "auto", + "none", + "default", + "aggressive", + "conservative", + "spatial", + "temporal", + ], + help="Tiling mode for VAE decoding", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream frames to output as they're decoded", + ) + parser.add_argument( + "--audio", + "-a", + action="store_true", + help="Enable synchronized audio generation", + ) + parser.add_argument( + "--audio-file", + type=str, + default=None, + help="Path to audio file for A2V (audio-to-video) conditioning", + ) + parser.add_argument( + "--audio-start-time", + type=float, + default=0.0, + help="Start time in seconds for audio file (default: 0.0)", + ) + parser.add_argument( + "--output-audio", type=str, default=None, help="Output audio path" + ) + parser.add_argument( + "--apg", + action="store_true", + help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)", + ) + parser.add_argument( + "--apg-eta", + type=float, + default=1.0, + help="APG parallel component weight (1.0 = keep full parallel)", + ) + parser.add_argument( + "--apg-norm-threshold", + type=float, + default=0.0, + help="APG guidance norm clamp (0 = no clamping)", + ) + parser.add_argument( + "--stg-scale", + type=float, + default=1.0, + help="STG (Spatiotemporal Guidance) scale (default 1.0, 0.0 = disabled)", + ) + parser.add_argument( + "--stg-blocks", + type=int, + nargs="+", + default=None, + help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)", + ) + parser.add_argument( + "--modality-scale", + type=float, + default=3.0, + help="Cross-modal guidance scale (default 3.0, 1.0 = disabled)", + ) + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to LoRA safetensors file (dev-two-stage pipeline)", + ) + parser.add_argument( + "--lora-strength", + type=float, + default=1.0, + help="LoRA merge strength (dev-two-stage pipeline, default 1.0)", + ) + parser.add_argument( + "--lora-strength-stage-1", + type=float, + default=0.25, + help="LoRA strength for HQ stage 1 (default 0.25)", + ) + parser.add_argument( + "--lora-strength-stage-2", + type=float, + default=0.5, + help="LoRA strength for HQ stage 2 (default 0.5)", + ) + parser.add_argument( + "--spatial-upscaler", + type=str, + default=None, + help="Spatial upscaler filename (e.g. ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors). " + "Auto-detects x2 by default. Use this to select x1.5 or a specific version.", + ) args = parser.parse_args() pipeline_map = { diff --git a/mlx_video/models/ltx_2/ltx.py b/mlx_video/models/ltx_2/ltx.py index 18496b8..ec21a6e 100644 --- a/mlx_video/models/ltx_2/ltx.py +++ b/mlx_video/models/ltx_2/ltx.py @@ -1,15 +1,14 @@ +from pathlib import Path from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from pathlib import Path + +from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle from mlx_video.models.ltx_2.config import ( LTXModelConfig, - LTXModelType, LTXRopeType, - TransformerConfig, ) -from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle from mlx_video.models.ltx_2.rope import precompute_freqs_cis from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection from mlx_video.models.ltx_2.transformer import ( @@ -58,11 +57,17 @@ class TransformerArgsPreprocessor: ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) + timestep_emb, embedded_timestep = self.adaln( + timestep.reshape(-1), hidden_dtype=hidden_dtype + ) # Reshape to (batch, tokens, dim) - timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) - embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) + timestep_emb = mx.reshape( + timestep_emb, (batch_size, -1, timestep_emb.shape[-1]) + ) + embedded_timestep = mx.reshape( + embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]) + ) return timestep_emb, embedded_timestep @@ -74,9 +79,15 @@ class TransformerArgsPreprocessor: hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) - timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) - embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) + timestep_emb, embedded_timestep = adaln( + timestep.reshape(-1), hidden_dtype=hidden_dtype + ) + timestep_emb = mx.reshape( + timestep_emb, (batch_size, -1, timestep_emb.shape[-1]) + ) + embedded_timestep = mx.reshape( + embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]) + ) return timestep_emb, embedded_timestep def _prepare_context( @@ -107,7 +118,9 @@ class TransformerArgsPreprocessor: # Convert boolean/int mask to float mask # 0 -> -inf (masked), 1 -> 0 (not masked) mask = (attention_mask.astype(x_dtype) - 1) * 1e9 - mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) + mask = mx.reshape( + mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) return mask def _prepare_positional_embeddings( @@ -132,9 +145,15 @@ class TransformerArgsPreprocessor: def prepare(self, modality: Modality) -> TransformerArgs: x = self.patchify_proj(modality.latent) - timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) - context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) - attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) + timestep, embedded_timestep = self._prepare_timestep( + modality.timesteps, x.shape[0], hidden_dtype=x.dtype + ) + context, attention_mask = self._prepare_context( + modality.context, x, modality.context_mask + ) + attention_mask = self._prepare_attention_mask( + attention_mask, modality.latent.dtype + ) # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) if modality.positional_embeddings is not None: @@ -152,8 +171,13 @@ class TransformerArgsPreprocessor: prompt_timestep = None prompt_embedded_timestep = None if self.prompt_adaln is not None and modality.sigma is not None: - prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln( - self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype, + prompt_timestep, prompt_embedded_timestep = ( + self._prepare_timestep_with_adaln( + self.prompt_adaln, + modality.sigma, + x.shape[0], + hidden_dtype=x.dtype, + ) ) return TransformerArgs( @@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor: ) # Prepare cross-attention timestep embeddings - cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( - timestep=modality.timesteps, - timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, - batch_size=transformer_args.x.shape[0], - hidden_dtype=transformer_args.x.dtype, + cross_scale_shift_timestep, cross_gate_timestep = ( + self._prepare_cross_attention_timestep( + timestep=modality.timesteps, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=transformer_args.x.dtype, + ) ) return replace( @@ -254,17 +280,25 @@ class MultiModalTransformerArgsPreprocessor: av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier - scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) - scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.reshape(-1), hidden_dtype=hidden_dtype + ) + scale_shift_timestep = mx.reshape( + scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]) + ) - gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) - gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) + gate_timestep, _ = self.cross_gate_adaln( + timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype + ) + gate_timestep = mx.reshape( + gate_timestep, (batch_size, -1, gate_timestep.shape[-1]) + ) return scale_shift_timestep, gate_timestep class LTXModel(nn.Module): - + def __init__(self, config: LTXModelConfig): super().__init__() @@ -285,18 +319,25 @@ class LTXModel(nn.Module): self._init_video(config) if config.model_type.is_audio_enabled(): - self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos + self.audio_positional_embedding_max_pos = ( + config.audio_positional_embedding_max_pos + ) self.audio_num_attention_heads = config.audio_num_attention_heads self.audio_inner_dim = config.audio_inner_dim self._init_audio(config) # Initialize cross-modal components - if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled(): + if ( + config.model_type.is_video_enabled() + and config.model_type.is_audio_enabled() + ): cross_pe_max_pos = max( config.positional_embedding_max_pos[0], config.audio_positional_embedding_max_pos[0], ) - self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier + self.av_ca_timestep_scale_multiplier = ( + config.av_ca_timestep_scale_multiplier + ) self.audio_cross_attention_dim = config.audio_cross_attention_dim self._init_audio_video(config) @@ -308,10 +349,14 @@ class LTXModel(nn.Module): self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True) adaln_coefficient = 9 if config.has_prompt_adaln else 6 - self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=adaln_coefficient + ) if config.has_prompt_adaln: - self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2 + ) else: self.caption_projection = PixArtAlphaTextProjection( in_features=config.caption_channels, @@ -323,13 +368,19 @@ class LTXModel(nn.Module): self.proj_out = nn.Linear(self.inner_dim, config.out_channels) def _init_audio(self, config: LTXModelConfig) -> None: - self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) + self.audio_patchify_proj = nn.Linear( + config.audio_in_channels, self.audio_inner_dim, bias=True + ) audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6 - self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient) + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient + ) if config.has_prompt_adaln: - self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, embedding_coefficient=2 + ) else: self.audio_caption_projection = PixArtAlphaTextProjection( in_features=config.audio_caption_channels, @@ -338,7 +389,9 @@ class LTXModel(nn.Module): # Output components self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim)) - self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False) + self.audio_norm_out = nn.LayerNorm( + self.audio_inner_dim, eps=config.norm_eps, affine=False + ) self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels) def _init_audio_video(self, config: LTXModelConfig) -> None: @@ -361,8 +414,13 @@ class LTXModel(nn.Module): embedding_coefficient=1, ) - def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None: - if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled(): + def _init_preprocessors( + self, config: LTXModelConfig, cross_pe_max_pos: Optional[int] + ) -> None: + if ( + config.model_type.is_video_enabled() + and config.model_type.is_audio_enabled() + ): # Multi-modal preprocessors self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.patchify_proj, @@ -468,7 +526,8 @@ class LTXModel(nn.Module): stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set() for idx, block in self.transformer_blocks.items(): video, audio = block( - video=video, audio=audio, + video=video, + audio=audio, skip_video_self_attn=(idx in stg_v_set), skip_audio_self_attn=(idx in stg_a_set), skip_cross_modal=skip_cross_modal, @@ -483,7 +542,7 @@ class LTXModel(nn.Module): x: mx.array, embedded_timestep: mx.array, ) -> mx.array: - + # scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim) # embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim) table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim) @@ -526,8 +585,12 @@ class LTXModel(nn.Module): raise ValueError("Audio is not enabled for this model") # Preprocess arguments - video_args = self.video_args_preprocessor.prepare(video) if video is not None else None - audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None + video_args = ( + self.video_args_preprocessor.prepare(video) if video is not None else None + ) + audio_args = ( + self.audio_args_preprocessor.prepare(audio) if audio is not None else None + ) # Process transformer blocks video_out, audio_out = self._process_transformer_blocks( @@ -567,7 +630,7 @@ class LTXModel(nn.Module): def sanitize(self, weights: dict) -> dict: sanitized = {} - + has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights) if not has_raw_prefix: return weights @@ -577,7 +640,10 @@ class LTXModel(nn.Module): if not key.startswith("model.diffusion_model."): continue - if "audio_embeddings_connector" in key or "video_embeddings_connector" in key: + if ( + "audio_embeddings_connector" in key + or "video_embeddings_connector" in key + ): continue # Remove 'model.diffusion_model.' prefix @@ -612,9 +678,11 @@ class LTXModel(nn.Module): for weight_file in model_path.glob("*.safetensors"): weights.update(mx.load(str(weight_file))) - sanitized = model.sanitize(weights) - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + sanitized = { + k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v + for k, v in sanitized.items() + } model.load_weights(list(sanitized.items()), strict=strict) mx.eval(model.parameters()) @@ -625,7 +693,7 @@ class LTXModel(nn.Module): class X0Model(nn.Module): def __init__(self, velocity_model: LTXModel): - + super().__init__() self.velocity_model = velocity_model @@ -639,13 +707,18 @@ class X0Model(nn.Module): ) -> Tuple[Optional[mx.array], Optional[mx.array]]: vx, ax = self.velocity_model( - video, audio, + video, + audio, stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks, skip_cross_modal=skip_cross_modal, ) - denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None - denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None + denoised_video = ( + to_denoised(video.latent, vx, video.timesteps) if vx is not None else None + ) + denoised_audio = ( + to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None + ) return denoised_video, denoised_audio diff --git a/mlx_video/models/ltx_2/postprocess.py b/mlx_video/models/ltx_2/postprocess.py index 03ef61d..7865975 100644 --- a/mlx_video/models/ltx_2/postprocess.py +++ b/mlx_video/models/ltx_2/postprocess.py @@ -1,9 +1,10 @@ import numpy as np -from typing import Optional -def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray: +def bilateral_filter( + image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75 +) -> np.ndarray: """Apply bilateral filter to reduce grid artifacts while preserving edges. Args: @@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig """ try: import cv2 + return cv2.bilateralFilter(image, d, sigma_color, sigma_space) except ImportError: # Fallback to simple Gaussian blur if cv2 not available @@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray: """ try: import cv2 + return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) except ImportError: # Simple box blur fallback from scipy.ndimage import uniform_filter - return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8) + + return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype( + np.uint8 + ) -def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray: +def unsharp_mask( + image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0 +) -> np.ndarray: """Apply unsharp masking to enhance edges after blur. Args: @@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am """ try: import cv2 + blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma) sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0) return np.clip(sharpened, 0, 255).astype(np.uint8) @@ -81,23 +90,23 @@ def reduce_grid_artifacts( if method == "bilateral": d = max(3, int(5 * strength)) sigma = 50 + 50 * strength - processed = np.stack([ - bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma) - for frame in video - ]) + processed = np.stack( + [ + bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma) + for frame in video + ] + ) elif method == "gaussian": kernel_size = max(3, int(3 + 4 * strength)) if kernel_size % 2 == 0: kernel_size += 1 - processed = np.stack([ - gaussian_blur(frame, kernel_size=kernel_size) - for frame in video - ]) + processed = np.stack( + [gaussian_blur(frame, kernel_size=kernel_size) for frame in video] + ) elif method == "frequency": - processed = np.stack([ - remove_grid_frequency(frame, grid_size=8) - for frame in video - ]) + processed = np.stack( + [remove_grid_frequency(frame, grid_size=8) for frame in video] + ) else: raise ValueError(f"Unknown method: {method}") @@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray: result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8) return result - - - diff --git a/mlx_video/models/ltx_2/rope.py b/mlx_video/models/ltx_2/rope.py index 21de1d4..2915b55 100644 --- a/mlx_video/models/ltx_2/rope.py +++ b/mlx_video/models/ltx_2/rope.py @@ -1,4 +1,3 @@ - import math from typing import List, Optional, Tuple @@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array: """ # x: (..., dim) where dim is even x_even = x[..., 0::2] # [x0, x2, x4, ...] - x_odd = x[..., 1::2] # [x1, x3, x5, ...] + x_odd = x[..., 1::2] # [x1, x3, x5, ...] # Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...] rotated = mx.stack([-x_odd, x_even], axis=-1) return mx.reshape(rotated, x.shape) + def apply_rotary_emb_1d( q: mx.array, k: mx.array, @@ -228,9 +228,9 @@ def get_fractional_positions( Fractional positions in range [-1, 1] after scaling """ n_pos_dims = indices_grid.shape[1] - assert n_pos_dims == len(max_pos), ( - f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" - ) + assert n_pos_dims == len( + max_pos + ), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" # Divide each dimension by its max position fractional_positions = [] @@ -392,11 +392,15 @@ def precompute_freqs_cis( if max_pos is None: max_pos = [20, 2048, 2048] - if double_precision: return _precompute_freqs_cis_double_precision( - indices_grid, dim, theta, max_pos, use_middle_indices_grid, - num_attention_heads, rope_type + indices_grid, + dim, + theta, + max_pos, + use_middle_indices_grid, + num_attention_heads, + rope_type, ) # Keep positions in float32 for RoPE computation. @@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision( # Compute frequencies: outer product # scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1) # freq_indices: (num_indices,) -> (1, 1, 1, num_indices) - freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1)) + freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape( + freq_indices, (1, 1, 1, -1) + ) # freqs: (B, T, n_dims, num_indices) # Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims) diff --git a/mlx_video/models/ltx_2/samplers.py b/mlx_video/models/ltx_2/samplers.py index 489780b..b97faa0 100644 --- a/mlx_video/models/ltx_2/samplers.py +++ b/mlx_video/models/ltx_2/samplers.py @@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation. """ import math -from typing import Optional import mlx.core as mx - # --------------------------------------------------------------------------- # Phi functions and RK coefficients (pure Python math, no MLX needed) # --------------------------------------------------------------------------- + def phi(j: int, neg_h: float) -> float: """Compute phi_j(z) where z = -h (negative step size in log-space). @@ -43,6 +42,7 @@ def get_res2s_coefficients( Returns: (a21, b1, b2): RK coefficients. """ + def get_phi(j: int, neg_h: float) -> float: cache_key = (j, neg_h) if cache_key in phi_cache: @@ -69,6 +69,7 @@ def get_res2s_coefficients( # SDE noise injection # --------------------------------------------------------------------------- + def get_sde_coeff( sigma_next: float, ) -> tuple[float, float, float]: @@ -139,7 +140,9 @@ def sde_noise_step( denoised_next = sample_f32 - sigma * eps_next # Mix deterministic and stochastic components - x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32 + x_noised = ( + alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32 + ) return x_noised @@ -148,6 +151,7 @@ def sde_noise_step( # Noise generation # --------------------------------------------------------------------------- + def channelwise_normalize(x: mx.array) -> mx.array: """Normalize each channel to zero mean and unit variance over spatial dims. diff --git a/mlx_video/models/ltx_2/text_encoder.py b/mlx_video/models/ltx_2/text_encoder.py index 4f14c8a..fbff7e1 100644 --- a/mlx_video/models/ltx_2/text_encoder.py +++ b/mlx_video/models/ltx_2/text_encoder.py @@ -1,25 +1,25 @@ - - import functools import logging import math import re -from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -import numpy as np -from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn - -from mlx_video.utils import rms_norm, apply_quantization -from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb - -from mlx_vlm.models.gemma3.language import Gemma3Model from mlx_vlm.models.gemma3.config import TextConfig +from mlx_vlm.models.gemma3.language import Gemma3Model +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, +) +from mlx_video.utils import apply_quantization, rms_norm # Path to system prompts PROMPTS_DIR = Path(__file__).parent / "prompts" @@ -36,11 +36,10 @@ def _load_system_prompt(prompt_name: str) -> str: class LanguageModel(nn.Module): - def __init__(self, config: TextConfig): super().__init__() # Create config matching LTX-2 text encoder requirements - self.config = config + self.config = config # Create the Gemma3Model from mlx-vlm self.model = Gemma3Model(self.config) @@ -51,7 +50,7 @@ class LanguageModel(nn.Module): attention_mask: Optional[mx.array], dtype: mx.Dtype, ) -> mx.array: - + causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_)) if attention_mask is not None: @@ -59,15 +58,25 @@ class LanguageModel(nn.Module): padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len) combined = causal_mask[None, :, :] & padding_mask[:, None, :] - min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 - mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype), - mx.full(combined.shape, min_val, dtype=dtype)) + min_val = ( + mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 + ) + mask = mx.where( + combined, + mx.zeros(combined.shape, dtype=dtype), + mx.full(combined.shape, min_val, dtype=dtype), + ) return mask[:, None, :, :] else: # No padding mask, just causal - min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 - mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype), - mx.full((seq_len, seq_len), min_val, dtype=dtype)) + min_val = ( + mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 + ) + mask = mx.where( + causal_mask, + mx.zeros((seq_len, seq_len), dtype=dtype), + mx.full((seq_len, seq_len), min_val, dtype=dtype), + ) return mask[None, None, :, :] # (1, 1, seq, seq) def __call__( @@ -91,7 +100,11 @@ class LanguageModel(nn.Module): batch_size, seq_len = inputs.shape # Get embeddings - h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs) + h = ( + input_embeddings + if input_embeddings is not None + else self.model.embed_tokens(inputs) + ) # Apply Gemma scaling h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype) @@ -103,11 +116,12 @@ class LanguageModel(nn.Module): if cache is None: cache = [None] * len(self.model.layers) - full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype) + full_causal_mask = self._create_causal_mask_with_padding( + seq_len, attention_mask, h.dtype + ) sliding_mask = full_causal_mask - num_layers = len(self.model.layers) for i, layer in enumerate(self.model.layers): is_global = ( @@ -147,9 +161,9 @@ class LanguageModel(nn.Module): for key, value in weights.items(): if key.startswith(prefix): if hasattr(value, "dtype") and value.dtype == mx.float32: - sanitized[key[len(prefix):]] = value.astype(mx.bfloat16) + sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16) else: - sanitized[key[len(prefix):]] = value + sanitized[key[len(prefix) :]] = value return sanitized @property @@ -158,6 +172,7 @@ class LanguageModel(nn.Module): def make_cache(self): from mlx_vlm.models.cache import KVCache, RotatingKVCache + caches = [] for i in range(len(self.layers)): if ( @@ -172,6 +187,7 @@ class LanguageModel(nn.Module): @classmethod def from_pretrained(cls, model_path: str): import json + weight_files = sorted(Path(model_path).glob("*.safetensors")) config_file = Path(model_path) / "config.json" config_dict = {} @@ -179,7 +195,9 @@ class LanguageModel(nn.Module): with open(config_file, "r") as f: config_dict = json.load(f) - language_model = cls(config=TextConfig.from_dict(config_dict["text_config"])) + language_model = cls( + config=TextConfig.from_dict(config_dict["text_config"]) + ) else: raise ValueError(f"Config file not found at {model_path}") @@ -188,19 +206,18 @@ class LanguageModel(nn.Module): for i, wf in enumerate(weight_files): weights.update(mx.load(str(wf))) - if hasattr(language_model, "sanitize"): weights = language_model.sanitize(weights=weights) - - apply_quantization(model=language_model, weights=weights, quantization=quantization) + apply_quantization( + model=language_model, weights=weights, quantization=quantization + ) language_model.load_weights(list(weights.items()), strict=False) return language_model - class ConnectorAttention(nn.Module): def __init__( @@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module): k = self.k_norm(k) # Reshape to (B, H, T, D) for SPLIT RoPE - q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) - k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) - v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) + q = mx.reshape( + q, (batch_size, seq_len, self.num_heads, self.head_dim) + ).transpose(0, 2, 1, 3) + k = mx.reshape( + k, (batch_size, seq_len, self.num_heads, self.head_dim) + ).transpose(0, 2, 1, 3) + v = mx.reshape( + v, (batch_size, seq_len, self.num_heads, self.head_dim) + ).transpose(0, 2, 1, 3) if pe is not None: q = self._apply_split_rope(q, pe[0], pe[1]) @@ -304,7 +327,7 @@ class ConnectorAttention(nn.Module): out2 = x2 * cos_freq + x1 * sin_freq return mx.concatenate([out1, out2], axis=-1).astype(input_dtype) - + class GEGLU(nn.Module): """GELU-gated linear unit.""" @@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module): class ConnectorTransformerBlock(nn.Module): - def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False): + def __init__( + self, + dim: int = 3840, + num_heads: int = 30, + head_dim: int = 128, + has_gate_logits: bool = False, + ): super().__init__() - self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) + self.attn1 = ConnectorAttention( + dim, num_heads, head_dim, has_gate_logits=has_gate_logits + ) self.ff = ConnectorFeedForward(dim) def __call__( @@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module): self.positional_embedding_max_pos = positional_embedding_max_pos or [1] self.transformer_1d_blocks = { - i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) + i: ConnectorTransformerBlock( + dim, num_heads, head_dim, has_gate_logits=has_gate_logits + ) for i in range(num_layers) } if num_learnable_registers > 0: self.learnable_registers = mx.zeros((num_learnable_registers, dim)) - def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]: + def _precompute_freqs_cis( + self, seq_len: int, dtype: mx.Dtype + ) -> Tuple[mx.array, mx.array]: """Compute RoPE frequencies for connector (SPLIT type matching PyTorch). Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2). @@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module): # Binary mask: 1 for valid tokens, 0 for padded # attention_mask is additive: 0 for valid, large negative for padded - mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq) + mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype( + mx.int32 + ) # (batch, seq) # Tile registers to match sequence length, cast to hidden_states dtype num_tiles = seq_len // self.num_learnable_registers - registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim) + registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype( + dtype + ) # (seq_len, dim) # Process each batch item (PyTorch uses advanced indexing) result_list = [] @@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module): # Extract valid tokens (where mask is 1) # Since we have left-padded input, valid tokens are at the end - valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim) + valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim) # Pad with zeros on the right to get back to seq_len pad_length = seq_len - num_valid if pad_length > 0: padding = mx.zeros((pad_length, dim), dtype=dtype) - adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim) + adjusted = mx.concatenate( + [valid_tokens, padding], axis=0 + ) # (seq_len, dim) else: adjusted = valid_tokens # Create flipped mask: 1s at front (where valid tokens now are), 0s at back - flipped_mask = mx.concatenate([ - mx.ones((num_valid,), dtype=mx.int32), - mx.zeros((pad_length,), dtype=mx.int32) - ], axis=0) # (seq,) + flipped_mask = mx.concatenate( + [ + mx.ones((num_valid,), dtype=mx.int32), + mx.zeros((pad_length,), dtype=mx.int32), + ], + axis=0, + ) # (seq,) # Combine: valid tokens at front, registers at back flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1) - combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers + combined = ( + flipped_mask_expanded * adjusted + + (1 - flipped_mask_expanded) * registers + ) result_list.append(combined) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) @@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module): # Process through transformer blocks for i in range(len(self.transformer_1d_blocks)): - hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis) + hidden_states = self.transformer_1d_blocks[i]( + hidden_states, attention_mask, freqs_cis + ) # Final RMS norm hidden_states = rms_norm(hidden_states) @@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module): return hidden_states, attention_mask - def norm_and_concat_hidden_states( hidden_states: List[mx.array], attention_mask: mx.array, @@ -567,8 +615,12 @@ def norm_and_concat_hidden_states( mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) # Compute masked min/max per layer - x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype)) - x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype)) + x_for_min = mx.where( + mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype) + ) + x_for_max = mx.where( + mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype) + ) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) range_val = x_max - x_min @@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms( dtype = encoded_text.dtype # Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D - variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L) + variance = mx.mean( + encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True + ) # (B, T, 1, L) normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6) normed = normed.astype(dtype) @@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array: class GemmaFeaturesExtractor(nn.Module): """V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization.""" - def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False): + def __init__( + self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False + ): super().__init__() self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias) @@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module): if mode == "video": target_dim = self.video_aggregate_embed.weight.shape[0] - return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) + return self.video_aggregate_embed( + _rescale_norm(normed, target_dim, self.embedding_dim) + ) else: target_dim = self.audio_aggregate_embed.weight.shape[0] - return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) - - - + return self.audio_aggregate_embed( + _rescale_norm(normed, target_dim, self.embedding_dim) + ) class AudioEmbeddingsConnector(nn.Module): @@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module): video_output_dim = 4096 audio_output_dim = 2048 self.feature_extractor_v2 = GemmaFeaturesExtractorV2( - flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) - embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) + flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) + embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) video_output_dim=video_output_dim, audio_output_dim=audio_output_dim, bias=True, @@ -728,37 +785,57 @@ class LTX2TextEncoder(nn.Module): # connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors # config (nested under config.transformer.connector_positional_embedding_max_pos) self.video_embeddings_connector = Embeddings1DConnector( - dim=video_output_dim, num_heads=32, head_dim=128, - num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[4096], has_gate_logits=True, + dim=video_output_dim, + num_heads=32, + head_dim=128, + num_layers=8, + num_learnable_registers=128, + positional_embedding_max_pos=[4096], + has_gate_logits=True, ) self.audio_embeddings_connector = Embeddings1DConnector( - dim=audio_output_dim, num_heads=32, head_dim=64, - num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[4096], has_gate_logits=True, + dim=audio_output_dim, + num_heads=32, + head_dim=64, + num_layers=8, + num_learnable_registers=128, + positional_embedding_max_pos=[4096], + has_gate_logits=True, ) else: # LTX-2: shared feature extractor, 3840-dim connectors - self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim) + self.feature_extractor = GemmaFeaturesExtractor( + feature_input_dim, hidden_dim + ) self.video_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, num_heads=30, head_dim=128, - num_layers=2, num_learnable_registers=128, + dim=hidden_dim, + num_heads=30, + head_dim=128, + num_layers=2, + num_learnable_registers=128, positional_embedding_max_pos=[1], ) self.audio_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, num_heads=30, head_dim=128, - num_layers=2, num_learnable_registers=128, + dim=hidden_dim, + num_heads=30, + head_dim=128, + num_layers=2, + num_learnable_registers=128, positional_embedding_max_pos=[1], ) self.processor = None - def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"): + def load( + self, + model_path: Optional[str] = None, + text_encoder_path: Optional[str] = "google/gemma-3-12b-it", + ): if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir(): text_encoder_path = str(Path(text_encoder_path) / "text_encoder") - + self.language_model = LanguageModel.from_pretrained(text_encoder_path) # Load transformer weights for feature extractor and connector. @@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module): if transformer_weights: self._load_feature_extractors(transformer_weights, is_reformatted) - self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted) - self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted) + self._load_connector( + "video_embeddings_connector", transformer_weights, is_reformatted + ) + self._load_connector( + "audio_embeddings_connector", transformer_weights, is_reformatted + ) else: - print("WARNING: No transformer weights found for text projection connectors. " - "Text conditioning will use uninitialized weights!") + print( + "WARNING: No transformer weights found for text projection connectors. " + "Text conditioning will use uninitialized weights!" + ) # Load tokenizer from transformers import AutoTokenizer + tokenizer_path = model_path / "tokenizer" if tokenizer_path.exists(): - self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True) + self.processor = AutoTokenizer.from_pretrained( + str(tokenizer_path), trust_remote_code=True + ) else: try: - self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True) + self.processor = AutoTokenizer.from_pretrained( + text_encoder_path, trust_remote_code=True + ) except Exception: - self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True) + self.processor = AutoTokenizer.from_pretrained( + "google/gemma-3-12b-it", trust_remote_code=True + ) # Set left padding to match official LTX-2 text encoder self.processor.padding_side = "left" @@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module): submodule.bias = weights[b_key] else: # LTX-2: single aggregate_embed - agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" + agg_key = ( + "aggregate_embed.weight" + if is_reformatted + else "text_embedding_projection.aggregate_embed.weight" + ) if agg_key in weights: self.feature_extractor.aggregate_embed.weight = weights[agg_key] @@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module): prefix = f"{name}." for key, value in weights.items(): if key.startswith(prefix): - connector_weights[key[len(prefix):]] = value + connector_weights[key[len(prefix) :]] = value else: mono_prefix = f"model.diffusion_model.{name}." for key, value in weights.items(): if key.startswith(mono_prefix): - connector_weights[key[len(mono_prefix):]] = value + connector_weights[key[len(mono_prefix) :]] = value if not connector_weights: return @@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module): input_ids = mx.array(inputs["input_ids"]) attention_mask = mx.array(inputs["attention_mask"]) - _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) + _, all_hidden_states = self.language_model( + inputs=input_ids, + input_embeddings=None, + attention_mask=attention_mask, + output_hidden_states=True, + ) if self.has_prompt_adaln: # LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale) - video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video") + video_features = self.feature_extractor_v2( + all_hidden_states, attention_mask, mode="video" + ) additive_mask = (attention_mask - 1).astype(video_features.dtype) - additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + additive_mask = ( + additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + ) - video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + video_embeddings, _ = self.video_embeddings_connector( + video_features, additive_mask + ) if return_audio_embeddings: - audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio") + audio_features = self.feature_extractor_v2( + all_hidden_states, attention_mask, mode="audio" + ) audio_mask = (attention_mask - 1).astype(audio_features.dtype) audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask) + audio_embeddings, _ = self.audio_embeddings_connector( + audio_features, audio_mask + ) return video_embeddings, audio_embeddings else: return video_embeddings, attention_mask @@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module): video_features = self.feature_extractor(concat_hidden) additive_mask = (attention_mask - 1).astype(video_features.dtype) - additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + additive_mask = ( + additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + ) - video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + video_embeddings, _ = self.video_embeddings_connector( + video_features, additive_mask + ) if return_audio_embeddings: - audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask) + audio_embeddings, _ = self.audio_embeddings_connector( + video_features, additive_mask + ) return video_embeddings, audio_embeddings else: return video_embeddings, attention_mask @@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module): # Remove leading/trailing whitespace response = response.strip() # Remove any leading punctuation - response = re.sub(r'^[^\w\s]+', '', response) + response = re.sub(r"^[^\w\s]+", "", response) return response def _apply_chat_template( @@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module): elif isinstance(content, list): # Handle multimodal content (image + text) text_parts = [c["text"] for c in content if c.get("type") == "text"] - formatted += f"user\n{' '.join(text_parts)}\n" + formatted += ( + f"user\n{' '.join(text_parts)}\n" + ) elif role == "assistant": formatted += f"model\n{content}\n" # Add generation prompt @@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module): from mlx_lm import stream_generate from mlx_lm.sample_utils import make_logits_processors, make_sampler except ImportError: - logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.") + logging.warning( + "mlx-lm not available for prompt enhancement. Using original prompt." + ) return prompt if self.processor is None: @@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module): ) input_ids = mx.array(inputs["input_ids"]) - sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1)) + sampler = make_sampler( + kwargs.get("temperature", 0.7), + kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + ) logits_processors = make_logits_processors( kwargs.get("logit_bias", None), kwargs.get("repetition_penalty", 1.3), @@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module): mx.clear_cache() # Decode only the new tokens - enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True) + enhanced_prompt = self.processor.decode( + generated_tokens, skip_special_tokens=True + ) enhanced_prompt = self._clean_response(enhanced_prompt) logging.info(f"Enhanced prompt: {enhanced_prompt}") return enhanced_prompt - def enhance_i2v( self, prompt: str, @@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder: encoder = LTX2TextEncoder() encoder.load(model_path=model_path) return encoder - diff --git a/mlx_video/models/ltx_2/text_projection.py b/mlx_video/models/ltx_2/text_projection.py index 55e684d..29165ca 100644 --- a/mlx_video/models/ltx_2/text_projection.py +++ b/mlx_video/models/ltx_2/text_projection.py @@ -11,7 +11,7 @@ class PixArtAlphaTextProjection(nn.Module): out_features: int | None = None, bias: bool = True, ): - + super().__init__() out_features = out_features or hidden_size diff --git a/mlx_video/models/ltx_2/transformer.py b/mlx_video/models/ltx_2/transformer.py index 2144acf..2f2c914 100644 --- a/mlx_video/models/ltx_2/transformer.py +++ b/mlx_video/models/ltx_2/transformer.py @@ -4,8 +4,8 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig from mlx_video.models.ltx_2.attention import Attention +from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig from mlx_video.models.ltx_2.feed_forward import FeedForward from mlx_video.utils import rms_norm @@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module): # timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim) timestep_reshaped = mx.reshape( - timestep, - (batch_size, timestep.shape[1], num_ada_params, -1) + timestep, (batch_size, timestep.shape[1], num_ada_params, -1) ) # Extract the relevant indices @@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module): ) # Squeeze the sequence dimension if it's 1 - scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada) - gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada) + scale_shift_squeezed = tuple( + mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada + ) + gate_squeezed = tuple( + mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada + ) return (*scale_shift_squeezed, *gate_squeezed) @@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module): # Check which modalities to run run_vx = video is not None and video.enabled and vx.size > 0 run_ax = audio is not None and audio.enabled and ax.size > 0 - run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal - run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal + run_a2v = ( + run_vx + and (audio is not None and audio.enabled and ax.size > 0) + and not skip_cross_modal + ) + run_v2a = ( + run_ax + and (video is not None and video.enabled and vx.size > 0) + and not skip_cross_modal + ) # Process video self-attention and cross-attention with text if run_vx: @@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module): # Self-attention with RoPE (skip_attention=True for STG perturbation) norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa - vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa + vx = ( + vx + + self.attn1( + norm_vx, + pe=video.positional_embeddings, + skip_attention=skip_video_self_attn, + ) + * vgate_msa + ) # Cross-attention with text context if self.has_prompt_adaln: @@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module): self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9) ) vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values( - self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2) + self.prompt_scale_shift_table, + vx.shape[0], + video.prompt_timesteps, + slice(0, 2), ) attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q - encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv - vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q + encoder_hidden_states = ( + video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv + ) + vx = ( + vx + + self.attn2( + attn_input, + context=encoder_hidden_states, + mask=video.context_mask, + ) + * vgate_q + ) else: vx = vx + self.attn2( rms_norm(vx, eps=self.norm_eps), @@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module): # Self-attention with RoPE (skip_attention=True for STG perturbation) norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa - ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa + ax = ( + ax + + self.audio_attn1( + norm_ax, + pe=audio.positional_embeddings, + skip_attention=skip_audio_self_attn, + ) + * agate_msa + ) # Cross-attention with text context if self.has_prompt_adaln: # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln ashift_q, ascale_q, agate_q = self.get_ada_values( - self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9) + self.audio_scale_shift_table, + ax.shape[0], + audio.timesteps, + slice(6, 9), ) aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values( - self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2) + self.audio_prompt_scale_shift_table, + ax.shape[0], + audio.prompt_timesteps, + slice(0, 2), + ) + attn_input_a = ( + rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q + ) + encoder_hidden_states_a = ( + audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv + ) + ax = ( + ax + + self.audio_attn2( + attn_input_a, + context=encoder_hidden_states_a, + mask=audio.context_mask, + ) + * agate_q ) - attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q - encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv - ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q else: ax = ax + self.audio_attn2( rms_norm(ax, eps=self.norm_eps), diff --git a/mlx_video/models/ltx_2/upsampler.py b/mlx_video/models/ltx_2/upsampler.py index 1056687..8ea8cd1 100644 --- a/mlx_video/models/ltx_2/upsampler.py +++ b/mlx_video/models/ltx_2/upsampler.py @@ -1,4 +1,5 @@ from typing import Tuple, Union + import mlx.core as mx import mlx.nn as nn @@ -36,11 +37,20 @@ class Conv3d(nn.Module): self.groups = groups # Weight shape: (C_out, KD, KH, KW, C_in) - scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5 + scale = ( + 1.0 + / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5 + ) self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels), + shape=( + out_channels, + kernel_size[0], + kernel_size[1], + kernel_size[2], + in_channels, + ), ) if bias: @@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module): n, d, h, w, c = x.shape input_dtype = x.dtype - x = x.astype(mx.float32) # Reshape to (N, D*H*W, num_groups, C//num_groups) @@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module): self.den = den # Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num) - self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1) + self.conv = nn.Conv2d( + mid_channels, num * num * mid_channels, kernel_size=3, padding=1 + ) self.pixel_shuffle = PixelShuffle2D(num, num) self.blur_down = BlurDownsample(stride=den) @@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module): x = self.conv(x) x = self.pixel_shuffle(x) # H*num, W*num - x = self.blur_down(x) # H*num/den, W*num/den + x = self.blur_down(x) # H*num/den, W*num/den _, h_out, w_out, _ = x.shape x = mx.reshape(x, (n, d, h_out, w_out, c)) @@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module): def _rational_for_scale(scale: float) -> Tuple[int, int]: """Convert a float scale to a rational fraction (numerator, denominator).""" from fractions import Fraction + frac = Fraction(scale).limit_denominator(10) return frac.numerator, frac.denominator @@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module): self.initial_norm = GroupNorm3d(32, mid_channels) # Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking - self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)} + self.res_blocks = { + i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage) + } # Upsampler: 2D spatial upsampling (frame-by-frame) if rational_resampler: - self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale) + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=spatial_scale + ) else: self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels) # Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking - self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)} + self.post_upsample_res_blocks = { + i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage) + } # Final projection self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1) @@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module): Returns: Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first """ + def debug_stats(name, t): if debug: mx.eval(t) - print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}") + print( + f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}" + ) if debug: print(" [DEBUG] LatentUpsampler forward pass:") @@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]: # x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2)) # x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample # Both formats may have upsampler.blur_down.kernel, so use channel count - conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight" + conv_key = ( + "upsampler.conv.weight" + if "upsampler.conv.weight" in raw_weights + else "upsampler.0.weight" + ) if conv_key in raw_weights: out_channels = raw_weights[conv_key].shape[0] ratio = out_channels // mid_channels @@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]: rational_resampler = False spatial_scale = 2.0 - print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}") + print( + f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}" + ) # Create model upsampler = LatentUpsampler( diff --git a/mlx_video/models/ltx_2/utils.py b/mlx_video/models/ltx_2/utils.py index a603539..5f70378 100644 --- a/mlx_video/models/ltx_2/utils.py +++ b/mlx_video/models/ltx_2/utils.py @@ -109,6 +109,7 @@ def convert_audio_encoder( return encoder_dir from huggingface_hub import hf_hub_download + vae_path = hf_hub_download( source_repo, "audio_vae/diffusion_pytorch_model.safetensors", diff --git a/mlx_video/models/ltx_2/video_vae/__init__.py b/mlx_video/models/ltx_2/video_vae/__init__.py index c154eea..fa19c4b 100644 --- a/mlx_video/models/ltx_2/video_vae/__init__.py +++ b/mlx_video/models/ltx_2/video_vae/__init__.py @@ -1,8 +1,8 @@ -from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder -from mlx_video.models.ltx_2.video_vae.encoder import encode_image from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder +from mlx_video.models.ltx_2.video_vae.encoder import encode_image from mlx_video.models.ltx_2.video_vae.tiling import ( - TilingConfig, SpatialTilingConfig, TemporalTilingConfig, + TilingConfig, ) +from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder diff --git a/mlx_video/models/ltx_2/video_vae/convolution.py b/mlx_video/models/ltx_2/video_vae/convolution.py index 4fe089d..db45568 100644 --- a/mlx_video/models/ltx_2/video_vae/convolution.py +++ b/mlx_video/models/ltx_2/video_vae/convolution.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array: # Height padding (axis 2) if pad_h > 0: # Get reflection indices - exclude boundary - top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion - bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion + top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion + bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][ + :, :, ::-1, :, : + ] # Flip bottom portion x = mx.concatenate([top_pad, x, bottom_pad], axis=2) # Width padding (axis 3) if pad_w > 0: - left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion - right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion + left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion + right_pad = x[:, :, :, -pad_w - 1 : -1, :][ + :, :, :, ::-1, : + ] # Flip right portion x = mx.concatenate([left_pad, x, right_pad], axis=3) return x @@ -50,7 +54,7 @@ def make_conv_nd( causal: bool = False, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ) -> nn.Module: - + if dims == 2: return CausalConv2d( in_channels=in_channels, @@ -118,15 +122,17 @@ class CausalConv3d(nn.Module): ) def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array: - + use_causal = causal if causal is not None else self.causal - # Apply temporal padding via frame replication + # Apply temporal padding via frame replication # Only apply if kernel_size > 1 if self.time_kernel_size > 1: if use_causal: # Causal: replicate first frame kernel_size-1 times at the beginning - first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2) + first_frame_pad = mx.repeat( + x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2 + ) x = mx.concatenate([first_frame_pad, x], axis=2) else: # Non-causal: replicate first frame at start, last frame at end @@ -176,7 +182,6 @@ class CausalConv3d(nn.Module): """ b, d, h, w, c = x.shape - total_elements = d * h * w * c max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk @@ -191,11 +196,10 @@ class CausalConv3d(nn.Module): overlap = kernel_t - 1 - expected_output_frames = d - overlap outputs = [] - out_idx = 0 + out_idx = 0 # Process chunks in_start = 0 diff --git a/mlx_video/models/ltx_2/video_vae/decoder.py b/mlx_video/models/ltx_2/video_vae/decoder.py index 0da4a61..7d1b8e3 100644 --- a/mlx_video/models/ltx_2/video_vae/decoder.py +++ b/mlx_video/models/ltx_2/video_vae/decoder.py @@ -15,14 +15,14 @@ Architecture (from PyTorch weights): """ import math -from typing import Optional, Dict from pathlib import Path +from typing import Dict, Optional import mlx.core as mx import mlx.nn as nn from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType -from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics +from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling @@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module): def __init__(self, embedding_dim: int): super().__init__() self.timestep_embedder = TimestepEmbedding( - in_channels=256, - time_embed_dim=embedding_dim + in_channels=256, time_embed_dim=embedding_dim ) - def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array: + def __call__( + self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32 + ) -> mx.array: timesteps_proj = get_timestep_embedding( - timestep, - embedding_dim=256, - flip_sin_to_cos=True, - downscale_freq_shift=0 + timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0 ) timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) return timesteps_emb @@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module): def _make_conv_wrapper(self, in_ch, out_ch, padding_mode): """Create a wrapper object with a 'conv' attribute to match PyTorch naming.""" + class ConvWrapper(nn.Module): def __init__(self_inner): super().__init__() @@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module): padding=1, spatial_padding_mode=padding_mode, ) + def __call__(self_inner, x, causal=False): return self_inner.conv(x, causal=causal) + return ConvWrapper() def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: """Apply pixel normalization.""" - return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps) + return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps) def __call__( self, @@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module): if self.timestep_conditioning and timestep_embed is not None: # scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1) # Combine table with timestep embedding - ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1) + ada_values = self.scale_shift_table[ + None, :, :, None, None, None + ] # (1, 4, C, 1, 1, 1) # Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1) channels = self.scale_shift_table.shape[1] ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1) @@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module): # Time embedder for this block group: embed_dim = 4 * channels if timestep_conditioning: - self.time_embedder = PixArtAlphaTimestepEmbedder( - embedding_dim=channels * 4 - ) + self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4) # Use dict with int keys for MLX to track parameters properly self.res_blocks = { i: ResnetBlock3DSimple( channels, spatial_padding_mode, - timestep_conditioning=timestep_conditioning + timestep_conditioning=timestep_conditioning, ) for i in range(num_layers) } @@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module): if self.timestep_conditioning and timestep is not None: batch_size = x.shape[0] timestep_embed = self.time_embedder( - timestep.flatten(), - hidden_dtype=x.dtype + timestep.flatten(), hidden_dtype=x.dtype ) # Reshape to (B, 4*C, 1, 1, 1) for broadcasting timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1) @@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module): padding=1, spatial_padding_mode=spatial_padding_mode, ) + def __call__(self_inner, x, causal=False): return self_inner.conv(x, causal=causal) + self.conv_in = ConvInWrapper() # Build up blocks from config @@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module): block_type = block_def[0] ch = block_def[1] if block_type == "res": - num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block - self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning) + num_layers = ( + block_def[2] if len(block_def) > 2 else num_layers_per_block + ) + self.up_blocks[idx] = ResBlockGroup( + ch, num_layers, spatial_padding_mode, timestep_conditioning + ) elif block_type == "d2s": reduction = block_def[2] if len(block_def) > 2 else 2 stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) @@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module): ) final_out_channels = out_channels * patch_size * patch_size + class ConvOutWrapper(nn.Module): def __init__(self_inner): super().__init__() @@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module): padding=1, spatial_padding_mode=spatial_padding_mode, ) + def __call__(self_inner, x, causal=False): return self_inner.conv(x, causal=causal) + self.conv_out = ConvOutWrapper() self.act = nn.SiLU() @@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module): return weights for key, value in weights.items(): new_key = key - + if not key.startswith("vae.") or key.startswith("vae.encoder."): continue @@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module): if key.startswith("vae.decoder."): new_key = key.replace("vae.decoder.", "") - # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) if ".conv.weight" in key and value.ndim == 5: value = mx.transpose(value, (0, 2, 3, 4, 1)) @@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module): if ".conv.weight" in new_key or ".conv.bias" in new_key: - if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key: + if ( + ".conv.conv.weight" not in new_key + and ".conv.conv.bias" not in new_key + ): new_key = new_key.replace(".conv.weight", ".conv.conv.weight") new_key = new_key.replace(".conv.bias", ".conv.conv.bias") @@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module): return sanitized @classmethod - def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder": + def from_pretrained( + cls, model_path: Path, strict: bool = True + ) -> "LTX2VideoDecoder": """Load a pretrained decoder from a directory with config.json and weights. Args: @@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module): for wf in weight_files: weights.update(mx.load(str(wf))) - # Infer block structure from weights decoder_blocks = cls._infer_blocks(weights) @@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module): return final_blocks - - def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: """Apply pixel normalization.""" - return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps) + return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps) def __call__( self, @@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module): debug: bool = False, chunked_conv: bool = False, ) -> mx.array: - batch_size = sample.shape[0] - - # Add noise if timestep conditioning is enabled if self.timestep_conditioning: noise = mx.random.normal(sample.shape) * self.decode_noise_scale sample = noise + (1.0 - self.decode_noise_scale) * sample - sample = self.per_channel_statistics.un_normalize(sample) - if timestep is None and self.timestep_conditioning: timestep = mx.full((batch_size,), self.decode_timestep) @@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module): scaled_timestep = timestep * self.timestep_scale_multiplier x = self.conv_in(sample, causal=causal) - for i, block in self.up_blocks.items(): if isinstance(block, ResBlockGroup): @@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module): x = block(x, causal=causal, chunked_conv=chunked_conv) else: x = block(x, causal=causal) - x = self.pixel_norm(x) - if self.timestep_conditioning and scaled_timestep is not None: embedded_timestep = self.last_time_embedder( - scaled_timestep.flatten(), - hidden_dtype=x.dtype + scaled_timestep.flatten(), hidden_dtype=x.dtype ) embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1) - ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1) + ada_values = self.last_scale_shift_table[ + None, :, :, None, None, None + ] # (1, 2, 128, 1, 1, 1) ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1) ada_values = ada_values + ts_reshaped @@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module): scale = ada_values[:, 1] x = x * (1 + scale) + shift - x = self.act(x) - x = self.conv_out(x, causal=causal) - + # Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4) x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1) - return x @@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module): # Auto-enable chunked conv for modes where it helps (larger tiles) # Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks - use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial") + use_chunked_conv = tiling_mode in ( + "conservative", + "none", + "auto", + "default", + "spatial", + ) if not needs_spatial_tiling and not needs_temporal_tiling: # No tiling needed, use regular decode - return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) + return self( + sample, + causal=causal, + timestep=timestep, + debug=debug, + chunked_conv=use_chunked_conv, + ) return decode_with_tiling( decoder_fn=self, diff --git a/mlx_video/models/ltx_2/video_vae/encoder.py b/mlx_video/models/ltx_2/video_vae/encoder.py index a605da0..2a29458 100644 --- a/mlx_video/models/ltx_2/video_vae/encoder.py +++ b/mlx_video/models/ltx_2/video_vae/encoder.py @@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation. """ import mlx.core as mx -from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder +from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder def encode_image( diff --git a/mlx_video/models/ltx_2/video_vae/ops.py b/mlx_video/models/ltx_2/video_vae/ops.py index d730d2f..a03b643 100644 --- a/mlx_video/models/ltx_2/video_vae/ops.py +++ b/mlx_video/models/ltx_2/video_vae/ops.py @@ -1,6 +1,5 @@ """Operations for Video VAE.""" -from typing import List, Tuple import mlx.core as mx import mlx.nn as nn @@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a new_c = c * patch_size_hw * patch_size_hw * patch_size_t # Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw) - x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)) + x = mx.reshape( + x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw) + ) # Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W') # PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph @@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module): Normalized tensor """ # Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1) - dtype = x.dtype + dtype = x.dtype # Cast to float32 for precision mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) @@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module): Returns: Denormalized tensor """ - dtype = x.dtype + dtype = x.dtype # Cast to float32 for precision mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) diff --git a/mlx_video/models/ltx_2/video_vae/resnet.py b/mlx_video/models/ltx_2/video_vae/resnet.py index 686636d..0bea4d3 100644 --- a/mlx_video/models/ltx_2/video_vae/resnet.py +++ b/mlx_video/models/ltx_2/video_vae/resnet.py @@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module): timestep_conditioning: bool = False, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): - + super().__init__() out_channels = out_channels or in_channels @@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module): causal: bool = True, generator: Optional[int] = None, ) -> mx.array: - + residual = x # First block @@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module): attention_head_dim: Optional[int] = None, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): - + super().__init__() self.num_layers = num_layers diff --git a/mlx_video/models/ltx_2/video_vae/sampling.py b/mlx_video/models/ltx_2/video_vae/sampling.py index 034c5a6..7e351ba 100644 --- a/mlx_video/models/ltx_2/video_vae/sampling.py +++ b/mlx_video/models/ltx_2/video_vae/sampling.py @@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module): class DepthToSpaceUpsample(nn.Module): - + def __init__( self, dims: int, @@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module): out_channels_reduction_factor: int = 1, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): - + super().__init__() if isinstance(stride, int): @@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module): return x - def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array: + def __call__( + self, x: mx.array, causal: bool = True, chunked_conv: bool = False + ) -> mx.array: b, c, d, h, w = x.shape st, sh, sw = self.stride @@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module): return x - def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array: + def _chunked_conv_depth_to_space( + self, x: mx.array, causal: bool = True + ) -> mx.array: """Chunked conv + depth_to_space that processes in temporal chunks. This reduces peak memory by avoiding the full high-channel intermediate tensor. diff --git a/mlx_video/models/ltx_2/video_vae/tiling.py b/mlx_video/models/ltx_2/video_vae/tiling.py index ad4c442..75ec47d 100644 --- a/mlx_video/models/ltx_2/video_vae/tiling.py +++ b/mlx_video/models/ltx_2/video_vae/tiling.py @@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d( # Apply right ramp (fade out) if ramp_right > 0: # Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1] - fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)] + fade_out = [ + (ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1) + ] for i in range(ramp_right): mask[length - ramp_right + i] *= fade_out[i] @@ -71,11 +73,17 @@ class SpatialTilingConfig: def __post_init__(self) -> None: if self.tile_size_in_pixels < 64: - raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + raise ValueError( + f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}" + ) if self.tile_size_in_pixels % 32 != 0: - raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + raise ValueError( + f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}" + ) if self.tile_overlap_in_pixels % 32 != 0: - raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + raise ValueError( + f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}" + ) if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: raise ValueError( f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" @@ -91,11 +99,17 @@ class TemporalTilingConfig: def __post_init__(self) -> None: if self.tile_size_in_frames < 16: - raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + raise ValueError( + f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}" + ) if self.tile_size_in_frames % 8 != 0: - raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + raise ValueError( + f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}" + ) if self.tile_overlap_in_frames % 8 != 0: - raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + raise ValueError( + f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}" + ) if self.tile_overlap_in_frames >= self.tile_size_in_frames: raise ValueError( f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" @@ -113,15 +127,21 @@ class TilingConfig: def default(cls) -> "TilingConfig": """Default tiling: 512px spatial, 64 frame temporal.""" return cls( - spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), - temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=512, tile_overlap_in_pixels=64 + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=64, tile_overlap_in_frames=24 + ), ) @classmethod def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig": """Spatial tiling only (for short videos with large resolution).""" return cls( - spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap), + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap + ), temporal_config=None, ) @@ -130,23 +150,33 @@ class TilingConfig: """Temporal tiling only (for long videos with small resolution).""" return cls( spatial_config=None, - temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap + ), ) @classmethod def aggressive(cls) -> "TilingConfig": """Aggressive tiling for very large videos (smaller tiles, much lower memory).""" return cls( - spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64), - temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8), + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=256, tile_overlap_in_pixels=64 + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=32, tile_overlap_in_frames=8 + ), ) @classmethod def conservative(cls) -> "TilingConfig": """Conservative tiling (larger tiles, less memory savings but faster).""" return cls( - spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64), - temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24), + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=768, tile_overlap_in_pixels=64 + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=96, tile_overlap_in_frames=24 + ), ) @classmethod @@ -186,10 +216,14 @@ class TilingConfig: temporal_config = None if needs_spatial: - spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64) + spatial_config = SpatialTilingConfig( + tile_size_in_pixels=512, tile_overlap_in_pixels=64 + ) if needs_temporal: - temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24) + temporal_config = TemporalTilingConfig( + tile_size_in_frames=64, tile_overlap_in_frames=24 + ) return cls(spatial_config=spatial_config, temporal_config=temporal_config) @@ -197,16 +231,21 @@ class TilingConfig: @dataclass class DimensionIntervals: """Intervals for splitting a single dimension.""" + starts: List[int] ends: List[int] left_ramps: List[int] right_ramps: List[int] -def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: +def split_in_spatial( + size: int, overlap: int, dimension_size: int +) -> DimensionIntervals: """Split a spatial dimension into intervals.""" if dimension_size <= size: - return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + return DimensionIntervals( + starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0] + ) amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) starts = [i * (size - overlap) for i in range(amount)] @@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI left_ramps = [0] + [overlap] * (amount - 1) right_ramps = [overlap] * (amount - 1) + [0] - return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + return DimensionIntervals( + starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps + ) -def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: +def split_in_temporal( + size: int, overlap: int, dimension_size: int +) -> DimensionIntervals: """Split a temporal dimension into intervals with causal adjustment.""" if dimension_size <= size: - return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + return DimensionIntervals( + starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0] + ) # Start with spatial split intervals = split_in_spatial(size, overlap, dimension_size) @@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension starts[i] = starts[i] - 1 left_ramps[i] = left_ramps[i] + 1 - return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps) + return DimensionIntervals( + starts=starts, + ends=intervals.ends, + left_ramps=left_ramps, + right_ramps=intervals.right_ramps, + ) -def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: +def map_temporal_slice( + begin: int, end: int, left_ramp: int, right_ramp: int, scale: int +) -> Tuple[slice, mx.array]: """Map temporal latent interval to output coordinates and mask.""" start = begin * scale stop = 1 + (end - 1) * scale left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0 right_ramp_scaled = right_ramp * scale - mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True) + mask = compute_trapezoidal_mask_1d( + stop - start, left_ramp_scaled, right_ramp_scaled, True + ) return slice(start, stop), mask -def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: +def map_spatial_slice( + begin: int, end: int, left_ramp: int, right_ramp: int, scale: int +) -> Tuple[slice, mx.array]: """Map spatial latent interval to output coordinates and mask.""" start = begin * scale stop = end * scale left_ramp_scaled = left_ramp * scale right_ramp_scaled = right_ramp * scale - mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False) + mask = compute_trapezoidal_mask_1d( + stop - start, left_ramp_scaled, right_ramp_scaled, False + ) return slice(start, stop), mask @@ -315,7 +373,9 @@ def decode_with_tiling( temporal_overlap = 0 # Compute intervals for each dimension - temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + temporal_intervals = split_in_temporal( + temporal_tile_size, temporal_overlap, f_latent + ) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) @@ -338,7 +398,9 @@ def decode_with_tiling( t_right = temporal_intervals.right_ramps[t_idx] # Map temporal coordinates - out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + out_t_slice, t_mask = map_temporal_slice( + t_start, t_end, t_left, t_right, temporal_scale + ) for h_idx in range(num_h_tiles): h_start = height_intervals.starts[h_idx] @@ -347,7 +409,9 @@ def decode_with_tiling( h_right = height_intervals.right_ramps[h_idx] # Map height coordinates - out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) + out_h_slice, h_mask = map_spatial_slice( + h_start, h_end, h_left, h_right, spatial_scale + ) for w_idx in range(num_w_tiles): w_start = width_intervals.starts[w_idx] @@ -356,13 +420,23 @@ def decode_with_tiling( w_right = width_intervals.right_ramps[w_idx] # Map width coordinates - out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) + out_w_slice, w_mask = map_spatial_slice( + w_start, w_end, w_left, w_right, spatial_scale + ) # Extract tile latents (small slice) - tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] + tile_latents = latents[ + :, :, t_start:t_end, h_start:h_end, w_start:w_end + ] # Decode tile - tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) + tile_output = decoder_fn( + tile_latents, + causal=causal, + timestep=timestep, + debug=False, + chunked_conv=chunked_conv, + ) mx.eval(tile_output) # Clear tile_latents reference @@ -385,13 +459,15 @@ def decode_with_tiling( w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask blend_mask = ( - t_mask_slice.reshape(1, 1, -1, 1, 1) * - h_mask_slice.reshape(1, 1, 1, -1, 1) * - w_mask_slice.reshape(1, 1, 1, 1, -1) + t_mask_slice.reshape(1, 1, -1, 1, 1) + * h_mask_slice.reshape(1, 1, 1, -1, 1) + * w_mask_slice.reshape(1, 1, 1, 1, -1) ) # Slice tile output to match - tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) + tile_output_slice = tile_output[ + :, :, :actual_t, :actual_h, :actual_w + ].astype(mx.float32) # Clear full tile_output del tile_output @@ -409,11 +485,37 @@ def decode_with_tiling( weighted_tile = tile_output_slice * blend_mask # Update output using slice assignment - output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( - output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile + output[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] = ( + output[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] + + weighted_tile ) - weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( - weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask + weights[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] = ( + weights[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] + + blend_mask ) # Force evaluation to free memory @@ -445,10 +547,12 @@ def decode_with_tiling( if next_tile_start_latent == 0: next_tile_start_out = 0 else: - next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + next_tile_start_out = ( + 1 + (next_tile_start_latent - 1) * temporal_scale + ) # We need to track how many frames we've already emitted - if not hasattr(decode_with_tiling, '_emitted_frames'): + if not hasattr(decode_with_tiling, "_emitted_frames"): decode_with_tiling._emitted_frames = 0 emitted = decode_with_tiling._emitted_frames @@ -456,7 +560,10 @@ def decode_with_tiling( # Normalize and emit frames [emitted, next_tile_start_out) finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] finalized_weights = mx.maximum(finalized_weights, 1e-8) - finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights + finalized_output = ( + output[:, :, emitted:next_tile_start_out, :, :] + / finalized_weights + ) finalized_output = finalized_output.astype(latents.dtype) mx.eval(finalized_output) @@ -473,7 +580,7 @@ def decode_with_tiling( # Emit remaining frames if callback provided if on_frames_ready is not None: - emitted = getattr(decode_with_tiling, '_emitted_frames', 0) + emitted = getattr(decode_with_tiling, "_emitted_frames", 0) if emitted < out_f: remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) mx.eval(remaining_output) @@ -481,7 +588,7 @@ def decode_with_tiling( del remaining_output # Reset emitted frames counter for next call - if hasattr(decode_with_tiling, '_emitted_frames'): + if hasattr(decode_with_tiling, "_emitted_frames"): del decode_with_tiling._emitted_frames # Clean up weights diff --git a/mlx_video/models/ltx_2/video_vae/video_vae.py b/mlx_video/models/ltx_2/video_vae/video_vae.py index 45a447d..bd85086 100644 --- a/mlx_video/models/ltx_2/video_vae/video_vae.py +++ b/mlx_video/models/ltx_2/video_vae/video_vae.py @@ -8,12 +8,15 @@ import mlx.core as mx import mlx.nn as nn from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType -from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify +from mlx_video.models.ltx_2.video_vae.ops import ( + PerChannelStatistics, + patchify, + unpatchify, +) from mlx_video.models.ltx_2.video_vae.resnet import ( NormLayerType, ResnetBlock3D, UNetMidBlock3D, - get_norm_layer, ) from mlx_video.models.ltx_2.video_vae.sampling import ( DepthToSpaceUpsample, @@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm class LogVarianceType(Enum): """Log variance mode for VAE.""" + PER_CHANNEL = "per_channel" UNIFORM = "uniform" CONSTANT = "constant" @@ -229,7 +233,6 @@ class VideoEncoder(nn.Module): config: VideoEncoderModelConfig with encoder parameters """ super().__init__() - from mlx_video.models.ltx_2.config import VideoEncoderModelConfig self.patch_size = config.patch_size self.norm_layer = config.norm_layer @@ -241,10 +244,12 @@ class VideoEncoder(nn.Module): encoder_spatial_padding_mode = config.encoder_spatial_padding_mode # Per-channel statistics for normalizing latents - self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels) + self.per_channel_statistics = PerChannelStatistics( + latent_channels=config.out_channels + ) # After patchify, channels increase by patch_size^2 - in_channels = config.in_channels * config.patch_size ** 2 + in_channels = config.in_channels * config.patch_size**2 feature_channels = config.out_channels # Initial convolution @@ -262,7 +267,11 @@ class VideoEncoder(nn.Module): # Use dict with int keys for MLX to track parameters (lists are NOT tracked) self.down_blocks = {} for idx, (block_name, block_params) in enumerate(encoder_blocks): - block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + block_config = ( + {"num_layers": block_params} + if isinstance(block_params, int) + else block_params + ) block, feature_channels = _make_encoder_block( block_name=block_name, @@ -291,7 +300,10 @@ class VideoEncoder(nn.Module): conv_out_channels = config.out_channels if config.latent_log_var == LogVarianceType.PER_CHANNEL: conv_out_channels *= 2 - elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + elif config.latent_log_var in { + LogVarianceType.UNIFORM, + LogVarianceType.CONSTANT, + }: conv_out_channels += 1 self.conv_out = CausalConv3d( @@ -349,13 +361,16 @@ class VideoEncoder(nn.Module): elif self.latent_log_var == LogVarianceType.CONSTANT: sample = sample[:, :-1, ...] approx_ln_0 = -30 - sample = mx.concatenate([ - sample, - mx.full_like(sample, approx_ln_0), - ], axis=1) + sample = mx.concatenate( + [ + sample, + mx.full_like(sample, approx_ln_0), + ], + axis=1, + ) # Split into means and logvar, normalize means - means = sample[:, :self.latent_channels, ...] + means = sample[:, : self.latent_channels, ...] return self.per_channel_statistics.normalize(means) def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: @@ -409,6 +424,7 @@ class VideoEncoder(nn.Module): Loaded VideoEncoder instance """ import json + from mlx_video.models.ltx_2.config import VideoEncoderModelConfig # Load config @@ -474,7 +490,7 @@ class VideoDecoder(nn.Module): decoder_blocks = [] self.patch_size = patch_size - out_channels = out_channels * patch_size ** 2 + out_channels = out_channels * patch_size**2 self.causal = causal self.timestep_conditioning = timestep_conditioning self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS @@ -510,7 +526,11 @@ class VideoDecoder(nn.Module): # Use dict with int keys for MLX to track parameters (lists are NOT tracked) self.up_blocks = {} for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)): - block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + block_config = ( + {"num_layers": block_params} + if isinstance(block_params, int) + else block_params + ) block, feature_channels = _make_decoder_block( block_name=block_name, diff --git a/mlx_video/models/wan2/attention.py b/mlx_video/models/wan2/attention.py index b0a6f2f..36f7608 100644 --- a/mlx_video/models/wan2/attention.py +++ b/mlx_video/models/wan2/attention.py @@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module): v = self.v(x_w).reshape(b, s, n, d) # RoPE in float32 for precision (official uses float64) - q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) - k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) + q = rope_apply( + q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin + ) + k = rope_apply( + k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin + ) # Cast back to weight dtype for efficient attention (matching official q.to(v.dtype)) q = q.astype(w_dtype).transpose(0, 2, 1, 3) @@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module): q, k, v, scale=self.scale, mask=mask ) else: - out = mx.fast.scaled_dot_product_attention( - q, k, v, scale=self.scale - ) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) out = out.transpose(0, 2, 1, 3).reshape(b, s, -1) return self.o(out) @@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module): q, k, v, scale=self.scale, mask=mask ) else: - out = mx.fast.scaled_dot_product_attention( - q, k, v, scale=self.scale - ) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d) return self.o(out) diff --git a/mlx_video/models/wan2/convert.py b/mlx_video/models/wan2/convert.py index 5636565..657eee7 100644 --- a/mlx_video/models/wan2/convert.py +++ b/mlx_video/models/wan2/convert.py @@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.utils -import numpy as np logger = logging.getLogger(__name__) @@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]: return weights -def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: +def sanitize_wan_transformer_weights( + weights: Dict[str, mx.array] +) -> Dict[str, mx.array]: """Convert Wan2.2 transformer weight keys to MLX model structure. Wan2.2 keys follow the pattern: @@ -246,8 +247,8 @@ def _load_lora_configs( Shared between weight-merging and runtime-wrapping paths. """ - from mlx_video.lora import LoRAConfig, load_multiple_loras from mlx_video.generate_wan import Colors + from mlx_video.lora import LoRAConfig, load_multiple_loras print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") @@ -264,7 +265,9 @@ def _load_lora_configs( module_to_loras = load_multiple_loras(configs) if not module_to_loras: - print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}") + print( + f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}" + ) return module_to_loras @@ -279,8 +282,8 @@ def load_and_apply_loras( For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). """ - from mlx_video.lora import apply_loras_to_weights from mlx_video.generate_wan import Colors + from mlx_video.lora import apply_loras_to_weights if not lora_configs: return model_weights @@ -289,12 +292,17 @@ def load_and_apply_loras( if not module_to_loras: return model_weights - print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}") + print( + f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}" + ) if verbose: print(f" Model has {len(model_weights)} weight keys") modified_weights = apply_loras_to_weights( - model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits + model_weights, + module_to_loras, + verbose=verbose, + quantization_bits=quantization_bits, ) print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}") @@ -435,8 +443,10 @@ def convert_wan_checkpoint( src_model_type = src_config.get("model_type", "t2v") src_text_len = src_config.get("text_len", 512) - print(f" Source config: dim={src_dim}, layers={src_num_layers}, " - f"heads={src_num_heads}, type={src_model_type}") + print( + f" Source config: dim={src_dim}, layers={src_num_layers}, " + f"heads={src_num_heads}, type={src_model_type}" + ) # Use preset for known TI2V 5B configuration if src_model_type == "ti2v" and src_dim == 3072: @@ -513,8 +523,11 @@ def convert_wan_checkpoint( weights = load_torch_weights(str(vae_path)) if is_wan22_vae: from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + include_encoder = config.model_type in ("ti2v", "i2v") - weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder) + weights = sanitize_wan22_vae_weights( + weights, include_encoder=include_encoder + ) else: weights = sanitize_wan_vae_weights(weights) # Always save VAE in float32 — official Wan2.2 runs VAE decode in @@ -527,7 +540,9 @@ def convert_wan_checkpoint( # Quantize transformer weights if requested if quantize: - print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...") + print( + f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..." + ) _quantize_saved_model(output_dir, config, is_dual, bits, group_size) print(f"\nConversion complete! Output: {output_dir}") @@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool: return False # Quantize attention Q/K/V/O and FFN fc1/fc2 quantize_patterns = ( - ".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o", - ".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o", - ".ffn.fc1", ".ffn.fc2", + ".self_attn.q", + ".self_attn.k", + ".self_attn.v", + ".self_attn.o", + ".cross_attn.q", + ".cross_attn.k", + ".cross_attn.v", + ".cross_attn.o", + ".ffn.fc1", + ".ffn.fc2", ) return any(path.endswith(p) for p in quantize_patterns) @@ -684,14 +706,20 @@ def quantize_mlx_model( # Build model config from mlx_video.models.wan.config import WanModelConfig - config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__} + config_dict = { + k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__ + } for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"): if key in config_dict and isinstance(config_dict[key], list): config_dict[key] = tuple(config_dict[key]) config = WanModelConfig(**config_dict) # Copy non-transformer files to output dir (skip large model weights) - transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"} + transformer_files = { + "low_noise_model.safetensors", + "high_noise_model.safetensors", + "model.safetensors", + } if dst.resolve() != src.resolve(): dst.mkdir(parents=True, exist_ok=True) for f in src.iterdir(): @@ -763,11 +791,18 @@ if __name__ == "__main__": if args.quantize_only: quantize_mlx_model( - args.checkpoint_dir, args.output_dir, - bits=args.bits, group_size=args.group_size, + args.checkpoint_dir, + args.output_dir, + bits=args.bits, + group_size=args.group_size, ) else: convert_wan_checkpoint( - args.checkpoint_dir, args.output_dir, args.dtype, args.model_version, - quantize=args.quantize, bits=args.bits, group_size=args.group_size, + args.checkpoint_dir, + args.output_dir, + args.dtype, + args.model_version, + quantize=args.quantize, + bits=args.bits, + group_size=args.group_size, ) diff --git a/mlx_video/models/wan2/generate.py b/mlx_video/models/wan2/generate.py index cc5d895..789a78d 100644 --- a/mlx_video/models/wan2/generate.py +++ b/mlx_video/models/wan2/generate.py @@ -4,18 +4,15 @@ import argparse import gc import math import random -import sys import time from pathlib import Path import mlx.core as mx -import mlx.nn as nn import numpy as np from tqdm import tqdm from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image from mlx_video.models.wan.loading import ( - _clean_text, encode_text, load_t5_encoder, load_vae_decoder, @@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import ( ) from mlx_video.models.wan.postprocess import save_video + class Colors: """ANSI color codes for terminal output.""" @@ -37,6 +35,7 @@ class Colors: DIM = "\033[2m" RESET = "\033[0m" + # Backward-compat alias (tests and external code may use the old name) _build_i2v_mask = build_i2v_mask @@ -143,10 +142,13 @@ def generate_video( for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"): if key in config_dict and isinstance(config_dict[key], list): config_dict[key] = tuple(config_dict[key]) - config = WanModelConfig(**{ - k: v for k, v in config_dict.items() - if k in WanModelConfig.__dataclass_fields__ - }) + config = WanModelConfig( + **{ + k: v + for k, v in config_dict.items() + if k in WanModelConfig.__dataclass_fields__ + } + ) else: # Auto-detect: dual model files → 2.2, single model → 2.1 if (model_dir / "low_noise_model.safetensors").exists(): @@ -182,7 +184,9 @@ def generate_video( if "patch_embedding_proj.weight" in k: actual_dim = v.shape[0] if actual_dim != config.dim: - print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}") + print( + f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}" + ) if actual_dim <= 2048: config = WanModelConfig.wan21_t2v_1_3b() else: @@ -192,13 +196,20 @@ def generate_video( # Auto-correct Wan2.2 VAE params from stale configs if config.in_dim == 48 and config.vae_z_dim != 48: - print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}") - config = WanModelConfig(**{ - **{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()}, - "vae_z_dim": 48, - "vae_stride": (4, 16, 16), - "sample_fps": 24, - }) + print( + f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}" + ) + config = WanModelConfig( + **{ + **{ + f.name: getattr(config, f.name) + for f in config.__dataclass_fields__.values() + }, + "vae_z_dim": 48, + "vae_stride": (4, 16, 16), + "sample_fps": 24, + } + ) # Apply defaults from config if not overridden if steps is None: @@ -227,7 +238,9 @@ def generate_video( gen_frames = num_frames if trim_first_frames > 0: gen_frames = num_frames + trim_first_frames * 4 - print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}") + print( + f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}" + ) version_str = f"Wan{config.model_version}" mode_str = "dual-model" if is_dual else "single-model" @@ -247,10 +260,16 @@ def generate_video( if is_i2v: print(f" Image: {image}") if neg_prompt_resolved and neg_prompt_resolved.strip(): - neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved + neg_display = ( + neg_prompt_resolved[:60] + "..." + if len(neg_prompt_resolved) > 60 + else neg_prompt_resolved + ) print(f" Neg prompt: {neg_display}") print(f" Size: {width}x{height}, Frames: {num_frames}") - print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}") + print( + f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}" + ) if cfg_disabled: print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)") print(f"{Colors.RESET}") @@ -275,12 +294,16 @@ def generate_video( height = align_h if width == 0: width = align_w - print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}") + print( + f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}" + ) # Enforce max_area constraint (model-specific resolution limit) if config.max_area > 0 and height * width > config.max_area: old_h, old_w = height, width - width, height = _best_output_size(width, height, align_w, align_h, config.max_area) + width, height = _best_output_size( + width, height, align_w, align_h, config.max_area + ) print( f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area " f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}" @@ -309,6 +332,7 @@ def generate_video( # Load tokenizer from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") # Encode prompts @@ -318,12 +342,15 @@ def generate_video( context_null = None mx.eval(context) else: - context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) + context_null = encode_text( + t5_encoder, tokenizer, neg_prompt_resolved, config.text_len + ) mx.eval(context, context_null) # Free T5 from memory del t5_encoder - gc.collect(); mx.clear_cache() + gc.collect() + mx.clear_cache() print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}") # I2V: encode image to latent space @@ -346,18 +373,25 @@ def generate_video( img = Image.open(image).convert("RGB") scale = max(width / img.width, height / img.height) - img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS) + img = img.resize( + (round(img.width * scale), round(img.height * scale)), Image.LANCZOS + ) x1, y1 = (img.width - width) // 2, (img.height - height) // 2 img = img.crop((x1, y1, x1 + width, y1 + height)) - img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3] + img_arr = mx.array( + np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0 + ) # [H, W, 3] img_chw = img_arr.transpose(2, 0, 1) # [3, H, W] # Build video: first frame = image, rest = zeros -> [3, F, H, W] # Chunked encoding processes 1-frame + 4-frame chunks with temporal caching - video = mx.concatenate([ - img_chw[:, None, :, :], - mx.zeros((3, num_frames - 1, height, width)), - ], axis=1) + video = mx.concatenate( + [ + img_chw[:, None, :, :], + mx.zeros((3, num_frames - 1, height, width)), + ], + axis=1, + ) # Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat] vae_enc = load_vae_encoder(vae_path, config) @@ -367,12 +401,17 @@ def generate_video( # Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W] msk = mx.ones((1, num_frames, h_latent, w_latent)) - msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate( + [msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1 + ) # Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat] - msk = mx.concatenate([ - mx.repeat(msk[:, :1], 4, axis=1), - msk[:, 1:], - ], axis=1) + msk = mx.concatenate( + [ + mx.repeat(msk[:, :1], 4, axis=1), + msk[:, 1:], + ], + axis=1, + ) # Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat] msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] @@ -395,13 +434,16 @@ def generate_video( del vae_enc, img_tensor - gc.collect(); mx.clear_cache() + gc.collect() + mx.clear_cache() print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}") # Load transformer models print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}") if quantization: - print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}") + print( + f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}" + ) t2 = time.time() # Merge per-model LoRAs with shared LoRAs @@ -412,10 +454,16 @@ def generate_video( if is_dual: low_noise_path = model_dir / "low_noise_model.safetensors" high_noise_path = model_dir / "high_noise_model.safetensors" - low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low) - high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high) + low_noise_model = load_wan_model( + low_noise_path, config, quantization, loras=_loras_low + ) + high_noise_model = load_wan_model( + high_noise_path, config, quantization, loras=_loras_high + ) else: - single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single) + single_model = load_wan_model( + model_dir / "model.safetensors", config, quantization, loras=_loras_single + ) print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") # Precompute text embeddings once (avoids redundant MLP in every step) @@ -437,8 +485,12 @@ def generate_video( context_emb_low = low_noise_model.embed_text([context, context_null]) context_emb_high = high_noise_model.embed_text([context, context_null]) mx.eval(context_emb_low, context_emb_high) - context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0) - context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0) + context_cfg_low = mx.concatenate( + [context_emb_low[0:1], context_emb_low[1:2]], axis=0 + ) + context_cfg_high = mx.concatenate( + [context_emb_high[0:1], context_emb_high[1:2]], axis=0 + ) else: context_emb = single_model.embed_text([context, context_null]) mx.eval(context_emb) @@ -534,7 +586,7 @@ def generate_video( rcs = rope_cos_sin # Use compiled forward when available (faster after first trace) - _call = getattr(model, '_compiled', model) + _call = getattr(model, "_compiled", model) if cfg_disabled: # No CFG: B=1 forward pass (2x faster than B=2 CFG batch) @@ -552,7 +604,9 @@ def generate_video( y_arg = [y_i2v] if is_i2v_channel_concat else None if is_dual: - ctx = context_cond_high if timestep_val >= boundary else context_cond_low + ctx = ( + context_cond_high if timestep_val >= boundary else context_cond_low + ) else: ctx = context_cond preds = _call( @@ -571,7 +625,11 @@ def generate_video( if is_dual: gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0] else: - gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] + gs = ( + guide_scale + if isinstance(guide_scale, (int, float)) + else guide_scale[0] + ) if is_i2v_mask_blend: t_tokens = i2v_mask_tokens * timestep_val @@ -586,8 +644,10 @@ def generate_video( y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None - ctx = context_cfg if not is_dual else ( - context_cfg_high if timestep_val >= boundary else context_cfg_low + ctx = ( + context_cfg + if not is_dual + else (context_cfg_high if timestep_val >= boundary else context_cfg_low) ) preds = _call( [latents, latents], @@ -618,16 +678,24 @@ def generate_video( if debug_latents: lat_np = np.array(latents) # [C, T, H, W] n_t = lat_np.shape[1] - print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}") - print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}") + print( + f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}" + ) + print( + f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}" + ) for t_pos in range(min(n_t, 8)): frame = lat_np[:, t_pos, :, :] - print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} " - f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}") + print( + f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} " + f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}" + ) if n_t > 8: interior = lat_np[:, 4:, :, :] - print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} " - f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}") + print( + f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} " + f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}" + ) print() # Free transformer models and text embeddings @@ -646,7 +714,8 @@ def generate_video( del model, kv, context if context_null is not None: del context_null - gc.collect(); mx.clear_cache() + gc.collect() + mx.clear_cache() # Load VAE and decode print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}") @@ -677,13 +746,25 @@ def generate_video( elif tiling == "temporal": tiling_config = TilingConfig.temporal_only() else: - print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + print( + f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}" + ) tiling_config = TilingConfig.auto(height, width, num_frames) if tiling_config is not None: - spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" - temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" - print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + spatial_info = ( + f"{tiling_config.spatial_config.tile_size_in_pixels}px" + if tiling_config.spatial_config + else "none" + ) + temporal_info = ( + f"{tiling_config.temporal_config.tile_size_in_frames}f" + if tiling_config.temporal_config + else "none" + ) + print( + f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}" + ) if is_wan22_vae: from mlx_video.models.wan.vae22 import denormalize_latents @@ -718,7 +799,9 @@ def generate_video( if trim_first_frames > 0: trim_pixels = trim_first_frames * 4 video = video[trim_pixels:] - print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}") + print( + f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}" + ) save_video(video, output_path, fps=config.sample_fps) print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}") @@ -727,58 +810,124 @@ def generate_video( def main(): parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)") - parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory") - parser.add_argument("--prompt", type=str, required=True, help="Text prompt") - parser.add_argument("--image", type=str, default=None, - help="Path to input image for I2V (omit for T2V mode)") - parser.add_argument("--negative-prompt", type=str, default=None, - help="Negative prompt for CFG (default: official Chinese prompt from config)") - parser.add_argument("--no-negative-prompt", action="store_true", - help="Disable negative prompt (use empty string instead of config default)") - parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)") - parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)") - parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)") - parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)") - parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair") - parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)") - parser.add_argument("--seed", type=int, default=-1, help="Random seed") - parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path") parser.add_argument( - "--scheduler", type=str, default="unipc", + "--model-dir", + type=str, + required=True, + help="Path to converted MLX model directory", + ) + parser.add_argument("--prompt", type=str, required=True, help="Text prompt") + parser.add_argument( + "--image", + type=str, + default=None, + help="Path to input image for I2V (omit for T2V mode)", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=None, + help="Negative prompt for CFG (default: official Chinese prompt from config)", + ) + parser.add_argument( + "--no-negative-prompt", + action="store_true", + help="Disable negative prompt (use empty string instead of config default)", + ) + parser.add_argument( + "--width", type=int, default=1280, help="Video width (default: 1280)" + ) + parser.add_argument( + "--height", + type=int, + default=704, + help="Video height (default: 704; 720p models use 704)", + ) + parser.add_argument( + "--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)" + ) + parser.add_argument( + "--steps", + type=int, + default=None, + help="Number of diffusion steps (default: from config)", + ) + parser.add_argument( + "--guide-scale", + type=str, + default=None, + help="Guidance scale: single float or low,high pair", + ) + parser.add_argument( + "--shift", + type=float, + default=None, + help="Noise schedule shift (default: from config)", + ) + parser.add_argument("--seed", type=int, default=-1, help="Random seed") + parser.add_argument( + "--output-path", type=str, default="output.mp4", help="Output video path" + ) + parser.add_argument( + "--scheduler", + type=str, + default="unipc", choices=["euler", "dpm++", "unipc"], help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)", ) parser.add_argument( - "--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + "--lora", + nargs=2, + action="append", + metavar=("PATH", "STRENGTH"), help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8", ) parser.add_argument( - "--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + "--lora-high", + nargs=2, + action="append", + metavar=("PATH", "STRENGTH"), help="Apply a LoRA to high-noise model only (dual-model, repeatable)", ) parser.add_argument( - "--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + "--lora-low", + nargs=2, + action="append", + metavar=("PATH", "STRENGTH"), help="Apply a LoRA to low-noise model only (dual-model, repeatable)", ) parser.add_argument( "--tiling", type=str, default="auto", - choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + choices=[ + "auto", + "none", + "default", + "aggressive", + "conservative", + "spatial", + "temporal", + ], help="VAE tiling mode to reduce memory during decoding (default: auto)", ) parser.add_argument( - "--no-compile", action="store_true", + "--no-compile", + action="store_true", help="Disable mx.compile on models (for debugging)", ) parser.add_argument( - "--trim-first-frames", type=int, default=0, metavar="N", + "--trim-first-frames", + type=int, + default=0, + metavar="N", help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. " - "Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). " - "Default: 0 (disabled)", + "Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). " + "Default: 0 (disabled)", ) parser.add_argument( - "--debug-latents", action="store_true", + "--debug-latents", + action="store_true", help="Print per-temporal-position latent statistics after denoising (diagnostic)", ) args = parser.parse_args() diff --git a/mlx_video/models/wan2/i2v_utils.py b/mlx_video/models/wan2/i2v_utils.py index 98a4752..0558130 100644 --- a/mlx_video/models/wan2/i2v_utils.py +++ b/mlx_video/models/wan2/i2v_utils.py @@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array: # Resize so that the image covers the target size (LANCZOS) scale = max(width / img.width, height / img.height) - img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS) + img = img.resize( + (round(img.width * scale), round(img.height * scale)), Image.LANCZOS + ) # Center crop x1 = (img.width - width) // 2 diff --git a/mlx_video/models/wan2/loading.py b/mlx_video/models/wan2/loading.py index 35e3d12..e83b0de 100644 --- a/mlx_video/models/wan2/loading.py +++ b/mlx_video/models/wan2/loading.py @@ -6,7 +6,12 @@ import mlx.core as mx import mlx.nn as nn -def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None): +def load_wan_model( + model_path: Path, + config, + quantization: dict | None = None, + loras: list | None = None, +): """Load and initialize WanModel, with optional quantization and LoRA support. Args: @@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None): if is_wan22: from mlx_video.models.wan.vae22 import Wan22VAEDecoder + vae = Wan22VAEDecoder(z_dim=48) else: from mlx_video.models.wan.vae import WanVAE + vae = WanVAE(z_dim=16) weights = mx.load(str(model_path)) @@ -140,6 +147,7 @@ def _clean_text(text: str) -> str: try: import ftfy + text = ftfy.fix_text(text) except ImportError: pass diff --git a/mlx_video/models/wan2/model.py b/mlx_video/models/wan2/model.py index 989e712..6684537 100644 --- a/mlx_video/models/wan2/model.py +++ b/mlx_video/models/wan2/model.py @@ -1,4 +1,5 @@ import math + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -37,7 +38,9 @@ class Head(nn.Module): proj_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, proj_dim) - self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32) + self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype( + mx.float32 + ) def __call__(self, x: mx.array, e: mx.array) -> mx.array: """ @@ -111,20 +114,23 @@ class WanModel(nn.Module): # Reference computes three rope_params with different dim normalizations # so each axis (temporal/height/width) gets its own full frequency range. d = dim // config.num_heads - self.freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + self.freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) # Precompute sinusoidal inv_freq for time embedding. half = config.freq_dim // 2 self._inv_freq = mx.array( - np.power(10000.0, -np.arange(half, dtype=np.float64) / half - ).astype(np.float32) + np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype( + np.float32 + ) ) - def _patchify(self, x: mx.array) -> tuple: """Convert video tensor to patch embeddings. @@ -297,12 +303,19 @@ class WanModel(nn.Module): seq_lens_list.append(p.shape[1]) x = mx.concatenate( [ - mx.concatenate( - [p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], - axis=1, + ( + mx.concatenate( + [ + p, + mx.zeros( + (1, seq_len - p.shape[1], self.dim), dtype=p.dtype + ), + ], + axis=1, + ) + if p.shape[1] < seq_len + else p ) - if p.shape[1] < seq_len - else p for p in patches ], axis=0, @@ -315,9 +328,7 @@ class WanModel(nn.Module): t = t[None] sinusoid = t[..., None].astype(mx.float32) * self._inv_freq - sin_emb = mx.concatenate( - [mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1 - ) + sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1) if t.ndim == 1: # Standard T2V: scalar timestep per batch element [B] diff --git a/mlx_video/models/wan2/postprocess.py b/mlx_video/models/wan2/postprocess.py index 4c24fc6..f916a0f 100644 --- a/mlx_video/models/wan2/postprocess.py +++ b/mlx_video/models/wan2/postprocess.py @@ -1,6 +1,8 @@ -import numpy as np from pathlib import Path +import numpy as np + + def save_video(frames: np.ndarray, output_path: str, fps: int = 16): """Save video frames to MP4. @@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16): """ try: import imageio + writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8) for frame in frames: writer.append_data(frame) @@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16): except ImportError: try: import cv2 + h, w = frames.shape[1], frames.shape[2] fourcc = cv2.VideoWriter_fourcc(*"avc1") writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) @@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16): except (ImportError, Exception): # Last resort: save as individual PNGs from PIL import Image + out_dir = Path(output_path).parent / Path(output_path).stem out_dir.mkdir(parents=True, exist_ok=True) for i, frame in enumerate(frames): Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png") - print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)") - + print( + f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)" + ) diff --git a/mlx_video/models/wan2/rope.py b/mlx_video/models/wan2/rope.py index d992607..1ad93ae 100644 --- a/mlx_video/models/wan2/rope.py +++ b/mlx_video/models/wan2/rope.py @@ -1,4 +1,3 @@ -import math import mlx.core as mx import numpy as np @@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array: Complex frequency tensor of shape [max_seq_len, dim // 2]. """ assert dim % 2 == 0 - freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * ( - 1.0 - / np.power( - theta, - np.arange(0, dim, 2, dtype=np.float64) / dim, - ) - )[None, :] + freqs = ( + np.arange(max_seq_len, dtype=np.float64)[:, None] + * ( + 1.0 + / np.power( + theta, + np.arange(0, dim, 2, dtype=np.float64) / dim, + ) + )[None, :] + ) # Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2] cos_freqs = np.cos(freqs).astype(np.float32) sin_freqs = np.sin(freqs).astype(np.float32) @@ -46,9 +48,9 @@ def rope_apply( # Check if all batch elements have the same grid (common for CFG B=2) f0, h0, w0 = grid_sizes[0] seq_len = f0 * h0 * w0 - all_same_grid = all( - grid_sizes[i] == grid_sizes[0] for i in range(1, b) - ) if b > 1 else True + all_same_grid = ( + all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True + ) if all_same_grid: # Vectorized path: apply RoPE to all batch elements at once @@ -57,7 +59,9 @@ def rope_apply( x_imag = x_seq[..., 1] out_real = x_real * cos_f - x_imag * sin_f out_imag = x_real * sin_f + x_imag * cos_f - x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d) + x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape( + b, seq_len, n, d + ) if seq_len < s: x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1) return x_rotated @@ -102,17 +106,11 @@ def rope_apply( # Build per-position frequencies by expanding along grid dims # temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2] - ft = mx.broadcast_to( - freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2) - ) + ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)) # height: [1,h,1,d_h,2] -> [f,h,w,d_h,2] - fh = mx.broadcast_to( - freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2) - ) + fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)) # width: [1,1,w,d_w,2] -> [f,h,w,d_w,2] - fw = mx.broadcast_to( - freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2) - ) + fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)) # Concatenate: [f*h*w, half_d, 2] freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2) diff --git a/mlx_video/models/wan2/scheduler.py b/mlx_video/models/wan2/scheduler.py index 15de21b..067b14e 100644 --- a/mlx_video/models/wan2/scheduler.py +++ b/mlx_video/models/wan2/scheduler.py @@ -7,9 +7,8 @@ for the same quality as Euler. import math -import numpy as np - import mlx.core as mx +import numpy as np def _compute_sigmas( @@ -25,9 +24,7 @@ def _compute_sigmas( Returns num_steps+1 values (the last being 0.0 for the terminal state). """ # sigma bounds from unshifted training schedule (constructor uses shift=1) - alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[ - ::-1 - ] + alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1] sigmas_unshifted = 1.0 - alphas sigma_max = float(sigmas_unshifted[0]) # (N-1)/N sigma_min = float(sigmas_unshifted[-1]) # 0.0 @@ -65,7 +62,10 @@ class FlowMatchEulerScheduler: sample: mx.array, ) -> mx.array: """Euler step: x_next = x + (sigma_next - sigma_cur) * v.""" - dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index] + dt = ( + self._sigmas_float[self._step_index + 1] + - self._sigmas_float[self._step_index] + ) x_next = sample + dt * model_output self._step_index += 1 return x_next @@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler: # Decide order: 1st for first step, last step (if lower_order_final # and few steps), otherwise 2nd - use_first_order = ( - self._prev_x0 is None - or ( - self.lower_order_final - and i == self._num_steps - 1 - and self._num_steps < 15 - ) + use_first_order = self._prev_x0 is None or ( + self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15 ) if use_first_order or sigma_next == 0.0: diff --git a/mlx_video/models/wan2/text_encoder.py b/mlx_video/models/wan2/text_encoder.py index b81a072..604e63e 100644 --- a/mlx_video/models/wan2/text_encoder.py +++ b/mlx_video/models/wan2/text_encoder.py @@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module): is_small = rel_pos < max_exact rel_pos_f = rel_pos.astype(mx.float32) - rel_pos_large = ( - max_exact - + ( - mx.log(rel_pos_f / max_exact) - / math.log(self.max_dist / max_exact) - * (num_buckets - max_exact) - ).astype(mx.int32) - ) + rel_pos_large = max_exact + ( + mx.log(rel_pos_f / max_exact) + / math.log(self.max_dist / max_exact) + * (num_buckets - max_exact) + ).astype(mx.int32) rel_pos_large = mx.minimum( rel_pos_large, mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32), ) - rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large) + rel_buckets = rel_buckets + mx.where( + is_small, rel_pos.astype(mx.int32), rel_pos_large + ) return rel_buckets def __call__(self, lq: int, lk: int) -> mx.array: @@ -115,7 +114,7 @@ class T5Attention(nn.Module): v = v.transpose(0, 2, 1, 3) # QK^T (no scaling) — compute in float32 for precision - attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)) + attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2) # Add position bias if pos_bias is not None: diff --git a/mlx_video/models/wan2/tiling.py b/mlx_video/models/wan2/tiling.py index 73f2624..9023c8d 100644 --- a/mlx_video/models/wan2/tiling.py +++ b/mlx_video/models/wan2/tiling.py @@ -75,7 +75,11 @@ def decode_with_tiling( b, c, f_latent, h_latent, w_latent = latents.shape # Compute output shape - out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale) + out_f = ( + (1 + (f_latent - 1) * temporal_scale) + if causal_temporal + else (f_latent * temporal_scale) + ) out_h = h_latent * spatial_scale out_w = w_latent * spatial_scale @@ -98,9 +102,13 @@ def decode_with_tiling( # Compute intervals for each dimension if causal_temporal: - temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + temporal_intervals = split_in_temporal( + temporal_tile_size, temporal_overlap, f_latent + ) else: - temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent) + temporal_intervals = split_in_spatial( + temporal_tile_size, temporal_overlap, f_latent + ) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) @@ -124,9 +132,13 @@ def decode_with_tiling( # Map temporal coordinates if causal_temporal: - out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + out_t_slice, t_mask = map_temporal_slice( + t_start, t_end, t_left, t_right, temporal_scale + ) else: - out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale) + out_t_slice, t_mask = map_spatial_slice( + t_start, t_end, t_left, t_right, temporal_scale + ) for h_idx in range(num_h_tiles): h_start = height_intervals.starts[h_idx] @@ -135,7 +147,9 @@ def decode_with_tiling( h_right = height_intervals.right_ramps[h_idx] # Map height coordinates - out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) + out_h_slice, h_mask = map_spatial_slice( + h_start, h_end, h_left, h_right, spatial_scale + ) for w_idx in range(num_w_tiles): w_start = width_intervals.starts[w_idx] @@ -144,13 +158,23 @@ def decode_with_tiling( w_right = width_intervals.right_ramps[w_idx] # Map width coordinates - out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) + out_w_slice, w_mask = map_spatial_slice( + w_start, w_end, w_left, w_right, spatial_scale + ) # Extract tile latents (small slice) - tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] + tile_latents = latents[ + :, :, t_start:t_end, h_start:h_end, w_start:w_end + ] # Decode tile - tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) + tile_output = decoder_fn( + tile_latents, + causal=causal, + timestep=timestep, + debug=False, + chunked_conv=chunked_conv, + ) mx.eval(tile_output) # Clear tile_latents reference @@ -173,13 +197,15 @@ def decode_with_tiling( w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask blend_mask = ( - t_mask_slice.reshape(1, 1, -1, 1, 1) * - h_mask_slice.reshape(1, 1, 1, -1, 1) * - w_mask_slice.reshape(1, 1, 1, 1, -1) + t_mask_slice.reshape(1, 1, -1, 1, 1) + * h_mask_slice.reshape(1, 1, 1, -1, 1) + * w_mask_slice.reshape(1, 1, 1, 1, -1) ) # Slice tile output to match - tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) + tile_output_slice = tile_output[ + :, :, :actual_t, :actual_h, :actual_w + ].astype(mx.float32) # Clear full tile_output del tile_output @@ -196,11 +222,37 @@ def decode_with_tiling( weighted_tile = tile_output_slice * blend_mask # Update output using slice assignment - output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( - output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile + output[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] = ( + output[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] + + weighted_tile ) - weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( - weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask + weights[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] = ( + weights[ + :, + :, + t_out_start:t_out_end, + h_out_start:h_out_end, + w_out_start:w_out_end, + ] + + blend_mask ) # Force evaluation to free memory @@ -232,12 +284,14 @@ def decode_with_tiling( if next_tile_start_latent == 0: next_tile_start_out = 0 elif causal_temporal: - next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + next_tile_start_out = ( + 1 + (next_tile_start_latent - 1) * temporal_scale + ) else: next_tile_start_out = next_tile_start_latent * temporal_scale # We need to track how many frames we've already emitted - if not hasattr(decode_with_tiling, '_emitted_frames'): + if not hasattr(decode_with_tiling, "_emitted_frames"): decode_with_tiling._emitted_frames = 0 emitted = decode_with_tiling._emitted_frames @@ -245,7 +299,10 @@ def decode_with_tiling( # Normalize and emit frames [emitted, next_tile_start_out) finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] finalized_weights = mx.maximum(finalized_weights, 1e-8) - finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights + finalized_output = ( + output[:, :, emitted:next_tile_start_out, :, :] + / finalized_weights + ) finalized_output = finalized_output.astype(latents.dtype) mx.eval(finalized_output) @@ -262,7 +319,7 @@ def decode_with_tiling( # Emit remaining frames if callback provided if on_frames_ready is not None: - emitted = getattr(decode_with_tiling, '_emitted_frames', 0) + emitted = getattr(decode_with_tiling, "_emitted_frames", 0) if emitted < out_f: remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) mx.eval(remaining_output) @@ -270,7 +327,7 @@ def decode_with_tiling( del remaining_output # Reset emitted frames counter for next call - if hasattr(decode_with_tiling, '_emitted_frames'): + if hasattr(decode_with_tiling, "_emitted_frames"): del decode_with_tiling._emitted_frames # Clean up weights diff --git a/mlx_video/models/wan2/transformer.py b/mlx_video/models/wan2/transformer.py index 7186b82..ea1c058 100644 --- a/mlx_video/models/wan2/transformer.py +++ b/mlx_video/models/wan2/transformer.py @@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module): # Cross-attention (with optional norm on context) self.norm3 = ( - WanLayerNorm(dim, eps, elementwise_affine=True) - if cross_attn_norm - else None + WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None ) self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps) @@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module): self.ffn = WanFFN(dim, ffn_dim) # Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision) - self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32) + self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype( + mx.float32 + ) def __call__( self, @@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module): # Self-attention with modulation (hidden state stays in w_dtype) x_mod = self.norm1(x) * (1 + e1) + e0 - y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask) + y = self.self_attn( + x_mod, + seq_lens, + grid_sizes, + freqs, + rope_cos_sin=rope_cos_sin, + attn_mask=attn_mask, + ) x = x + y * e2 # Cross-attention (no modulation, just norm) diff --git a/mlx_video/models/wan2/vae.py b/mlx_video/models/wan2/vae.py index faa2372..ecc539a 100644 --- a/mlx_video/models/wan2/vae.py +++ b/mlx_video/models/wan2/vae.py @@ -6,19 +6,45 @@ so weights load directly without key sanitization. import mlx.core as mx import mlx.nn as nn -import numpy as np - CACHE_T = 2 # Per-channel normalization statistics for z_dim=16 VAE_MEAN = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921, + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, ] VAE_STD = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160, + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, ] @@ -50,7 +76,9 @@ class CausalConv3d(nn.Module): self._pad_w = padding[2] # MLX Conv3d: weight shape [O, D, H, W, I] - self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)) + self.weight = mx.zeros( + (out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels) + ) self.bias = mx.zeros((out_channels,)) def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array: @@ -67,8 +95,16 @@ class CausalConv3d(nn.Module): x = mx.concatenate([pad_t, x], axis=2) if self._pad_h > 0 or self._pad_w > 0: - x = mx.pad(x, [(0, 0), (0, 0), (0, 0), - (self._pad_h, self._pad_h), (self._pad_w, self._pad_w)]) + x = mx.pad( + x, + [ + (0, 0), + (0, 0), + (0, 0), + (self._pad_h, self._pad_h), + (self._pad_w, self._pad_w), + ], + ) x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C] out = self._conv3d(x) @@ -118,7 +154,11 @@ class RMS_norm(nn.Module): def __call__(self, x: mx.array) -> mx.array: norm_dim = 1 if self.channel_first else -1 # L2 normalize along channel dim (matches F.normalize) - norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None)) + norm = mx.sqrt( + mx.clip( + mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None + ) + ) return (x / norm) * self.scale * self.gamma @@ -133,12 +173,12 @@ class ResidualBlock(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() self.residual = [ - RMS_norm(in_dim, images=False), # [0] - None, # [1] SiLU + RMS_norm(in_dim, images=False), # [0] + None, # [1] SiLU CausalConv3d(in_dim, out_dim, 3, padding=1), # [2] - RMS_norm(out_dim, images=False), # [3] - None, # [4] SiLU - None, # [5] Dropout + RMS_norm(out_dim, images=False), # [3] + None, # [4] SiLU + None, # [5] Dropout CausalConv3d(out_dim, out_dim, 3, padding=1), # [6] ] self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None @@ -226,13 +266,16 @@ class Resample(nn.Module): # resample.0 = Upsample (no params), resample.1 = Conv2d self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] if mode == "upsample3d": - self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0) + ) else: # resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2) self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)] if mode == "downsample3d": self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: """x: [B, C, T, H, W]""" @@ -272,8 +315,7 @@ class Resample(nn.Module): else: # Subsequent chunks: use cached frame as temporal context cache_x = x[:, :, -1:] - x = self.time_conv( - x, cache_x=feat_cache[idx][:, :, -1:]) + x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: @@ -328,8 +370,8 @@ class Decoder3d(nn.Module): # Output head: [RMS_norm, SiLU (no params), CausalConv3d] self.head = [ - RMS_norm(dims[-1], images=False), # [0] - None, # [1] SiLU + RMS_norm(dims[-1], images=False), # [0] + None, # [1] SiLU CausalConv3d(dims[-1], 3, 3, padding=1), # [2] ] @@ -405,8 +447,7 @@ class Encoder3d(nn.Module): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: - cache_x = mx.concatenate( - [feat_cache[idx][:, :, -1:], cache_x], axis=2) + cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.conv1(x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -431,8 +472,7 @@ class Encoder3d(nn.Module): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: - cache_x = mx.concatenate( - [feat_cache[idx][:, :, -1:], cache_x], axis=2) + cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.head[2](x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -583,7 +623,7 @@ class WanVAE(nn.Module): decoder_fn=tile_decode, latents=z_denorm, tiling_config=tiling_config, - spatial_scale=8, # 3× spatial 2× upsamples = 8× - temporal_scale=4, # 2× temporal upsamples × 2 = 4× + spatial_scale=8, # 3× spatial 2× upsamples = 8× + temporal_scale=4, # 2× temporal upsamples × 2 = 4× causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T) ) diff --git a/mlx_video/models/wan2/vae22.py b/mlx_video/models/wan2/vae22.py index a1b233f..4d26b95 100644 --- a/mlx_video/models/wan2/vae22.py +++ b/mlx_video/models/wan2/vae22.py @@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed. """ import logging -import math import mlx.core as mx import mlx.nn as nn @@ -19,23 +18,111 @@ logger = logging.getLogger(__name__) CACHE_T = 2 # Per-channel normalization for z_dim=48 latent space -VAE22_MEAN = mx.array([ - -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, - -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, - -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, - -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, - -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, - 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, -]) +VAE22_MEAN = mx.array( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ] +) -VAE22_STD = mx.array([ - 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, - 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, - 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, - 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, - 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, - 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744, -]) +VAE22_STD = mx.array( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ] +) class CausalConv3d(nn.Module): @@ -65,9 +152,9 @@ class CausalConv3d(nn.Module): self._pad_w = padding[2] # Weight: [O, D, H, W, I] for MLX - self.weight = mx.zeros(( - out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels - )) + self.weight = mx.zeros( + (out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels) + ) self.bias = mx.zeros((out_channels,)) def __call__(self, x, cache_x=None): @@ -96,8 +183,16 @@ class CausalConv3d(nn.Module): # Spatial padding if self._pad_h > 0 or self._pad_w > 0: - x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h), - (self._pad_w, self._pad_w), (0, 0)]) + x = mx.pad( + x, + [ + (0, 0), + (0, 0), + (self._pad_h, self._pad_h), + (self._pad_w, self._pad_w), + (0, 0), + ], + ) T_padded = x.shape[1] H_padded, W_padded = x.shape[2], x.shape[3] @@ -113,8 +208,9 @@ class CausalConv3d(nn.Module): for d in range(kd): frame = x[:, t_start + d] # [B, H_padded, W_padded, C] w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I] - conv_out = mx.conv_general(frame, w2d, - stride=(self.stride[1], self.stride[2])) + conv_out = mx.conv_general( + frame, w2d, stride=(self.stride[1], self.stride[2]) + ) accum = conv_out if accum is None else accum + conv_out outputs.append(accum + self.bias) @@ -126,7 +222,7 @@ class RMS_norm(nn.Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 # Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze self.gamma = mx.ones((dim,)) @@ -134,7 +230,9 @@ class RMS_norm(nn.Module): # x: [..., C] (channels-last) # PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps) l2_sq = mx.sum(x * x, axis=-1, keepdims=True) - return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma + return ( + x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma + ) class ResidualBlock(nn.Module): @@ -145,11 +243,7 @@ class ResidualBlock(nn.Module): # Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d] # We store as named layers matching PyTorch's indices self.residual = ResidualBlockLayers(in_dim, out_dim) - self.shortcut = ( - CausalConv3d(in_dim, out_dim, 1) - if in_dim != out_dim - else None - ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None def __call__(self, x, feat_cache=None, feat_idx=None): h = self.shortcut(x) if self.shortcut is not None else x @@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module): # Save last CACHE_T frames before conv (for next chunk's context) cache_x = x[:, -CACHE_T:] if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate( - [feat_cache[idx][:, -1:], cache_x], axis=1 - ) + cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1) out = conv(x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -231,7 +323,9 @@ class AttentionBlock(nn.Module): x = self.norm(x) # QKV via 1x1 conv2d (equivalent to linear on last dim) - qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C] + qkv = ( + mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias + ) # [BT, H, W, 3C] qkv = qkv.reshape(B * T, H * W, 3 * C) q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C] @@ -240,8 +334,10 @@ class AttentionBlock(nn.Module): k = k[:, None, :, :] v = v[:, None, :, :] - scale = C ** -0.5 - out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C] + scale = C**-0.5 + out = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) # [BT, 1, HW, C] out = out.squeeze(1).reshape(B * T, H, W, C) # Project output @@ -270,16 +366,24 @@ class DupUp3D(nn.Module): x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats] # Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s] - x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s) + x = x.reshape( + B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s + ) # Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C] x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4) # Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C] - x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels) + x = x.reshape( + B, + T * self.factor_t, + H * self.factor_s, + W * self.factor_s, + self.out_channels, + ) if first_chunk: - x = x[:, self.factor_t - 1:, :, :, :] + x = x[:, self.factor_t - 1 :, :, :, :] return x @@ -348,7 +452,9 @@ class Resample(nn.Module): self.resample_weight = mx.zeros((dim, 3, 3, dim)) self.resample_bias = mx.zeros((dim,)) # time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1)) - self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) else: raise ValueError(f"Unsupported mode: {mode}") @@ -369,7 +475,9 @@ class Resample(nn.Module): """Apply strided Conv2d for downsampling. x: [N, H, W, C].""" # ZeroPad2d((0,1,0,1)): pad right=1, bottom=1 x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) - return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias + return ( + mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias + ) def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None): # x: [B, T, H, W, C] @@ -444,14 +552,17 @@ class Resample(nn.Module): class Up_ResidualBlock(nn.Module): """Upsampling residual block with optional DupUp3D shortcut.""" - def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False): + def __init__( + self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False + ): super().__init__() self.up_flag = up_flag # DupUp3D shortcut (no learnable params) if up_flag: self.avg_shortcut = DupUp3D( - in_dim, out_dim, + in_dim, + out_dim, factor_t=2 if temperal_upsample else 1, factor_s=2 if up_flag else 1, ) @@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module): class Down_ResidualBlock(nn.Module): """Downsampling residual block with AvgDown3D shortcut.""" - def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False): + def __init__( + self, + in_dim, + out_dim, + num_res_blocks, + temperal_downsample=False, + down_flag=False, + ): super().__init__() self.down_flag = down_flag # AvgDown3D shortcut (no learnable params, always present) self.avg_shortcut = AvgDown3D( - in_dim, out_dim, + in_dim, + out_dim, factor_t=2 if temperal_downsample else 1, factor_s=2 if down_flag else 1, ) @@ -562,13 +681,15 @@ class Decoder3d(nn.Module): self.upsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): t_up = temperal_upsample[i] if i < len(temperal_upsample) else False - self.upsamples.append(Up_ResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks + 1, - temperal_upsample=t_up, - up_flag=(i != len(dim_mult) - 1), - )) + self.upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks + 1, + temperal_upsample=t_up, + up_flag=(i != len(dim_mult) - 1), + ) + ) # Output head: [RMS_norm, SiLU, CausalConv3d] self.head = Head22(dims[-1]) @@ -612,13 +733,15 @@ class Encoder3d(nn.Module): for i in range(len(dim_mult)): in_d, out_d = dims[i], dims[i + 1] t_down = temperal_downsample[i] if i < len(temperal_downsample) else False - self.downsamples.append(Down_ResidualBlock( - in_dim=in_d, - out_dim=out_d, - num_res_blocks=num_res_blocks, - temperal_downsample=t_down, - down_flag=(i < len(dim_mult) - 1), - )) + self.downsamples.append( + Down_ResidualBlock( + in_dim=in_d, + out_dim=out_d, + num_res_blocks=num_res_blocks, + temperal_downsample=t_down, + down_flag=(i < len(dim_mult) - 1), + ) + ) # Middle blocks (same as decoder) out_dim = dims[-1] @@ -658,9 +781,7 @@ class Encoder3d(nn.Module): idx = feat_idx[0] cache_x = x[:, -CACHE_T:] if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate( - [feat_cache[idx][:, -1:], cache_x], axis=1 - ) + cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1) x = self.conv1(x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -700,9 +821,7 @@ class Head22(nn.Module): idx = feat_idx[0] cache_x = x[:, -CACHE_T:] if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate( - [feat_cache[idx][:, -1:], cache_x], axis=1 - ) + cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1) x = self.layer_2(x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module): if i == 0: chunk = x[:, :1] else: - chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i] + chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i] chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx) if out is None: out = chunk_out @@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module): # conv1 (pointwise) + split into mu, log_var out = self.conv1(out) - mu = out[:, :, :, :, :self.z_dim] + mu = out[:, :, :, :, : self.z_dim] # Normalize mu = normalize_latents(mu) @@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module): decoder_fn=tile_decode, latents=z_cf, tiling_config=tiling_config, - spatial_scale=16, # 8× conv upsample + 2× unpatchify - temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal) + spatial_scale=16, # 8× conv upsample + 2× unpatchify + temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal) causal_temporal=True, ) diff --git a/mlx_video/utils.py b/mlx_video/utils.py index 2cd8647..eb8903b 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -1,14 +1,15 @@ import math +from functools import partial +from pathlib import Path from typing import Optional, Union import mlx.core as mx import mlx.nn as nn import numpy as np -from functools import partial -from pathlib import Path from huggingface_hub import snapshot_download from PIL import Image + def get_model_path(model_repo: str): """Get or download LTX-2 model path.""" try: @@ -17,15 +18,19 @@ def get_model_path(model_repo: str): return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) except Exception: print("Downloading LTX-2 model weights...") - return Path(snapshot_download( - repo_id=model_repo, - local_files_only=False, - resume_download=True, - allow_patterns=["*.safetensors", "*.json"], - )) + return Path( + snapshot_download( + repo_id=model_repo, + local_files_only=False, + resume_download=True, + allow_patterns=["*.safetensors", "*.json"], + ) + ) + def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): if quantization is not None: + def get_class_predicate(p, m): # Handle custom per layer quantizations if p in quantization: @@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): class_predicate=get_class_predicate, ) -@partial(mx.compile, shapeless=True) + +@partial(mx.compile, shapeless=True) def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps) - @partial(mx.compile, shapeless=True) def to_denoised( - noisy: mx.array, - velocity: mx.array, - sigma: mx.array | float + noisy: mx.array, velocity: mx.array, sigma: mx.array | float ) -> mx.array: """Convert velocity prediction to denoised output. @@ -284,7 +287,9 @@ def prepare_image_for_encoding( if image_np.max() <= 1.0: image_np = (image_np * 255).astype(np.uint8) pil_image = Image.fromarray(image_np) - pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS) + pil_image = pil_image.resize( + (target_width, target_height), Image.Resampling.LANCZOS + ) image = mx.array(np.array(pil_image).astype(np.float32) / 255.0) # Normalize to [-1, 1] diff --git a/mlx_video/version.py b/mlx_video/version.py index b3c06d4..f102a9c 100644 --- a/mlx_video/version.py +++ b/mlx_video/version.py @@ -1 +1 @@ -__version__ = "0.0.1" \ No newline at end of file +__version__ = "0.0.1" diff --git a/scripts/video/compare_videos.py b/scripts/video/compare_videos.py index 1d18804..462e282 100644 --- a/scripts/video/compare_videos.py +++ b/scripts/video/compare_videos.py @@ -170,19 +170,33 @@ def print_report(results, ref_path, test_path): print("AGGREGATE METRICS") print("-" * 40) - print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}") - print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}") - print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}") - print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}") - print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}") + print( + f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}" + ) + print( + f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}" + ) + print( + f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}" + ) + print( + f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}" + ) + print( + f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}" + ) print() print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)") print("-" * 40) print(f" Reference: {results['ref_temporal_coherence']:.2f}") print(f" Test: {results['test_temporal_coherence']:.2f}") - ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10) - print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}") + ratio = results["test_temporal_coherence"] / ( + results["ref_temporal_coherence"] + 1e-10 + ) + print( + f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}" + ) print() # Identify worst frames @@ -190,7 +204,9 @@ def print_report(results, ref_path, test_path): print("-" * 40) worst_idx = np.argsort(psnr)[:5] for i in worst_idx: - print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}") + print( + f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}" + ) print() # Quality assessment @@ -210,7 +226,9 @@ def print_report(results, ref_path, test_path): grade = "Very different" print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})") if mean_psnr < 30: - print(" ⚠ Videos differ significantly — likely a bug or different generation seed") + print( + " ⚠ Videos differ significantly — likely a bug or different generation seed" + ) print("=" * 72) @@ -242,9 +260,7 @@ def main(): parser.add_argument( "--diff-video", help="Save side-by-side diff visualization to this path" ) - parser.add_argument( - "--max-frames", type=int, help="Compare only first N frames" - ) + parser.add_argument("--max-frames", type=int, help="Compare only first N frames") parser.add_argument( "--ssim-win", type=int, default=7, help="SSIM window size (default: 7)" ) @@ -254,26 +270,29 @@ def main(): default=5.0, help="Diff heatmap amplification (default: 5.0)", ) - parser.add_argument( - "--csv", help="Export per-frame metrics to CSV file" - ) + parser.add_argument("--csv", help="Export per-frame metrics to CSV file") args = parser.parse_args() print(f"Loading reference: {args.reference}") ref_frames, ref_fps = load_video(args.reference, args.max_frames) - print(f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}") + print( + f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}" + ) print(f"Loading test: {args.test}") test_frames, test_fps = load_video(args.test, args.max_frames) - print(f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}") + print( + f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}" + ) if ref_frames[0].shape != test_frames[0].shape: - print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}") + print( + f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}" + ) print("Resizing test frames to match reference...") h, w = ref_frames[0].shape[:2] test_frames = [ - cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) - for f in test_frames + cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) for f in test_frames ] print("Computing metrics...") @@ -282,23 +301,29 @@ def main(): print_report(results, args.reference, args.test) if args.diff_video: - save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale) + save_diff_video( + ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale + ) if args.csv: import csv with open(args.csv, "w", newline="") as f: writer = csv.writer(f) - writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]) + writer.writerow( + ["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"] + ) for i in range(results["num_frames"]): - writer.writerow([ - i, - f"{results['psnr'][i]:.4f}", - f"{results['ssim'][i]:.6f}", - f"{results['mean_diff'][i]:.4f}", - f"{results['max_diff'][i]:.1f}", - f"{results['color_dist'][i]:.6f}", - ]) + writer.writerow( + [ + i, + f"{results['psnr'][i]:.4f}", + f"{results['ssim'][i]:.6f}", + f"{results['mean_diff'][i]:.4f}", + f"{results['max_diff'][i]:.1f}", + f"{results['color_dist'][i]:.6f}", + ] + ) print(f"Per-frame metrics saved to {args.csv}") diff --git a/scripts/video/video_quality.py b/scripts/video/video_quality.py index f756b5a..9ed287a 100644 --- a/scripts/video/video_quality.py +++ b/scripts/video/video_quality.py @@ -158,10 +158,14 @@ def analyze_video(frames, chunk_size=None, compute_flow=False): boundary_metrics = [] for b in boundaries: if b < n and b > 0: - pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1] + pre = ( + metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1] + ) at = metrics["frame_diff"][b] ratio = at / (pre + 1e-10) - brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1] + brightness_jump = ( + metrics["brightness"][b] - metrics["brightness"][b - 1] + ) contrast_jump = ( (metrics["contrast"][b] - metrics["contrast"][b - 1]) / (metrics["contrast"][b - 1] + 1e-10) @@ -198,7 +202,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed): print("VIDEO QUALITY REPORT") print("=" * 72) print(f" File: {path}") - print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}") + print( + f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}" + ) duration = total_frames / fps if fps > 0 else 0 print(f" Duration: {duration:.1f}s") print() @@ -211,52 +217,76 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed): print("-" * 40) if n_uniform: frames_list = np.where(metrics["is_uniform"])[0][:10] - print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}") + print( + f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}" + ) if n_noisy: frames_list = np.where(metrics["is_noisy"])[0][:10] - print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}") + print( + f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}" + ) print() print("SHARPNESS") print("-" * 40) - print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}") - print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}") + print( + f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}" + ) + print( + f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}" + ) if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3: print(" ⚠ High sharpness variation — possible blur artifacts") print() print("BRIGHTNESS & CONTRAST") print("-" * 40) - print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}") - print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}") + print( + f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}" + ) + print( + f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}" + ) if np.std(br) > 3.0: print(" ⚠ Brightness instability — may indicate chunk boundary artifacts") print() print("COLOR DISTRIBUTION (BGR)") print("-" * 40) - print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}") - print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}") - print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}") + print( + f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}" + ) + print( + f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}" + ) + print( + f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}" + ) print() print("TEMPORAL STABILITY") print("-" * 40) fd_nz = fd[1:] # skip first frame (always 0) if len(fd_nz) > 0: - print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}") + print( + f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}" + ) if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5: print(" ⚠ High diff variance — jitter or discontinuities") if "flow_mean" in metrics: fm = metrics["flow_mean"][1:] - print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}") + print( + f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}" + ) print() # Chunk boundaries if "boundaries" in metrics and metrics["boundaries"]: print("CHUNK BOUNDARIES") print("-" * 40) - print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}") + print( + f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}" + ) for bm in metrics["boundaries"]: print( f" {bm['frame']:6d}" @@ -267,7 +297,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed): ) avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]]) if avg_ratio > 2.0: - print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions") + print( + f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions" + ) print() # Overall grade @@ -303,9 +335,7 @@ def main(): type=int, help="Frames per chunk for boundary analysis (e.g., 32)", ) - parser.add_argument( - "--start", type=int, default=0, help="Start frame (default: 0)" - ) + parser.add_argument("--start", type=int, default=0, help="Start frame (default: 0)") parser.add_argument("--end", type=int, help="End frame (default: all)") parser.add_argument( "--flow", @@ -329,8 +359,14 @@ def main(): import csv keys = [ - "sharpness_lap", "sharpness_grad", "brightness", "contrast", - "color_mean_b", "color_mean_g", "color_mean_r", "frame_diff", + "sharpness_lap", + "sharpness_grad", + "brightness", + "contrast", + "color_mean_b", + "color_mean_g", + "color_mean_r", + "frame_diff", ] if args.flow: keys += ["flow_mean", "flow_max"] diff --git a/tests/test_generate_dev.py b/tests/test_generate_dev.py index e4fa17e..3e85f61 100644 --- a/tests/test_generate_dev.py +++ b/tests/test_generate_dev.py @@ -1,17 +1,17 @@ """Tests for LTX-2 dev model generation pipeline.""" -import pytest import mlx.core as mx +import pytest from mlx_video.generate_dev import ( - ltx2_scheduler, - create_position_grid, - create_audio_position_grid, - compute_audio_frames, - cfg_delta, - DEFAULT_NEGATIVE_PROMPT, - AUDIO_SAMPLE_RATE, AUDIO_LATENTS_PER_SECOND, + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + cfg_delta, + compute_audio_frames, + create_audio_position_grid, + create_position_grid, + ltx2_scheduler, ) @@ -22,12 +22,16 @@ class TestLTX2Scheduler: """Scheduler should return steps+1 sigma values.""" steps = 20 sigmas = ltx2_scheduler(steps=steps) - assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}" + assert sigmas.shape == ( + steps + 1, + ), f"Expected ({steps + 1},), got {sigmas.shape}" def test_scheduler_starts_at_one(self): """Sigma schedule should start at 1.0.""" sigmas = ltx2_scheduler(steps=20) - assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}" + assert ( + abs(sigmas[0].item() - 1.0) < 1e-5 + ), f"Expected 1.0, got {sigmas[0].item()}" def test_scheduler_ends_at_zero(self): """Sigma schedule should end at 0.0.""" @@ -39,8 +43,9 @@ class TestLTX2Scheduler: sigmas = ltx2_scheduler(steps=20) sigmas_list = sigmas.tolist() for i in range(len(sigmas_list) - 1): - assert sigmas_list[i] >= sigmas_list[i + 1], \ - f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}" + assert ( + sigmas_list[i] >= sigmas_list[i + 1] + ), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}" def test_scheduler_dtype(self): """Scheduler should return float32 array.""" @@ -84,14 +89,16 @@ class TestCreatePositionGrid: num_patches = num_frames * height * width expected_shape = (batch_size, 3, num_patches, 2) - assert positions.shape == expected_shape, \ - f"Expected {expected_shape}, got {positions.shape}" + assert ( + positions.shape == expected_shape + ), f"Expected {expected_shape}, got {positions.shape}" def test_position_grid_dtype(self): """Position grid should be float32 for RoPE precision.""" positions = create_position_grid(1, 5, 16, 24) - assert positions.dtype == mx.float32, \ - f"Expected float32 for RoPE precision, got {positions.dtype}" + assert ( + positions.dtype == mx.float32 + ), f"Expected float32 for RoPE precision, got {positions.dtype}" def test_position_grid_batch_size(self): """Position grid should respect batch size.""" @@ -165,7 +172,9 @@ class TestCFGDelta: mx.eval(delta) # Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0 - assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero" + assert ( + mx.max(mx.abs(delta)).item() < 1e-6 + ), "CFG delta with scale=1.0 should be zero" def test_cfg_delta_formula(self): """CFG delta should follow the formula: (scale-1) * (cond - uncond).""" @@ -204,8 +213,9 @@ class TestDefaultNegativePrompt: # Check for common negative quality terms assert "blurry" in prompt_lower, "Should contain 'blurry'" - assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \ - "Should contain quality-related terms" + assert ( + "low quality" in prompt_lower or "low contrast" in prompt_lower + ), "Should contain quality-related terms" class TestInputValidation: @@ -248,15 +258,16 @@ class TestInputValidation: (30, 33), # 30 -> nearest valid is 33 (35, 33), # 35 -> nearest valid is 33 (40, 41), # 40 -> nearest valid is 41 - (1, 1), # 1 is already valid + (1, 1), # 1 is already valid (33, 33), # 33 is already valid ] for input_frames, expected in test_cases: if input_frames % 8 != 1: adjusted = round((input_frames - 1) / 8) * 8 + 1 - assert adjusted == expected, \ - f"Expected {expected} for input {input_frames}, got {adjusted}" + assert ( + adjusted == expected + ), f"Expected {expected} for input {input_frames}, got {adjusted}" class TestDenoiseWithCFGMocked: @@ -277,14 +288,16 @@ class TestTilingDefault: def test_tiling_default_is_none(self): """Default tiling should be 'none' for performance.""" import inspect + from mlx_video.generate_dev import generate_video_dev sig = inspect.signature(generate_video_dev) - tiling_param = sig.parameters.get('tiling') + tiling_param = sig.parameters.get("tiling") assert tiling_param is not None - assert tiling_param.default == "none", \ - f"Expected default tiling='none', got '{tiling_param.default}'" + assert ( + tiling_param.default == "none" + ), f"Expected default tiling='none', got '{tiling_param.default}'" class TestLatentDimensions: @@ -296,8 +309,9 @@ class TestLatentDimensions: for height, expected_latent_h in test_cases: latent_h = height // 32 - assert latent_h == expected_latent_h, \ - f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}" + assert ( + latent_h == expected_latent_h + ), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}" def test_latent_width_calculation(self): """Latent width should be width // 32.""" @@ -305,8 +319,9 @@ class TestLatentDimensions: for width, expected_latent_w in test_cases: latent_w = width // 32 - assert latent_w == expected_latent_w, \ - f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}" + assert ( + latent_w == expected_latent_w + ), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}" def test_latent_frames_calculation(self): """Latent frames should be 1 + (num_frames - 1) // 8.""" @@ -314,8 +329,9 @@ class TestLatentDimensions: for num_frames, expected_latent_f in test_cases: latent_f = 1 + (num_frames - 1) // 8 - assert latent_f == expected_latent_f, \ - f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}" + assert ( + latent_f == expected_latent_f + ), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}" def test_num_tokens_calculation(self): """Number of tokens should be latent_f * latent_h * latent_w.""" @@ -343,14 +359,14 @@ class TestAudioPositionGrid: positions = create_audio_position_grid(batch_size, audio_frames) expected_shape = (batch_size, 1, audio_frames, 2) - assert positions.shape == expected_shape, \ - f"Expected {expected_shape}, got {positions.shape}" + assert ( + positions.shape == expected_shape + ), f"Expected {expected_shape}, got {positions.shape}" def test_audio_position_grid_dtype(self): """Audio position grid should be float32.""" positions = create_audio_position_grid(1, 34) - assert positions.dtype == mx.float32, \ - f"Expected float32, got {positions.dtype}" + assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}" def test_audio_position_grid_batch_size(self): """Audio position grid should respect batch size.""" @@ -371,8 +387,12 @@ class TestAudioPositionGrid: """Audio position grid should not contain NaN or Inf.""" positions = create_audio_position_grid(1, 34) - assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN" - assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf" + assert not mx.any( + mx.isnan(positions) + ).item(), "Audio position grid contains NaN" + assert not mx.any( + mx.isinf(positions) + ).item(), "Audio position grid contains Inf" class TestComputeAudioFrames: @@ -391,8 +411,9 @@ class TestComputeAudioFrames: audio_33 = compute_audio_frames(33, 24.0) audio_65 = compute_audio_frames(65, 24.0) - assert audio_65 > audio_33, \ - f"Expected more audio frames for longer video: {audio_65} <= {audio_33}" + assert ( + audio_65 > audio_33 + ), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}" def test_audio_frames_formula(self): """Audio frames should match expected formula.""" diff --git a/tests/test_rope.py b/tests/test_rope.py index 8590963..f05574c 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,11 +1,9 @@ -import pytest import mlx.core as mx import numpy as np +import pytest -from mlx_video.models.ltx_2.rope import ( - precompute_freqs_cis, -) from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType +from mlx_video.models.ltx_2.rope import precompute_freqs_cis def create_video_position_grid( @@ -20,7 +18,7 @@ def create_video_position_grid( h_coords = np.arange(0, height) w_coords = np.arange(0, width) - t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') + t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij") patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) patch_ends = patch_starts + 1 @@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads): scaled = fractional * 2 - 1 # [-1, 1] # Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices) - freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :] + freqs = ( + scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :] + ) # (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten freqs = np.swapaxes(freqs, -1, -2) - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims) + freqs = freqs.reshape( + freqs.shape[0], freqs.shape[1], -1 + ) # (B, T, num_indices * n_dims) cos_ref = np.cos(freqs) sin_ref = np.sin(freqs) @@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads): pad_size = expected - cos_ref.shape[-1] if pad_size > 0: # Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis() - cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1) - sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1) + cos_ref = np.concatenate( + [np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1 + ) + sin_ref = np.concatenate( + [np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1 + ) B, T, _ = cos_ref.shape dim_per_head = dim // num_heads @@ -124,10 +130,12 @@ class TestRoPEPositionPrecision: assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf" # Verify cos/sin are in valid range [-1, 1] - assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \ - "cos_freq values out of [-1, 1] range" - assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \ - "sin_freq values out of [-1, 1] range" + assert ( + mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item() + ), "cos_freq values out of [-1, 1] range" + assert ( + mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item() + ), "sin_freq values out of [-1, 1] range" def test_bfloat16_positions_cause_precision_loss(self): """bfloat16 positions should produce different (less precise) results than float32. @@ -175,7 +183,9 @@ class TestRoPEPositionPrecision: # The threshold here is intentionally low to catch the issue precision_threshold = 1e-6 - has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold + has_precision_loss = ( + max_cos_diff > precision_threshold or max_sin_diff > precision_threshold + ) # Document the precision loss (this is expected behavior) if has_precision_loss: @@ -184,8 +194,9 @@ class TestRoPEPositionPrecision: print(f" Max sin difference: {max_sin_diff:.6e}") # This assertion documents the issue - bfloat16 positions cause precision loss - assert has_precision_loss, \ - "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" + assert ( + has_precision_loss + ), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" def test_double_precision_converts_to_float32_internally(self): """Verify that double_precision mode converts bfloat16 to float32 first.""" @@ -215,20 +226,26 @@ class TestRoPEPositionPrecision: # Recommended: create positions in float32 positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) - assert positions.dtype == mx.float32, \ - "Position grids should be created in float32 for RoPE precision" + assert ( + positions.dtype == mx.float32 + ), "Position grids should be created in float32 for RoPE precision" # Verify the position values are reasonable # Temporal positions should be small (seconds) temporal_positions = positions[:, 0, :, :] - assert mx.max(temporal_positions).item() < 100, \ - "Temporal positions should be in seconds (small values)" + assert ( + mx.max(temporal_positions).item() < 100 + ), "Temporal positions should be in seconds (small values)" # Spatial positions should be larger (pixels) spatial_h = positions[:, 1, :, :] spatial_w = positions[:, 2, :, :] - assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" - assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" + assert ( + mx.max(spatial_h).item() > 0 + ), "Spatial height positions should be positive" + assert ( + mx.max(spatial_w).item() > 0 + ), "Spatial width positions should be positive" def test_float32_positions_match_numpy_float64_reference(self): """Regression test: float32 RoPE must closely match a NumPy float64 reference. @@ -259,7 +276,9 @@ class TestRoPEPositionPrecision: ) # NumPy float64 reference - cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) + cos_ref, sin_ref = _numpy_reference_rope( + positions_np, dim, theta, max_pos, num_heads + ) cos_mlx_np = np.array(cos_mlx) sin_mlx_np = np.array(sin_mlx) @@ -270,16 +289,21 @@ class TestRoPEPositionPrecision: # Cosine similarity (flatten for single scalar) cos_flat = cos_mlx_np.flatten() ref_flat = cos_ref.flatten() - cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)) + cosine_sim = np.dot(cos_flat, ref_flat) / ( + np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat) + ) # float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa. # Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff). - assert max_cos_diff < 0.01, \ - f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" - assert max_sin_diff < 0.01, \ - f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" - assert cosine_sim > 0.9999, \ - f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999" + assert ( + max_cos_diff < 0.01 + ), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" + assert ( + max_sin_diff < 0.01 + ), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" + assert ( + cosine_sim > 0.9999 + ), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999" def test_high_frequency_amplification_regression(self): """Regression test for the specific failure mode: high-frequency index amplification. @@ -309,16 +333,20 @@ class TestRoPEPositionPrecision: double_precision=False, ) - cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) + cos_ref, sin_ref = _numpy_reference_rope( + positions_np, dim, theta, max_pos, num_heads + ) max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref)) max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref)) # Float32 should keep errors well below the bfloat16 failure threshold of ~2.0 - assert max_cos_diff < 0.01, \ - f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected" - assert max_sin_diff < 0.01, \ - f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected" + assert ( + max_cos_diff < 0.01 + ), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected" + assert ( + max_sin_diff < 0.01 + ), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected" class TestRoPEInterleaved: @@ -359,9 +387,13 @@ class TestRoPEInputCasting: positions_bf16 = positions_f32.astype(mx.bfloat16) kwargs = dict( - dim=128, theta=10000.0, max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, num_attention_heads=32, - rope_type=LTXRopeType.SPLIT, double_precision=False, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=False, ) cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) @@ -383,9 +415,13 @@ class TestRoPEInputCasting: positions_bf16 = positions_f32.astype(mx.bfloat16) kwargs = dict( - dim=128, theta=10000.0, max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, num_attention_heads=32, - rope_type=LTXRopeType.SPLIT, double_precision=True, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, ) cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) @@ -405,9 +441,13 @@ class TestRoPEInputCasting: cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions_f16, - dim=128, theta=10000.0, max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, num_attention_heads=32, - rope_type=LTXRopeType.SPLIT, double_precision=False, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=False, ) assert cos_freq.dtype == mx.float32 @@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig: def test_ltx2_forces_double_precision_rope_false(self): """LTX-2 (no prompt adaln) must have double_precision_rope=False.""" config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True) - assert config.double_precision_rope is False, \ - "LTX-2 should force double_precision_rope=False regardless of input" + assert ( + config.double_precision_rope is False + ), "LTX-2 should force double_precision_rope=False regardless of input" def test_ltx23_preserves_double_precision_rope_true(self): """LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True.""" config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True) - assert config.double_precision_rope is True, \ - "LTX-2.3 should preserve double_precision_rope=True" + assert ( + config.double_precision_rope is True + ), "LTX-2.3 should preserve double_precision_rope=True" def test_ltx23_preserves_double_precision_rope_false(self): """LTX-2.3 with double_precision_rope=False should stay False.""" config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False) - assert config.double_precision_rope is False, \ - "LTX-2.3 should respect double_precision_rope=False when explicitly set" + assert ( + config.double_precision_rope is False + ), "LTX-2.3 should respect double_precision_rope=False when explicitly set" def test_ltx2_default_double_precision_rope(self): """LTX-2 default (double_precision_rope not set) should be False.""" @@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig: def test_config_from_dict_ltx2(self): """Config created from dict for LTX-2 should force double_precision_rope=False.""" - config = LTXModelConfig.from_dict({ - "has_prompt_adaln": False, - "double_precision_rope": True, - "rope_type": "split", - }) + config = LTXModelConfig.from_dict( + { + "has_prompt_adaln": False, + "double_precision_rope": True, + "rope_type": "split", + } + ) assert config.double_precision_rope is False def test_config_from_dict_ltx23(self): """Config created from dict for LTX-2.3 should preserve double_precision_rope.""" - config = LTXModelConfig.from_dict({ - "has_prompt_adaln": True, - "double_precision_rope": True, - "rope_type": "split", - }) + config = LTXModelConfig.from_dict( + { + "has_prompt_adaln": True, + "double_precision_rope": True, + "rope_type": "split", + } + ) assert config.double_precision_rope is True @@ -496,10 +543,12 @@ class TestRoPESplit: # dim=128, num_heads=32, so dim_per_head=4, and split uses half=2 dim_per_head = dim // num_heads expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2) - assert cos_freq.shape == expected_shape, \ - f"Expected shape {expected_shape}, got {cos_freq.shape}" - assert sin_freq.shape == expected_shape, \ - f"Expected shape {expected_shape}, got {sin_freq.shape}" + assert ( + cos_freq.shape == expected_shape + ), f"Expected shape {expected_shape}, got {cos_freq.shape}" + assert ( + sin_freq.shape == expected_shape + ), f"Expected shape {expected_shape}, got {sin_freq.shape}" if __name__ == "__main__": diff --git a/tests/test_vae_streaming.py b/tests/test_vae_streaming.py index 0f3abd8..13d1a82 100644 --- a/tests/test_vae_streaming.py +++ b/tests/test_vae_streaming.py @@ -1,8 +1,8 @@ """Tests for VAE streaming and chunked conv features.""" -import pytest import mlx.core as mx import numpy as np +import pytest from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.tiling import ( @@ -50,7 +50,7 @@ class TestChunkedConv: np.array(out_chunked), rtol=1e-5, atol=1e-5, - err_msg="Chunked conv output differs from regular output" + err_msg="Chunked conv output differs from regular output", ) def test_chunked_conv_small_input_passthrough(self): @@ -117,13 +117,17 @@ class TestProgressiveFrameSaving: frames_received = [] def on_frames_ready(frames: mx.array, start_idx: int): - frames_received.append({ - 'shape': frames.shape, - 'start_idx': start_idx, - }) + frames_received.append( + { + "shape": frames.shape, + "start_idx": start_idx, + } + ) # Create a mock decoder that just returns scaled input - def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): + def mock_decoder( + x, causal=False, timestep=None, debug=False, chunked_conv=False + ): # Simulate VAE output: upsample 8x temporal, 32x spatial b, c, f, h, w = x.shape out_f = 1 + (f - 1) * 8 @@ -154,7 +158,9 @@ class TestProgressiveFrameSaving: # All received frames should have correct channel count for received in frames_received: - assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}" + assert ( + received["shape"][1] == 3 + ), f"Expected 3 channels, got {received['shape'][1]}" def test_on_frames_ready_covers_all_frames(self): """Verify all frames are emitted via callbacks.""" @@ -165,7 +171,9 @@ class TestProgressiveFrameSaving: for i in range(num_frames): all_frame_indices.add(start_idx + i) - def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): + def mock_decoder( + x, causal=False, timestep=None, debug=False, chunked_conv=False + ): b, c, f, h, w = x.shape out_f = 1 + (f - 1) * 8 out_h = h * 32 @@ -191,24 +199,29 @@ class TestProgressiveFrameSaving: expected_frames = 1 + (12 - 1) * 8 # 89 frames # All frames should have been emitted - assert len(all_frame_indices) == expected_frames, \ - f"Expected {expected_frames} frames, got {len(all_frame_indices)}" - assert all_frame_indices == set(range(expected_frames)), \ - "Not all frame indices were covered" + assert ( + len(all_frame_indices) == expected_frames + ), f"Expected {expected_frames} frames, got {len(all_frame_indices)}" + assert all_frame_indices == set( + range(expected_frames) + ), "Not all frame indices were covered" class TestAutoChunkedConv: """Tests for auto-enabling chunked_conv based on tiling mode.""" - @pytest.mark.parametrize("tiling_mode,should_enable", [ - ("conservative", True), - ("none", True), - ("auto", True), - ("default", True), - ("spatial", True), - ("aggressive", False), - ("temporal", False), - ]) + @pytest.mark.parametrize( + "tiling_mode,should_enable", + [ + ("conservative", True), + ("none", True), + ("auto", True), + ("default", True), + ("spatial", True), + ("aggressive", False), + ("temporal", False), + ], + ) def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool): """Verify chunked_conv is auto-enabled for correct tiling modes.""" # The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial") @@ -216,8 +229,9 @@ class TestAutoChunkedConv: use_chunked_conv = tiling_mode in expected_modes - assert use_chunked_conv == should_enable, \ - f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}" + assert ( + use_chunked_conv == should_enable + ), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}" class TestTrapezoidalMask: @@ -250,7 +264,9 @@ class TestTrapezoidalMask: # Right ramp should be decreasing right_ramp = mask_np[-8:] - assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing" + assert np.all( + np.diff(right_ramp) <= 0 + ), "Right ramp not monotonically decreasing" def test_temporal_mask_starts_from_zero(self): """Verify temporal mask (left_starts_from_0=True) starts from 0.""" diff --git a/tests/test_wan_attention.py b/tests/test_wan_attention.py index 02e471b..700bb61 100644 --- a/tests/test_wan_attention.py +++ b/tests/test_wan_attention.py @@ -2,24 +2,25 @@ import mlx.core as mx import numpy as np -import pytest - # --------------------------------------------------------------------------- # RoPE Tests # --------------------------------------------------------------------------- + class TestRoPE: """Tests for 3-way factorized RoPE.""" def test_rope_params_shape(self): from mlx_video.models.wan.rope import rope_params + freqs = rope_params(1024, 64) mx.eval(freqs) assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2] def test_rope_params_different_dims(self): from mlx_video.models.wan.rope import rope_params + for dim in [32, 64, 128]: freqs = rope_params(512, dim) mx.eval(freqs) @@ -27,6 +28,7 @@ class TestRoPE: def test_rope_params_cos_sin_range(self): from mlx_video.models.wan.rope import rope_params + freqs = rope_params(256, 64) mx.eval(freqs) cos_vals = np.array(freqs[:, :, 0]) @@ -37,13 +39,15 @@ class TestRoPE: def test_rope_params_position_zero(self): """At position 0, cos should be 1 and sin should be 0.""" from mlx_video.models.wan.rope import rope_params + freqs = rope_params(10, 64) mx.eval(freqs) np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6) np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6) def test_rope_apply_output_shape(self): - from mlx_video.models.wan.rope import rope_params, rope_apply + from mlx_video.models.wan.rope import rope_apply, rope_params + B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim x = mx.random.normal((B, L, N, D)) freqs = rope_params(1024, D) @@ -54,7 +58,8 @@ class TestRoPE: def test_rope_apply_preserves_norm(self): """RoPE rotation should preserve vector norms.""" - from mlx_video.models.wan.rope import rope_params, rope_apply + from mlx_video.models.wan.rope import rope_apply, rope_params + B, N, D = 1, 2, 16 F, H, W = 2, 3, 4 L = F * H * W @@ -74,7 +79,8 @@ class TestRoPE: def test_rope_apply_with_padding(self): """When seq_len < L, extra tokens should be preserved unchanged.""" - from mlx_video.models.wan.rope import rope_params, rope_apply + from mlx_video.models.wan.rope import rope_apply, rope_params + B, N, D = 1, 2, 16 F, H, W = 2, 2, 2 seq_len = F * H * W # 8 @@ -94,7 +100,8 @@ class TestRoPE: def test_rope_apply_batch(self): """Test with batch_size > 1 and different grid sizes.""" - from mlx_video.models.wan.rope import rope_params, rope_apply + from mlx_video.models.wan.rope import rope_apply, rope_params + B, N, D = 2, 2, 16 grids = [(2, 3, 4), (2, 3, 4)] L = 2 * 3 * 4 @@ -122,9 +129,11 @@ class TestRoPE: # Attention Tests # --------------------------------------------------------------------------- + class TestWanRMSNorm: def test_output_shape(self): from mlx_video.models.wan.attention import WanRMSNorm + norm = WanRMSNorm(64) x = mx.random.normal((2, 10, 64)) out = norm(x) @@ -134,6 +143,7 @@ class TestWanRMSNorm: def test_zero_mean_variance(self): """RMS norm should make RMS ≈ 1 before scaling.""" from mlx_video.models.wan.attention import WanRMSNorm + norm = WanRMSNorm(64) x = mx.random.normal((1, 5, 64)) * 10.0 out = norm(x) @@ -147,6 +157,7 @@ class TestWanRMSNorm: def test_dtype_preservation(self): """RMSNorm weight is float32, so output is promoted to float32.""" from mlx_video.models.wan.attention import WanRMSNorm + norm = WanRMSNorm(32) x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16) out = norm(x) @@ -158,6 +169,7 @@ class TestWanRMSNorm: class TestWanLayerNorm: def test_output_shape(self): from mlx_video.models.wan.attention import WanLayerNorm + norm = WanLayerNorm(64) x = mx.random.normal((2, 10, 64)) out = norm(x) @@ -166,6 +178,7 @@ class TestWanLayerNorm: def test_without_affine(self): from mlx_video.models.wan.attention import WanLayerNorm + norm = WanLayerNorm(64, elementwise_affine=False) x = mx.random.normal((1, 4, 64)) out = norm(x) @@ -178,6 +191,7 @@ class TestWanLayerNorm: def test_with_affine(self): from mlx_video.models.wan.attention import WanLayerNorm + norm = WanLayerNorm(32, elementwise_affine=True) assert hasattr(norm, "weight") assert hasattr(norm, "bias") @@ -196,6 +210,7 @@ class TestWanSelfAttention: def test_output_shape(self): from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.rope import rope_params + attn = WanSelfAttention(self.dim, self.num_heads) B, L = 1, 24 F, H, W = 2, 3, 4 @@ -207,12 +222,14 @@ class TestWanSelfAttention: def test_with_qk_norm(self): from mlx_video.models.wan.attention import WanSelfAttention + attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True) assert attn.norm_q is not None assert attn.norm_k is not None def test_without_qk_norm(self): from mlx_video.models.wan.attention import WanSelfAttention + attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) assert attn.norm_q is None assert attn.norm_k is None @@ -221,6 +238,7 @@ class TestWanSelfAttention: """Test that masking works: shorter seq_lens should mask later tokens.""" from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.rope import rope_params + attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) B, L = 1, 24 F, H, W = 2, 3, 4 @@ -245,6 +263,7 @@ class TestWanCrossAttention: def test_output_shape(self): from mlx_video.models.wan.attention import WanCrossAttention + attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 24, 16 x = mx.random.normal((B, L_q, self.dim)) @@ -255,6 +274,7 @@ class TestWanCrossAttention: def test_with_context_mask(self): from mlx_video.models.wan.attention import WanCrossAttention + attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 12, 16 x = mx.random.normal((B, L_q, self.dim)) @@ -268,6 +288,7 @@ class TestWanCrossAttention: # bfloat16 Autocast Tests # --------------------------------------------------------------------------- + class TestBFloat16Autocast: """Tests that attention and FFN cast inputs to weight dtype (bfloat16) for efficient matmul, matching official PyTorch autocast behavior.""" @@ -292,6 +313,7 @@ class TestBFloat16Autocast: """Self-attention should cast input to weight dtype for QKV projections.""" from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.rope import rope_params + attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -305,6 +327,7 @@ class TestBFloat16Autocast: def test_cross_attn_casts_to_weight_dtype(self): """Cross-attention should cast input to weight dtype.""" from mlx_video.models.wan.attention import WanCrossAttention + attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -318,6 +341,7 @@ class TestBFloat16Autocast: def test_cross_attn_kv_cache_uses_weight_dtype(self): """prepare_kv should cast context to weight dtype.""" from mlx_video.models.wan.attention import WanCrossAttention + attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -330,6 +354,7 @@ class TestBFloat16Autocast: def test_ffn_casts_to_weight_dtype(self): """FFN should cast input to weight dtype for linear layers.""" from mlx_video.models.wan.transformer import WanFFN + ffn = WanFFN(self.dim, 128) ffn.update(self._to_bf16(ffn.parameters())) @@ -343,6 +368,7 @@ class TestBFloat16Autocast: """RoPE should be applied in float32 for precision, even with bf16 weights.""" from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.rope import rope_params + attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -355,8 +381,9 @@ class TestBFloat16Autocast: def test_block_float32_residual_with_bf16_weights(self): """Full block: residual stream stays float32, matmuls use bf16 weights.""" - from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True) block.update(self._to_bf16(block.parameters())) diff --git a/tests/test_wan_config.py b/tests/test_wan_config.py index 5b943df..2ffddcf 100644 --- a/tests/test_wan_config.py +++ b/tests/test_wan_config.py @@ -1,17 +1,17 @@ """Tests for Wan model configuration.""" -import pytest - # --------------------------------------------------------------------------- # Config Tests # --------------------------------------------------------------------------- + class TestWanModelConfig: """Tests for WanModelConfig dataclass.""" def test_default_values(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() assert config.dim == 5120 assert config.ffn_dim == 13824 @@ -33,11 +33,13 @@ class TestWanModelConfig: def test_head_dim_property(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() assert config.head_dim == 128 # 5120 // 40 def test_to_dict_roundtrip(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() d = config.to_dict() assert isinstance(d, dict) @@ -47,6 +49,7 @@ class TestWanModelConfig: def test_t5_config_values(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() assert config.t5_vocab_size == 256384 assert config.t5_dim == 4096 @@ -61,11 +64,13 @@ class TestWanModelConfig: # Wan2.1 Config Tests # --------------------------------------------------------------------------- + class TestWan21Config: """Tests for Wan2.1 config presets.""" def test_wan21_14b_factory(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_14b() assert config.model_version == "2.1" assert config.dual_model is False @@ -81,6 +86,7 @@ class TestWan21Config: def test_wan21_1_3b_factory(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_1_3b() assert config.model_version == "2.1" assert config.dual_model is False @@ -93,6 +99,7 @@ class TestWan21Config: def test_wan22_14b_factory(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan22_t2v_14b() assert config.model_version == "2.2" assert config.dual_model is True @@ -104,6 +111,7 @@ class TestWan21Config: def test_wan21_config_to_dict(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() assert d["model_version"] == "2.1" @@ -112,6 +120,7 @@ class TestWan21Config: def test_wan21_1_3b_config_to_dict(self): from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_1_3b() d = config.to_dict() assert d["dim"] == 1536 @@ -120,6 +129,7 @@ class TestWan21Config: def test_default_config_is_wan22(self): """Default WanModelConfig() should be Wan2.2 14B.""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() assert config.model_version == "2.2" assert config.dual_model is True diff --git a/tests/test_wan_convert.py b/tests/test_wan_convert.py index 81630ce..69a8dd3 100644 --- a/tests/test_wan_convert.py +++ b/tests/test_wan_convert.py @@ -3,17 +3,16 @@ import logging import mlx.core as mx -import numpy as np -import pytest - # --------------------------------------------------------------------------- # Transformer Weight Conversion Tests # --------------------------------------------------------------------------- + class TestSanitizeTransformerWeights: def test_patch_embedding_reshape(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.bias": mx.random.normal((5120,)), @@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights: def test_text_embedding_rename(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "text_embedding.0.weight": mx.zeros((64, 32)), "text_embedding.0.bias": mx.zeros((64,)), @@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights: def test_time_embedding_rename(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "time_embedding.0.weight": mx.zeros((64, 32)), "time_embedding.2.weight": mx.zeros((64, 64)), @@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights: def test_time_projection_rename(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "time_projection.1.weight": mx.zeros((384, 64)), "time_projection.1.bias": mx.zeros((384,)), @@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights: def test_ffn_rename(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "blocks.0.ffn.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.0.bias": mx.zeros((128,)), @@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights: def test_freqs_skipped(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "freqs": mx.zeros((1024, 64, 2)), "blocks.0.norm1.weight": mx.zeros((64,)), @@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights: def test_passthrough_keys(self): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "blocks.0.self_attn.q.weight": mx.zeros((64, 64)), "blocks.0.self_attn.k.weight": mx.zeros((64, 64)), @@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights: def test_no_unconsumed_keys(self, caplog): from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.bias": mx.random.normal((5120,)), @@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights: class TestSanitizeT5Weights: def test_gate_rename(self): from mlx_video.convert_wan import sanitize_wan_t5_weights + weights = { "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.fc1.weight": mx.zeros((128, 64)), @@ -133,6 +140,7 @@ class TestSanitizeT5Weights: def test_passthrough(self): from mlx_video.convert_wan import sanitize_wan_t5_weights + weights = { "token_embedding.weight": mx.zeros((100, 64)), "blocks.0.attn.q.weight": mx.zeros((64, 64)), @@ -144,6 +152,7 @@ class TestSanitizeT5Weights: def test_no_unconsumed_keys(self, caplog): from mlx_video.convert_wan import sanitize_wan_t5_weights + weights = { "token_embedding.weight": mx.zeros((100, 64)), "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), @@ -159,6 +168,7 @@ class TestSanitizeT5Weights: class TestSanitizeVAEWeights: def test_conv3d_transpose(self): from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W] } @@ -167,6 +177,7 @@ class TestSanitizeVAEWeights: def test_conv2d_transpose(self): from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W] } @@ -175,6 +186,7 @@ class TestSanitizeVAEWeights: def test_non_conv_passthrough(self): from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { "decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose "decoder.bias": mx.zeros((16,)), @@ -185,6 +197,7 @@ class TestSanitizeVAEWeights: def test_mixed_weights(self): from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { "conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D "conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D @@ -199,6 +212,7 @@ class TestSanitizeVAEWeights: def test_no_unconsumed_keys(self, caplog): from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), @@ -214,6 +228,7 @@ class TestSanitizeVAEWeights: # Wan2.1 Conversion Tests # --------------------------------------------------------------------------- + class TestWan21Convert: """Tests for Wan2.1 conversion support.""" @@ -222,7 +237,7 @@ class TestWan21Convert: # Create a Wan2.1-style directory (no low_noise_model subdir) (tmp_path / "dummy.safetensors").touch() # The auto-detect logic: no low_noise_model dir → 2.1 - from pathlib import Path + low = tmp_path / "low_noise_model" assert not low.exists() # Simulates auto detection @@ -233,7 +248,7 @@ class TestWan21Convert: """Auto-detect dual-model directory as Wan2.2.""" (tmp_path / "low_noise_model").mkdir() (tmp_path / "high_noise_model").mkdir() - from pathlib import Path + low = tmp_path / "low_noise_model" assert low.exists() version = "2.2" if low.exists() else "2.1" @@ -242,6 +257,7 @@ class TestWan21Convert: def test_wan21_config_saved_correctly(self): """Verify config dict has correct fields for Wan2.1.""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() assert d["model_version"] == "2.1" @@ -254,6 +270,7 @@ class TestWan21Convert: # Encoder Weight Sanitization Tests # --------------------------------------------------------------------------- + class TestSanitizeEncoderWeights: """Tests for sanitize_wan22_vae_weights with include_encoder.""" diff --git a/tests/test_wan_generate.py b/tests/test_wan_generate.py index a643d9e..e42713c 100644 --- a/tests/test_wan_generate.py +++ b/tests/test_wan_generate.py @@ -2,15 +2,13 @@ import mlx.core as mx import numpy as np -import pytest - from wan_test_helpers import _make_tiny_config - # --------------------------------------------------------------------------- # Integration: end-to-end tiny model forward pass # --------------------------------------------------------------------------- + class TestEndToEnd: """End-to-end test with tiny model (no real weights needed).""" @@ -78,6 +76,7 @@ class TestEndToEnd: # I2V Mask Tests # --------------------------------------------------------------------------- + class TestI2VMask: """Tests for _build_i2v_mask.""" @@ -113,6 +112,7 @@ class TestI2VMaskAlignment: def test_mask_with_ti2v_dimensions(self): """Mask should work with TI2V-5B typical dimensions.""" from mlx_video.generate_wan import _build_i2v_mask + # TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2) # 704x1280 → latent 44x80, t_latent=21 for 81 frames z_shape = (48, 21, 44, 80) @@ -133,6 +133,7 @@ class TestI2VMaskAlignment: def test_mask_per_token_timestep(self): """Per-token timesteps: first-frame tokens get t=0, rest get t=sigma.""" from mlx_video.generate_wan import _build_i2v_mask + z_shape = (4, 3, 4, 4) patch_size = (1, 2, 2) _, mask_tokens = _build_i2v_mask(z_shape, patch_size) @@ -144,13 +145,16 @@ class TestI2VMaskAlignment: first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw) np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7) - np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7) + np.testing.assert_allclose( + np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7 + ) # --------------------------------------------------------------------------- # Dimension Alignment Tests # --------------------------------------------------------------------------- + class TestDimensionAlignment: """Tests for automatic dimension alignment in generate_wan.""" @@ -198,6 +202,7 @@ class TestDimensionAlignment: def test_patchify_valid_after_alignment(self): """After alignment, patchify should succeed without reshape errors.""" from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) @@ -222,11 +227,16 @@ class TestDimensionAlignment: patches, grid_size = model._patchify(vid) mx.eval(patches) assert patches.ndim == 3 # [1, L, dim] - assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2]) + assert grid_size == ( + t_latent, + h_latent // patch_size[1], + w_latent // patch_size[2], + ) def test_alignment_with_ti2v_config(self): """TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32.""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan22_ti2v_5b() align_h = config.patch_size[1] * config.vae_stride[1] align_w = config.patch_size[2] * config.vae_stride[2] diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index 067d6c3..112e7cc 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -1,9 +1,6 @@ """Tests for Wan2.2 I2V-14B support.""" import mlx.core as mx -import numpy as np -import pytest - from wan_test_helpers import _make_tiny_config @@ -145,7 +142,10 @@ class TestModelYParameter: latents = mx.random.normal((C_noise, F, H, W)) y = mx.random.normal((C_y, F, H, W)) t = mx.array([500.0, 500.0]) - ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))] + ctx = [ + mx.random.normal((6, config.text_dim)), + mx.random.normal((6, config.text_dim)), + ] out = model([latents, latents], t, ctx, seq_len, y=[y, y]) mx.eval(out[0], out[1]) @@ -160,7 +160,9 @@ class TestVAEEncoder: def test_encoder3d_instantiation(self): from mlx_video.models.wan.vae import Encoder3d - enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2) + enc = Encoder3d( + dim=32, z_dim=8 + ) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2) assert enc.conv1 is not None assert len(enc.downsamples) > 0 assert len(enc.middle) == 3 @@ -199,10 +201,10 @@ class TestVAEEncoder: from mlx_video.models.wan.vae import WanVAE vae_no_enc = WanVAE(z_dim=4, encoder=False) - assert not hasattr(vae_no_enc, 'encoder') + assert not hasattr(vae_no_enc, "encoder") vae_enc = WanVAE(z_dim=4, encoder=True) - assert hasattr(vae_enc, 'encoder') + assert hasattr(vae_enc, "encoder") class TestResampleDownsample: @@ -258,7 +260,9 @@ class TestI2VMaskConstruction: # Build mask following reference logic msk = mx.ones((1, num_frames, h_latent, w_latent)) - msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate( + [msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1 + ) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] @@ -272,7 +276,9 @@ class TestI2VMaskConstruction: t_latent = (num_frames - 1) // 4 + 1 # = 3 msk = mx.ones((1, num_frames, h_latent, w_latent)) - msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate( + [msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1 + ) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.transpose(0, 2, 1, 3, 4)[0] @@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline: config = _make_tiny_i2v_config() config.vae_z_dim = 16 config.out_dim = 16 # must match VAE z_dim for decode - config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36 + config.in_dim = ( + 16 + 4 + 16 + ) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36 model = WanModel(config) # --- Tiny VAE (with encoder) --- @@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline: img = mx.random.uniform(-1, 1, (1, 3, 1, height, width)) # Build video: first frame = image, rest = zeros -> [1, 3, F, H, W] - video = mx.concatenate([ - img, - mx.zeros((1, 3, num_frames - 1, height, width)), - ], axis=2) + video = mx.concatenate( + [ + img, + mx.zeros((1, 3, num_frames - 1, height, width)), + ], + axis=2, + ) # --- VAE encode --- z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat] @@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline: # --- Build I2V mask (4 channels) --- msk = mx.ones((1, num_frames, h_latent, w_latent)) - msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate( + [msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1 + ) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] @@ -453,7 +466,9 @@ class TestDualModelSwitching: noise_pred_cond, noise_pred_uncond = preds[0], preds[1] noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) - latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze( + 0 + ) mx.eval(latents) # With shift=5.0, early timesteps should be high (>=900), later ones low @@ -461,9 +476,9 @@ class TestDualModelSwitching: assert len(low_used_steps) > 0, "Low-noise model was never selected" # High-noise steps should come before low-noise steps (timesteps decrease) if high_used_steps and low_used_steps: - assert max(high_used_steps) < min(low_used_steps) or \ - min(high_used_steps) < max(low_used_steps), \ - "Model switching should happen during the loop" + assert max(high_used_steps) < min(low_used_steps) or min( + high_used_steps + ) < max(low_used_steps), "Model switching should happen during the loop" assert latents.shape == (C_noise, F, H, W) assert not mx.any(mx.isnan(latents)).item() @@ -515,7 +530,9 @@ class TestDualModelSwitching: y=[y_i2v, y_i2v], ) noise_pred = pred[1] + gs * (pred[0] - pred[1]) - latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze( + 0 + ) mx.eval(latents) # Verify both guide scales were used diff --git a/tests/test_wan_lora.py b/tests/test_wan_lora.py index 1670d84..7dc8c4b 100644 --- a/tests/test_wan_lora.py +++ b/tests/test_wan_lora.py @@ -4,7 +4,6 @@ import tempfile from pathlib import Path import mlx.core as mx -import numpy as np import pytest @@ -40,7 +39,9 @@ class TestLoRATypes: lora_a = mx.ones((2, 4)) lora_b = mx.ones((8, 2)) - w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + w = LoRAWeights( + lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test" + ) applied = AppliedLoRA(weights=w, strength=0.5) delta = applied.compute_delta() # scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones) @@ -51,7 +52,9 @@ class TestLoRATypes: class TestLoRALoader: """Test LoRA weight loading from safetensors.""" - def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"): + def _make_lora_file( + self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB" + ): """Helper to create a mock LoRA safetensors file.""" weights = {} for name in module_names: @@ -133,8 +136,16 @@ class TestWanKeyNormalization: """Simulate typical Wan2.2 MLX model weight keys.""" keys = set() for i in range(2): - for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o", - "cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]: + for layer in [ + "self_attn.q", + "self_attn.k", + "self_attn.v", + "self_attn.o", + "cross_attn.q", + "cross_attn.k", + "cross_attn.v", + "cross_attn.o", + ]: keys.add(f"blocks.{i}.{layer}.weight") keys.add(f"blocks.{i}.ffn.fc1.weight") keys.add(f"blocks.{i}.ffn.fc2.weight") @@ -150,7 +161,10 @@ class TestWanKeyNormalization: from mlx_video.lora.apply import _normalize_wan_lora_key keys = self._wan_model_keys() - assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q" + assert ( + _normalize_wan_lora_key("blocks.0.self_attn.q", keys) + == "blocks.0.self_attn.q" + ) def test_strip_diffusion_model_prefix(self): from mlx_video.lora.apply import _normalize_wan_lora_key @@ -163,7 +177,9 @@ class TestWanKeyNormalization: from mlx_video.lora.apply import _normalize_wan_lora_key keys = self._wan_model_keys() - result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys) + result = _normalize_wan_lora_key( + "model.diffusion_model.blocks.0.self_attn.k", keys + ) assert result == "blocks.0.self_attn.k" def test_ffn_key_mapping(self): @@ -197,7 +213,9 @@ class TestWanKeyNormalization: from mlx_video.lora.apply import _normalize_wan_lora_key keys = self._wan_model_keys() - assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj" + assert ( + _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj" + ) def test_combined_prefix_and_ffn(self): from mlx_video.lora.apply import _normalize_wan_lora_key @@ -219,7 +237,9 @@ class TestApplyLoRA: # LoRA weights in float32 (typical when loaded from safetensors) lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 - w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + w = LoRAWeights( + lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test" + ) result = apply_lora_to_linear(original, [(w, 1.0)]) assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}" @@ -230,7 +250,9 @@ class TestApplyLoRA: original = mx.ones((8, 4), dtype=mx.float16) lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 - w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + w = LoRAWeights( + lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test" + ) result = apply_lora_to_linear(original, [(w, 1.0)]) assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}" @@ -241,7 +263,9 @@ class TestApplyLoRA: original = mx.ones((8, 4)) lora_a = mx.ones((2, 4)) * 0.1 lora_b = mx.ones((8, 2)) * 0.1 - w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + w = LoRAWeights( + lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test" + ) result = apply_lora_to_linear(original, [(w, 1.0)]) # delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4) expected = original + 0.02 * mx.ones((8, 4)) @@ -255,12 +279,16 @@ class TestApplyLoRA: w1 = LoRAWeights( lora_A=mx.ones((2, 4)), lora_B=mx.ones((8, 2)), - rank=2, alpha=2.0, module_name="a", + rank=2, + alpha=2.0, + module_name="a", ) w2 = LoRAWeights( lora_A=mx.ones((2, 4)) * 2, lora_B=mx.ones((8, 2)) * 2, - rank=2, alpha=4.0, module_name="b", + rank=2, + alpha=4.0, + module_name="b", ) result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)]) # w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4) @@ -282,7 +310,9 @@ class TestApplyLoRA: w = LoRAWeights( lora_A=mx.ones((4, 64)) * 0.01, lora_B=mx.ones((128, 4)) * 0.01, - rank=4, alpha=4.0, module_name="blocks.0.self_attn.q", + rank=4, + alpha=4.0, + module_name="blocks.0.self_attn.q", ) module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]} result = apply_loras_to_weights(model_weights, module_to_loras) @@ -319,9 +349,7 @@ class TestEndToEnd: "blocks.0.self_attn.k.weight": mx.ones((128, 64)), } - result = load_and_apply_loras( - model_weights, [(str(lora_path), 1.0)] - ) + result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)]) # q weight should be modified, k unchanged assert not mx.array_equal( diff --git a/tests/test_wan_model.py b/tests/test_wan_model.py index caaae89..96c564a 100644 --- a/tests/test_wan_model.py +++ b/tests/test_wan_model.py @@ -3,18 +3,17 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -import pytest - from wan_test_helpers import _make_tiny_config - # --------------------------------------------------------------------------- # Sinusoidal Embedding Tests # --------------------------------------------------------------------------- + class TestSinusoidalEmbedding: def test_output_shape(self): from mlx_video.models.wan.model import sinusoidal_embedding_1d + pos = mx.arange(10).astype(mx.float32) emb = sinusoidal_embedding_1d(256, pos) mx.eval(emb) @@ -23,6 +22,7 @@ class TestSinusoidalEmbedding: def test_position_zero(self): """Position 0 should have cos=1 for all dims and sin=0.""" from mlx_video.models.wan.model import sinusoidal_embedding_1d + pos = mx.array([0.0]) emb = sinusoidal_embedding_1d(64, pos) mx.eval(emb) @@ -34,6 +34,7 @@ class TestSinusoidalEmbedding: def test_different_positions_differ(self): from mlx_video.models.wan.model import sinusoidal_embedding_1d + pos = mx.array([0.0, 100.0, 999.0]) emb = sinusoidal_embedding_1d(128, pos) mx.eval(emb) @@ -46,9 +47,11 @@ class TestSinusoidalEmbedding: # Head Tests # --------------------------------------------------------------------------- + class TestHead: def test_output_shape(self): from mlx_video.models.wan.model import Head + head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) B, L = 1, 24 x = mx.random.normal((B, L, 64)) @@ -60,6 +63,7 @@ class TestHead: def test_modulation_shape(self): from mlx_video.models.wan.model import Head + head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) assert head.modulation.shape == (1, 2, 64) @@ -68,12 +72,14 @@ class TestHead: # WanModel (Tiny) Tests # --------------------------------------------------------------------------- + class TestWanModel: def setup_method(self): mx.random.seed(42) def test_instantiation(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters())) @@ -81,6 +87,7 @@ class TestWanModel: def test_patchify_shape(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) # Input: [C=4, F=1, H=4, W=4] @@ -93,6 +100,7 @@ class TestWanModel: def test_patchify_various_sizes(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]: @@ -108,6 +116,7 @@ class TestWanModel: def test_unpatchify_inverse(self): """Patchify then unpatchify should reconstruct original spatial dims.""" from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 2, 4, 6 @@ -123,6 +132,7 @@ class TestWanModel: def test_forward_pass(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 @@ -140,6 +150,7 @@ class TestWanModel: def test_forward_batch(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 @@ -148,7 +159,10 @@ class TestWanModel: x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))] t = mx.array([500.0, 200.0]) - context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))] + context = [ + mx.random.normal((6, config.text_dim)), + mx.random.normal((4, config.text_dim)), + ] out = model(x_list, t, context, seq_len) mx.eval(out[0], out[1]) @@ -158,12 +172,17 @@ class TestWanModel: def test_output_is_float32(self): from mlx_video.models.wan.model import WanModel + config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 seq_len = (F // 1) * (H // 2) * (W // 2) - out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]), - [mx.random.normal((4, config.text_dim))], seq_len) + out = model( + [mx.random.normal((C, F, H, W))], + mx.array([100.0]), + [mx.random.normal((4, config.text_dim))], + seq_len, + ) mx.eval(out[0]) assert out[0].dtype == mx.float32 @@ -172,6 +191,7 @@ class TestWanModel: # Wan2.1 Model Tests # --------------------------------------------------------------------------- + class TestWan21Model: """Test tiny Wan2.1-style model (single model mode).""" @@ -181,6 +201,7 @@ class TestWan21Model: def _make_tiny_wan21_config(self): """Create a tiny config mimicking Wan2.1 (single model).""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_14b() # Override to tiny values config.dim = 64 @@ -197,6 +218,7 @@ class TestWan21Model: def _make_tiny_wan21_1_3b_config(self): """Create a tiny config mimicking Wan2.1 1.3B.""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig.wan21_t2v_1_3b() # Override to tiny values (preserve 1.3B head structure: 12 heads) config.dim = 48 @@ -271,7 +293,9 @@ class TestWan21Model: for i in range(3): t = sched.timesteps[i] pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0] - pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0] + pred_uncond = model( + [latents], mx.array([t.item()]), [context_null], seq_len + )[0] pred = pred_uncond + gs * (pred_cond - pred_uncond) latents = sched.step(pred[None], t, latents[None]).squeeze(0) mx.eval(latents) @@ -304,6 +328,7 @@ class TestWan21Model: # Per-Token Timestep Tests # --------------------------------------------------------------------------- + class TestPerTokenTimestep: """Tests for per-token sinusoidal embedding.""" diff --git a/tests/test_wan_quantization.py b/tests/test_wan_quantization.py index a219eb7..5ec7355 100644 --- a/tests/test_wan_quantization.py +++ b/tests/test_wan_quantization.py @@ -1,22 +1,22 @@ """Tests for Wan model quantization pipeline.""" import json + import mlx.core as mx import mlx.nn as nn import mlx.utils import numpy as np -import pytest - from wan_test_helpers import _make_tiny_config - # --------------------------------------------------------------------------- # Quantize Predicate Tests # --------------------------------------------------------------------------- + class TestQuantizePredicate: def test_matches_self_attention_layers(self): from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: path = f"blocks.0.self_attn.{suffix}" @@ -24,6 +24,7 @@ class TestQuantizePredicate: def test_matches_cross_attention_layers(self): from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: path = f"blocks.0.cross_attn.{suffix}" @@ -31,23 +32,31 @@ class TestQuantizePredicate: def test_matches_ffn_layers(self): from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) def test_rejects_embeddings(self): from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) - for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]: + for path in [ + "patch_embedding_proj", + "text_embedding_fc1", + "time_embedding.fc1", + ]: assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" def test_rejects_norms(self): from mlx_video.convert_wan import _quantize_predicate + mock_norm = nn.RMSNorm(64) assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) def test_rejects_non_quantizable_modules(self): from mlx_video.convert_wan import _quantize_predicate + mock_norm = nn.RMSNorm(64) # Even if path matches, module must have to_quantized assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm) @@ -55,13 +64,19 @@ class TestQuantizePredicate: def test_all_10_patterns_covered(self): """Verify exactly 10 layer patterns are targeted.""" from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) patterns = [ - "blocks.0.self_attn.q", "blocks.0.self_attn.k", - "blocks.0.self_attn.v", "blocks.0.self_attn.o", - "blocks.0.cross_attn.q", "blocks.0.cross_attn.k", - "blocks.0.cross_attn.v", "blocks.0.cross_attn.o", - "blocks.0.ffn.fc1", "blocks.0.ffn.fc2", + "blocks.0.self_attn.q", + "blocks.0.self_attn.k", + "blocks.0.self_attn.v", + "blocks.0.self_attn.o", + "blocks.0.cross_attn.q", + "blocks.0.cross_attn.k", + "blocks.0.cross_attn.v", + "blocks.0.cross_attn.o", + "blocks.0.ffn.fc1", + "blocks.0.ffn.fc2", ] matched = [p for p in patterns if _quantize_predicate(p, mock_linear)] assert len(matched) == 10 @@ -71,11 +86,12 @@ class TestQuantizePredicate: # Quantize Round-Trip Tests # --------------------------------------------------------------------------- + class TestQuantizeRoundTrip: def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): """Helper: create model, quantize, save to tmp_path.""" - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan.model import WanModel model = WanModel(config) nn.quantize( @@ -101,8 +117,10 @@ class TestQuantizeRoundTrip: model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( - model_path, config, + model_path, + config, quantization={"bits": 4, "group_size": 64}, ) @@ -119,8 +137,10 @@ class TestQuantizeRoundTrip: model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( - model_path, config, + model_path, + config, quantization={"bits": 8, "group_size": 64}, ) @@ -132,8 +152,10 @@ class TestQuantizeRoundTrip: model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( - model_path, config, + model_path, + config, quantization={"bits": 4, "group_size": 64}, ) @@ -151,6 +173,7 @@ class TestQuantizeRoundTrip: mx.save_safetensors(str(model_path), weights_dict) from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model(model_path, config, quantization=None) assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear) @@ -161,10 +184,11 @@ class TestQuantizeRoundTrip: # Quantized Inference Tests # --------------------------------------------------------------------------- + class TestQuantizedInference: def _make_quantized_model(self, config, bits=4): - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan.model import WanModel model = WanModel(config) nn.quantize( @@ -214,8 +238,8 @@ class TestQuantizedInference: def test_quantized_output_differs_from_unquantized(self): """Sanity check: quantization should change the weights.""" - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan.model import WanModel config = _make_tiny_config() mx.random.seed(42) @@ -243,11 +267,12 @@ class TestQuantizedInference: # Config Metadata Tests # --------------------------------------------------------------------------- + class TestQuantizationConfig: def test_config_metadata_written(self, tmp_path): """Verify _quantize_saved_model writes quantization metadata to config.json.""" - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_saved_model + from mlx_video.models.wan.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -270,8 +295,8 @@ class TestQuantizationConfig: assert cfg["quantization"]["group_size"] == 64 def test_config_metadata_8bit(self, tmp_path): - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_saved_model + from mlx_video.models.wan.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -291,8 +316,8 @@ class TestQuantizationConfig: def test_dual_model_quantization(self, tmp_path): """Verify dual-model quantization writes both model files.""" - from mlx_video.models.wan.model import WanModel from mlx_video.convert_wan import _quantize_saved_model + from mlx_video.models.wan.model import WanModel config = _make_tiny_config() diff --git a/tests/test_wan_rope_freqs.py b/tests/test_wan_rope_freqs.py index 9e41c5a..b37d7b0 100644 --- a/tests/test_wan_rope_freqs.py +++ b/tests/test_wan_rope_freqs.py @@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction: d = 128 # head_dim for all Wan models # Reference: three separate calls - correct = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + correct = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) # Wrong: single call wrong = rope_params(1024, d) mx.eval(correct, wrong) assert correct.shape == wrong.shape diff = np.abs(np.array(correct) - np.array(wrong)).max() - assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}" + assert ( + diff > 0.1 + ), f"Three-call and single-call should differ significantly, got max diff {diff}" def test_each_axis_starts_at_frequency_one(self): """Each axis (temporal/height/width) should have cos=1, sin=0 at position 0. @@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction: from mlx_video.models.wan.rope import rope_params d = 128 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) mx.eval(freqs) f = np.array(freqs) @@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction: # At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1) # Temporal axis first freq - np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5, - err_msg="temporal[0] cos at pos 1") + np.testing.assert_allclose( + f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1" + ) # Height axis first freq (starts at index d_t) - np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5, - err_msg="height[0] cos at pos 1") + np.testing.assert_allclose( + f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1" + ) # Width axis first freq (starts at index d_t + d_h) - np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, - err_msg="width[0] cos at pos 1") + np.testing.assert_allclose( + f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1" + ) def test_height_width_frequencies_identical(self): """Height and width axes should have identical frequency tables. @@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction: d = 128 d_h_dim = 2 * (d // 6) # 42 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, d_h_dim), - rope_params(1024, d_h_dim), - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, d_h_dim), + rope_params(1024, d_h_dim), + ], + axis=1, + ) mx.eval(freqs) f = np.array(freqs) @@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction: d_t = half_d - 2 * (half_d // 3) d_h = half_d // 3 - height_freqs = f[:, d_t:d_t + d_h] - width_freqs = f[:, d_t + d_h:] + height_freqs = f[:, d_t : d_t + d_h] + width_freqs = f[:, d_t + d_h :] np.testing.assert_array_equal(height_freqs, width_freqs) def test_frequency_range_per_axis(self): @@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction: from mlx_video.models.wan.rope import rope_params d = 128 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) mx.eval(freqs) f = np.array(freqs) @@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction: pos1_h = f[1, d_t, 0] # height first freq pos1_w = f[1, d_t + d_h, 0] # width first freq - assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}" + assert ( + pos1_t > 0.5 + ), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}" assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}" assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}" @@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction: freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4) d = head_dim # 16 - freqs_manual = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + freqs_manual = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) mx.eval(freqs_model, freqs_manual) np.testing.assert_array_equal( - np.array(freqs_model), np.array(freqs_manual), - err_msg="WanModel.freqs should use three-call construction" + np.array(freqs_model), + np.array(freqs_manual), + err_msg="WanModel.freqs should use three-call construction", ) def test_model_freqs_14b_dimensions(self): @@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction: from mlx_video.models.wan.rope import rope_params d = 128 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs - rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs - rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs + rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs + rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs + ], + axis=1, + ) mx.eval(freqs) assert freqs.shape == (1024, 64, 2) @@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference: @pytest.fixture def has_torch(self): try: - import torch + pass + return True except ImportError: pytest.skip("PyTorch not installed") @@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference: def test_freqs_match_pytorch_reference(self, has_torch): """Numerically compare MLX and PyTorch frequency tables.""" import torch + from mlx_video.models.wan.rope import rope_params d = 128 @@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference: def pt_rope_params(max_seq_len, dim, theta=10000): freqs = torch.outer( torch.arange(max_seq_len), - 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + 1.0 + / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)), + ) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - ref = torch.cat([ - pt_rope_params(1024, d - 4 * (d // 6)), - pt_rope_params(1024, 2 * (d // 6)), - pt_rope_params(1024, 2 * (d // 6)), - ], dim=1) + ref = torch.cat( + [ + pt_rope_params(1024, d - 4 * (d // 6)), + pt_rope_params(1024, 2 * (d // 6)), + pt_rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) # MLX - ours = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + ours = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) mx.eval(ours) our_cos = np.array(ours[:, :, 0]) @@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference: ref_cos = ref.real.float().numpy() ref_sin = ref.imag.float().numpy() - np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6, - err_msg="cos mismatch vs PyTorch reference") - np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6, - err_msg="sin mismatch vs PyTorch reference") + np.testing.assert_allclose( + our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference" + ) + np.testing.assert_allclose( + our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference" + ) class TestRoPEApplyWithCorrectFreqs: @@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs: This is the key property that was broken by the single-call bug: height/width frequencies were too low to distinguish nearby positions. """ - from mlx_video.models.wan.rope import rope_params, rope_apply + from mlx_video.models.wan.rope import rope_apply, rope_params d = 128 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) B, N = 1, 4 F, H, W = 1, 4, 4 @@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs: # Max diff should be >0.5 for both axes. With the bug, height was ~0.04 # and width was ~0.002. With correct freqs, both are ~1.3. - assert height_diff > 0.5, ( - f"Adjacent height positions should differ significantly, got {height_diff:.4f}" - ) - assert width_diff > 0.5, ( - f"Adjacent width positions should differ significantly, got {width_diff:.4f}" - ) + assert ( + height_diff > 0.5 + ), f"Adjacent height positions should differ significantly, got {height_diff:.4f}" + assert ( + width_diff > 0.5 + ), f"Adjacent width positions should differ significantly, got {width_diff:.4f}" # Height and width should have identical frequency tables → same diffs - np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5, - err_msg="Height and width should use identical frequency tables") + np.testing.assert_allclose( + height_diff, + width_diff, + rtol=1e-5, + err_msg="Height and width should use identical frequency tables", + ) def test_precomputed_matches_online(self): """rope_precompute_cos_sin + rope_apply should match non-precomputed path.""" @@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs: ) d = 128 - freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - ], axis=1) + freqs = mx.concatenate( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + axis=1, + ) B, N = 2, 4 F, H, W = 2, 3, 4 @@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs: mx.eval(out_online, out_precomp) np.testing.assert_allclose( - np.array(out_online), np.array(out_precomp), atol=1e-5, - err_msg="Precomputed and online RoPE should match" + np.array(out_online), + np.array(out_precomp), + atol=1e-5, + err_msg="Precomputed and online RoPE should match", ) diff --git a/tests/test_wan_scheduler.py b/tests/test_wan_scheduler.py index d16ff49..19cdcd7 100644 --- a/tests/test_wan_scheduler.py +++ b/tests/test_wan_scheduler.py @@ -6,14 +6,15 @@ import mlx.core as mx import numpy as np import pytest - # --------------------------------------------------------------------------- # Euler Scheduler Tests # --------------------------------------------------------------------------- + class TestFlowMatchEulerScheduler: def test_initialization(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() assert sched.num_train_timesteps == 1000 assert sched.timesteps is None @@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler: def test_set_timesteps(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) mx.eval(sched.timesteps, sched.sigmas) @@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler: def test_timesteps_decreasing(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) mx.eval(sched.timesteps) @@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler: def test_sigmas_decreasing(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=1.0) mx.eval(sched.sigmas) @@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler: def test_terminal_sigma_is_zero(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=5.0) mx.eval(sched.sigmas) @@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler: def test_shift_effect(self): """Larger shift should push sigmas toward higher values.""" from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched1 = FlowMatchEulerScheduler() sched2 = FlowMatchEulerScheduler() sched1.set_timesteps(20, shift=1.0) @@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler: def test_step_euler(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(10, shift=1.0) mx.eval(sched.sigmas) @@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler: # Euler: x_next = x + (sigma_next - sigma) * v expected = 1.0 + (sigma_next - sigma) * 0.5 np.testing.assert_allclose( - np.array(result).flatten()[0], expected, rtol=1e-4, + np.array(result).flatten()[0], + expected, + rtol=1e-4, ) def test_step_index_increments(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) assert sched._step_index == 0 @@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler: def test_reset(self): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 1, 1)) @@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler: @pytest.mark.parametrize("steps", [10, 20, 40, 50]) def test_various_step_counts(self, steps): from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(steps, shift=12.0) mx.eval(sched.timesteps, sched.sigmas) @@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler: def test_full_denoise_loop(self): """Run a complete denoise loop with zero velocity -> sample unchanged.""" from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 2, 1, 2, 2)) @@ -141,22 +154,26 @@ class TestComputeSigmas: def test_length(self): from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) assert len(sigmas) == 21 # num_steps + terminal def test_terminal_zero(self): from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(10, shift=1.0) assert sigmas[-1] == 0.0 def test_starts_near_one(self): from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) # Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0) np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3) def test_decreasing(self): from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) assert np.all(np.diff(sigmas) <= 0) @@ -169,6 +186,7 @@ class TestComputeSigmas: shift is applied only once (single-shift). """ from mlx_video.models.wan.scheduler import _compute_sigmas + steps, shift, N = 50, 5.0, 1000 sigmas = _compute_sigmas(steps, shift, N) # Official single-shift: unshifted bounds, then shift once @@ -183,6 +201,7 @@ class TestComputeSigmas: def test_shift_one_is_near_linear(self): from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(10, shift=1.0) # With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule) # so schedule is nearly linear from ~0.999 to 0 @@ -196,6 +215,7 @@ class TestComputeSigmas: FlowMatchEulerScheduler, FlowUniPCScheduler, ) + scheds = [ FlowMatchEulerScheduler(1000), FlowDPMPP2MScheduler(1000), @@ -214,6 +234,7 @@ class TestComputeSigmas: FlowMatchEulerScheduler, FlowUniPCScheduler, ) + scheds = [ FlowMatchEulerScheduler(1000), FlowDPMPP2MScheduler(1000), @@ -235,12 +256,14 @@ class TestComputeSigmas: class TestFlowDPMPP2MScheduler: def test_initialization(self): from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() assert sched.num_train_timesteps == 1000 assert sched.lower_order_final is True def test_set_timesteps(self): from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) mx.eval(sched.timesteps, sched.sigmas) @@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler: def test_step_index_increments(self): from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 4, 1, 2, 2)) @@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler: def test_reset(self): from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 1, 1)) @@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler: def test_full_loop_finite(self): """Full loop with constant velocity should produce finite output.""" from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=1.0) sample = mx.ones((1, 2, 1, 2, 2)) @@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler: def test_first_step_is_first_order(self): """First step should use 1st-order (no prev_x0 available).""" from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) sample = mx.random.normal((1, 4, 2, 4, 4)) @@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler: def test_second_step_uses_correction(self): """After first step, DPM++ should have stored prev_x0 for correction.""" from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) sample = mx.random.normal((1, 4, 1, 2, 2)) @@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler: x0_after_second = sched._prev_x0 assert x0_after_second is not None # The stored x0 should differ from the first step's - assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6) + assert not np.allclose( + np.array(x0_after_first), np.array(x0_after_second), atol=1e-6 + ) def test_denoise_to_target(self): """Perfect oracle should denoise to target with any solver.""" from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) target = mx.zeros((1, 2, 1, 4, 4)) @@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(steps, shift=5.0) mx.eval(sched.timesteps, sched.sigmas) @@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler: def test_terminal_sigma_produces_x0(self): """When sigma_next=0 the scheduler should return x0 directly.""" from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 1, 1)) * 3.0 @@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler: class TestFlowUniPCScheduler: def test_initialization(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() assert sched.num_train_timesteps == 1000 assert sched.solver_order == 2 @@ -369,6 +403,7 @@ class TestFlowUniPCScheduler: def test_set_timesteps(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(30, shift=12.0) mx.eval(sched.timesteps, sched.sigmas) @@ -377,6 +412,7 @@ class TestFlowUniPCScheduler: def test_step_index_increments(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 1, 1)) @@ -387,6 +423,7 @@ class TestFlowUniPCScheduler: def test_reset(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 1, 1)) @@ -399,6 +436,7 @@ class TestFlowUniPCScheduler: def test_full_loop_finite(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(10, shift=1.0) sample = mx.ones((1, 2, 1, 2, 2)) @@ -411,6 +449,7 @@ class TestFlowUniPCScheduler: def test_corrector_not_applied_first_step(self): """First step should skip the corrector (no history).""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) sample = mx.random.normal((1, 4, 1, 2, 2)) @@ -424,6 +463,7 @@ class TestFlowUniPCScheduler: def test_corrector_applied_after_first_step(self): """Steps after the first should use the corrector when enabled.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) sample = mx.random.normal((1, 2, 1, 4, 4)) @@ -436,6 +476,7 @@ class TestFlowUniPCScheduler: def test_denoise_to_target(self): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(20, shift=5.0) target = mx.zeros((1, 2, 1, 4, 4)) @@ -450,6 +491,7 @@ class TestFlowUniPCScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() sched.set_timesteps(steps, shift=5.0) mx.eval(sched.timesteps, sched.sigmas) @@ -459,6 +501,7 @@ class TestFlowUniPCScheduler: def test_disable_corrector(self): """Disabling corrector on step 0 should still work without error.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) sched.set_timesteps(5, shift=1.0) sample = mx.ones((1, 1, 1, 2, 2)) @@ -471,6 +514,7 @@ class TestFlowUniPCScheduler: def test_solver_order_3(self): """Order 3 should work without error.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) sched.set_timesteps(10, shift=5.0) sample = mx.random.normal((1, 2, 1, 2, 2)) @@ -483,6 +527,7 @@ class TestFlowUniPCScheduler: def test_corrector_rhos_c_not_hardcoded(self): """Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5.""" import math + # For 50-step schedule with shift=5.0, order 2 corrector at step 5: # rhos_c[0] (history) should be ~0.07, NOT 0.5 # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 @@ -525,16 +570,23 @@ class TestFlowUniPCScheduler: rhos_c = np.linalg.solve(R, b) # History weight should be small (~0.07-0.09), not 0.5 - assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large" - assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive" + assert ( + rhos_c[0] < 0.15 + ), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large" + assert ( + rhos_c[0] > 0.0 + ), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive" # D1_t weight should be ~0.42-0.45, not 0.5 - assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range" + assert ( + 0.3 < rhos_c[1] < 0.5 + ), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range" # --------------------------------------------------------------------------- # Scheduler Coherence Tests # --------------------------------------------------------------------------- + class TestSchedulerCoherence: """Tests that Euler, DPM++, and UniPC schedulers produce coherent results. @@ -599,11 +651,15 @@ class TestSchedulerCoherence: results[name] = np.array(r) np.testing.assert_allclose( - results["dpm++"], results["euler"], atol=1e-5, + results["dpm++"], + results["euler"], + atol=1e-5, err_msg="DPM++ step 0 should match Euler", ) np.testing.assert_allclose( - results["unipc"], results["euler"], atol=1e-5, + results["unipc"], + results["euler"], + atol=1e-5, err_msg="UniPC step 0 should match Euler", ) @@ -621,11 +677,15 @@ class TestSchedulerCoherence: unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise) mx.eval(euler_r, dpm_r, unipc_r) np.testing.assert_allclose( - np.array(dpm_r), np.array(euler_r), atol=1e-5, + np.array(dpm_r), + np.array(euler_r), + atol=1e-5, err_msg=f"DPM++ step 0 differs from Euler at shift={shift}", ) np.testing.assert_allclose( - np.array(unipc_r), np.array(euler_r), atol=1e-5, + np.array(unipc_r), + np.array(euler_r), + atol=1e-5, err_msg=f"UniPC step 0 differs from Euler at shift={shift}", ) @@ -644,7 +704,9 @@ class TestSchedulerCoherence: latents = sched.step(v, sched.timesteps[i], latents) mx.eval(latents) np.testing.assert_allclose( - np.array(latents), 0.0, atol=1e-3, + np.array(latents), + 0.0, + atol=1e-3, err_msg=f"{name} did not converge to target with oracle", ) @@ -669,12 +731,12 @@ class TestSchedulerCoherence: # Higher-order solvers should not be significantly worse than Euler # (add small epsilon to handle near-zero errors from floating point noise) eps = 1e-6 - assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, ( - f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}" - ) - assert errors["unipc"] <= errors["euler"] * 1.5 + eps, ( - f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}" - ) + assert ( + errors["dpm++"] <= errors["euler"] * 1.5 + eps + ), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}" + assert ( + errors["unipc"] <= errors["euler"] * 1.5 + eps + ), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}" def test_multistep_trajectory_similar_magnitude(self): """Over a full denoising loop with constant velocity, all solvers @@ -696,9 +758,9 @@ class TestSchedulerCoherence: # All solvers should produce results within the same order of magnitude vals = list(final_means.values()) ratio = max(vals) / max(min(vals), 1e-10) - assert ratio < 10.0, ( - f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}" - ) + assert ( + ratio < 10.0 + ), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}" def test_intermediate_values_finite(self): """Every intermediate latent value must be finite for all solvers.""" @@ -712,9 +774,9 @@ class TestSchedulerCoherence: vel = mx.random.normal(shape) latents = sched.step(vel, sched.timesteps[i], latents) mx.eval(latents) - assert np.isfinite(np.array(latents)).all(), ( - f"{name} produced non-finite values at step {i}" - ) + assert np.isfinite( + np.array(latents) + ).all(), f"{name} produced non-finite values at step {i}" def test_lambda_boundary_values(self): """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" @@ -724,17 +786,17 @@ class TestSchedulerCoherence: ) for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): - assert cls._lambda(1.0) == -math.inf, ( - f"{cls.__name__}._lambda(1.0) should be -inf" - ) - assert cls._lambda(0.0) == math.inf, ( - f"{cls.__name__}._lambda(0.0) should be +inf" - ) + assert ( + cls._lambda(1.0) == -math.inf + ), f"{cls.__name__}._lambda(1.0) should be -inf" + assert ( + cls._lambda(0.0) == math.inf + ), f"{cls.__name__}._lambda(0.0) should be +inf" # Interior values should be finite lam = cls._lambda(0.5) - assert math.isfinite(lam) and lam == 0.0, ( - f"{cls.__name__}._lambda(0.5) should be 0.0" - ) + assert ( + math.isfinite(lam) and lam == 0.0 + ), f"{cls.__name__}._lambda(0.5) should be 0.0" def test_lambda_monotonically_decreasing(self): """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" @@ -770,7 +832,9 @@ class TestSchedulerCoherence: result = scheds[name].step(vel, scheds[name].timesteps[0], sample) mx.eval(result) np.testing.assert_allclose( - np.array(result), np.array(expected), atol=5e-4, + np.array(result), + np.array(expected), + atol=5e-4, err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})", ) @@ -790,10 +854,14 @@ class TestSchedulerCoherence: results[name] = np.array(r) np.testing.assert_allclose( - results["dpm++"], results["euler"], atol=1e-5, + results["dpm++"], + results["euler"], + atol=1e-5, ) np.testing.assert_allclose( - results["unipc"], results["euler"], atol=1e-5, + results["unipc"], + results["euler"], + atol=1e-5, ) def test_dpmpp_unipc_agree_on_step1(self): @@ -834,7 +902,10 @@ class TestSchedulerCoherence: shape = (1, 2, 1, 2, 2) noise = mx.random.normal(shape) - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowUniPCScheduler, + ) for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): sched = cls() @@ -857,14 +928,19 @@ class TestSchedulerCoherence: mx.eval(latents) result2 = np.array(latents) - np.testing.assert_allclose(result1, result2, atol=1e-5, - err_msg=f"{cls.__name__} not reproducible after reset()") + np.testing.assert_allclose( + result1, + result2, + atol=1e-5, + err_msg=f"{cls.__name__} not reproducible after reset()", + ) # --------------------------------------------------------------------------- # UniPC Corrector Default Tests # --------------------------------------------------------------------------- + class TestUniPCCorrectorDefault: """Tests that the UniPC corrector is enabled by default, matching official FlowUniPCMultistepScheduler behavior.""" @@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault: def test_corrector_enabled_by_default(self): """Default construction should have corrector enabled.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() assert sched._use_corrector is True def test_corrector_affects_output(self): """Corrector should produce different results than no corrector after step 1.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + mx.random.seed(42) shape = (1, 4, 1, 4, 4) noise = mx.random.normal(shape) @@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault: def test_corrector_does_not_affect_first_step(self): """Step 0 should be identical regardless of corrector setting.""" from mlx_video.models.wan.scheduler import FlowUniPCScheduler + mx.random.seed(42) shape = (1, 4, 1, 4, 4) noise = mx.random.normal(shape) diff --git a/tests/test_wan_t5.py b/tests/test_wan_t5.py index 7cb064f..7bf0c18 100644 --- a/tests/test_wan_t5.py +++ b/tests/test_wan_t5.py @@ -3,16 +3,16 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -import pytest - # --------------------------------------------------------------------------- # T5 Encoder Tests # --------------------------------------------------------------------------- + class TestT5LayerNorm: def test_output_shape(self): from mlx_video.models.wan.text_encoder import T5LayerNorm + norm = T5LayerNorm(64) x = mx.random.normal((2, 10, 64)) out = norm(x) @@ -22,6 +22,7 @@ class TestT5LayerNorm: def test_rms_normalization(self): """After T5LayerNorm with weight=1, RMS should be ~1.""" from mlx_video.models.wan.text_encoder import T5LayerNorm + norm = T5LayerNorm(128) x = mx.random.normal((1, 5, 128)) * 5.0 out = norm(x) @@ -35,6 +36,7 @@ class TestT5LayerNorm: class TestT5RelativeEmbedding: def test_output_shape(self): from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(10, 10) mx.eval(out) @@ -42,6 +44,7 @@ class TestT5RelativeEmbedding: def test_asymmetric_lengths(self): from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(8, 12) mx.eval(out) @@ -50,6 +53,7 @@ class TestT5RelativeEmbedding: def test_symmetry(self): """Position bias should have structure (not all zeros/random).""" from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2) out = rel_emb(6, 6) mx.eval(out) @@ -64,6 +68,7 @@ class TestT5RelativeEmbedding: class TestT5Attention: def test_output_shape(self): from mlx_video.models.wan.text_encoder import T5Attention + attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) out = attn(x) @@ -73,12 +78,14 @@ class TestT5Attention: def test_no_scaling(self): """T5 attention famously has no sqrt(d) scaling. Verify structure.""" from mlx_video.models.wan.text_encoder import T5Attention + attn = T5Attention(dim=64, dim_attn=64, num_heads=4) # No scale attribute (unlike standard attention) assert not hasattr(attn, "scale") def test_with_position_bias(self): from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding + attn = T5Attention(dim=64, dim_attn=64, num_heads=4) rel_emb = T5RelativeEmbedding(32, 4) x = mx.random.normal((1, 10, 64)) @@ -89,6 +96,7 @@ class TestT5Attention: def test_with_mask(self): from mlx_video.models.wan.text_encoder import T5Attention + attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) mask = mx.ones((1, 10)) @@ -101,6 +109,7 @@ class TestT5Attention: class TestT5FeedForward: def test_output_shape(self): from mlx_video.models.wan.text_encoder import T5FeedForward + ffn = T5FeedForward(64, 256) x = mx.random.normal((1, 10, 64)) out = ffn(x) @@ -110,6 +119,7 @@ class TestT5FeedForward: def test_gated_structure(self): """T5 FFN is gated: gate(x) * fc1(x).""" from mlx_video.models.wan.text_encoder import T5FeedForward + ffn = T5FeedForward(32, 64) assert hasattr(ffn, "gate_proj") assert hasattr(ffn, "fc1") @@ -122,9 +132,16 @@ class TestT5Encoder: def test_output_shape(self): from mlx_video.models.wan.text_encoder import T5Encoder + encoder = T5Encoder( - vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, - num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, + vocab_size=100, + dim=64, + dim_attn=64, + dim_ffn=128, + num_heads=4, + num_layers=2, + num_buckets=32, + shared_pos=False, ) ids = mx.array([[1, 5, 10, 0, 0]]) mask = mx.array([[1, 1, 1, 0, 0]]) @@ -134,9 +151,16 @@ class TestT5Encoder: def test_shared_pos(self): from mlx_video.models.wan.text_encoder import T5Encoder + encoder = T5Encoder( - vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, - num_heads=4, num_layers=2, num_buckets=32, shared_pos=True, + vocab_size=100, + dim=64, + dim_attn=64, + dim_ffn=128, + num_heads=4, + num_layers=2, + num_buckets=32, + shared_pos=True, ) assert encoder.pos_embedding is not None for block in encoder.blocks: @@ -144,9 +168,16 @@ class TestT5Encoder: def test_per_layer_pos(self): from mlx_video.models.wan.text_encoder import T5Encoder + encoder = T5Encoder( - vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, - num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, + vocab_size=100, + dim=64, + dim_attn=64, + dim_ffn=128, + num_heads=4, + num_layers=2, + num_buckets=32, + shared_pos=False, ) assert encoder.pos_embedding is None for block in encoder.blocks: @@ -154,18 +185,32 @@ class TestT5Encoder: def test_param_count(self): from mlx_video.models.wan.text_encoder import T5Encoder + encoder = T5Encoder( - vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, - num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, + vocab_size=100, + dim=64, + dim_attn=64, + dim_ffn=128, + num_heads=4, + num_layers=2, + num_buckets=32, + shared_pos=False, ) num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters())) assert num_params > 0 def test_without_mask(self): from mlx_video.models.wan.text_encoder import T5Encoder + encoder = T5Encoder( - vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, - num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, + vocab_size=100, + dim=64, + dim_attn=64, + dim_ffn=128, + num_heads=4, + num_layers=2, + num_buckets=32, + shared_pos=False, ) ids = mx.array([[1, 5, 10]]) out = encoder(ids) diff --git a/tests/test_wan_tiling.py b/tests/test_wan_tiling.py index 3353dd4..303f048 100644 --- a/tests/test_wan_tiling.py +++ b/tests/test_wan_tiling.py @@ -2,13 +2,11 @@ import mlx.core as mx import numpy as np -import pytest from mlx_video.models.ltx.video_vae.tiling import ( TilingConfig, decode_with_tiling, split_in_spatial, - split_in_temporal, ) @@ -49,16 +47,24 @@ class TestNonCausalTemporal: # Causal: 1 + (4-1)*4 = 13 out_causal = decode_with_tiling( - dummy_decoder_causal, latents, config, - spatial_scale=scale, temporal_scale=scale, causal_temporal=True, + dummy_decoder_causal, + latents, + config, + spatial_scale=scale, + temporal_scale=scale, + causal_temporal=True, ) mx.eval(out_causal) assert out_causal.shape[2] == 1 + (t - 1) * scale # 13 # Non-causal: 4*4 = 16 out_noncausal = decode_with_tiling( - dummy_decoder_noncausal, latents, config, - spatial_scale=scale, temporal_scale=scale, causal_temporal=False, + dummy_decoder_noncausal, + latents, + config, + spatial_scale=scale, + temporal_scale=scale, + causal_temporal=False, ) mx.eval(out_noncausal) assert out_noncausal.shape[2] == t * scale # 16 @@ -100,9 +106,9 @@ class TestWan22TiledDecoding: mx.eval(out_tiled) # Both should produce the same shape - assert out_regular.shape == out_tiled.shape, ( - f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" - ) + assert ( + out_regular.shape == out_tiled.shape + ), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" def test_decode_tiled_falls_through_when_small(self): """When input is smaller than tile size, decode_tiled should produce same output as __call__.""" @@ -120,8 +126,10 @@ class TestWan22TiledDecoding: mx.eval(out_tiled) np.testing.assert_allclose( - np.array(out_regular), np.array(out_tiled), - rtol=1e-4, atol=1e-4, + np.array(out_regular), + np.array(out_tiled), + rtol=1e-4, + atol=1e-4, err_msg="Tiled decode should match regular decode for small inputs", ) @@ -152,9 +160,9 @@ class TestWan21TiledDecoding: out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) mx.eval(out_tiled) - assert out_regular.shape == out_tiled.shape, ( - f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" - ) + assert ( + out_regular.shape == out_tiled.shape + ), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" def test_decode_tiled_falls_through_when_small(self): """When input is smaller than tile size, decode_tiled should produce same output as decode.""" @@ -171,8 +179,10 @@ class TestWan21TiledDecoding: mx.eval(out_tiled) np.testing.assert_allclose( - np.array(out_regular), np.array(out_tiled), - rtol=1e-4, atol=1e-4, + np.array(out_regular), + np.array(out_tiled), + rtol=1e-4, + atol=1e-4, err_msg="Tiled decode should match regular decode for small inputs", ) @@ -185,8 +195,13 @@ class TestWan21TemporalScale: from mlx_video.models.wan.vae import Decoder3d # Small decoder for fast test - dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1, - temporal_upsample=[True, True, False]) + dec = Decoder3d( + dim=16, + z_dim=4, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temporal_upsample=[True, True, False], + ) mx.eval(dec.parameters()) x = mx.random.normal((1, 4, 3, 4, 4)) # T=3 diff --git a/tests/test_wan_transformer.py b/tests/test_wan_transformer.py index dd9acec..8cbfb67 100644 --- a/tests/test_wan_transformer.py +++ b/tests/test_wan_transformer.py @@ -2,16 +2,16 @@ import mlx.core as mx import numpy as np -import pytest - # --------------------------------------------------------------------------- # Transformer Block Tests # --------------------------------------------------------------------------- + class TestWanFFN: def test_output_shape(self): from mlx_video.models.wan.transformer import WanFFN + ffn = WanFFN(64, 256) x = mx.random.normal((2, 10, 64)) out = ffn(x) @@ -21,6 +21,7 @@ class TestWanFFN: def test_gelu_activation(self): """FFN should use GELU activation (non-linearity).""" from mlx_video.models.wan.transformer import WanFFN + ffn = WanFFN(32, 128) x = mx.ones((1, 1, 32)) * 2.0 out1 = ffn(x) @@ -39,10 +40,13 @@ class TestWanAttentionBlock: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock( - self.dim, self.ffn_dim, self.num_heads, + self.dim, + self.ffn_dim, + self.num_heads, cross_attn_norm=True, ) B, L = 1, 24 @@ -53,37 +57,49 @@ class TestWanAttentionBlock: freqs = rope_params(1024, self.dim // self.num_heads) out = block( - x, e, seq_lens=[L], grid_sizes=[(F, H, W)], - freqs=freqs, context=context, + x, + e, + seq_lens=[L], + grid_sizes=[(F, H, W)], + freqs=freqs, + context=context, ) mx.eval(out) assert out.shape == (B, L, self.dim) def test_modulation_shape(self): from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) assert block.modulation.shape == (1, 6, self.dim) def test_with_cross_attn_norm(self): from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock( - self.dim, self.ffn_dim, self.num_heads, + self.dim, + self.ffn_dim, + self.num_heads, cross_attn_norm=True, ) assert block.norm3 is not None def test_without_cross_attn_norm(self): from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock( - self.dim, self.ffn_dim, self.num_heads, + self.dim, + self.ffn_dim, + self.num_heads, cross_attn_norm=False, ) assert block.norm3 is None def test_residual_connection(self): """Output should differ from zero even with small random init.""" - from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) B, L = 1, 8 F, H, W = 2, 2, 2 @@ -102,6 +118,7 @@ class TestWanAttentionBlock: # Float32 Modulation Precision Tests # --------------------------------------------------------------------------- + class TestFloat32Modulation: """Tests that modulation/gate operations are computed in float32, matching official torch.amp.autocast('cuda', dtype=torch.float32).""" @@ -113,13 +130,15 @@ class TestFloat32Modulation: def test_block_modulation_in_float32(self): """Modulation param starts random but should be usable as float32.""" from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True) assert block.modulation.dtype == mx.float32 def test_block_output_float32_with_bf16_modulation_input(self): """Even if e (time embedding) arrives as bf16, modulation should cast to f32.""" - from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan.transformer import WanAttentionBlock + block = WanAttentionBlock(self.dim, 128, 4) B, L = 1, 8 x = mx.random.normal((B, L, self.dim)) @@ -135,6 +154,7 @@ class TestFloat32Modulation: def test_head_modulation_float32(self): """Head modulation should be float32 even with bf16 e input.""" from mlx_video.models.wan.model import Head + head = Head(self.dim, 4, (1, 2, 2)) x = mx.random.normal((1, 8, self.dim)) e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16) @@ -145,6 +165,7 @@ class TestFloat32Modulation: def test_model_time_embedding_float32(self): """sinusoidal_embedding_1d output must be float32.""" from mlx_video.models.wan.model import sinusoidal_embedding_1d + t = mx.array([500.0]) emb = sinusoidal_embedding_1d(256, t) mx.eval(emb) @@ -153,6 +174,7 @@ class TestFloat32Modulation: def test_model_per_token_time_embedding_float32(self): """Per-token time embeddings (I2V) should also be float32.""" from mlx_video.models.wan.model import sinusoidal_embedding_1d + t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] emb = sinusoidal_embedding_1d(256, t) mx.eval(emb) diff --git a/tests/test_wan_vae.py b/tests/test_wan_vae.py index cd2cf94..c604e74 100644 --- a/tests/test_wan_vae.py +++ b/tests/test_wan_vae.py @@ -4,16 +4,16 @@ import math import mlx.core as mx import numpy as np -import pytest - # --------------------------------------------------------------------------- # VAE 2.1 Tests # --------------------------------------------------------------------------- + class TestCausalConv3d: def test_output_shape_stride1(self): from mlx_video.models.wan.vae import CausalConv3d + conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1) # Initialize weights conv.weight = mx.random.normal(conv.weight.shape) * 0.02 @@ -29,6 +29,7 @@ class TestCausalConv3d: def test_output_shape_kernel1(self): from mlx_video.models.wan.vae import CausalConv3d + conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0) conv.weight = mx.random.normal(conv.weight.shape) * 0.02 x = mx.random.normal((1, 4, 2, 4, 4)) @@ -39,6 +40,7 @@ class TestCausalConv3d: def test_causal_padding(self): """Causal conv should only use past/current frames, not future.""" from mlx_video.models.wan.vae import CausalConv3d + conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.bias = mx.zeros((2,)) @@ -55,6 +57,7 @@ class TestCausalConv3d: class TestResidualBlock: def test_same_dim(self): from mlx_video.models.wan.vae import ResidualBlock + block = ResidualBlock(8, 8) x = mx.random.normal((1, 8, 2, 4, 4)) out = block(x) @@ -63,6 +66,7 @@ class TestResidualBlock: def test_different_dim(self): from mlx_video.models.wan.vae import ResidualBlock + block = ResidualBlock(8, 16) x = mx.random.normal((1, 8, 2, 4, 4)) out = block(x) @@ -71,11 +75,13 @@ class TestResidualBlock: def test_shortcut_exists_when_dims_differ(self): from mlx_video.models.wan.vae import ResidualBlock + block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_when_dims_same(self): from mlx_video.models.wan.vae import ResidualBlock + block = ResidualBlock(8, 8) assert block.shortcut is None @@ -83,6 +89,7 @@ class TestResidualBlock: class TestAttentionBlock: def test_output_shape(self): from mlx_video.models.wan.vae import AttentionBlock + block = AttentionBlock(8) x = mx.random.normal((1, 8, 2, 4, 4)) out = block(x) @@ -91,6 +98,7 @@ class TestAttentionBlock: def test_residual_connection(self): from mlx_video.models.wan.vae import AttentionBlock + block = AttentionBlock(8) x = mx.random.normal((1, 8, 1, 3, 3)) out = block(x) @@ -102,13 +110,15 @@ class TestAttentionBlock: class TestWanVAE: def test_instantiation(self): from mlx_video.models.wan.vae import WanVAE + vae = WanVAE(z_dim=16) assert vae.z_dim == 16 assert vae.mean.shape == (16,) assert vae.std.shape == (16,) def test_normalization_stats(self): - from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD + from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD + assert len(VAE_MEAN) == 16 assert len(VAE_STD) == 16 assert all(s > 0 for s in VAE_STD) @@ -124,6 +134,7 @@ class TestVAE22CausalConv3d: def test_output_shape_k3(self): from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(8, 16, kernel_size=3, padding=1) x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] out = conv(x) @@ -132,6 +143,7 @@ class TestVAE22CausalConv3d: def test_output_shape_k1(self): from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(8, 16, kernel_size=1) x = mx.random.normal((1, 2, 4, 4, 8)) out = conv(x) @@ -141,6 +153,7 @@ class TestVAE22CausalConv3d: def test_temporal_causal(self): """Output at t=0 should not depend on t>0.""" from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(2, 2, kernel_size=3, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.bias = mx.zeros(conv.bias.shape) @@ -151,10 +164,13 @@ class TestVAE22CausalConv3d: t0_ref = np.array(out_zero[0, 0]) # Modify t=2..3; output at t=0 should be unchanged - x_mod = mx.concatenate([ - x[:, :2], - mx.ones((1, 2, 4, 4, 2)), - ], axis=1) + x_mod = mx.concatenate( + [ + x[:, :2], + mx.ones((1, 2, 4, 4, 2)), + ], + axis=1, + ) out_mod = conv(x_mod) mx.eval(out_mod) t0_mod = np.array(out_mod[0, 0]) @@ -163,6 +179,7 @@ class TestVAE22CausalConv3d: def test_channels_last_format(self): """Verify input/output are channels-last [B, T, H, W, C].""" from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(4, 8, kernel_size=3, padding=1) x = mx.random.normal((2, 3, 6, 6, 4)) out = conv(x) @@ -175,6 +192,7 @@ class TestRMSNorm: def test_output_shape(self): from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(16) x = mx.random.normal((2, 4, 4, 4, 16)) out = norm(x) @@ -184,6 +202,7 @@ class TestRMSNorm: def test_l2_normalization(self): """RMS_norm should normalize to unit L2 norm * sqrt(dim).""" from mlx_video.models.wan.vae22 import RMS_norm + dim = 32 norm = RMS_norm(dim) x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values @@ -197,6 +216,7 @@ class TestRMSNorm: def test_scale_invariant(self): """Scaling input by constant should not change output (L2 norm property).""" from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(8) x = mx.random.normal((1, 1, 1, 1, 8)) out1 = norm(x) @@ -207,6 +227,7 @@ class TestRMSNorm: def test_gamma_effect(self): """Non-unit gamma should scale output.""" from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(4) norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) x = mx.ones((1, 1, 1, 1, 4)) @@ -221,6 +242,7 @@ class TestDupUp3D: def test_spatial_only(self): from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) out = up(x) @@ -229,6 +251,7 @@ class TestDupUp3D: def test_temporal_and_spatial(self): from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(16, 8, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 16)) out = up(x) @@ -237,6 +260,7 @@ class TestDupUp3D: def test_first_chunk_trims(self): from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) out_normal = up(x, first_chunk=False) @@ -248,6 +272,7 @@ class TestDupUp3D: def test_no_temporal_first_chunk_noop(self): from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) out_normal = up(x, first_chunk=False) @@ -262,6 +287,7 @@ class TestVAE22Resample: def test_upsample2d_shape(self): from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample2d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 x = mx.random.normal((1, 2, 4, 4, 8)) @@ -271,6 +297,7 @@ class TestVAE22Resample: def test_upsample3d_shape(self): from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 x = mx.random.normal((1, 2, 4, 4, 8)) @@ -280,6 +307,7 @@ class TestVAE22Resample: def test_upsample3d_first_chunk(self): from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 x = mx.random.normal((1, 2, 4, 4, 8)) @@ -291,6 +319,7 @@ class TestVAE22Resample: def test_upsample3d_first_chunk_single_frame(self): """Single-frame input with first_chunk: no temporal upsample.""" from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 x = mx.random.normal((1, 1, 4, 4, 8)) @@ -308,6 +337,7 @@ class TestVAE22Resample: the first input frame (not on time_conv parameters). """ from mlx_video.models.wan.vae22 import Resample + C = 8 r = Resample(C, "upsample3d") # Set time_conv weights to large values so its effect is detectable @@ -334,8 +364,9 @@ class TestVAE22Resample: # Compare first output frame to reference first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C) mx.eval(first_out) - assert mx.allclose(first_out, ref, atol=1e-5).item(), \ - "First frame should bypass time_conv and match spatial-only upsample" + assert mx.allclose( + first_out, ref, atol=1e-5 + ).item(), "First frame should bypass time_conv and match spatial-only upsample" class TestVAE22ResidualBlock: @@ -343,6 +374,7 @@ class TestVAE22ResidualBlock: def test_same_dim(self): from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 8) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) @@ -351,6 +383,7 @@ class TestVAE22ResidualBlock: def test_different_dim(self): from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) @@ -359,11 +392,13 @@ class TestVAE22ResidualBlock: def test_shortcut_when_dims_differ(self): from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_same_dim(self): from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 8) assert block.shortcut is None @@ -374,6 +409,7 @@ class TestResidualBlockLayers: def test_layer_names_no_underscore_prefix(self): """Layer names must NOT start with underscore (MLX ignores them).""" from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 8) params = dict(block.parameters()) # All param keys should use layer_N, not _layer_N @@ -382,6 +418,7 @@ class TestResidualBlockLayers: def test_has_expected_layers(self): from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 16) assert hasattr(block, "layer_0") # first RMS_norm assert hasattr(block, "layer_2") # first CausalConv3d @@ -390,6 +427,7 @@ class TestResidualBlockLayers: def test_forward_shape(self): from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) @@ -402,6 +440,7 @@ class TestVAE22AttentionBlock: def test_output_shape(self): from mlx_video.models.wan.vae22 import AttentionBlock + block = AttentionBlock(16) block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01 @@ -412,6 +451,7 @@ class TestVAE22AttentionBlock: def test_residual_connection(self): from mlx_video.models.wan.vae22 import AttentionBlock + block = AttentionBlock(8) block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) block.proj_weight = mx.zeros(block.proj_weight.shape) @@ -427,6 +467,7 @@ class TestHead22: def test_output_shape(self): from mlx_video.models.wan.vae22 import Head22 + head = Head22(16, out_channels=12) x = mx.random.normal((1, 2, 4, 4, 16)) out = head(x) @@ -436,6 +477,7 @@ class TestHead22: def test_layer_names_no_underscore(self): """Head layers must not use underscore prefix.""" from mlx_video.models.wan.vae22 import Head22 + head = Head22(8) assert hasattr(head, "layer_0") # RMS_norm assert hasattr(head, "layer_2") # CausalConv3d @@ -449,6 +491,7 @@ class TestUnpatchify: def test_basic_shape(self): from mlx_video.models.wan.vae22 import _unpatchify + x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 out = _unpatchify(x, patch_size=2) mx.eval(out) @@ -456,6 +499,7 @@ class TestUnpatchify: def test_patch_size_1_noop(self): from mlx_video.models.wan.vae22 import _unpatchify + x = mx.random.normal((1, 2, 4, 4, 3)) out = _unpatchify(x, patch_size=1) mx.eval(out) @@ -464,6 +508,7 @@ class TestUnpatchify: def test_preserves_content(self): """Unpatchify should be a lossless rearrangement.""" from mlx_video.models.wan.vae22 import _unpatchify + x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) out = _unpatchify(x, patch_size=2) mx.eval(out) @@ -477,6 +522,7 @@ class TestDenormalizeLatents: def test_output_shape(self): from mlx_video.models.wan.vae22 import denormalize_latents + z = mx.random.normal((1, 2, 4, 4, 48)) out = denormalize_latents(z) mx.eval(out) @@ -484,16 +530,23 @@ class TestDenormalizeLatents: def test_custom_mean_std(self): from mlx_video.models.wan.vae22 import denormalize_latents + z = mx.ones((1, 1, 1, 1, 4)) mean = mx.array([1.0, 2.0, 3.0, 4.0]) std = mx.array([0.5, 0.5, 0.5, 0.5]) out = denormalize_latents(z, mean=mean, std=std) mx.eval(out) # z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5] - np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5) + np.testing.assert_allclose( + np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5 + ) def test_uses_default_constants(self): - from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents + from mlx_video.models.wan.vae22 import ( + VAE22_MEAN, + denormalize_latents, + ) + # Should not raise with default constants z = mx.zeros((1, 1, 1, 1, 48)) out = denormalize_latents(z) @@ -511,12 +564,14 @@ class TestVAE22NormConstants: def test_dimensions(self): from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD + mx.eval(VAE22_MEAN, VAE22_STD) assert VAE22_MEAN.shape == (48,) assert VAE22_STD.shape == (48,) def test_std_positive(self): from mlx_video.models.wan.vae22 import VAE22_STD + mx.eval(VAE22_STD) assert (np.array(VAE22_STD) > 0).all() @@ -527,6 +582,7 @@ class TestWan22VAEDecoder: def test_output_shape_small(self): """Tiny decoder should produce correct spatial/temporal output.""" from mlx_video.models.wan.vae22 import Wan22VAEDecoder + # Use very small dims to keep test fast dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) # Latent: [B=1, T=3, H=2, W=2, C=4] @@ -542,6 +598,7 @@ class TestWan22VAEDecoder: def test_output_clipped(self): from mlx_video.models.wan.vae22 import Wan22VAEDecoder + dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values out = dec(z) @@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights: def test_skip_encoder(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { "encoder.layer.weight": mx.zeros((4,)), "conv1.weight": mx.zeros((4,)), @@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights: def test_sequential_index_remapping(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), "decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)), @@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights: def test_resample_conv_remapping(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), "decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)), @@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights: def test_attention_remapping(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), "decoder.middle.1.to_qkv.bias": mx.zeros((24,)), @@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights: def test_conv3d_transpose(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] w = mx.zeros((16, 8, 3, 3, 3)) weights = {"decoder.conv1.weight": w} @@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights: def test_conv2d_transpose(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # Conv2d weight: [O, I, H, W] → [O, H, W, I] w = mx.zeros((8, 8, 3, 3)) weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w} @@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights: def test_gamma_squeeze(self): from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # gamma: (dim, 1, 1, 1) → (dim,) w = mx.ones((16, 1, 1, 1)) weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w} @@ -635,7 +699,10 @@ class TestUpResidualBlock: def test_no_upsample(self): from mlx_video.models.wan.vae22 import Up_ResidualBlock - block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False) + + block = Up_ResidualBlock( + 8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False + ) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) mx.eval(out) @@ -644,7 +711,10 @@ class TestUpResidualBlock: def test_spatial_upsample(self): from mlx_video.models.wan.vae22 import Up_ResidualBlock - block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True) + + block = Up_ResidualBlock( + 8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True + ) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) mx.eval(out) @@ -653,7 +723,10 @@ class TestUpResidualBlock: def test_spatial_temporal_upsample(self): from mlx_video.models.wan.vae22 import Up_ResidualBlock - block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True) + + block = Up_ResidualBlock( + 8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True + ) x = mx.random.normal((1, 2, 4, 4, 8)) out = block(x) mx.eval(out) @@ -720,7 +793,9 @@ class TestDownResidualBlock: def test_no_downsample(self): from mlx_video.models.wan.vae22 import Down_ResidualBlock - block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False) + block = Down_ResidualBlock( + 8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False + ) x = mx.random.normal((1, 2, 8, 8, 8)) out = block(x) mx.eval(out) @@ -729,7 +804,9 @@ class TestDownResidualBlock: def test_spatial_downsample(self): from mlx_video.models.wan.vae22 import Down_ResidualBlock - block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True) + block = Down_ResidualBlock( + 8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True + ) x = mx.random.normal((1, 2, 8, 8, 8)) out = block(x) mx.eval(out) @@ -738,7 +815,9 @@ class TestDownResidualBlock: def test_spatial_temporal_downsample(self): from mlx_video.models.wan.vae22 import Down_ResidualBlock - block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True) + block = Down_ResidualBlock( + 8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True + ) x = mx.random.normal((1, 4, 8, 8, 8)) out = block(x) mx.eval(out) @@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder: def test_encoder_temporal_downsample_pattern(self): """Encoder3d with (False, True, True): T=5→5→3→2.""" from mlx_video.models.wan.vae22 import Encoder3d + enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) x = mx.random.normal((1, 5, 16, 16, 12)) mx.eval(enc.parameters()) @@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder: def test_wrapper_uses_correct_pattern(self): """Wan22VAEEncoder should use (False, True, True) temporal downsample.""" - from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample + from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder + enc = Wan22VAEEncoder(z_dim=48, dim=16) down_blocks = enc.encoder.downsamples found_modes = [] @@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder: def test_single_frame_encoder(self): """Single frame (T=1) should work with (False, True, True) pattern.""" from mlx_video.models.wan.vae22 import Wan22VAEEncoder + enc = Wan22VAEEncoder(z_dim=48, dim=16) img = mx.random.normal((1, 1, 32, 32, 3)) mx.eval(enc.parameters()) @@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder: def test_wrong_order_gives_different_result(self): """(True, True, False) vs (False, True, True) produce different outputs.""" from mlx_video.models.wan.vae22 import Encoder3d - enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) + + enc_correct = Encoder3d( + dim=16, z_dim=8, temperal_downsample=(False, True, True) + ) enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False)) x = mx.random.normal((1, 5, 16, 16, 12)) @@ -883,12 +968,8 @@ class TestVAE21RoundTrip: z_dim = 4 dim = 8 # No temporal up/downsampling to keep the test simple - enc = Encoder3d( - dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False] - ) - dec = Decoder3d( - dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False] - ) + enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]) + dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]) mx.eval(enc.parameters(), dec.parameters()) # [B=1, C=3, T=1, H=8, W=8] @@ -937,15 +1018,12 @@ class TestVAE22RoundTrip: mx.eval(out) # 3 spatial upsamples(×8) + unpatchify(×2) = ×16 - assert out.shape[0] == 1 # batch - assert out.shape[2] == 32 # H recovered - assert out.shape[3] == 32 # W recovered - assert out.shape[-1] == 3 # RGB + assert out.shape[0] == 1 # batch + assert out.shape[2] == 32 # H recovered + assert out.shape[3] == 32 # W recovered + assert out.shape[-1] == 3 # RGB out_np = np.array(out) assert np.all(np.isfinite(out_np)) assert out_np.min() >= -1.0 - 1e-6 assert out_np.max() <= 1.0 + 1e-6 - - - diff --git a/tests/wan_test_helpers.py b/tests/wan_test_helpers.py index 6999af1..0d1a2b1 100644 --- a/tests/wan_test_helpers.py +++ b/tests/wan_test_helpers.py @@ -4,6 +4,7 @@ def _make_tiny_config(): """Create a tiny WanModelConfig for testing.""" from mlx_video.models.wan.config import WanModelConfig + config = WanModelConfig() # Override to tiny values config.dim = 64 diff --git a/uv.lock b/uv.lock index 09489ee..b4cda06 100644 --- a/uv.lock +++ b/uv.lock @@ -622,6 +622,18 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "ftfy" +version = "6.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -996,7 +1008,10 @@ wheels = [ name = "mlx-video" source = { editable = "." } dependencies = [ + { name = "ftfy" }, { name = "huggingface-hub" }, + { name = "imageio" }, + { name = "imageio-ffmpeg" }, { name = "librosa" }, { name = "mlx" }, { name = "mlx-vlm" }, @@ -1016,7 +1031,10 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ftfy" }, { name = "huggingface-hub" }, + { name = "imageio", specifier = ">=2.37.2" }, + { name = "imageio-ffmpeg", specifier = ">=0.6.0" }, { name = "librosa", specifier = ">=0.10.0" }, { name = "mlx", specifier = ">=0.22.0" }, { name = "mlx-vlm" }, @@ -2509,6 +2527,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" }, ] +[[package]] +name = "wcwidth" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, +] + [[package]] name = "xxhash" version = "3.6.0"