This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,36 +1,33 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig
# Audio VAE components # Audio VAE components
from mlx_video.models.ltx_2.audio_vae import ( from mlx_video.models.ltx_2.audio_vae import (
AudioDecoder, AudioDecoder,
AudioEncoder, AudioEncoder,
AudioLatentShape,
AudioPatchifier,
PerChannelStatistics,
Vocoder, Vocoder,
decode_audio, decode_audio,
AudioPatchifier,
AudioLatentShape,
PerChannelStatistics,
) )
# Conditioning # Conditioning
from mlx_video.models.ltx_2.conditioning import ( from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex
VideoConditionByLatentIndex,
)
# Utilities # Utilities
from mlx_video.models.ltx_2.utils import ( from mlx_video.models.ltx_2.utils import (
convert_audio_encoder, convert_audio_encoder,
get_model_path, get_model_path,
load_safetensors,
load_config, load_config,
load_safetensors,
save_weights, save_weights,
) )
from mlx_video.models.wan import WanModel, WanModelConfig
__all__ = [ __all__ = [
# Models # Models
"LTXModel", "LTXModel",
"LTXModelConfig", "LTXModelConfig",
# Audio VAE # Audio VAE
"AudioDecoder", "AudioDecoder",
"AudioEncoder", "AudioEncoder",

View File

@@ -6,10 +6,7 @@ from mlx_video.lora.apply import (
apply_loras_to_model, apply_loras_to_model,
apply_loras_to_weights, apply_loras_to_weights,
) )
from mlx_video.lora.loader import ( from mlx_video.lora.loader import load_lora_weights, load_multiple_loras
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [ __all__ = [

View File

@@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
candidates = [lora_key] candidates = [lora_key]
for prefix in prefixes_to_strip: for prefix in prefixes_to_strip:
if lora_key.startswith(prefix): if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):]) candidates.append(lora_key[len(prefix) :])
for candidate in candidates: for candidate in candidates:
# Try as-is # Try as-is
@@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.") transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.") transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"): if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1" transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"): if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2" transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0 # Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.") transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.") transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"): if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0" transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"): if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1" transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0 # Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.") transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.") transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"): if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0" transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"): if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1" transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection # Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.") transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"): if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection" transformed = transformed[: -len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj # Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed: if (
"patch_embedding" in transformed
and "patch_embedding_proj" not in transformed
):
transformed = transformed.replace("patch_embedding", "patch_embedding_proj") transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys: if f"{transformed}.weight" in model_keys or transformed in model_keys:
@@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
# Return best attempt with prefix stripped # Return best attempt with prefix stripped
for prefix in prefixes_to_strip: for prefix in prefixes_to_strip:
if lora_key.startswith(prefix): if lora_key.startswith(prefix):
return lora_key[len(prefix):] return lora_key[len(prefix) :]
return lora_key return lora_key
@@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip: for prefix in prefixes_to_strip:
if lora_key.startswith(prefix): if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):] normalized = lora_key[len(prefix) :]
if f"{normalized}.weight" in model_keys or normalized in model_keys: if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized return normalized
transformed = normalized transformed = normalized
if transformed.endswith(".to_out.0"): if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out" transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.") transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.") transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out") transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") transformed = transformed.replace(
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in") ".audio_ff.net.0.proj.", ".audio_ff.proj_in."
)
transformed = transformed.replace(
".audio_ff.net.0.proj", ".audio_ff.proj_in"
)
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out") transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
@@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
# Try transformations on the original key # Try transformations on the original key
transformed = lora_key transformed = lora_key
if transformed.endswith(".to_out.0"): if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out" transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.") transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
@@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip: for prefix in prefixes_to_strip:
if lora_key.startswith(prefix): if lora_key.startswith(prefix):
return lora_key[len(prefix):] return lora_key[len(prefix) :]
return lora_key return lora_key
@@ -226,7 +233,9 @@ def apply_loras_to_weights(
skipped_count += 1 skipped_count += 1
skipped_modules.append(module_name) skipped_modules.append(module_name)
if verbose and skipped_count <= 5: if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND") print(
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
)
similar = [ similar = [
k k
for k in list(model_keys)[:1000] for k in list(model_keys)[:1000]
@@ -251,13 +260,21 @@ def apply_loras_to_weights(
if is_quantized: if is_quantized:
scales = modified_weights[scales_key] scales = modified_weights[scales_key]
biases = modified_weights[biases_key] biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits) group_size = (original_weight.shape[-1] * 32) // (
scales.shape[-1] * quantization_bits
)
dequantized = mx.dequantize( dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits original_weight,
scales,
biases,
group_size=group_size,
bits=quantization_bits,
) )
modified = apply_lora_to_linear(dequantized, loras) modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters # Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits) new_w, new_scales, new_biases = mx.quantize(
modified, group_size=group_size, bits=quantization_bits
)
modified_weights[weight_key] = new_w modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases modified_weights[biases_key] = new_biases
@@ -346,9 +363,15 @@ def apply_loras_to_model(
parent = model parent = model
try: try:
for part in parts[:-1]: for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)] parent = (
getattr(parent, part) if not part.isdigit() else parent[int(part)]
)
leaf_name = parts[-1] leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)] target = (
getattr(parent, leaf_name)
if not leaf_name.isdigit()
else parent[int(leaf_name)]
)
except (AttributeError, IndexError, TypeError): except (AttributeError, IndexError, TypeError):
skipped.append(lora_key) skipped.append(lora_key)
if verbose: if verbose:
@@ -358,8 +381,11 @@ def apply_loras_to_model(
if isinstance(target, nn.QuantizedLinear): if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear # Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize( weight = mx.dequantize(
target.weight, target.scales, target.biases, target.weight,
group_size=target.group_size, bits=target.bits, target.scales,
target.biases,
group_size=target.group_size,
bits=target.bits,
) )
merged = apply_lora_to_linear(weight, loras) merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0]) new_linear = nn.Linear(merged.shape[1], merged.shape[0])
@@ -379,7 +405,9 @@ def apply_loras_to_model(
else: else:
skipped.append(lora_key) skipped.append(lora_key)
if verbose: if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear") print(
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
)
continue continue
if applied_count > 0: if applied_count > 0:

View File

@@ -2,7 +2,7 @@
import re import re
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List
import mlx.core as mx import mlx.core as mx

View File

@@ -1,3 +1,2 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig from mlx_video.models.wan import WanModel, WanModelConfig

View File

@@ -1,8 +1,7 @@
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
from mlx_video.models.ltx_2.config import ( from mlx_video.models.ltx_2.config import (
LTXModelConfig, LTXModelConfig,
TransformerConfig,
LTXModelType, LTXModelType,
TransformerConfig,
) )
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module): class AdaLayerNormSingle(nn.Module):
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
) )
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) self.linear = nn.Linear(
embedding_dim, embedding_coefficient * embedding_dim, bias=True
)
def __call__( def __call__(
self, self,
@@ -63,8 +64,12 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
self.size_emb_dim = size_emb_dim self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim) timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.timestep_embedder = TimestepEmbedding(
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
)
if use_additional_conditions and size_emb_dim > 0: if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim) self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
# Add additional conditions if enabled # Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0: if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None: if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype) additional_embeds = self.additional_embedder(
resolution, aspect_ratio, hidden_dtype
)
timesteps_emb = timesteps_emb + additional_embeds timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb return timesteps_emb

View File

@@ -1,10 +1,10 @@
"""Audio VAE module for LTX-2 audio generation.""" """Audio VAE module for LTX-2 audio generation."""
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
from .causal_conv_2d import CausalConv2d, make_conv2d
from ..config import CausalityAxis from ..config import CausalityAxis
from .attention import AttentionType, AttnBlock, make_attn
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .downsample import Downsample, build_downsampling_path from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics

View File

@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
""" """
@@ -103,6 +105,8 @@ def make_attn(
elif attn_type == AttentionType.NONE: elif attn_type == AttentionType.NONE:
return Identity() return Identity()
elif attn_type == AttentionType.LINEAR: elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") raise NotImplementedError(
f"Attention type {attn_type.value} is not supported yet."
)
else: else:
raise ValueError(f"Unknown attention type: {attn_type}") raise ValueError(f"Unknown attention type: {attn_type}")

View File

@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
using librosa for macOS/MLX compatibility. using librosa for macOS/MLX compatibility.
""" """
from pathlib import Path
import numpy as np
import mlx.core as mx import mlx.core as mx
import numpy as np
def load_audio( def load_audio(
@@ -99,14 +98,16 @@ def waveform_to_mel(
for ch in range(channels): for ch in range(channels):
# Magnitude spectrogram (power=1.0) # Magnitude spectrogram (power=1.0)
S = np.abs(librosa.stft( S = np.abs(
waveform[ch], librosa.stft(
n_fft=n_fft, waveform[ch],
hop_length=hop_length, n_fft=n_fft,
win_length=win_length, hop_length=hop_length,
center=True, win_length=win_length,
pad_mode="reflect", center=True,
)) pad_mode="reflect",
)
)
# Mel filterbank with slaney normalization # Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel( mel_basis = librosa.filters.mel(

View File

@@ -1,15 +1,15 @@
"""Audio VAE encoder and decoder for LTX-2.""" """Audio VAE encoder and decoder for LTX-2."""
from typing import Dict
from pathlib import Path from pathlib import Path
from typing import Dict
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
from .attention import AttentionType, make_attn from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .downsample import build_downsampling_path from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
@@ -39,7 +39,9 @@ def build_mid_block(
causality_axis=causality_axis, causality_axis=causality_axis,
) )
mid["attn_1"] = ( mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None make_attn(channels, attn_type=attn_type, norm_type=norm_type)
if add_attention
else None
) )
mid["block_2"] = ResnetBlock( mid["block_2"] = ResnetBlock(
in_channels=channels, in_channels=channels,
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.conv_in = make_conv2d( self.conv_in = make_conv2d(
config.in_channels, self.ch, kernel_size=3, stride=1, config.in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis, causality_axis=self.causality_axis,
) )
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d( self.conv_out = make_conv2d(
block_in, out_channels, kernel_size=3, stride=1, block_in,
out_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis, causality_axis=self.causality_axis,
) )
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
continue continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value sanitized[new_key] = value
return sanitized return sanitized
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder": def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights.""" """Load audio encoder from pretrained weights."""
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
import json import json
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
model_path = Path(model_path) model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) config = AudioEncoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
encoder = cls(config) encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors")) weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True) encoder.load_weights(list(weights.items()), strict=True)
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
""" """
super().__init__() super().__init__()
# Per-channel statistics for denormalizing latents # Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension # Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16) # Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
self.z_shape = (1, config.z_channels, base_resolution, base_resolution) self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d( self.conv_in = make_conv2d(
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis config.z_channels,
base_block_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
) )
self.mid = build_mid_block( self.mid = build_mid_block(
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
initial_block_channels=base_block_channels, initial_block_channels=base_block_channels,
) )
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) self.norm_out = build_normalization_layer(
final_block_channels, normtype=self.norm_type
)
self.conv_out = make_conv2d( self.conv_out = make_conv2d(
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis final_block_channels,
config.out_ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
) )
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
# PyTorch: (out_channels, in_channels, H, W) # PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels) # MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value sanitized[new_key] = value
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder": def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model.""" """Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
import json import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
config = AudioDecoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
decoder = cls(config) decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors")) weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights) # weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True) decoder.load_weights(list(weights.items()), strict=True)
return decoder return decoder
def __call__(self, sample: mx.array) -> mx.array: def __call__(self, sample: mx.array) -> mx.array:
""" """
Decode latent features back to audio spectrograms. Decode latent features back to audio spectrograms.
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
return self._adjust_output_shape(h, target_shape) return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]: def _denormalize_latents(
self, sample: mx.array
) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics.""" """Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format # sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape( latent_shape = AudioLatentShape(
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
batch=latent_shape.batch, batch=latent_shape.batch,
channels=self.out_ch, channels=self.out_ch,
frames=target_frames, frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, mel_bins=(
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
),
) )
return sample, target_shape return sample, target_shape
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
# Step 1: Crop first to avoid exceeding target dimensions # Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[ decoded_output = decoded_output[
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels :,
: min(current_time, target_time),
: min(current_freq, target_freq),
:target_channels,
] ]
# Step 2: Calculate padding needed for time and frequency dimensions # Step 2: Calculate padding needed for time and frequency dimensions
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
return mx.tanh(h) if self.tanh_out else h return mx.tanh(h) if self.tanh_out else h
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array: def decode_audio(
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
) -> mx.array:
""" """
Decode an audio latent representation using the provided audio decoder and vocoder. Decode an audio latent representation using the provided audio decoder and vocoder.
Args: Args:

View File

@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width) # For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
if self.causality_axis == CausalityAxis.NONE: if self.causality_axis == CausalityAxis.NONE:
# Non-causal: symmetric padding # Non-causal: symmetric padding
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2) self.padding = (
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY): pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
)
elif self.causality_axis in (
CausalityAxis.WIDTH,
CausalityAxis.WIDTH_COMPATIBILITY,
):
# Causal on width: pad left (before width axis) # Causal on width: pad left (before width axis)
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0) self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
elif self.causality_axis == CausalityAxis.HEIGHT: elif self.causality_axis == CausalityAxis.HEIGHT:
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
if any(p > 0 for p in self.padding): if any(p > 0 for p in self.padding):
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...] # MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C # For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)]) x = mx.pad(
x,
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
)
return self.conv(x) return self.conv(x)
@@ -124,7 +135,14 @@ def make_conv2d(
if causality_axis is not None: if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d # For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d( return CausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis in_channels,
out_channels,
kernel_size,
stride,
dilation,
groups,
bias,
causality_axis,
) )
else: else:
# For non-causal convolution, use symmetric padding if not specified # For non-causal convolution, use symmetric padding if not specified

View File

@@ -5,8 +5,8 @@ from typing import Set, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .attention import AttentionType, make_attn
from ..config import CausalityAxis from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .normalization import NormType from .normalization import NormType
from .resnet import ResnetBlock from .resnet import ResnetBlock
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
if self.with_conv: if self.with_conv:
# Do time downsampling here # Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves # no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
""" """
@@ -116,10 +118,14 @@ def build_downsampling_path(
) )
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if i_level != num_resolutions - 1: if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) stage["downsample"] = Downsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res = curr_res // 2 curr_res = curr_res // 2
down_modules[i_level] = stage down_modules[i_level] = stage

View File

@@ -51,7 +51,9 @@ def build_normalization_layer(
A normalization layer A normalization layer
""" """
if normtype == NormType.GROUP: if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True) return nn.GroupNorm(
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
)
if normtype == NormType.PIXEL: if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1) # For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W) # PyTorch uses dim=1 for channels-first format (B, C, H, W)

View File

@@ -1,12 +1,12 @@
"""ResNet blocks for audio VAE and vocoder.""" """ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis from ..config import CausalityAxis
from .causal_conv_2d import make_conv2d
from .normalization import NormType, build_normalization_layer from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d( self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
) )
if temb_channels > 0: if temb_channels > 0:
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout self.dropout_rate = dropout
self.conv2 = make_conv2d( self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis out_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
) )
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d( self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
) )
else: else:
self.nin_shortcut = make_conv2d( self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis in_channels,
out_channels,
kernel_size=1,
stride=1,
causality_axis=causality_axis,
) )
def __call__( def __call__(
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
if temb is not None and self.temb_channels > 0: if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels) # temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting # Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1) h = h + mx.expand_dims(
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
)
h = self.norm2(h) h = self.norm2(h)
h = nn.silu(h) h = nn.silu(h)

View File

@@ -5,9 +5,9 @@ from typing import Set, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .normalization import NormType from .normalization import NormType
from .resnet import ResnetBlock from .resnet import ResnetBlock
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
self.causality_axis = causality_axis self.causality_axis = causality_axis
if self.with_conv: if self.with_conv:
self.conv = make_conv2d( self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis in_channels,
in_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
) )
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
@@ -124,10 +128,14 @@ def build_upsampling_path(
) )
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if level != 0: if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) stage["upsample"] = Upsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res *= 2 curr_res *= 2
up_modules[level] = stage up_modules[level] = stage

View File

@@ -7,8 +7,8 @@ Supports:
""" """
import math import math
from typing import List, Tuple
from pathlib import Path from pathlib import Path
from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -32,7 +32,9 @@ class Snake(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__() super().__init__()
self.alpha_logscale = alpha_logscale self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format # x: (N, L, C) in MLX format
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__() super().__init__()
self.alpha_logscale = alpha_logscale self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) self.alpha = (
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
self.beta = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha alpha = self.alpha
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
) )
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array: def kaiser_sinc_filter1d(
cutoff: float, half_width: float, kernel_size: int
) -> mx.array:
"""Compute a Kaiser-windowed sinc filter.""" """Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0 even = kernel_size % 2 == 0
half_size = kernel_size // 2 half_size = kernel_size // 2
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
# Kaiser window - compute using scipy-compatible formula # Kaiser window - compute using scipy-compatible formula
import numpy as np import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32)) window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even: if even:
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]: def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler).""" """Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np import numpy as np
rolloff = 0.99 rolloff = 0.99
lowpass_filter_width = 6 lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff) width = math.ceil(lowpass_filter_width / rolloff)
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
self.kernel_size = filt.shape[2] self.kernel_size = filt.shape[2]
self.filter = filt self.filter = filt
else: else:
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1 self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 self.pad_left = (
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 self.pad * self.stride + (self.kernel_size - self.stride) // 2
)
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
self.filter = kaiser_sinc_filter1d( self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, cutoff=0.5 / ratio,
half_width=0.6 / ratio, half_width=0.6 / ratio,
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
filt = self.filter.astype(x.dtype) # (1, 1, K) filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1) filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1) x = self.ratio * mx.conv_transpose1d(
x, filt, stride=self.stride
) # (N*C, L', 1)
# Trim padding # Trim padding
x = x[:, self.pad_left:-self.pad_right, :] x = x[:, self.pad_left : -self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L') x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C) x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
self.convs1 = { self.convs1 = {
i: nn.Conv1d( i: nn.Conv1d(
channels, channels, kernel_size, stride=1, channels,
dilation=d, padding=get_padding(kernel_size, d), channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
) )
for i, d in enumerate(dilation) for i, d in enumerate(dilation)
} }
self.convs2 = { self.convs2 = {
i: nn.Conv1d( i: nn.Conv1d(
channels, channels, kernel_size, stride=1, channels,
dilation=1, padding=get_padding(kernel_size, 1), channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
) )
for i in range(len(dilation)) for i in range(len(dilation))
} }
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
y = mx.concatenate([first, y], axis=1) y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX # forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1) basis = mx.transpose(
self.forward_basis.astype(y.dtype), (0, 2, 1)
) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514) # Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length) spec = mx.conv1d(y, basis, stride=self.hop_length)
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
real = spec[..., :n_freqs] real = spec[..., :n_freqs]
imag = spec[..., n_freqs:] imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real ** 2 + imag ** 2) magnitude = mx.sqrt(real**2 + imag**2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype) phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
real.dtype
)
# Output: (B, T_frames, n_freqs) in MLX channels-last # Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase return magnitude, phase
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
class MelSTFT(nn.Module): class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases.""" """Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None: def __init__(
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
) -> None:
super().__init__() super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length) self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1 n_freqs = filter_length // 2 + 1
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
""" """
magnitude, phase = self.stft_fn(y) magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs) # magnitude: (B, T_frames, n_freqs)
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels) mel = (
magnitude @ self.mel_basis.astype(magnitude.dtype).T
) # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None)) log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format # Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1)) return mx.transpose(log_mel, (0, 2, 1))
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
in_channels = 128 if config.stereo else 64 in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d( self.conv_pre = nn.Conv1d(
in_channels, config.upsample_initial_channel, in_channels,
kernel_size=7, stride=1, padding=3, config.upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
) )
# Upsampling layers # Upsampling layers
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
for i, (stride, kernel_size) in enumerate( for i, (stride, kernel_size) in enumerate(
zip(config.upsample_rates, config.upsample_kernel_sizes) zip(config.upsample_rates, config.upsample_kernel_sizes)
): ):
in_ch = config.upsample_initial_channel // (2 ** i) in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1)) out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d( self.ups[i] = nn.ConvTranspose1d(
in_ch, out_ch, in_ch,
kernel_size=kernel_size, stride=stride, out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2, padding=(kernel_size - stride) // 2,
) )
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
config.resblock_kernel_sizes, config.resblock_dilation_sizes config.resblock_kernel_sizes, config.resblock_dilation_sizes
): ):
self.resblocks[block_idx] = AMPBlock1( self.resblocks[block_idx] = AMPBlock1(
ch, kernel_size, tuple(dilations), ch,
kernel_size,
tuple(dilations),
activation=config.activation, activation=config.activation,
) )
block_idx += 1 block_idx += 1
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
for kernel_size, dilations in zip( for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes config.resblock_kernel_sizes, config.resblock_dilation_sizes
): ):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) self.resblocks[block_idx] = resblock_class(
ch, kernel_size, tuple(dilations)
)
block_idx += 1 block_idx += 1
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates)) final_channels = config.upsample_initial_channel // (
2 ** len(config.upsample_rates)
)
# Post-activation # Post-activation
if self.is_amp: if self.is_amp:
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
# Final conv # Final conv
out_channels = 2 if config.stereo else 1 out_channels = 2 if config.stereo else 1
self.conv_post = nn.Conv1d( self.conv_post = nn.Conv1d(
final_channels, out_channels, final_channels,
kernel_size=7, stride=1, padding=3, out_channels,
kernel_size=7,
stride=1,
padding=3,
bias=config.use_bias_at_final, bias=config.use_bias_at_final,
) )
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
""" """
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape _, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate output_length = (
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
)
# Pad to hop_length multiple # Pad to hop_length multiple
remainder = length_low_rate % self.hop_length remainder = length_low_rate % self.hop_length
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
model.load_weights(list(weights.items()), strict=False) model.load_weights(list(weights.items()), strict=False)
return model return model

View File

@@ -1,3 +1,6 @@
"""Conditioning modules for LTX-2 video generation.""" """Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning from mlx_video.models.ltx_2.conditioning.latent import (
VideoConditionByLatentIndex,
apply_conditioning,
)

View File

@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List, Tuple from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
frame_idx: Frame index to condition (0 = first frame) frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original) strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
""" """
latent: mx.array latent: mx.array
frame_idx: int = 0 frame_idx: int = 0
strength: float = 1.0 strength: float = 1.0
@@ -41,6 +42,7 @@ class LatentState:
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean 1.0 = full denoise, 0.0 = keep clean
""" """
latent: mx.array latent: mx.array
clean_latent: mx.array clean_latent: mx.array
denoise_mask: mx.array denoise_mask: mx.array
@@ -130,15 +132,15 @@ def apply_conditioning(
if frame_idx <= i < end_idx: if frame_idx <= i < end_idx:
# Use conditioning latent # Use conditioning latent
cond_idx = i - frame_idx cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
# Set mask: 1.0 - strength means less denoising for conditioned frames # Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype)) mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else: else:
# Keep original # Keep original
latent_list.append(state.latent[:, :, i:i+1]) latent_list.append(state.latent[:, :, i : i + 1])
clean_list.append(state.clean_latent[:, :, i:i+1]) clean_list.append(state.clean_latent[:, :, i : i + 1])
mask_list.append(state.denoise_mask[:, :, i:i+1]) mask_list.append(state.denoise_mask[:, :, i : i + 1])
state.latent = mx.concatenate(latent_list, axis=2) state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2) state.clean_latent = mx.concatenate(clean_list, axis=2)

View File

@@ -1,4 +1,3 @@
import inspect import inspect
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
SPLIT = "split" SPLIT = "split"
TWO_D = "2d" TWO_D = "2d"
class AttentionType(Enum): class AttentionType(Enum):
DEFAULT = "default" DEFAULT = "default"
@dataclass @dataclass
class BaseModelConfig: class BaseModelConfig:
@@ -46,7 +47,7 @@ class BaseModelConfig:
if v is not None: if v is not None:
if isinstance(v, Enum): if isinstance(v, Enum):
result[k] = v.value result[k] = v.value
elif hasattr(v, 'to_dict'): elif hasattr(v, "to_dict"):
result[k] = v.to_dict() result[k] = v.to_dict()
else: else:
result[k] = v result[k] = v
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
out_channels: int = 128 out_channels: int = 128
latent_channels: int = 128 latent_channels: int = 128
patch_size: int = 4 patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [ encoder_blocks: List[tuple] = field(
("res_x", {"num_layers": 4}), default_factory=lambda: [
("compress_space_res", {"multiplier": 2}), ("res_x", {"num_layers": 4}),
("res_x", {"num_layers": 6}), ("compress_space_res", {"multiplier": 2}),
("compress_time_res", {"multiplier": 2}), ("res_x", {"num_layers": 6}),
("res_x", {"num_layers": 6}), ("compress_time_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}), ("res_x", {"num_layers": 6}),
("res_x", {"num_layers": 2}), ("compress_all_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}), ("res_x", {"num_layers": 2}),
("res_x", {"num_layers": 2}), ("compress_all_res", {"multiplier": 2}),
]) ("res_x", {"num_layers": 2}),
decoder_blocks: List[tuple] = field(default_factory=lambda: [ ]
("res_x", {"num_layers": 5, "inject_noise": False}), )
("compress_all", {"residual": True, "multiplier": 2}), decoder_blocks: List[tuple] = field(
("res_x", {"num_layers": 5, "inject_noise": False}), default_factory=lambda: [
("compress_all", {"residual": True, "multiplier": 2}), ("res_x", {"num_layers": 5, "inject_noise": False}),
("res_x", {"num_layers": 5, "inject_noise": False}), ("compress_all", {"residual": True, "multiplier": 2}),
("compress_all", {"residual": True, "multiplier": 2}), ("res_x", {"num_layers": 5, "inject_noise": False}),
("res_x", {"num_layers": 5, "inject_noise": False}), ("compress_all", {"residual": True, "multiplier": 2}),
]) ("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
]
)
@dataclass @dataclass
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
audio_in_channels: int = 128 audio_in_channels: int = 128
audio_out_channels: int = 128 audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048 audio_cross_attention_dim: int = 2048
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video) audio_caption_channels: int = (
3840 # Input dim for audio text embeddings (same as video)
)
# Positional embedding config # Positional embedding config
positional_embedding_theta: float = 10000.0 positional_embedding_theta: float = 10000.0
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
) )
class CausalityAxis(Enum): class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions.""" """Enum for specifying the causality axis in causal convolutions."""
@@ -237,8 +243,8 @@ class AudioDecoderModelConfig(BaseModelConfig):
def __post_init__(self): def __post_init__(self):
"""Convert string enum values to proper enum types.""" """Convert string enum values to proper enum types."""
# Import here to avoid circular imports # Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
# Convert causality_axis string to enum # Convert causality_axis string to enum
if isinstance(self.causality_axis, str): if isinstance(self.causality_axis, str):
@@ -252,6 +258,7 @@ class AudioDecoderModelConfig(BaseModelConfig):
if isinstance(self.attn_type, str): if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type) self.attn_type = AttentionType(self.attn_type)
@dataclass @dataclass
class AudioEncoderModelConfig(BaseModelConfig): class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128 ch: int = 128
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
def __post_init__(self): def __post_init__(self):
"""Convert string enum values to proper enum types.""" """Convert string enum values to proper enum types."""
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
if isinstance(self.causality_axis, str): if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis) self.causality_axis = CausalityAxis(self.causality_axis)
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
dropout: float = 0.0 dropout: float = 0.0
timestep_conditioning: bool = False timestep_conditioning: bool = False
@dataclass @dataclass
class VideoEncoderModelConfig(BaseModelConfig): class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3 convolution_dimensions: int = 3
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
norm_layer: Enum = None norm_layer: Enum = None
latent_log_var: Enum = None latent_log_var: Enum = None
encoder_spatial_padding_mode: Enum = None encoder_spatial_padding_mode: Enum = None
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}), encoder_blocks: List[tuple] = field(
("compress_space_res", {"multiplier": 2}), default_factory=lambda: [
("res_x", {"num_layers": 6}), ("res_x", {"num_layers": 4}),
("compress_time_res", {"multiplier": 2}), ("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}), ("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}), ("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}), ("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}), ("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}) ("res_x", {"num_layers": 2}),
]) ("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
)
def __post_init__(self): def __post_init__(self):
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
if self.norm_layer is None: if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM self.norm_layer = NormLayerType.PIXEL_NORM
@@ -371,7 +382,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
if isinstance(self.latent_log_var, str): if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var) self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str): if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode) self.encoder_spatial_padding_mode = PaddingModeType(
self.encoder_spatial_padding_mode
)
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
result = super().to_dict() result = super().to_dict()

View File

@@ -49,7 +49,6 @@ from typing import Dict
import mlx.core as mx import mlx.core as mx
# ─── Key prefix routing ────────────────────────────────────────────────────── # ─── Key prefix routing ──────────────────────────────────────────────────────
TRANSFORMER_PREFIX = "model.diffusion_model." TRANSFORMER_PREFIX = "model.diffusion_model."
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key: if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue continue
new_key = key[len(TRANSFORMER_PREFIX):] new_key = key[len(TRANSFORMER_PREFIX) :]
new_key = new_key.replace(".to_out.0.", ".to_out.") new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
else: else:
continue continue
elif key.startswith(VAE_DECODER_PREFIX): elif key.startswith(VAE_DECODER_PREFIX):
new_key = key[len(VAE_DECODER_PREFIX):] new_key = key[len(VAE_DECODER_PREFIX) :]
else: else:
continue continue
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if value.dtype != mx.float32: if value.dtype != mx.float32:
value = value.astype(mx.float32) value = value.astype(mx.float32)
elif key.startswith(VAE_ENCODER_PREFIX): elif key.startswith(VAE_ENCODER_PREFIX):
new_key = key[len(VAE_ENCODER_PREFIX):] new_key = key[len(VAE_ENCODER_PREFIX) :]
else: else:
continue continue
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None new_key = None
if key.startswith(AUDIO_DECODER_PREFIX): if key.startswith(AUDIO_DECODER_PREFIX):
new_key = key[len(AUDIO_DECODER_PREFIX):] new_key = key[len(AUDIO_DECODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX): elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key: if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means" new_key = "per_channel_statistics.mean_of_means"
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None new_key = None
if key.startswith(AUDIO_ENCODER_PREFIX): if key.startswith(AUDIO_ENCODER_PREFIX):
new_key = key[len(AUDIO_ENCODER_PREFIX):] new_key = key[len(AUDIO_ENCODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX): elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key: if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means" new_key = "per_channel_statistics.mean_of_means"
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if not key.startswith(VOCODER_PREFIX): if not key.startswith(VOCODER_PREFIX):
continue continue
new_key = key[len(VOCODER_PREFIX):] new_key = key[len(VOCODER_PREFIX) :]
# Handle Conv1d/ConvTranspose1d weight shape conversion # Handle Conv1d/ConvTranspose1d weight shape conversion
if "weight" in new_key and value.ndim == 3: if "weight" in new_key and value.ndim == 3:
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
# aggregate_embed weights (text_embedding_projection.*) # aggregate_embed weights (text_embedding_projection.*)
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(TEXT_PROJ_PREFIX): if key.startswith(TEXT_PROJ_PREFIX):
new_key = key[len(TEXT_PROJ_PREFIX):] new_key = key[len(TEXT_PROJ_PREFIX) :]
extracted[new_key] = value extracted[new_key] = value
# video_embeddings_connector # video_embeddings_connector
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(VIDEO_CONNECTOR_PREFIX): if key.startswith(VIDEO_CONNECTOR_PREFIX):
suffix = key[len(VIDEO_CONNECTOR_PREFIX):] suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix) new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value extracted[new_key] = value
# audio_embeddings_connector # audio_embeddings_connector
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(AUDIO_CONNECTOR_PREFIX): if key.startswith(AUDIO_CONNECTOR_PREFIX):
suffix = key[len(AUDIO_CONNECTOR_PREFIX):] suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix) new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value extracted[new_key] = value
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
# ─── Source resolution ───────────────────────────────────────────────────────── # ─── Source resolution ─────────────────────────────────────────────────────────
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc. # Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$") MONOLITHIC_PATTERN = re.compile(
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
)
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors, # Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc. # ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$") UPSCALER_PATTERN = re.compile(
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
)
def resolve_source(source: str, variant: str) -> Path: def resolve_source(source: str, variant: str) -> Path:
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict: def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
"""Infer VAE decoder config from weights.""" """Infer VAE decoder config from weights."""
# Check for timestep conditioning keys # Check for timestep conditioning keys
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights) has_timestep = any(
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
)
# Count channel multipliers from up_blocks # Count channel multipliers from up_blocks
max_block = -1 max_block = -1
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
config = infer_transformer_config(transformer_weights) config = infer_transformer_config(transformer_weights)
save_config(config, output_path / "transformer") save_config(config, output_path / "transformer")
t_params = sum(v.size for v in transformer_weights.values()) t_params = sum(v.size for v in transformer_weights.values())
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards") print(
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
)
# 2. VAE Decoder # 2. VAE Decoder
print(" [2/7] VAE Decoder...") print(" [2/7] VAE Decoder...")
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
] ]
else: else:
upscaler_files = [ upscaler_files = [
f.name for f in source_dir.iterdir() f.name
for f in source_dir.iterdir()
if f.is_file() and UPSCALER_PATTERN.match(f.name) if f.is_file() and UPSCALER_PATTERN.match(f.name)
] ]
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f"\nDone! Converted {all_converted}/{total_keys} keys") print(f"\nDone! Converted {all_converted}/{total_keys} keys")
if all_converted < total_keys: if all_converted < total_keys:
known_prefixes = ( known_prefixes = (
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX, TRANSFORMER_PREFIX,
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX, VAE_DECODER_PREFIX,
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX, VAE_ENCODER_PREFIX,
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX, VAE_STATS_PREFIX,
AUDIO_DECODER_PREFIX,
AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX,
VOCODER_PREFIX,
TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX,
AUDIO_CONNECTOR_PREFIX,
) )
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)] skipped = [
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
]
if skipped: if skipped:
print(f" Skipped {len(skipped)} keys:") print(f" Skipped {len(skipped)} keys:")
for k in sorted(skipped)[:20]: for k in sorted(skipped)[:20]:

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,14 @@
from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.config import ( from mlx_video.models.ltx_2.config import (
LTXModelConfig, LTXModelConfig,
LTXModelType,
LTXRopeType, LTXRopeType,
TransformerConfig,
) )
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.rope import precompute_freqs_cis from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import ( from mlx_video.models.ltx_2.transformer import (
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) timestep_emb, embedded_timestep = self.adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
# Reshape to (batch, tokens, dim) # Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) timestep_emb = mx.reshape(
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep return timestep_emb, embedded_timestep
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
hidden_dtype: mx.Dtype = None, hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) timestep_emb, embedded_timestep = adaln(
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) timestep.reshape(-1), hidden_dtype=hidden_dtype
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) )
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep return timestep_emb, embedded_timestep
def _prepare_context( def _prepare_context(
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
# Convert boolean/int mask to float mask # Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked) # 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9 mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) mask = mx.reshape(
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
)
return mask return mask
def _prepare_positional_embeddings( def _prepare_positional_embeddings(
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs: def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent) x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) timestep, embedded_timestep = self._prepare_timestep(
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) modality.timesteps, x.shape[0], hidden_dtype=x.dtype
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) )
context, attention_mask = self._prepare_context(
modality.context, x, modality.context_mask
)
attention_mask = self._prepare_attention_mask(
attention_mask, modality.latent.dtype
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None: if modality.positional_embeddings is not None:
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
prompt_timestep = None prompt_timestep = None
prompt_embedded_timestep = None prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None: if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln( prompt_timestep, prompt_embedded_timestep = (
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype, self._prepare_timestep_with_adaln(
self.prompt_adaln,
modality.sigma,
x.shape[0],
hidden_dtype=x.dtype,
)
) )
return TransformerArgs( return TransformerArgs(
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
) )
# Prepare cross-attention timestep embeddings # Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( cross_scale_shift_timestep, cross_gate_timestep = (
timestep=modality.timesteps, self._prepare_cross_attention_timestep(
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, timestep=modality.timesteps,
batch_size=transformer_args.x.shape[0], timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
hidden_dtype=transformer_args.x.dtype, batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
) )
return replace( return replace(
@@ -254,11 +280,19 @@ class MultiModalTransformerArgsPreprocessor:
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) scale_shift_timestep, _ = self.cross_scale_shift_adaln(
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) timestep.reshape(-1), hidden_dtype=hidden_dtype
)
scale_shift_timestep = mx.reshape(
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
)
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) gate_timestep, _ = self.cross_gate_adaln(
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
)
gate_timestep = mx.reshape(
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
)
return scale_shift_timestep, gate_timestep return scale_shift_timestep, gate_timestep
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
self._init_video(config) self._init_video(config)
if config.model_type.is_audio_enabled(): if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos self.audio_positional_embedding_max_pos = (
config.audio_positional_embedding_max_pos
)
self.audio_num_attention_heads = config.audio_num_attention_heads self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config) self._init_audio(config)
# Initialize cross-modal components # Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled(): if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
cross_pe_max_pos = max( cross_pe_max_pos = max(
config.positional_embedding_max_pos[0], config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0], config.audio_positional_embedding_max_pos[0],
) )
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier self.av_ca_timestep_scale_multiplier = (
config.av_ca_timestep_scale_multiplier
)
self.audio_cross_attention_dim = config.audio_cross_attention_dim self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config) self._init_audio_video(config)
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True) self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
adaln_coefficient = 9 if config.has_prompt_adaln else 6 adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient) self.adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=adaln_coefficient
)
if config.has_prompt_adaln: if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2
)
else: else:
self.caption_projection = PixArtAlphaTextProjection( self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels, in_features=config.caption_channels,
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
self.proj_out = nn.Linear(self.inner_dim, config.out_channels) self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None: def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_patchify_proj = nn.Linear(
config.audio_in_channels, self.audio_inner_dim, bias=True
)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6 audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient) self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
)
if config.has_prompt_adaln: if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=2
)
else: else:
self.audio_caption_projection = PixArtAlphaTextProjection( self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels, in_features=config.audio_caption_channels,
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
# Output components # Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim)) self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False) self.audio_norm_out = nn.LayerNorm(
self.audio_inner_dim, eps=config.norm_eps, affine=False
)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels) self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None: def _init_audio_video(self, config: LTXModelConfig) -> None:
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
embedding_coefficient=1, embedding_coefficient=1,
) )
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None: def _init_preprocessors(
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled(): self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
) -> None:
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
# Multi-modal preprocessors # Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj, patchify_proj=self.patchify_proj,
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set() stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items(): for idx, block in self.transformer_blocks.items():
video, audio = block( video, audio = block(
video=video, audio=audio, video=video,
audio=audio,
skip_video_self_attn=(idx in stg_v_set), skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set), skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal, skip_cross_modal=skip_cross_modal,
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
raise ValueError("Audio is not enabled for this model") raise ValueError("Audio is not enabled for this model")
# Preprocess arguments # Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None video_args = (
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None self.video_args_preprocessor.prepare(video) if video is not None else None
)
audio_args = (
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
)
# Process transformer blocks # Process transformer blocks
video_out, audio_out = self._process_transformer_blocks( video_out, audio_out = self._process_transformer_blocks(
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
if not key.startswith("model.diffusion_model."): if not key.startswith("model.diffusion_model."):
continue continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key: if (
"audio_embeddings_connector" in key
or "video_embeddings_connector" in key
):
continue continue
# Remove 'model.diffusion_model.' prefix # Remove 'model.diffusion_model.' prefix
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
for weight_file in model_path.glob("*.safetensors"): for weight_file in model_path.glob("*.safetensors"):
weights.update(mx.load(str(weight_file))) weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights) sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} sanitized = {
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
for k, v in sanitized.items()
}
model.load_weights(list(sanitized.items()), strict=strict) model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters()) mx.eval(model.parameters())
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
) -> Tuple[Optional[mx.array], Optional[mx.array]]: ) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model( vx, ax = self.velocity_model(
video, audio, video,
audio,
stg_video_blocks=stg_video_blocks, stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks, stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal, skip_cross_modal=skip_cross_modal,
) )
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None denoised_video = (
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
)
denoised_audio = (
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
)
return denoised_video, denoised_audio return denoised_video, denoised_audio

View File

@@ -1,9 +1,10 @@
import numpy as np import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray: def bilateral_filter(
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges. """Apply bilateral filter to reduce grid artifacts while preserving edges.
Args: Args:
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
""" """
try: try:
import cv2 import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space) return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError: except ImportError:
# Fallback to simple Gaussian blur if cv2 not available # Fallback to simple Gaussian blur if cv2 not available
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
""" """
try: try:
import cv2 import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError: except ImportError:
# Simple box blur fallback # Simple box blur fallback
from scipy.ndimage import uniform_filter from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
np.uint8
)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray: def unsharp_mask(
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur. """Apply unsharp masking to enhance edges after blur.
Args: Args:
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
""" """
try: try:
import cv2 import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma) blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0) sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8) return np.clip(sharpened, 0, 255).astype(np.uint8)
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
if method == "bilateral": if method == "bilateral":
d = max(3, int(5 * strength)) d = max(3, int(5 * strength))
sigma = 50 + 50 * strength sigma = 50 + 50 * strength
processed = np.stack([ processed = np.stack(
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma) [
for frame in video bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
]) for frame in video
]
)
elif method == "gaussian": elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength)) kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0: if kernel_size % 2 == 0:
kernel_size += 1 kernel_size += 1
processed = np.stack([ processed = np.stack(
gaussian_blur(frame, kernel_size=kernel_size) [gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
for frame in video )
])
elif method == "frequency": elif method == "frequency":
processed = np.stack([ processed = np.stack(
remove_grid_frequency(frame, grid_size=8) [remove_grid_frequency(frame, grid_size=8) for frame in video]
for frame in video )
])
else: else:
raise ValueError(f"Unknown method: {method}") raise ValueError(f"Unknown method: {method}")
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8) result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result return result

View File

@@ -1,4 +1,3 @@
import math import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
""" """
# x: (..., dim) where dim is even # x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...] x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...] x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...] # Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1) rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape) return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d( def apply_rotary_emb_1d(
q: mx.array, q: mx.array,
k: mx.array, k: mx.array,
@@ -228,9 +228,9 @@ def get_fractional_positions(
Fractional positions in range [-1, 1] after scaling Fractional positions in range [-1, 1] after scaling
""" """
n_pos_dims = indices_grid.shape[1] n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), ( assert n_pos_dims == len(
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" max_pos
) ), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
# Divide each dimension by its max position # Divide each dimension by its max position
fractional_positions = [] fractional_positions = []
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
if max_pos is None: if max_pos is None:
max_pos = [20, 2048, 2048] max_pos = [20, 2048, 2048]
if double_precision: if double_precision:
return _precompute_freqs_cis_double_precision( return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid, indices_grid,
num_attention_heads, rope_type dim,
theta,
max_pos,
use_middle_indices_grid,
num_attention_heads,
rope_type,
) )
# Keep positions in float32 for RoPE computation. # Keep positions in float32 for RoPE computation.
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
# Compute frequencies: outer product # Compute frequencies: outer product
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1) # scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices) # freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1)) freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
freq_indices, (1, 1, 1, -1)
)
# freqs: (B, T, n_dims, num_indices) # freqs: (B, T, n_dims, num_indices)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims) # Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)

View File

@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
""" """
import math import math
from typing import Optional
import mlx.core as mx import mlx.core as mx
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed) # Phi functions and RK coefficients (pure Python math, no MLX needed)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float: def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space). """Compute phi_j(z) where z = -h (negative step size in log-space).
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
Returns: Returns:
(a21, b1, b2): RK coefficients. (a21, b1, b2): RK coefficients.
""" """
def get_phi(j: int, neg_h: float) -> float: def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h) cache_key = (j, neg_h)
if cache_key in phi_cache: if cache_key in phi_cache:
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
# SDE noise injection # SDE noise injection
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def get_sde_coeff( def get_sde_coeff(
sigma_next: float, sigma_next: float,
) -> tuple[float, float, float]: ) -> tuple[float, float, float]:
@@ -139,7 +140,9 @@ def sde_noise_step(
denoised_next = sample_f32 - sigma * eps_next denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components # Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32 x_noised = (
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
)
return x_noised return x_noised
@@ -148,6 +151,7 @@ def sde_noise_step(
# Noise generation # Noise generation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array: def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims. """Normalize each channel to zero mean and unit variance over spatial dims.

View File

@@ -1,25 +1,25 @@
import functools import functools
import logging import logging
import math import math
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig from mlx_vlm.models.gemma3.config import TextConfig
from mlx_vlm.models.gemma3.language import Gemma3Model
from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from mlx_video.utils import apply_quantization, rms_norm
# Path to system prompts # Path to system prompts
PROMPTS_DIR = Path(__file__).parent / "prompts" PROMPTS_DIR = Path(__file__).parent / "prompts"
@@ -36,7 +36,6 @@ def _load_system_prompt(prompt_name: str) -> str:
class LanguageModel(nn.Module): class LanguageModel(nn.Module):
def __init__(self, config: TextConfig): def __init__(self, config: TextConfig):
super().__init__() super().__init__()
# Create config matching LTX-2 text encoder requirements # Create config matching LTX-2 text encoder requirements
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len) padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
combined = causal_mask[None, :, :] & padding_mask[:, None, :] combined = causal_mask[None, :, :] & padding_mask[:, None, :]
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 min_val = (
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype), mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mx.full(combined.shape, min_val, dtype=dtype)) )
mask = mx.where(
combined,
mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype),
)
return mask[:, None, :, :] return mask[:, None, :, :]
else: else:
# No padding mask, just causal # No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 min_val = (
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype), mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mx.full((seq_len, seq_len), min_val, dtype=dtype)) )
mask = mx.where(
causal_mask,
mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype),
)
return mask[None, None, :, :] # (1, 1, seq, seq) return mask[None, None, :, :] # (1, 1, seq, seq)
def __call__( def __call__(
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
batch_size, seq_len = inputs.shape batch_size, seq_len = inputs.shape
# Get embeddings # Get embeddings
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs) h = (
input_embeddings
if input_embeddings is not None
else self.model.embed_tokens(inputs)
)
# Apply Gemma scaling # Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype) h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
if cache is None: if cache is None:
cache = [None] * len(self.model.layers) cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype) full_causal_mask = self._create_causal_mask_with_padding(
seq_len, attention_mask, h.dtype
)
sliding_mask = full_causal_mask sliding_mask = full_causal_mask
num_layers = len(self.model.layers) num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers): for i, layer in enumerate(self.model.layers):
is_global = ( is_global = (
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(prefix): if key.startswith(prefix):
if hasattr(value, "dtype") and value.dtype == mx.float32: if hasattr(value, "dtype") and value.dtype == mx.float32:
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16) sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
else: else:
sanitized[key[len(prefix):]] = value sanitized[key[len(prefix) :]] = value
return sanitized return sanitized
@property @property
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
def make_cache(self): def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = [] caches = []
for i in range(len(self.layers)): for i in range(len(self.layers)):
if ( if (
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, model_path: str): def from_pretrained(cls, model_path: str):
import json import json
weight_files = sorted(Path(model_path).glob("*.safetensors")) weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json" config_file = Path(model_path) / "config.json"
config_dict = {} config_dict = {}
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
with open(config_file, "r") as f: with open(config_file, "r") as f:
config_dict = json.load(f) config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"])) language_model = cls(
config=TextConfig.from_dict(config_dict["text_config"])
)
else: else:
raise ValueError(f"Config file not found at {model_path}") raise ValueError(f"Config file not found at {model_path}")
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
for i, wf in enumerate(weight_files): for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf))) weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"): if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights) weights = language_model.sanitize(weights=weights)
apply_quantization(
apply_quantization(model=language_model, weights=weights, quantization=quantization) model=language_model, weights=weights, quantization=quantization
)
language_model.load_weights(list(weights.items()), strict=False) language_model.load_weights(list(weights.items()), strict=False)
return language_model return language_model
class ConnectorAttention(nn.Module): class ConnectorAttention(nn.Module):
def __init__( def __init__(
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
# Reshape to (B, H, T, D) for SPLIT RoPE # Reshape to (B, H, T, D) for SPLIT RoPE
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) q = mx.reshape(
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) q, (batch_size, seq_len, self.num_heads, self.head_dim)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) ).transpose(0, 2, 1, 3)
k = mx.reshape(
k, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
v = mx.reshape(
v, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
if pe is not None: if pe is not None:
q = self._apply_split_rope(q, pe[0], pe[1]) q = self._apply_split_rope(q, pe[0], pe[1])
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
class ConnectorTransformerBlock(nn.Module): class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False): def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
has_gate_logits: bool = False,
):
super().__init__() super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) self.attn1 = ConnectorAttention(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
self.ff = ConnectorFeedForward(dim) self.ff = ConnectorFeedForward(dim)
def __call__( def __call__(
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
self.positional_embedding_max_pos = positional_embedding_max_pos or [1] self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
self.transformer_1d_blocks = { self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) i: ConnectorTransformerBlock(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
for i in range(num_layers) for i in range(num_layers)
} }
if num_learnable_registers > 0: if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim)) self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]: def _precompute_freqs_cis(
self, seq_len: int, dtype: mx.Dtype
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch). """Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2). Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
# Binary mask: 1 for valid tokens, 0 for padded # Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded # attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq) mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
mx.int32
) # (batch, seq)
# Tile registers to match sequence length, cast to hidden_states dtype # Tile registers to match sequence length, cast to hidden_states dtype
num_tiles = seq_len // self.num_learnable_registers num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim) registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
dtype
) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing) # Process each batch item (PyTorch uses advanced indexing)
result_list = [] result_list = []
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
# Extract valid tokens (where mask is 1) # Extract valid tokens (where mask is 1)
# Since we have left-padded input, valid tokens are at the end # Since we have left-padded input, valid tokens are at the end
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim) valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
# Pad with zeros on the right to get back to seq_len # Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid pad_length = seq_len - num_valid
if pad_length > 0: if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=dtype) padding = mx.zeros((pad_length, dim), dtype=dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim) adjusted = mx.concatenate(
[valid_tokens, padding], axis=0
) # (seq_len, dim)
else: else:
adjusted = valid_tokens adjusted = valid_tokens
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back # Create flipped mask: 1s at front (where valid tokens now are), 0s at back
flipped_mask = mx.concatenate([ flipped_mask = mx.concatenate(
mx.ones((num_valid,), dtype=mx.int32), [
mx.zeros((pad_length,), dtype=mx.int32) mx.ones((num_valid,), dtype=mx.int32),
], axis=0) # (seq,) mx.zeros((pad_length,), dtype=mx.int32),
],
axis=0,
) # (seq,)
# Combine: valid tokens at front, registers at back # Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1) flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers combined = (
flipped_mask_expanded * adjusted
+ (1 - flipped_mask_expanded) * registers
)
result_list.append(combined) result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
# Process through transformer blocks # Process through transformer blocks
for i in range(len(self.transformer_1d_blocks)): for i in range(len(self.transformer_1d_blocks)):
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis) hidden_states = self.transformer_1d_blocks[i](
hidden_states, attention_mask, freqs_cis
)
# Final RMS norm # Final RMS norm
hidden_states = rms_norm(hidden_states) hidden_states = rms_norm(hidden_states)
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
return hidden_states, attention_mask return hidden_states, attention_mask
def norm_and_concat_hidden_states( def norm_and_concat_hidden_states(
hidden_states: List[mx.array], hidden_states: List[mx.array],
attention_mask: mx.array, attention_mask: mx.array,
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer # Compute masked min/max per layer
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype)) x_for_min = mx.where(
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype)) mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
)
x_for_max = mx.where(
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
)
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min range_val = x_max - x_min
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
dtype = encoded_text.dtype dtype = encoded_text.dtype
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D # Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L) variance = mx.mean(
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
) # (B, T, 1, L)
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6) normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
normed = normed.astype(dtype) normed = normed.astype(dtype)
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
class GemmaFeaturesExtractor(nn.Module): class GemmaFeaturesExtractor(nn.Module):
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization.""" """V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False): def __init__(
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
):
super().__init__() super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias) self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
if mode == "video": if mode == "video":
target_dim = self.video_aggregate_embed.weight.shape[0] target_dim = self.video_aggregate_embed.weight.shape[0]
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) return self.video_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
else: else:
target_dim = self.audio_aggregate_embed.weight.shape[0] target_dim = self.audio_aggregate_embed.weight.shape[0]
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) return self.audio_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
class AudioEmbeddingsConnector(nn.Module): class AudioEmbeddingsConnector(nn.Module):
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
video_output_dim = 4096 video_output_dim = 4096
audio_output_dim = 2048 audio_output_dim = 2048
self.feature_extractor_v2 = GemmaFeaturesExtractorV2( self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
video_output_dim=video_output_dim, video_output_dim=video_output_dim,
audio_output_dim=audio_output_dim, audio_output_dim=audio_output_dim,
bias=True, bias=True,
@@ -728,33 +785,53 @@ class LTX2TextEncoder(nn.Module):
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors # connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# config (nested under config.transformer.connector_positional_embedding_max_pos) # config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector( self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128, dim=video_output_dim,
num_layers=8, num_learnable_registers=128, num_heads=32,
positional_embedding_max_pos=[4096], has_gate_logits=True, head_dim=128,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
) )
self.audio_embeddings_connector = Embeddings1DConnector( self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64, dim=audio_output_dim,
num_layers=8, num_learnable_registers=128, num_heads=32,
positional_embedding_max_pos=[4096], has_gate_logits=True, head_dim=64,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
) )
else: else:
# LTX-2: shared feature extractor, 3840-dim connectors # LTX-2: shared feature extractor, 3840-dim connectors
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim) self.feature_extractor = GemmaFeaturesExtractor(
feature_input_dim, hidden_dim
)
self.video_embeddings_connector = Embeddings1DConnector( self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128, dim=hidden_dim,
num_layers=2, num_learnable_registers=128, num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1], positional_embedding_max_pos=[1],
) )
self.audio_embeddings_connector = Embeddings1DConnector( self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128, dim=hidden_dim,
num_layers=2, num_learnable_registers=128, num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1], positional_embedding_max_pos=[1],
) )
self.processor = None self.processor = None
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"): def load(
self,
model_path: Optional[str] = None,
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
):
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir(): if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
text_encoder_path = str(Path(text_encoder_path) / "text_encoder") text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
if transformer_weights: if transformer_weights:
self._load_feature_extractors(transformer_weights, is_reformatted) self._load_feature_extractors(transformer_weights, is_reformatted)
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted) self._load_connector(
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted) "video_embeddings_connector", transformer_weights, is_reformatted
)
self._load_connector(
"audio_embeddings_connector", transformer_weights, is_reformatted
)
else: else:
print("WARNING: No transformer weights found for text projection connectors. " print(
"Text conditioning will use uninitialized weights!") "WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!"
)
# Load tokenizer # Load tokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer_path = model_path / "tokenizer" tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists(): if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(
str(tokenizer_path), trust_remote_code=True
)
else: else:
try: try:
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(
text_encoder_path, trust_remote_code=True
)
except Exception: except Exception:
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(
"google/gemma-3-12b-it", trust_remote_code=True
)
# Set left padding to match official LTX-2 text encoder # Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left" self.processor.padding_side = "left"
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
submodule.bias = weights[b_key] submodule.bias = weights[b_key]
else: else:
# LTX-2: single aggregate_embed # LTX-2: single aggregate_embed
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" agg_key = (
"aggregate_embed.weight"
if is_reformatted
else "text_embedding_projection.aggregate_embed.weight"
)
if agg_key in weights: if agg_key in weights:
self.feature_extractor.aggregate_embed.weight = weights[agg_key] self.feature_extractor.aggregate_embed.weight = weights[agg_key]
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
prefix = f"{name}." prefix = f"{name}."
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(prefix): if key.startswith(prefix):
connector_weights[key[len(prefix):]] = value connector_weights[key[len(prefix) :]] = value
else: else:
mono_prefix = f"model.diffusion_model.{name}." mono_prefix = f"model.diffusion_model.{name}."
for key, value in weights.items(): for key, value in weights.items():
if key.startswith(mono_prefix): if key.startswith(mono_prefix):
connector_weights[key[len(mono_prefix):]] = value connector_weights[key[len(mono_prefix) :]] = value
if not connector_weights: if not connector_weights:
return return
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) _, all_hidden_states = self.language_model(
inputs=input_ids,
input_embeddings=None,
attention_mask=attention_mask,
output_hidden_states=True,
)
if self.has_prompt_adaln: if self.has_prompt_adaln:
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale) # LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video") video_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="video"
)
additive_mask = (attention_mask - 1).astype(video_features.dtype) additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings: if return_audio_embeddings:
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio") audio_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="audio"
)
audio_mask = (attention_mask - 1).astype(audio_features.dtype) audio_mask = (attention_mask - 1).astype(audio_features.dtype)
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask) audio_embeddings, _ = self.audio_embeddings_connector(
audio_features, audio_mask
)
return video_embeddings, audio_embeddings return video_embeddings, audio_embeddings
else: else:
return video_embeddings, attention_mask return video_embeddings, attention_mask
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
video_features = self.feature_extractor(concat_hidden) video_features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(video_features.dtype) additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings: if return_audio_embeddings:
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask) audio_embeddings, _ = self.audio_embeddings_connector(
video_features, additive_mask
)
return video_embeddings, audio_embeddings return video_embeddings, audio_embeddings
else: else:
return video_embeddings, attention_mask return video_embeddings, attention_mask
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
# Remove leading/trailing whitespace # Remove leading/trailing whitespace
response = response.strip() response = response.strip()
# Remove any leading punctuation # Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response) response = re.sub(r"^[^\w\s]+", "", response)
return response return response
def _apply_chat_template( def _apply_chat_template(
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
elif isinstance(content, list): elif isinstance(content, list):
# Handle multimodal content (image + text) # Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"] text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n" formatted += (
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
)
elif role == "assistant": elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n" formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt # Add generation prompt
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
from mlx_lm import stream_generate from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError: except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.") logging.warning(
"mlx-lm not available for prompt enhancement. Using original prompt."
)
return prompt return prompt
if self.processor is None: if self.processor is None:
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
) )
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1)) sampler = make_sampler(
kwargs.get("temperature", 0.7),
kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", -1),
)
logits_processors = make_logits_processors( logits_processors = make_logits_processors(
kwargs.get("logit_bias", None), kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3), kwargs.get("repetition_penalty", 1.3),
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
mx.clear_cache() mx.clear_cache()
# Decode only the new tokens # Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True) enhanced_prompt = self.processor.decode(
generated_tokens, skip_special_tokens=True
)
enhanced_prompt = self._clean_response(enhanced_prompt) enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}") logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
def enhance_i2v( def enhance_i2v(
self, self,
prompt: str, prompt: str,
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder() encoder = LTX2TextEncoder()
encoder.load(model_path=model_path) encoder.load(model_path=model_path)
return encoder return encoder

View File

@@ -4,8 +4,8 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.attention import Attention from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.feed_forward import FeedForward from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim) # timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape( timestep_reshaped = mx.reshape(
timestep, timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
(batch_size, timestep.shape[1], num_ada_params, -1)
) )
# Extract the relevant indices # Extract the relevant indices
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
) )
# Squeeze the sequence dimension if it's 1 # Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada) scale_shift_squeezed = tuple(
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada) mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
)
gate_squeezed = tuple(
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
)
return (*scale_shift_squeezed, *gate_squeezed) return (*scale_shift_squeezed, *gate_squeezed)
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
# Check which modalities to run # Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0 run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0 run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal run_a2v = (
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal run_vx
and (audio is not None and audio.enabled and ax.size > 0)
and not skip_cross_modal
)
run_v2a = (
run_ax
and (video is not None and video.enabled and vx.size > 0)
and not skip_cross_modal
)
# Process video self-attention and cross-attention with text # Process video self-attention and cross-attention with text
if run_vx: if run_vx:
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
# Self-attention with RoPE (skip_attention=True for STG perturbation) # Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa vx = (
vx
+ self.attn1(
norm_vx,
pe=video.positional_embeddings,
skip_attention=skip_video_self_attn,
)
* vgate_msa
)
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln: if self.has_prompt_adaln:
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9) self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
) )
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values( vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2) self.prompt_scale_shift_table,
vx.shape[0],
video.prompt_timesteps,
slice(0, 2),
) )
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv encoder_hidden_states = (
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
)
vx = (
vx
+ self.attn2(
attn_input,
context=encoder_hidden_states,
mask=video.context_mask,
)
* vgate_q
)
else: else:
vx = vx + self.attn2( vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps), rms_norm(vx, eps=self.norm_eps),
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
# Self-attention with RoPE (skip_attention=True for STG perturbation) # Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa ax = (
ax
+ self.audio_attn1(
norm_ax,
pe=audio.positional_embeddings,
skip_attention=skip_audio_self_attn,
)
* agate_msa
)
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln: if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values( ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9) self.audio_scale_shift_table,
ax.shape[0],
audio.timesteps,
slice(6, 9),
) )
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values( aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2) self.audio_prompt_scale_shift_table,
ax.shape[0],
audio.prompt_timesteps,
slice(0, 2),
)
attn_input_a = (
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
)
encoder_hidden_states_a = (
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
)
ax = (
ax
+ self.audio_attn2(
attn_input_a,
context=encoder_hidden_states_a,
mask=audio.context_mask,
)
* agate_q
) )
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
else: else:
ax = ax + self.audio_attn2( ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps), rms_norm(ax, eps=self.norm_eps),

View File

@@ -1,4 +1,5 @@
from typing import Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
self.groups = groups self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in) # Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5 scale = (
1.0
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
)
self.weight = mx.random.uniform( self.weight = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels), shape=(
out_channels,
kernel_size[0],
kernel_size[1],
kernel_size[2],
in_channels,
),
) )
if bias: if bias:
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
n, d, h, w, c = x.shape n, d, h, w, c = x.shape
input_dtype = x.dtype input_dtype = x.dtype
x = x.astype(mx.float32) x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups) # Reshape to (N, D*H*W, num_groups, C//num_groups)
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
self.den = den self.den = den
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num) # Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1) self.conv = nn.Conv2d(
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffle2D(num, num) self.pixel_shuffle = PixelShuffle2D(num, num)
self.blur_down = BlurDownsample(stride=den) self.blur_down = BlurDownsample(stride=den)
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
x = self.conv(x) x = self.conv(x)
x = self.pixel_shuffle(x) # H*num, W*num x = self.pixel_shuffle(x) # H*num, W*num
x = self.blur_down(x) # H*num/den, W*num/den x = self.blur_down(x) # H*num/den, W*num/den
_, h_out, w_out, _ = x.shape _, h_out, w_out, _ = x.shape
x = mx.reshape(x, (n, d, h_out, w_out, c)) x = mx.reshape(x, (n, d, h_out, w_out, c))
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
def _rational_for_scale(scale: float) -> Tuple[int, int]: def _rational_for_scale(scale: float) -> Tuple[int, int]:
"""Convert a float scale to a rational fraction (numerator, denominator).""" """Convert a float scale to a rational fraction (numerator, denominator)."""
from fractions import Fraction from fractions import Fraction
frac = Fraction(scale).limit_denominator(10) frac = Fraction(scale).limit_denominator(10)
return frac.numerator, frac.denominator return frac.numerator, frac.denominator
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
self.initial_norm = GroupNorm3d(32, mid_channels) self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking # Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)} self.res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Upsampler: 2D spatial upsampling (frame-by-frame) # Upsampler: 2D spatial upsampling (frame-by-frame)
if rational_resampler: if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale) self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=spatial_scale
)
else: else:
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels) self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking # Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)} self.post_upsample_res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Final projection # Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1) self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
Returns: Returns:
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
""" """
def debug_stats(name, t): def debug_stats(name, t):
if debug: if debug:
mx.eval(t) mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}") print(
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
)
if debug: if debug:
print(" [DEBUG] LatentUpsampler forward pass:") print(" [DEBUG] LatentUpsampler forward pass:")
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2)) # x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample # x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
# Both formats may have upsampler.blur_down.kernel, so use channel count # Both formats may have upsampler.blur_down.kernel, so use channel count
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight" conv_key = (
"upsampler.conv.weight"
if "upsampler.conv.weight" in raw_weights
else "upsampler.0.weight"
)
if conv_key in raw_weights: if conv_key in raw_weights:
out_channels = raw_weights[conv_key].shape[0] out_channels = raw_weights[conv_key].shape[0]
ratio = out_channels // mid_channels ratio = out_channels // mid_channels
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
rational_resampler = False rational_resampler = False
spatial_scale = 2.0 spatial_scale = 2.0
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}") print(
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
)
# Create model # Create model
upsampler = LatentUpsampler( upsampler = LatentUpsampler(

View File

@@ -109,6 +109,7 @@ def convert_audio_encoder(
return encoder_dir return encoder_dir
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
vae_path = hf_hub_download( vae_path = hf_hub_download(
source_repo, source_repo,
"audio_vae/diffusion_pytorch_model.safetensors", "audio_vae/diffusion_pytorch_model.safetensors",

View File

@@ -1,8 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.tiling import ( from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig, SpatialTilingConfig,
TemporalTilingConfig, TemporalTilingConfig,
TilingConfig,
) )
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
# Height padding (axis 2) # Height padding (axis 2)
if pad_h > 0: if pad_h > 0:
# Get reflection indices - exclude boundary # Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
:, :, ::-1, :, :
] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2) x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3) # Width padding (axis 3)
if pad_w > 0: if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion right_pad = x[:, :, :, -pad_w - 1 : -1, :][
:, :, :, ::-1, :
] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3) x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x return x
@@ -126,7 +130,9 @@ class CausalConv3d(nn.Module):
if self.time_kernel_size > 1: if self.time_kernel_size > 1:
if use_causal: if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning # Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2) first_frame_pad = mx.repeat(
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
)
x = mx.concatenate([first_frame_pad, x], axis=2) x = mx.concatenate([first_frame_pad, x], axis=2)
else: else:
# Non-causal: replicate first frame at start, last frame at end # Non-causal: replicate first frame at start, last frame at end
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
""" """
b, d, h, w, c = x.shape b, d, h, w, c = x.shape
total_elements = d * h * w * c total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
@@ -191,7 +196,6 @@ class CausalConv3d(nn.Module):
overlap = kernel_t - 1 overlap = kernel_t - 1
expected_output_frames = d - overlap expected_output_frames = d - overlap
outputs = [] outputs = []

View File

@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
""" """
import math import math
from typing import Optional, Dict
from pathlib import Path from pathlib import Path
from typing import Dict, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
def __init__(self, embedding_dim: int): def __init__(self, embedding_dim: int):
super().__init__() super().__init__()
self.timestep_embedder = TimestepEmbedding( self.timestep_embedder = TimestepEmbedding(
in_channels=256, in_channels=256, time_embed_dim=embedding_dim
time_embed_dim=embedding_dim
) )
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array: def __call__(
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
) -> mx.array:
timesteps_proj = get_timestep_embedding( timesteps_proj = get_timestep_embedding(
timestep, timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
) )
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb return timesteps_emb
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode): def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming.""" """Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module): class ConvWrapper(nn.Module):
def __init__(self_inner): def __init__(self_inner):
super().__init__() super().__init__()
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
padding=1, padding=1,
spatial_padding_mode=padding_mode, spatial_padding_mode=padding_mode,
) )
def __call__(self_inner, x, causal=False): def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal) return self_inner.conv(x, causal=causal)
return ConvWrapper() return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization.""" """Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps) return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__( def __call__(
self, self,
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
if self.timestep_conditioning and timestep_embed is not None: if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1) # scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding # Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1) ada_values = self.scale_shift_table[
None, :, :, None, None, None
] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1) # Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1] channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1) ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
# Time embedder for this block group: embed_dim = 4 * channels # Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning: if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder( self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
embedding_dim=channels * 4
)
# Use dict with int keys for MLX to track parameters properly # Use dict with int keys for MLX to track parameters properly
self.res_blocks = { self.res_blocks = {
i: ResnetBlock3DSimple( i: ResnetBlock3DSimple(
channels, channels,
spatial_padding_mode, spatial_padding_mode,
timestep_conditioning=timestep_conditioning timestep_conditioning=timestep_conditioning,
) )
for i in range(num_layers) for i in range(num_layers)
} }
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
if self.timestep_conditioning and timestep is not None: if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0] batch_size = x.shape[0]
timestep_embed = self.time_embedder( timestep_embed = self.time_embedder(
timestep.flatten(), timestep.flatten(), hidden_dtype=x.dtype
hidden_dtype=x.dtype
) )
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting # Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1) timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1, padding=1,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
def __call__(self_inner, x, causal=False): def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal) return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper() self.conv_in = ConvInWrapper()
# Build up blocks from config # Build up blocks from config
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
block_type = block_def[0] block_type = block_def[0]
ch = block_def[1] ch = block_def[1]
if block_type == "res": if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block num_layers = (
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning) block_def[2] if len(block_def) > 2 else num_layers_per_block
)
self.up_blocks[idx] = ResBlockGroup(
ch, num_layers, spatial_padding_mode, timestep_conditioning
)
elif block_type == "d2s": elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2 reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
) )
final_out_channels = out_channels * patch_size * patch_size final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module): class ConvOutWrapper(nn.Module):
def __init__(self_inner): def __init__(self_inner):
super().__init__() super().__init__()
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1, padding=1,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
def __call__(self_inner, x, causal=False): def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal) return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper() self.conv_out = ConvOutWrapper()
self.act = nn.SiLU() self.act = nn.SiLU()
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
if key.startswith("vae.decoder."): if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "") new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5: if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1)) value = mx.transpose(value, (0, 2, 3, 4, 1))
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
if ".conv.weight" in new_key or ".conv.bias" in new_key: if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key: if (
".conv.conv.weight" not in new_key
and ".conv.conv.bias" not in new_key
):
new_key = new_key.replace(".conv.weight", ".conv.conv.weight") new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias") new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
return sanitized return sanitized
@classmethod @classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder": def from_pretrained(
cls, model_path: Path, strict: bool = True
) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights. """Load a pretrained decoder from a directory with config.json and weights.
Args: Args:
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(str(wf))) weights.update(mx.load(str(wf)))
# Infer block structure from weights # Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights) decoder_blocks = cls._infer_blocks(weights)
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
return final_blocks return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization.""" """Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps) return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__( def __call__(
self, self,
@@ -552,20 +562,15 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv: bool = False, chunked_conv: bool = False,
) -> mx.array: ) -> mx.array:
batch_size = sample.shape[0] batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled # Add noise if timestep conditioning is enabled
if self.timestep_conditioning: if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample) sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning: if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep) timestep = mx.full((batch_size,), self.decode_timestep)
@@ -575,7 +580,6 @@ class LTX2VideoDecoder(nn.Module):
x = self.conv_in(sample, causal=causal) x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items(): for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup): if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep) x = block(x, causal=causal, timestep=scaled_timestep)
@@ -584,18 +588,17 @@ class LTX2VideoDecoder(nn.Module):
else: else:
x = block(x, causal=causal) x = block(x, causal=causal)
x = self.pixel_norm(x) x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None: if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder( embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(), scaled_timestep.flatten(), hidden_dtype=x.dtype
hidden_dtype=x.dtype
) )
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1) embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1) ada_values = self.last_scale_shift_table[
None, :, :, None, None, None
] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1) ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped ada_values = ada_values + ts_reshaped
@@ -604,16 +607,13 @@ class LTX2VideoDecoder(nn.Module):
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
x = self.act(x) x = self.act(x)
x = self.conv_out(x, causal=causal) x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4) # Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1) x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x return x
def decode_tiled( def decode_tiled(
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
# Auto-enable chunked conv for modes where it helps (larger tiles) # Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks # Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial") use_chunked_conv = tiling_mode in (
"conservative",
"none",
"auto",
"default",
"spatial",
)
if not needs_spatial_tiling and not needs_temporal_tiling: if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode # No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) return self(
sample,
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
)
return decode_with_tiling( return decode_with_tiling(
decoder_fn=self, decoder_fn=self,

View File

@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
""" """
import mlx.core as mx import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image( def encode_image(

View File

@@ -1,6 +1,5 @@
"""Operations for Video VAE.""" """Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
new_c = c * patch_size_hw * patch_size_hw * patch_size_t new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw) # Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)) x = mx.reshape(
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
)
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W') # Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph # PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph

View File

@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
return x return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array: def __call__(
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
) -> mx.array:
b, c, d, h, w = x.shape b, c, d, h, w = x.shape
st, sh, sw = self.stride st, sh, sw = self.stride
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
return x return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array: def _chunked_conv_depth_to_space(
self, x: mx.array, causal: bool = True
) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks. """Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor. This reduces peak memory by avoiding the full high-channel intermediate tensor.

View File

@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
# Apply right ramp (fade out) # Apply right ramp (fade out)
if ramp_right > 0: if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1] # Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)] fade_out = [
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
]
for i in range(ramp_right): for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i] mask[length - ramp_right + i] *= fade_out[i]
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64: if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") raise ValueError(
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
)
if self.tile_size_in_pixels % 32 != 0: if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") raise ValueError(
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
)
if self.tile_overlap_in_pixels % 32 != 0: if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") raise ValueError(
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
)
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError( raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.tile_size_in_frames < 16: if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") raise ValueError(
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
)
if self.tile_size_in_frames % 8 != 0: if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") raise ValueError(
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
)
if self.tile_overlap_in_frames % 8 != 0: if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") raise ValueError(
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
)
if self.tile_overlap_in_frames >= self.tile_size_in_frames: if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError( raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
@@ -113,15 +127,21 @@ class TilingConfig:
def default(cls) -> "TilingConfig": def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal.""" """Default tiling: 512px spatial, 64 frame temporal."""
return cls( return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), spatial_config=SpatialTilingConfig(
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), tile_size_in_pixels=512, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
),
) )
@classmethod @classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig": def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution).""" """Spatial tiling only (for short videos with large resolution)."""
return cls( return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap), spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
),
temporal_config=None, temporal_config=None,
) )
@@ -130,23 +150,33 @@ class TilingConfig:
"""Temporal tiling only (for long videos with small resolution).""" """Temporal tiling only (for long videos with small resolution)."""
return cls( return cls(
spatial_config=None, spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap), temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
),
) )
@classmethod @classmethod
def aggressive(cls) -> "TilingConfig": def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory).""" """Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls( return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64), spatial_config=SpatialTilingConfig(
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8), tile_size_in_pixels=256, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=32, tile_overlap_in_frames=8
),
) )
@classmethod @classmethod
def conservative(cls) -> "TilingConfig": def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster).""" """Conservative tiling (larger tiles, less memory savings but faster)."""
return cls( return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64), spatial_config=SpatialTilingConfig(
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24), tile_size_in_pixels=768, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=96, tile_overlap_in_frames=24
),
) )
@classmethod @classmethod
@@ -186,10 +216,14 @@ class TilingConfig:
temporal_config = None temporal_config = None
if needs_spatial: if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64) spatial_config = SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
)
if needs_temporal: if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24) temporal_config = TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
)
return cls(spatial_config=spatial_config, temporal_config=temporal_config) return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@@ -197,16 +231,21 @@ class TilingConfig:
@dataclass @dataclass
class DimensionIntervals: class DimensionIntervals:
"""Intervals for splitting a single dimension.""" """Intervals for splitting a single dimension."""
starts: List[int] starts: List[int]
ends: List[int] ends: List[int]
left_ramps: List[int] left_ramps: List[int]
right_ramps: List[int] right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: def split_in_spatial(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a spatial dimension into intervals.""" """Split a spatial dimension into intervals."""
if dimension_size <= size: if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)] starts = [i * (size - overlap) for i in range(amount)]
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
left_ramps = [0] + [overlap] * (amount - 1) left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0] right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) return DimensionIntervals(
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: def split_in_temporal(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment.""" """Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size: if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
# Start with spatial split # Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size) intervals = split_in_spatial(size, overlap, dimension_size)
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
starts[i] = starts[i] - 1 starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1 left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps) return DimensionIntervals(
starts=starts,
ends=intervals.ends,
left_ramps=left_ramps,
right_ramps=intervals.right_ramps,
)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: def map_temporal_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask.""" """Map temporal latent interval to output coordinates and mask."""
start = begin * scale start = begin * scale
stop = 1 + (end - 1) * scale stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0 left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True) mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, True
)
return slice(start, stop), mask return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: def map_spatial_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask.""" """Map spatial latent interval to output coordinates and mask."""
start = begin * scale start = begin * scale
stop = end * scale stop = end * scale
left_ramp_scaled = left_ramp * scale left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False) mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, False
)
return slice(start, stop), mask return slice(start, stop), mask
@@ -315,7 +373,9 @@ def decode_with_tiling(
temporal_overlap = 0 temporal_overlap = 0
# Compute intervals for each dimension # Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -338,7 +398,9 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx] t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates # Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles): for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx] h_start = height_intervals.starts[h_idx]
@@ -347,7 +409,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx] h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates # Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles): for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx] w_start = width_intervals.starts[w_idx]
@@ -356,13 +420,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx] w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates # Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice) # Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile # Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output) mx.eval(tile_output)
# Clear tile_latents reference # Clear tile_latents reference
@@ -385,13 +459,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = ( blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) * t_mask_slice.reshape(1, 1, -1, 1, 1)
h_mask_slice.reshape(1, 1, 1, -1, 1) * * h_mask_slice.reshape(1, 1, 1, -1, 1)
w_mask_slice.reshape(1, 1, 1, 1, -1) * w_mask_slice.reshape(1, 1, 1, 1, -1)
) )
# Slice tile output to match # Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output # Clear full tile_output
del tile_output del tile_output
@@ -409,11 +485,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment # Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( output[
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile :,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
) )
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( weights[
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask :,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
) )
# Force evaluation to free memory # Force evaluation to free memory
@@ -445,10 +547,12 @@ def decode_with_tiling(
if next_tile_start_latent == 0: if next_tile_start_latent == 0:
next_tile_start_out = 0 next_tile_start_out = 0
else: else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
# We need to track how many frames we've already emitted # We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'): if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0 decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames emitted = decode_with_tiling._emitted_frames
@@ -456,7 +560,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out) # Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8) finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype) finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output) mx.eval(finalized_output)
@@ -473,7 +580,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided # Emit remaining frames if callback provided
if on_frames_ready is not None: if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0) emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f: if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output) mx.eval(remaining_output)
@@ -481,7 +588,7 @@ def decode_with_tiling(
del remaining_output del remaining_output
# Reset emitted frames counter for next call # Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'): if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames del decode_with_tiling._emitted_frames
# Clean up weights # Clean up weights

View File

@@ -8,12 +8,15 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify from mlx_video.models.ltx_2.video_vae.ops import (
PerChannelStatistics,
patchify,
unpatchify,
)
from mlx_video.models.ltx_2.video_vae.resnet import ( from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType, NormLayerType,
ResnetBlock3D, ResnetBlock3D,
UNetMidBlock3D, UNetMidBlock3D,
get_norm_layer,
) )
from mlx_video.models.ltx_2.video_vae.sampling import ( from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample, DepthToSpaceUpsample,
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
class LogVarianceType(Enum): class LogVarianceType(Enum):
"""Log variance mode for VAE.""" """Log variance mode for VAE."""
PER_CHANNEL = "per_channel" PER_CHANNEL = "per_channel"
UNIFORM = "uniform" UNIFORM = "uniform"
CONSTANT = "constant" CONSTANT = "constant"
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters config: VideoEncoderModelConfig with encoder parameters
""" """
super().__init__() super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.norm_layer = config.norm_layer self.norm_layer = config.norm_layer
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents # Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels) self.per_channel_statistics = PerChannelStatistics(
latent_channels=config.out_channels
)
# After patchify, channels increase by patch_size^2 # After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2 in_channels = config.in_channels * config.patch_size**2
feature_channels = config.out_channels feature_channels = config.out_channels
# Initial convolution # Initial convolution
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked) # Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {} self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks): for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_encoder_block( block, feature_channels = _make_encoder_block(
block_name=block_name, block_name=block_name,
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
conv_out_channels = config.out_channels conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL: if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2 conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: elif config.latent_log_var in {
LogVarianceType.UNIFORM,
LogVarianceType.CONSTANT,
}:
conv_out_channels += 1 conv_out_channels += 1
self.conv_out = CausalConv3d( self.conv_out = CausalConv3d(
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
elif self.latent_log_var == LogVarianceType.CONSTANT: elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...] sample = sample[:, :-1, ...]
approx_ln_0 = -30 approx_ln_0 = -30
sample = mx.concatenate([ sample = mx.concatenate(
sample, [
mx.full_like(sample, approx_ln_0), sample,
], axis=1) mx.full_like(sample, approx_ln_0),
],
axis=1,
)
# Split into means and logvar, normalize means # Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...] means = sample[:, : self.latent_channels, ...]
return self.per_channel_statistics.normalize(means) return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance Loaded VideoEncoder instance
""" """
import json import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config # Load config
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
decoder_blocks = [] decoder_blocks = []
self.patch_size = patch_size self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2 out_channels = out_channels * patch_size**2
self.causal = causal self.causal = causal
self.timestep_conditioning = timestep_conditioning self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked) # Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {} self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)): for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_decoder_block( block, feature_channels = _make_decoder_block(
block_name=block_name, block_name=block_name,

View File

@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d) v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64) # RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) q = rope_apply(
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
k = rope_apply(
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype)) # Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3) q = q.astype(w_dtype).transpose(0, 2, 1, 3)
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask q, k, v, scale=self.scale, mask=mask
) )
else: else:
out = mx.fast.scaled_dot_product_attention( out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1) out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out) return self.o(out)
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask q, k, v, scale=self.scale, mask=mask
) )
else: else:
out = mx.fast.scaled_dot_product_attention( out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d) out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out) return self.o(out)

View File

@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.utils import mlx.utils
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
return weights return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize_wan_transformer_weights(
weights: Dict[str, mx.array]
) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure. """Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern: Wan2.2 keys follow the pattern:
@@ -246,8 +247,8 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths. Shared between weight-merging and runtime-wrapping paths.
""" """
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors from mlx_video.generate_wan import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -264,7 +265,9 @@ def _load_lora_configs(
module_to_loras = load_multiple_loras(configs) module_to_loras = load_multiple_loras(configs)
if not module_to_loras: if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}") print(
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
)
return module_to_loras return module_to_loras
@@ -279,8 +282,8 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
""" """
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors from mlx_video.generate_wan import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs: if not lora_configs:
return model_weights return model_weights
@@ -289,12 +292,17 @@ def load_and_apply_loras(
if not module_to_loras: if not module_to_loras:
return model_weights return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}") print(
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
)
if verbose: if verbose:
print(f" Model has {len(model_weights)} weight keys") print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights( modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits model_weights,
module_to_loras,
verbose=verbose,
quantization_bits=quantization_bits,
) )
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}") print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
src_model_type = src_config.get("model_type", "t2v") src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512) src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, " print(
f"heads={src_num_heads}, type={src_model_type}") f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}"
)
# Use preset for known TI2V 5B configuration # Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072: if src_model_type == "ti2v" and src_dim == 3072:
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path)) weights = load_torch_weights(str(vae_path))
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v") include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder) weights = sanitize_wan22_vae_weights(
weights, include_encoder=include_encoder
)
else: else:
weights = sanitize_wan_vae_weights(weights) weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in # Always save VAE in float32 — official Wan2.2 runs VAE decode in
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
# Quantize transformer weights if requested # Quantize transformer weights if requested
if quantize: if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...") print(
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
)
_quantize_saved_model(output_dir, config, is_dual, bits, group_size) _quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}") print(f"\nConversion complete! Output: {output_dir}")
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
return False return False
# Quantize attention Q/K/V/O and FFN fc1/fc2 # Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = ( quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o", ".self_attn.q",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o", ".self_attn.k",
".ffn.fc1", ".ffn.fc2", ".self_attn.v",
".self_attn.o",
".cross_attn.q",
".cross_attn.k",
".cross_attn.v",
".cross_attn.o",
".ffn.fc1",
".ffn.fc2",
) )
return any(path.endswith(p) for p in quantize_patterns) return any(path.endswith(p) for p in quantize_patterns)
@@ -684,14 +706,20 @@ def quantize_mlx_model(
# Build model config # Build model config
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__} config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"): for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list): if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key]) config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict) config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights) # Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"} transformer_files = {
"low_noise_model.safetensors",
"high_noise_model.safetensors",
"model.safetensors",
}
if dst.resolve() != src.resolve(): if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True) dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir(): for f in src.iterdir():
@@ -763,11 +791,18 @@ if __name__ == "__main__":
if args.quantize_only: if args.quantize_only:
quantize_mlx_model( quantize_mlx_model(
args.checkpoint_dir, args.output_dir, args.checkpoint_dir,
bits=args.bits, group_size=args.group_size, args.output_dir,
bits=args.bits,
group_size=args.group_size,
) )
else: else:
convert_wan_checkpoint( convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version, args.checkpoint_dir,
quantize=args.quantize, bits=args.bits, group_size=args.group_size, args.output_dir,
args.dtype,
args.model_version,
quantize=args.quantize,
bits=args.bits,
group_size=args.group_size,
) )

View File

@@ -4,18 +4,15 @@ import argparse
import gc import gc
import math import math
import random import random
import sys
import time import time
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import ( from mlx_video.models.wan.loading import (
_clean_text,
encode_text, encode_text,
load_t5_encoder, load_t5_encoder,
load_vae_decoder, load_vae_decoder,
@@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import (
) )
from mlx_video.models.wan.postprocess import save_video from mlx_video.models.wan.postprocess import save_video
class Colors: class Colors:
"""ANSI color codes for terminal output.""" """ANSI color codes for terminal output."""
@@ -37,6 +35,7 @@ class Colors:
DIM = "\033[2m" DIM = "\033[2m"
RESET = "\033[0m" RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name) # Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask _build_i2v_mask = build_i2v_mask
@@ -143,10 +142,13 @@ def generate_video(
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"): for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list): if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key]) config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{ config = WanModelConfig(
k: v for k, v in config_dict.items() **{
if k in WanModelConfig.__dataclass_fields__ k: v
}) for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
}
)
else: else:
# Auto-detect: dual model files → 2.2, single model → 2.1 # Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists(): if (model_dir / "low_noise_model.safetensors").exists():
@@ -182,7 +184,9 @@ def generate_video(
if "patch_embedding_proj.weight" in k: if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0] actual_dim = v.shape[0]
if actual_dim != config.dim: if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}") print(
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
)
if actual_dim <= 2048: if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
else: else:
@@ -192,13 +196,20 @@ def generate_video(
# Auto-correct Wan2.2 VAE params from stale configs # Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48: if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}") print(
config = WanModelConfig(**{ f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()}, )
"vae_z_dim": 48, config = WanModelConfig(
"vae_stride": (4, 16, 16), **{
"sample_fps": 24, **{
}) f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
}
)
# Apply defaults from config if not overridden # Apply defaults from config if not overridden
if steps is None: if steps is None:
@@ -227,7 +238,9 @@ def generate_video(
gen_frames = num_frames gen_frames = num_frames
if trim_first_frames > 0: if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4 gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}") print(
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
)
version_str = f"Wan{config.model_version}" version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model" mode_str = "dual-model" if is_dual else "single-model"
@@ -247,10 +260,16 @@ def generate_video(
if is_i2v: if is_i2v:
print(f" Image: {image}") print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip(): if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved neg_display = (
neg_prompt_resolved[:60] + "..."
if len(neg_prompt_resolved) > 60
else neg_prompt_resolved
)
print(f" Neg prompt: {neg_display}") print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}") print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}") print(
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
)
if cfg_disabled: if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)") print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}") print(f"{Colors.RESET}")
@@ -275,12 +294,16 @@ def generate_video(
height = align_h height = align_h
if width == 0: if width == 0:
width = align_w width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}") print(
f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
)
# Enforce max_area constraint (model-specific resolution limit) # Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area: if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area) width, height = _best_output_size(
width, height, align_w, align_h, config.max_area
)
print( print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area " f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}" f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
@@ -309,6 +332,7 @@ def generate_video(
# Load tokenizer # Load tokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts # Encode prompts
@@ -318,12 +342,15 @@ def generate_video(
context_null = None context_null = None
mx.eval(context) mx.eval(context)
else: else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) context_null = encode_text(
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
)
mx.eval(context, context_null) mx.eval(context, context_null)
# Free T5 from memory # Free T5 from memory
del t5_encoder del t5_encoder
gc.collect(); mx.clear_cache() gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}") print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space # I2V: encode image to latent space
@@ -346,18 +373,25 @@ def generate_video(
img = Image.open(image).convert("RGB") img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height) scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS) img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2 x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height)) img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3] img_arr = mx.array(
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W] img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W] # Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching # Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([ video = mx.concatenate(
img_chw[:, None, :, :], [
mx.zeros((3, num_frames - 1, height, width)), img_chw[:, None, :, :],
], axis=1) mx.zeros((3, num_frames - 1, height, width)),
],
axis=1,
)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat] # Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config) vae_enc = load_vae_encoder(vae_path, config)
@@ -367,12 +401,17 @@ def generate_video(
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W] # Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent)) msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat] # Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([ msk = mx.concatenate(
mx.repeat(msk[:, :1], 4, axis=1), [
msk[:, 1:], mx.repeat(msk[:, :1], 4, axis=1),
], axis=1) msk[:, 1:],
],
axis=1,
)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat] # Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -395,13 +434,16 @@ def generate_video(
del vae_enc, img_tensor del vae_enc, img_tensor
gc.collect(); mx.clear_cache() gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models # Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}") print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization: if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}") print(
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
)
t2 = time.time() t2 = time.time()
# Merge per-model LoRAs with shared LoRAs # Merge per-model LoRAs with shared LoRAs
@@ -412,10 +454,16 @@ def generate_video(
if is_dual: if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors" low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors" high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low) low_noise_model = load_wan_model(
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high) low_noise_path, config, quantization, loras=_loras_low
)
high_noise_model = load_wan_model(
high_noise_path, config, quantization, loras=_loras_high
)
else: else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single) single_model = load_wan_model(
model_dir / "model.safetensors", config, quantization, loras=_loras_single
)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step) # Precompute text embeddings once (avoids redundant MLP in every step)
@@ -437,8 +485,12 @@ def generate_video(
context_emb_low = low_noise_model.embed_text([context, context_null]) context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null]) context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high) mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0) context_cfg_low = mx.concatenate(
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0) [context_emb_low[0:1], context_emb_low[1:2]], axis=0
)
context_cfg_high = mx.concatenate(
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
)
else: else:
context_emb = single_model.embed_text([context, context_null]) context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb) mx.eval(context_emb)
@@ -534,7 +586,7 @@ def generate_video(
rcs = rope_cos_sin rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace) # Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model) _call = getattr(model, "_compiled", model)
if cfg_disabled: if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch) # No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
@@ -552,7 +604,9 @@ def generate_video(
y_arg = [y_i2v] if is_i2v_channel_concat else None y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual: if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low ctx = (
context_cond_high if timestep_val >= boundary else context_cond_low
)
else: else:
ctx = context_cond ctx = context_cond
preds = _call( preds = _call(
@@ -571,7 +625,11 @@ def generate_video(
if is_dual: if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0] gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else: else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] gs = (
guide_scale
if isinstance(guide_scale, (int, float))
else guide_scale[0]
)
if is_i2v_mask_blend: if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val t_tokens = i2v_mask_tokens * timestep_val
@@ -586,8 +644,10 @@ def generate_video(
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else ( ctx = (
context_cfg_high if timestep_val >= boundary else context_cfg_low context_cfg
if not is_dual
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
) )
preds = _call( preds = _call(
[latents, latents], [latents, latents],
@@ -618,16 +678,24 @@ def generate_video(
if debug_latents: if debug_latents:
lat_np = np.array(latents) # [C, T, H, W] lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1] n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}") print(
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}") f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
)
print(
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
)
for t_pos in range(min(n_t, 8)): for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :] frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} " print(
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}") f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
)
if n_t > 8: if n_t > 8:
interior = lat_np[:, 4:, :, :] interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} " print(
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}") f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
)
print() print()
# Free transformer models and text embeddings # Free transformer models and text embeddings
@@ -646,7 +714,8 @@ def generate_video(
del model, kv, context del model, kv, context
if context_null is not None: if context_null is not None:
del context_null del context_null
gc.collect(); mx.clear_cache() gc.collect()
mx.clear_cache()
# Load VAE and decode # Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}") print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
@@ -677,13 +746,25 @@ def generate_video(
elif tiling == "temporal": elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only() tiling_config = TilingConfig.temporal_only()
else: else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") print(
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
)
tiling_config = TilingConfig.auto(height, width, num_frames) tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None: if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" spatial_info = (
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" f"{tiling_config.spatial_config.tile_size_in_pixels}px"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") if tiling_config.spatial_config
else "none"
)
temporal_info = (
f"{tiling_config.temporal_config.tile_size_in_frames}f"
if tiling_config.temporal_config
else "none"
)
print(
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
)
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents from mlx_video.models.wan.vae22 import denormalize_latents
@@ -718,7 +799,9 @@ def generate_video(
if trim_first_frames > 0: if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4 trim_pixels = trim_first_frames * 4
video = video[trim_pixels:] video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}") print(
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
)
save_video(video, output_path, fps=config.sample_fps) save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}") print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
@@ -727,58 +810,124 @@ def generate_video(
def main(): def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)") parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument( parser.add_argument(
"--scheduler", type=str, default="unipc", "--model-dir",
type=str,
required=True,
help="Path to converted MLX model directory",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument(
"--image",
type=str,
default=None,
help="Path to input image for I2V (omit for T2V mode)",
)
parser.add_argument(
"--negative-prompt",
type=str,
default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)",
)
parser.add_argument(
"--no-negative-prompt",
action="store_true",
help="Disable negative prompt (use empty string instead of config default)",
)
parser.add_argument(
"--width", type=int, default=1280, help="Video width (default: 1280)"
)
parser.add_argument(
"--height",
type=int,
default=704,
help="Video height (default: 704; 720p models use 704)",
)
parser.add_argument(
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
)
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of diffusion steps (default: from config)",
)
parser.add_argument(
"--guide-scale",
type=str,
default=None,
help="Guidance scale: single float or low,high pair",
)
parser.add_argument(
"--shift",
type=float,
default=None,
help="Noise schedule shift (default: from config)",
)
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument(
"--output-path", type=str, default="output.mp4", help="Output video path"
)
parser.add_argument(
"--scheduler",
type=str,
default="unipc",
choices=["euler", "dpm++", "unipc"], choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)", help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
) )
parser.add_argument( parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"), "--lora",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8", help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
) )
parser.add_argument( parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"), "--lora-high",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)", help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
) )
parser.add_argument( parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"), "--lora-low",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)", help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
) )
parser.add_argument( parser.add_argument(
"--tiling", "--tiling",
type=str, type=str,
default="auto", default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], choices=[
"auto",
"none",
"default",
"aggressive",
"conservative",
"spatial",
"temporal",
],
help="VAE tiling mode to reduce memory during decoding (default: auto)", help="VAE tiling mode to reduce memory during decoding (default: auto)",
) )
parser.add_argument( parser.add_argument(
"--no-compile", action="store_true", "--no-compile",
action="store_true",
help="Disable mx.compile on models (for debugging)", help="Disable mx.compile on models (for debugging)",
) )
parser.add_argument( parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N", "--trim-first-frames",
type=int,
default=0,
metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. " help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). " "Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)", "Default: 0 (disabled)",
) )
parser.add_argument( parser.add_argument(
"--debug-latents", action="store_true", "--debug-latents",
action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)", help="Print per-temporal-position latent statistics after denoising (diagnostic)",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
# Resize so that the image covers the target size (LANCZOS) # Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height) scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS) img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
# Center crop # Center crop
x1 = (img.width - width) // 2 x1 = (img.width - width) // 2

View File

@@ -6,7 +6,12 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None): def load_wan_model(
model_path: Path,
config,
quantization: dict | None = None,
loras: list | None = None,
):
"""Load and initialize WanModel, with optional quantization and LoRA support. """Load and initialize WanModel, with optional quantization and LoRA support.
Args: Args:
@@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None):
if is_wan22: if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48) vae = Wan22VAEDecoder(z_dim=48)
else: else:
from mlx_video.models.wan.vae import WanVAE from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path)) weights = mx.load(str(model_path))
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
try: try:
import ftfy import ftfy
text = ftfy.fix_text(text) text = ftfy.fix_text(text)
except ImportError: except ImportError:
pass pass

View File

@@ -1,4 +1,5 @@
import math import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
@@ -37,7 +38,9 @@ class Head(nn.Module):
proj_dim = math.prod(patch_size) * out_dim proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps) self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim) self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32) self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(self, x: mx.array, e: mx.array) -> mx.array: def __call__(self, x: mx.array, e: mx.array) -> mx.array:
""" """
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
# Reference computes three rope_params with different dim normalizations # Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range. # so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads d = dim // config.num_heads
self.freqs = mx.concatenate([ self.freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Precompute sinusoidal inv_freq for time embedding. # Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2 half = config.freq_dim // 2
self._inv_freq = mx.array( self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
).astype(np.float32) np.float32
)
) )
def _patchify(self, x: mx.array) -> tuple: def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings. """Convert video tensor to patch embeddings.
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
seq_lens_list.append(p.shape[1]) seq_lens_list.append(p.shape[1])
x = mx.concatenate( x = mx.concatenate(
[ [
mx.concatenate( (
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], mx.concatenate(
axis=1, [
p,
mx.zeros(
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
),
],
axis=1,
)
if p.shape[1] < seq_len
else p
) )
if p.shape[1] < seq_len
else p
for p in patches for p in patches
], ],
axis=0, axis=0,
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
t = t[None] t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate( sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
if t.ndim == 1: if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B] # Standard T2V: scalar timestep per batch element [B]

View File

@@ -1,6 +1,8 @@
import numpy as np
from pathlib import Path from pathlib import Path
import numpy as np
def save_video(frames: np.ndarray, output_path: str, fps: int = 16): def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4. """Save video frames to MP4.
@@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
""" """
try: try:
import imageio import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8) writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames: for frame in frames:
writer.append_data(frame) writer.append_data(frame)
@@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except ImportError: except ImportError:
try: try:
import cv2 import cv2
h, w = frames.shape[1], frames.shape[2] h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1") fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
@@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except (ImportError, Exception): except (ImportError, Exception):
# Last resort: save as individual PNGs # Last resort: save as individual PNGs
from PIL import Image from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames): for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png") Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)") print(
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
)

View File

@@ -1,4 +1,3 @@
import math
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
Complex frequency tensor of shape [max_seq_len, dim // 2]. Complex frequency tensor of shape [max_seq_len, dim // 2].
""" """
assert dim % 2 == 0 assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * ( freqs = (
1.0 np.arange(max_seq_len, dtype=np.float64)[:, None]
/ np.power( * (
theta, 1.0
np.arange(0, dim, 2, dtype=np.float64) / dim, / np.power(
) theta,
)[None, :] np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
)
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2] # Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32) cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32) sin_freqs = np.sin(freqs).astype(np.float32)
@@ -46,9 +48,9 @@ def rope_apply(
# Check if all batch elements have the same grid (common for CFG B=2) # Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0] f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0 seq_len = f0 * h0 * w0
all_same_grid = all( all_same_grid = (
grid_sizes[i] == grid_sizes[0] for i in range(1, b) all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
) if b > 1 else True )
if all_same_grid: if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once # Vectorized path: apply RoPE to all batch elements at once
@@ -57,7 +59,9 @@ def rope_apply(
x_imag = x_seq[..., 1] x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d) x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
b, seq_len, n, d
)
if seq_len < s: if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1) x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated return x_rotated
@@ -102,17 +106,11 @@ def rope_apply(
# Build per-position frequencies by expanding along grid dims # Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2] # temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to( ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2] # height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to( fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2] # width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to( fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
# Concatenate: [f*h*w, half_d, 2] # Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2) freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)

View File

@@ -7,9 +7,8 @@ for the same quality as Euler.
import math import math
import numpy as np
import mlx.core as mx import mlx.core as mx
import numpy as np
def _compute_sigmas( def _compute_sigmas(
@@ -25,9 +24,7 @@ def _compute_sigmas(
Returns num_steps+1 values (the last being 0.0 for the terminal state). Returns num_steps+1 values (the last being 0.0 for the terminal state).
""" """
# sigma bounds from unshifted training schedule (constructor uses shift=1) # sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[ alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
::-1
]
sigmas_unshifted = 1.0 - alphas sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0 sigma_min = float(sigmas_unshifted[-1]) # 0.0
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
sample: mx.array, sample: mx.array,
) -> mx.array: ) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v.""" """Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index] dt = (
self._sigmas_float[self._step_index + 1]
- self._sigmas_float[self._step_index]
)
x_next = sample + dt * model_output x_next = sample + dt * model_output
self._step_index += 1 self._step_index += 1
return x_next return x_next
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
# Decide order: 1st for first step, last step (if lower_order_final # Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd # and few steps), otherwise 2nd
use_first_order = ( use_first_order = self._prev_x0 is None or (
self._prev_x0 is None self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
) )
if use_first_order or sigma_next == 0.0: if use_first_order or sigma_next == 0.0:

View File

@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
is_small = rel_pos < max_exact is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32) rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = ( rel_pos_large = max_exact + (
max_exact mx.log(rel_pos_f / max_exact)
+ ( / math.log(self.max_dist / max_exact)
mx.log(rel_pos_f / max_exact) * (num_buckets - max_exact)
/ math.log(self.max_dist / max_exact) ).astype(mx.int32)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum( rel_pos_large = mx.minimum(
rel_pos_large, rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32), mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
) )
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large) rel_buckets = rel_buckets + mx.where(
is_small, rel_pos.astype(mx.int32), rel_pos_large
)
return rel_buckets return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array: def __call__(self, lq: int, lk: int) -> mx.array:
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
v = v.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision # QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)) attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
# Add position bias # Add position bias
if pos_bias is not None: if pos_bias is not None:

View File

@@ -75,7 +75,11 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape # Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale) out_f = (
(1 + (f_latent - 1) * temporal_scale)
if causal_temporal
else (f_latent * temporal_scale)
)
out_h = h_latent * spatial_scale out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale out_w = w_latent * spatial_scale
@@ -98,9 +102,13 @@ def decode_with_tiling(
# Compute intervals for each dimension # Compute intervals for each dimension
if causal_temporal: if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
else: else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent) temporal_intervals = split_in_spatial(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -124,9 +132,13 @@ def decode_with_tiling(
# Map temporal coordinates # Map temporal coordinates
if causal_temporal: if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
else: else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale) out_t_slice, t_mask = map_spatial_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles): for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx] h_start = height_intervals.starts[h_idx]
@@ -135,7 +147,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx] h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates # Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles): for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx] w_start = width_intervals.starts[w_idx]
@@ -144,13 +158,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx] w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates # Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice) # Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile # Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output) mx.eval(tile_output)
# Clear tile_latents reference # Clear tile_latents reference
@@ -173,13 +197,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = ( blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) * t_mask_slice.reshape(1, 1, -1, 1, 1)
h_mask_slice.reshape(1, 1, 1, -1, 1) * * h_mask_slice.reshape(1, 1, 1, -1, 1)
w_mask_slice.reshape(1, 1, 1, 1, -1) * w_mask_slice.reshape(1, 1, 1, 1, -1)
) )
# Slice tile output to match # Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output # Clear full tile_output
del tile_output del tile_output
@@ -196,11 +222,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment # Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( output[
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile :,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
) )
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( weights[
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask :,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
) )
# Force evaluation to free memory # Force evaluation to free memory
@@ -232,12 +284,14 @@ def decode_with_tiling(
if next_tile_start_latent == 0: if next_tile_start_latent == 0:
next_tile_start_out = 0 next_tile_start_out = 0
elif causal_temporal: elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
else: else:
next_tile_start_out = next_tile_start_latent * temporal_scale next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted # We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'): if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0 decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames emitted = decode_with_tiling._emitted_frames
@@ -245,7 +299,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out) # Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8) finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype) finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output) mx.eval(finalized_output)
@@ -262,7 +319,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided # Emit remaining frames if callback provided
if on_frames_ready is not None: if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0) emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f: if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output) mx.eval(remaining_output)
@@ -270,7 +327,7 @@ def decode_with_tiling(
del remaining_output del remaining_output
# Reset emitted frames counter for next call # Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'): if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames del decode_with_tiling._emitted_frames
# Clean up weights # Clean up weights

View File

@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
# Cross-attention (with optional norm on context) # Cross-attention (with optional norm on context)
self.norm3 = ( self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True) WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
if cross_attn_norm
else None
) )
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps) self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
self.ffn = WanFFN(dim, ffn_dim) self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision) # Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32) self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__( def __call__(
self, self,
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
# Self-attention with modulation (hidden state stays in w_dtype) # Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0 x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask) y = self.self_attn(
x_mod,
seq_lens,
grid_sizes,
freqs,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
x = x + y * e2 x = x + y * e2
# Cross-attention (no modulation, just norm) # Cross-attention (no modulation, just norm)

View File

@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
CACHE_T = 2 CACHE_T = 2
# Per-channel normalization statistics for z_dim=16 # Per-channel normalization statistics for z_dim=16
VAE_MEAN = [ VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, -0.7571,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921, -0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
] ]
VAE_STD = [ VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 2.8184,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160, 1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
] ]
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2] self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I] # MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)) self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,)) self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array: def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
x = mx.concatenate([pad_t, x], axis=2) x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0: if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0), x = mx.pad(
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)]) x,
[
(0, 0),
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
],
)
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C] x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x) out = self._conv3d(x)
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1 norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize) # L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None)) norm = mx.sqrt(
mx.clip(
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
)
)
return (x / norm) * self.scale * self.gamma return (x / norm) * self.scale * self.gamma
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int): def __init__(self, in_dim: int, out_dim: int):
super().__init__() super().__init__()
self.residual = [ self.residual = [
RMS_norm(in_dim, images=False), # [0] RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2] CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3] RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU None, # [4] SiLU
None, # [5] Dropout None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6] CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
] ]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
@@ -226,13 +266,16 @@ class Resample(nn.Module):
# resample.0 = Upsample (no params), resample.1 = Conv2d # resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d": if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
)
else: else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2) # resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)] self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d": if mode == "downsample3d":
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]""" """x: [B, C, T, H, W]"""
@@ -272,8 +315,7 @@ class Resample(nn.Module):
else: else:
# Subsequent chunks: use cached frame as temporal context # Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:] cache_x = x[:, :, -1:]
x = self.time_conv( x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
# Output head: [RMS_norm, SiLU (no params), CausalConv3d] # Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [ self.head = [
RMS_norm(dims[-1], images=False), # [0] RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2] CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
] ]
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:] cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate( cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx]) x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:] cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate( cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx]) x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
decoder_fn=tile_decode, decoder_fn=tile_decode,
latents=z_denorm, latents=z_denorm,
tiling_config=tiling_config, tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8× spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4× temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T) causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
) )

View File

@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
""" """
import logging import logging
import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
CACHE_T = 2 CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space # Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([ VAE22_MEAN = mx.array(
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, [
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, -0.2289,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.0052,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, -0.1323,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, -0.2339,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, -0.2799,
]) 0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
]
)
VAE22_STD = mx.array([ VAE22_STD = mx.array(
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, [
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.4765,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 1.0364,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.4514,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 1.1677,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744, 0.5313,
]) 0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
]
)
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2] self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX # Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros(( self.weight = mx.zeros(
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels (out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)) )
self.bias = mx.zeros((out_channels,)) self.bias = mx.zeros((out_channels,))
def __call__(self, x, cache_x=None): def __call__(self, x, cache_x=None):
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
# Spatial padding # Spatial padding
if self._pad_h > 0 or self._pad_w > 0: if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h), x = mx.pad(
(self._pad_w, self._pad_w), (0, 0)]) x,
[
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
(0, 0),
],
)
T_padded = x.shape[1] T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3] H_padded, W_padded = x.shape[2], x.shape[3]
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
for d in range(kd): for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C] frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I] w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d, conv_out = mx.conv_general(
stride=(self.stride[1], self.stride[2])) frame, w2d, stride=(self.stride[1], self.stride[2])
)
accum = conv_out if accum is None else accum + conv_out accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias) outputs.append(accum + self.bias)
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
self.scale = dim ** 0.5 self.scale = dim**0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze # Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,)) self.gamma = mx.ones((dim,))
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
# x: [..., C] (channels-last) # x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps) # PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True) l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma return (
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d] # Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices # We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim) self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = ( self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
def __call__(self, x, feat_cache=None, feat_idx=None): def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x h = self.shortcut(x) if self.shortcut is not None else x
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
# Save last CACHE_T frames before conv (for next chunk's context) # Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:] cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None: if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate( cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
[feat_cache[idx][:, -1:], cache_x], axis=1
)
out = conv(x, cache_x=feat_cache[idx]) out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
x = self.norm(x) x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim) # QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C] qkv = (
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
) # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C) qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C] q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
k = k[:, None, :, :] k = k[:, None, :, :]
v = v[:, None, :, :] v = v[:, None, :, :]
scale = C ** -0.5 scale = C**-0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C] out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale
) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C) out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output # Project output
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats] x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s] # Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s) x = x.reshape(
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C] # Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4) x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C] # Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels) x = x.reshape(
B,
T * self.factor_t,
H * self.factor_s,
W * self.factor_s,
self.out_channels,
)
if first_chunk: if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :] x = x[:, self.factor_t - 1 :, :, :, :]
return x return x
@@ -348,7 +452,9 @@ class Resample(nn.Module):
self.resample_weight = mx.zeros((dim, 3, 3, dim)) self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,)) self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1)) # time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else: else:
raise ValueError(f"Unsupported mode: {mode}") raise ValueError(f"Unsupported mode: {mode}")
@@ -369,7 +475,9 @@ class Resample(nn.Module):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C].""" """Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1 # ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias return (
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
)
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None): def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C] # x: [B, T, H, W, C]
@@ -444,14 +552,17 @@ class Resample(nn.Module):
class Up_ResidualBlock(nn.Module): class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut.""" """Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False): def __init__(
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
):
super().__init__() super().__init__()
self.up_flag = up_flag self.up_flag = up_flag
# DupUp3D shortcut (no learnable params) # DupUp3D shortcut (no learnable params)
if up_flag: if up_flag:
self.avg_shortcut = DupUp3D( self.avg_shortcut = DupUp3D(
in_dim, out_dim, in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1, factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1, factor_s=2 if up_flag else 1,
) )
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
class Down_ResidualBlock(nn.Module): class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut.""" """Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False): def __init__(
self,
in_dim,
out_dim,
num_res_blocks,
temperal_downsample=False,
down_flag=False,
):
super().__init__() super().__init__()
self.down_flag = down_flag self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present) # AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D( self.avg_shortcut = AvgDown3D(
in_dim, out_dim, in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1, factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1, factor_s=2 if down_flag else 1,
) )
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
self.upsamples = [] self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock( self.upsamples.append(
in_dim=in_dim, Up_ResidualBlock(
out_dim=out_dim, in_dim=in_dim,
num_res_blocks=num_res_blocks + 1, out_dim=out_dim,
temperal_upsample=t_up, num_res_blocks=num_res_blocks + 1,
up_flag=(i != len(dim_mult) - 1), temperal_upsample=t_up,
)) up_flag=(i != len(dim_mult) - 1),
)
)
# Output head: [RMS_norm, SiLU, CausalConv3d] # Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1]) self.head = Head22(dims[-1])
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
for i in range(len(dim_mult)): for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1] in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock( self.downsamples.append(
in_dim=in_d, Down_ResidualBlock(
out_dim=out_d, in_dim=in_d,
num_res_blocks=num_res_blocks, out_dim=out_d,
temperal_downsample=t_down, num_res_blocks=num_res_blocks,
down_flag=(i < len(dim_mult) - 1), temperal_downsample=t_down,
)) down_flag=(i < len(dim_mult) - 1),
)
)
# Middle blocks (same as decoder) # Middle blocks (same as decoder)
out_dim = dims[-1] out_dim = dims[-1]
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, -CACHE_T:] cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None: if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate( cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.conv1(x, cache_x=feat_cache[idx]) x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -700,9 +821,7 @@ class Head22(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, -CACHE_T:] cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None: if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate( cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.layer_2(x, cache_x=feat_cache[idx]) x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
if i == 0: if i == 0:
chunk = x[:, :1] chunk = x[:, :1]
else: else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i] chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx) chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None: if out is None:
out = chunk_out out = chunk_out
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
# conv1 (pointwise) + split into mu, log_var # conv1 (pointwise) + split into mu, log_var
out = self.conv1(out) out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim] mu = out[:, :, :, :, : self.z_dim]
# Normalize # Normalize
mu = normalize_latents(mu) mu = normalize_latents(mu)
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
decoder_fn=tile_decode, decoder_fn=tile_decode,
latents=z_cf, latents=z_cf,
tiling_config=tiling_config, tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal) temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True, causal_temporal=True,
) )

View File

@@ -1,14 +1,15 @@
import math import math
from functools import partial
from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
def get_model_path(model_repo: str): def get_model_path(model_repo: str):
"""Get or download LTX-2 model path.""" """Get or download LTX-2 model path."""
try: try:
@@ -17,15 +18,19 @@ def get_model_path(model_repo: str):
return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception: except Exception:
print("Downloading LTX-2 model weights...") print("Downloading LTX-2 model weights...")
return Path(snapshot_download( return Path(
repo_id=model_repo, snapshot_download(
local_files_only=False, repo_id=model_repo,
resume_download=True, local_files_only=False,
allow_patterns=["*.safetensors", "*.json"], resume_download=True,
)) allow_patterns=["*.safetensors", "*.json"],
)
)
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None: if quantization is not None:
def get_class_predicate(p, m): def get_class_predicate(p, m):
# Handle custom per layer quantizations # Handle custom per layer quantizations
if p in quantization: if p in quantization:
@@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
class_predicate=get_class_predicate, class_predicate=get_class_predicate,
) )
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps) return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def to_denoised( def to_denoised(
noisy: mx.array, noisy: mx.array, velocity: mx.array, sigma: mx.array | float
velocity: mx.array,
sigma: mx.array | float
) -> mx.array: ) -> mx.array:
"""Convert velocity prediction to denoised output. """Convert velocity prediction to denoised output.
@@ -284,7 +287,9 @@ def prepare_image_for_encoding(
if image_np.max() <= 1.0: if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8) image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np) pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS) pil_image = pil_image.resize(
(target_width, target_height), Image.Resampling.LANCZOS
)
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0) image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
# Normalize to [-1, 1] # Normalize to [-1, 1]

View File

@@ -170,19 +170,33 @@ def print_report(results, ref_path, test_path):
print("AGGREGATE METRICS") print("AGGREGATE METRICS")
print("-" * 40) print("-" * 40)
print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}") print(
print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}") f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}"
print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}") )
print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}") print(
print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}") f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}"
)
print(
f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}"
)
print(
f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}"
)
print(
f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}"
)
print() print()
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)") print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
print("-" * 40) print("-" * 40)
print(f" Reference: {results['ref_temporal_coherence']:.2f}") print(f" Reference: {results['ref_temporal_coherence']:.2f}")
print(f" Test: {results['test_temporal_coherence']:.2f}") print(f" Test: {results['test_temporal_coherence']:.2f}")
ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10) ratio = results["test_temporal_coherence"] / (
print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}") results["ref_temporal_coherence"] + 1e-10
)
print(
f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}"
)
print() print()
# Identify worst frames # Identify worst frames
@@ -190,7 +204,9 @@ def print_report(results, ref_path, test_path):
print("-" * 40) print("-" * 40)
worst_idx = np.argsort(psnr)[:5] worst_idx = np.argsort(psnr)[:5]
for i in worst_idx: for i in worst_idx:
print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}") print(
f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}"
)
print() print()
# Quality assessment # Quality assessment
@@ -210,7 +226,9 @@ def print_report(results, ref_path, test_path):
grade = "Very different" grade = "Very different"
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})") print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
if mean_psnr < 30: if mean_psnr < 30:
print(" ⚠ Videos differ significantly — likely a bug or different generation seed") print(
" ⚠ Videos differ significantly — likely a bug or different generation seed"
)
print("=" * 72) print("=" * 72)
@@ -242,9 +260,7 @@ def main():
parser.add_argument( parser.add_argument(
"--diff-video", help="Save side-by-side diff visualization to this path" "--diff-video", help="Save side-by-side diff visualization to this path"
) )
parser.add_argument( parser.add_argument("--max-frames", type=int, help="Compare only first N frames")
"--max-frames", type=int, help="Compare only first N frames"
)
parser.add_argument( parser.add_argument(
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)" "--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
) )
@@ -254,26 +270,29 @@ def main():
default=5.0, default=5.0,
help="Diff heatmap amplification (default: 5.0)", help="Diff heatmap amplification (default: 5.0)",
) )
parser.add_argument( parser.add_argument("--csv", help="Export per-frame metrics to CSV file")
"--csv", help="Export per-frame metrics to CSV file"
)
args = parser.parse_args() args = parser.parse_args()
print(f"Loading reference: {args.reference}") print(f"Loading reference: {args.reference}")
ref_frames, ref_fps = load_video(args.reference, args.max_frames) ref_frames, ref_fps = load_video(args.reference, args.max_frames)
print(f"{len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}") print(
f"{len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}"
)
print(f"Loading test: {args.test}") print(f"Loading test: {args.test}")
test_frames, test_fps = load_video(args.test, args.max_frames) test_frames, test_fps = load_video(args.test, args.max_frames)
print(f"{len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}") print(
f"{len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}"
)
if ref_frames[0].shape != test_frames[0].shape: if ref_frames[0].shape != test_frames[0].shape:
print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}") print(
f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}"
)
print("Resizing test frames to match reference...") print("Resizing test frames to match reference...")
h, w = ref_frames[0].shape[:2] h, w = ref_frames[0].shape[:2]
test_frames = [ test_frames = [
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) for f in test_frames
for f in test_frames
] ]
print("Computing metrics...") print("Computing metrics...")
@@ -282,23 +301,29 @@ def main():
print_report(results, args.reference, args.test) print_report(results, args.reference, args.test)
if args.diff_video: if args.diff_video:
save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale) save_diff_video(
ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale
)
if args.csv: if args.csv:
import csv import csv
with open(args.csv, "w", newline="") as f: with open(args.csv, "w", newline="") as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]) writer.writerow(
["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]
)
for i in range(results["num_frames"]): for i in range(results["num_frames"]):
writer.writerow([ writer.writerow(
i, [
f"{results['psnr'][i]:.4f}", i,
f"{results['ssim'][i]:.6f}", f"{results['psnr'][i]:.4f}",
f"{results['mean_diff'][i]:.4f}", f"{results['ssim'][i]:.6f}",
f"{results['max_diff'][i]:.1f}", f"{results['mean_diff'][i]:.4f}",
f"{results['color_dist'][i]:.6f}", f"{results['max_diff'][i]:.1f}",
]) f"{results['color_dist'][i]:.6f}",
]
)
print(f"Per-frame metrics saved to {args.csv}") print(f"Per-frame metrics saved to {args.csv}")

View File

@@ -158,10 +158,14 @@ def analyze_video(frames, chunk_size=None, compute_flow=False):
boundary_metrics = [] boundary_metrics = []
for b in boundaries: for b in boundaries:
if b < n and b > 0: if b < n and b > 0:
pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1] pre = (
metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
)
at = metrics["frame_diff"][b] at = metrics["frame_diff"][b]
ratio = at / (pre + 1e-10) ratio = at / (pre + 1e-10)
brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1] brightness_jump = (
metrics["brightness"][b] - metrics["brightness"][b - 1]
)
contrast_jump = ( contrast_jump = (
(metrics["contrast"][b] - metrics["contrast"][b - 1]) (metrics["contrast"][b] - metrics["contrast"][b - 1])
/ (metrics["contrast"][b - 1] + 1e-10) / (metrics["contrast"][b - 1] + 1e-10)
@@ -198,7 +202,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
print("VIDEO QUALITY REPORT") print("VIDEO QUALITY REPORT")
print("=" * 72) print("=" * 72)
print(f" File: {path}") print(f" File: {path}")
print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}") print(
f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}"
)
duration = total_frames / fps if fps > 0 else 0 duration = total_frames / fps if fps > 0 else 0
print(f" Duration: {duration:.1f}s") print(f" Duration: {duration:.1f}s")
print() print()
@@ -211,52 +217,76 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
print("-" * 40) print("-" * 40)
if n_uniform: if n_uniform:
frames_list = np.where(metrics["is_uniform"])[0][:10] frames_list = np.where(metrics["is_uniform"])[0][:10]
print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}") print(
f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}"
)
if n_noisy: if n_noisy:
frames_list = np.where(metrics["is_noisy"])[0][:10] frames_list = np.where(metrics["is_noisy"])[0][:10]
print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}") print(
f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}"
)
print() print()
print("SHARPNESS") print("SHARPNESS")
print("-" * 40) print("-" * 40)
print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}") print(
print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}") f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}"
)
print(
f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}"
)
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3: if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
print(" ⚠ High sharpness variation — possible blur artifacts") print(" ⚠ High sharpness variation — possible blur artifacts")
print() print()
print("BRIGHTNESS & CONTRAST") print("BRIGHTNESS & CONTRAST")
print("-" * 40) print("-" * 40)
print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}") print(
print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}") f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}"
)
print(
f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}"
)
if np.std(br) > 3.0: if np.std(br) > 3.0:
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts") print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
print() print()
print("COLOR DISTRIBUTION (BGR)") print("COLOR DISTRIBUTION (BGR)")
print("-" * 40) print("-" * 40)
print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}") print(
print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}") f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}"
print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}") )
print(
f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}"
)
print(
f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}"
)
print() print()
print("TEMPORAL STABILITY") print("TEMPORAL STABILITY")
print("-" * 40) print("-" * 40)
fd_nz = fd[1:] # skip first frame (always 0) fd_nz = fd[1:] # skip first frame (always 0)
if len(fd_nz) > 0: if len(fd_nz) > 0:
print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}") print(
f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}"
)
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5: if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
print(" ⚠ High diff variance — jitter or discontinuities") print(" ⚠ High diff variance — jitter or discontinuities")
if "flow_mean" in metrics: if "flow_mean" in metrics:
fm = metrics["flow_mean"][1:] fm = metrics["flow_mean"][1:]
print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}") print(
f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}"
)
print() print()
# Chunk boundaries # Chunk boundaries
if "boundaries" in metrics and metrics["boundaries"]: if "boundaries" in metrics and metrics["boundaries"]:
print("CHUNK BOUNDARIES") print("CHUNK BOUNDARIES")
print("-" * 40) print("-" * 40)
print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}") print(
f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}"
)
for bm in metrics["boundaries"]: for bm in metrics["boundaries"]:
print( print(
f" {bm['frame']:6d}" f" {bm['frame']:6d}"
@@ -267,7 +297,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
) )
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]]) avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
if avg_ratio > 2.0: if avg_ratio > 2.0:
print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions") print(
f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions"
)
print() print()
# Overall grade # Overall grade
@@ -303,9 +335,7 @@ def main():
type=int, type=int,
help="Frames per chunk for boundary analysis (e.g., 32)", help="Frames per chunk for boundary analysis (e.g., 32)",
) )
parser.add_argument( parser.add_argument("--start", type=int, default=0, help="Start frame (default: 0)")
"--start", type=int, default=0, help="Start frame (default: 0)"
)
parser.add_argument("--end", type=int, help="End frame (default: all)") parser.add_argument("--end", type=int, help="End frame (default: all)")
parser.add_argument( parser.add_argument(
"--flow", "--flow",
@@ -329,8 +359,14 @@ def main():
import csv import csv
keys = [ keys = [
"sharpness_lap", "sharpness_grad", "brightness", "contrast", "sharpness_lap",
"color_mean_b", "color_mean_g", "color_mean_r", "frame_diff", "sharpness_grad",
"brightness",
"contrast",
"color_mean_b",
"color_mean_g",
"color_mean_r",
"frame_diff",
] ]
if args.flow: if args.flow:
keys += ["flow_mean", "flow_max"] keys += ["flow_mean", "flow_max"]

View File

@@ -1,17 +1,17 @@
"""Tests for LTX-2 dev model generation pipeline.""" """Tests for LTX-2 dev model generation pipeline."""
import pytest
import mlx.core as mx import mlx.core as mx
import pytest
from mlx_video.generate_dev import ( from mlx_video.generate_dev import (
ltx2_scheduler,
create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta,
DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND, AUDIO_LATENTS_PER_SECOND,
AUDIO_SAMPLE_RATE,
DEFAULT_NEGATIVE_PROMPT,
cfg_delta,
compute_audio_frames,
create_audio_position_grid,
create_position_grid,
ltx2_scheduler,
) )
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
"""Scheduler should return steps+1 sigma values.""" """Scheduler should return steps+1 sigma values."""
steps = 20 steps = 20
sigmas = ltx2_scheduler(steps=steps) sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}" assert sigmas.shape == (
steps + 1,
), f"Expected ({steps + 1},), got {sigmas.shape}"
def test_scheduler_starts_at_one(self): def test_scheduler_starts_at_one(self):
"""Sigma schedule should start at 1.0.""" """Sigma schedule should start at 1.0."""
sigmas = ltx2_scheduler(steps=20) sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}" assert (
abs(sigmas[0].item() - 1.0) < 1e-5
), f"Expected 1.0, got {sigmas[0].item()}"
def test_scheduler_ends_at_zero(self): def test_scheduler_ends_at_zero(self):
"""Sigma schedule should end at 0.0.""" """Sigma schedule should end at 0.0."""
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
sigmas = ltx2_scheduler(steps=20) sigmas = ltx2_scheduler(steps=20)
sigmas_list = sigmas.tolist() sigmas_list = sigmas.tolist()
for i in range(len(sigmas_list) - 1): for i in range(len(sigmas_list) - 1):
assert sigmas_list[i] >= sigmas_list[i + 1], \ assert (
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}" sigmas_list[i] >= sigmas_list[i + 1]
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
def test_scheduler_dtype(self): def test_scheduler_dtype(self):
"""Scheduler should return float32 array.""" """Scheduler should return float32 array."""
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
num_patches = num_frames * height * width num_patches = num_frames * height * width
expected_shape = (batch_size, 3, num_patches, 2) expected_shape = (batch_size, 3, num_patches, 2)
assert positions.shape == expected_shape, \ assert (
f"Expected {expected_shape}, got {positions.shape}" positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_position_grid_dtype(self): def test_position_grid_dtype(self):
"""Position grid should be float32 for RoPE precision.""" """Position grid should be float32 for RoPE precision."""
positions = create_position_grid(1, 5, 16, 24) positions = create_position_grid(1, 5, 16, 24)
assert positions.dtype == mx.float32, \ assert (
f"Expected float32 for RoPE precision, got {positions.dtype}" positions.dtype == mx.float32
), f"Expected float32 for RoPE precision, got {positions.dtype}"
def test_position_grid_batch_size(self): def test_position_grid_batch_size(self):
"""Position grid should respect batch size.""" """Position grid should respect batch size."""
@@ -165,7 +172,9 @@ class TestCFGDelta:
mx.eval(delta) mx.eval(delta)
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0 # Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero" assert (
mx.max(mx.abs(delta)).item() < 1e-6
), "CFG delta with scale=1.0 should be zero"
def test_cfg_delta_formula(self): def test_cfg_delta_formula(self):
"""CFG delta should follow the formula: (scale-1) * (cond - uncond).""" """CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
# Check for common negative quality terms # Check for common negative quality terms
assert "blurry" in prompt_lower, "Should contain 'blurry'" assert "blurry" in prompt_lower, "Should contain 'blurry'"
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \ assert (
"Should contain quality-related terms" "low quality" in prompt_lower or "low contrast" in prompt_lower
), "Should contain quality-related terms"
class TestInputValidation: class TestInputValidation:
@@ -248,15 +258,16 @@ class TestInputValidation:
(30, 33), # 30 -> nearest valid is 33 (30, 33), # 30 -> nearest valid is 33
(35, 33), # 35 -> nearest valid is 33 (35, 33), # 35 -> nearest valid is 33
(40, 41), # 40 -> nearest valid is 41 (40, 41), # 40 -> nearest valid is 41
(1, 1), # 1 is already valid (1, 1), # 1 is already valid
(33, 33), # 33 is already valid (33, 33), # 33 is already valid
] ]
for input_frames, expected in test_cases: for input_frames, expected in test_cases:
if input_frames % 8 != 1: if input_frames % 8 != 1:
adjusted = round((input_frames - 1) / 8) * 8 + 1 adjusted = round((input_frames - 1) / 8) * 8 + 1
assert adjusted == expected, \ assert (
f"Expected {expected} for input {input_frames}, got {adjusted}" adjusted == expected
), f"Expected {expected} for input {input_frames}, got {adjusted}"
class TestDenoiseWithCFGMocked: class TestDenoiseWithCFGMocked:
@@ -277,14 +288,16 @@ class TestTilingDefault:
def test_tiling_default_is_none(self): def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance.""" """Default tiling should be 'none' for performance."""
import inspect import inspect
from mlx_video.generate_dev import generate_video_dev from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev) sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling') tiling_param = sig.parameters.get("tiling")
assert tiling_param is not None assert tiling_param is not None
assert tiling_param.default == "none", \ assert (
f"Expected default tiling='none', got '{tiling_param.default}'" tiling_param.default == "none"
), f"Expected default tiling='none', got '{tiling_param.default}'"
class TestLatentDimensions: class TestLatentDimensions:
@@ -296,8 +309,9 @@ class TestLatentDimensions:
for height, expected_latent_h in test_cases: for height, expected_latent_h in test_cases:
latent_h = height // 32 latent_h = height // 32
assert latent_h == expected_latent_h, \ assert (
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}" latent_h == expected_latent_h
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
def test_latent_width_calculation(self): def test_latent_width_calculation(self):
"""Latent width should be width // 32.""" """Latent width should be width // 32."""
@@ -305,8 +319,9 @@ class TestLatentDimensions:
for width, expected_latent_w in test_cases: for width, expected_latent_w in test_cases:
latent_w = width // 32 latent_w = width // 32
assert latent_w == expected_latent_w, \ assert (
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}" latent_w == expected_latent_w
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
def test_latent_frames_calculation(self): def test_latent_frames_calculation(self):
"""Latent frames should be 1 + (num_frames - 1) // 8.""" """Latent frames should be 1 + (num_frames - 1) // 8."""
@@ -314,8 +329,9 @@ class TestLatentDimensions:
for num_frames, expected_latent_f in test_cases: for num_frames, expected_latent_f in test_cases:
latent_f = 1 + (num_frames - 1) // 8 latent_f = 1 + (num_frames - 1) // 8
assert latent_f == expected_latent_f, \ assert (
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}" latent_f == expected_latent_f
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
def test_num_tokens_calculation(self): def test_num_tokens_calculation(self):
"""Number of tokens should be latent_f * latent_h * latent_w.""" """Number of tokens should be latent_f * latent_h * latent_w."""
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
positions = create_audio_position_grid(batch_size, audio_frames) positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2) expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \ assert (
f"Expected {expected_shape}, got {positions.shape}" positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self): def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32.""" """Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34) positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \ assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self): def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size.""" """Audio position grid should respect batch size."""
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
"""Audio position grid should not contain NaN or Inf.""" """Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34) positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN" assert not mx.any(
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf" mx.isnan(positions)
).item(), "Audio position grid contains NaN"
assert not mx.any(
mx.isinf(positions)
).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames: class TestComputeAudioFrames:
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
audio_33 = compute_audio_frames(33, 24.0) audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0) audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \ assert (
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}" audio_65 > audio_33
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self): def test_audio_frames_formula(self):
"""Audio frames should match expected formula.""" """Audio frames should match expected formula."""

View File

@@ -1,11 +1,9 @@
import pytest
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
from mlx_video.models.ltx_2.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
def create_video_position_grid( def create_video_position_grid(
@@ -20,7 +18,7 @@ def create_video_position_grid(
h_coords = np.arange(0, height) h_coords = np.arange(0, height)
w_coords = np.arange(0, width) w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1 patch_ends = patch_starts + 1
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
scaled = fractional * 2 - 1 # [-1, 1] scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices) # Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :] freqs = (
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
)
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten # (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2) freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims) freqs = freqs.reshape(
freqs.shape[0], freqs.shape[1], -1
) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs) cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs) sin_ref = np.sin(freqs)
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
pad_size = expected - cos_ref.shape[-1] pad_size = expected - cos_ref.shape[-1]
if pad_size > 0: if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis() # Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1) cos_ref = np.concatenate(
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1) [np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
)
sin_ref = np.concatenate(
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
)
B, T, _ = cos_ref.shape B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads dim_per_head = dim // num_heads
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf" assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1] # Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \ assert (
"cos_freq values out of [-1, 1] range" mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \ ), "cos_freq values out of [-1, 1] range"
"sin_freq values out of [-1, 1] range" assert (
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
), "sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self): def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32. """bfloat16 positions should produce different (less precise) results than float32.
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
# The threshold here is intentionally low to catch the issue # The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6 precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold has_precision_loss = (
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
)
# Document the precision loss (this is expected behavior) # Document the precision loss (this is expected behavior)
if has_precision_loss: if has_precision_loss:
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
print(f" Max sin difference: {max_sin_diff:.6e}") print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss # This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \ assert (
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" has_precision_loss
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self): def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first.""" """Verify that double_precision mode converts bfloat16 to float32 first."""
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
# Recommended: create positions in float32 # Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \ assert (
"Position grids should be created in float32 for RoPE precision" positions.dtype == mx.float32
), "Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable # Verify the position values are reasonable
# Temporal positions should be small (seconds) # Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :] temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \ assert (
"Temporal positions should be in seconds (small values)" mx.max(temporal_positions).item() < 100
), "Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels) # Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :] spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :] spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" assert (
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" mx.max(spatial_h).item() > 0
), "Spatial height positions should be positive"
assert (
mx.max(spatial_w).item() > 0
), "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self): def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference. """Regression test: float32 RoPE must closely match a NumPy float64 reference.
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
) )
# NumPy float64 reference # NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
cos_mlx_np = np.array(cos_mlx) cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx) sin_mlx_np = np.array(sin_mlx)
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
# Cosine similarity (flatten for single scalar) # Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten() cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten() ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)) cosine_sim = np.dot(cos_flat, ref_flat) / (
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
)
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa. # float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff). # Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert max_cos_diff < 0.01, \ assert (
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" max_cos_diff < 0.01
assert max_sin_diff < 0.01, \ ), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" assert (
assert cosine_sim > 0.9999, \ max_sin_diff < 0.01
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999" ), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
cosine_sim > 0.9999
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self): def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification. """Regression test for the specific failure mode: high-frequency index amplification.
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
double_precision=False, double_precision=False,
) )
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref)) max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref)) max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0 # Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert max_cos_diff < 0.01, \ assert (
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected" max_cos_diff < 0.01
assert max_sin_diff < 0.01, \ ), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected" assert (
max_sin_diff < 0.01
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved: class TestRoPEInterleaved:
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16) positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict( kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048], dim=128,
use_middle_indices_grid=True, num_attention_heads=32, theta=10000.0,
rope_type=LTXRopeType.SPLIT, double_precision=False, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
) )
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16) positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict( kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048], dim=128,
use_middle_indices_grid=True, num_attention_heads=32, theta=10000.0,
rope_type=LTXRopeType.SPLIT, double_precision=True, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
) )
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
cos_freq, sin_freq = precompute_freqs_cis( cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16, indices_grid=positions_f16,
dim=128, theta=10000.0, max_pos=[20, 2048, 2048], dim=128,
use_middle_indices_grid=True, num_attention_heads=32, theta=10000.0,
rope_type=LTXRopeType.SPLIT, double_precision=False, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
) )
assert cos_freq.dtype == mx.float32 assert cos_freq.dtype == mx.float32
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
def test_ltx2_forces_double_precision_rope_false(self): def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False.""" """LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True) config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert config.double_precision_rope is False, \ assert (
"LTX-2 should force double_precision_rope=False regardless of input" config.double_precision_rope is False
), "LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self): def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True.""" """LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True) config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert config.double_precision_rope is True, \ assert (
"LTX-2.3 should preserve double_precision_rope=True" config.double_precision_rope is True
), "LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self): def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False.""" """LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False) config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert config.double_precision_rope is False, \ assert (
"LTX-2.3 should respect double_precision_rope=False when explicitly set" config.double_precision_rope is False
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self): def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False.""" """LTX-2 default (double_precision_rope not set) should be False."""
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
def test_config_from_dict_ltx2(self): def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False.""" """Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict({ config = LTXModelConfig.from_dict(
"has_prompt_adaln": False, {
"double_precision_rope": True, "has_prompt_adaln": False,
"rope_type": "split", "double_precision_rope": True,
}) "rope_type": "split",
}
)
assert config.double_precision_rope is False assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self): def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope.""" """Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict({ config = LTXModelConfig.from_dict(
"has_prompt_adaln": True, {
"double_precision_rope": True, "has_prompt_adaln": True,
"rope_type": "split", "double_precision_rope": True,
}) "rope_type": "split",
}
)
assert config.double_precision_rope is True assert config.double_precision_rope is True
@@ -496,10 +543,12 @@ class TestRoPESplit:
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2 # dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2) expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \ assert (
f"Expected shape {expected_shape}, got {cos_freq.shape}" cos_freq.shape == expected_shape
assert sin_freq.shape == expected_shape, \ ), f"Expected shape {expected_shape}, got {cos_freq.shape}"
f"Expected shape {expected_shape}, got {sin_freq.shape}" assert (
sin_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,8 +1,8 @@
"""Tests for VAE streaming and chunked conv features.""" """Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import ( from mlx_video.models.ltx_2.video_vae.tiling import (
@@ -50,7 +50,7 @@ class TestChunkedConv:
np.array(out_chunked), np.array(out_chunked),
rtol=1e-5, rtol=1e-5,
atol=1e-5, atol=1e-5,
err_msg="Chunked conv output differs from regular output" err_msg="Chunked conv output differs from regular output",
) )
def test_chunked_conv_small_input_passthrough(self): def test_chunked_conv_small_input_passthrough(self):
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
frames_received = [] frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int): def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({ frames_received.append(
'shape': frames.shape, {
'start_idx': start_idx, "shape": frames.shape,
}) "start_idx": start_idx,
}
)
# Create a mock decoder that just returns scaled input # Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
# Simulate VAE output: upsample 8x temporal, 32x spatial # Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8 out_f = 1 + (f - 1) * 8
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
# All received frames should have correct channel count # All received frames should have correct channel count
for received in frames_received: for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}" assert (
received["shape"][1] == 3
), f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self): def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks.""" """Verify all frames are emitted via callbacks."""
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
for i in range(num_frames): for i in range(num_frames):
all_frame_indices.add(start_idx + i) all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
b, c, f, h, w = x.shape b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8 out_f = 1 + (f - 1) * 8
out_h = h * 32 out_h = h * 32
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
expected_frames = 1 + (12 - 1) * 8 # 89 frames expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted # All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \ assert (
f"Expected {expected_frames} frames, got {len(all_frame_indices)}" len(all_frame_indices) == expected_frames
assert all_frame_indices == set(range(expected_frames)), \ ), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
"Not all frame indices were covered" assert all_frame_indices == set(
range(expected_frames)
), "Not all frame indices were covered"
class TestAutoChunkedConv: class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode.""" """Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [ @pytest.mark.parametrize(
("conservative", True), "tiling_mode,should_enable",
("none", True), [
("auto", True), ("conservative", True),
("default", True), ("none", True),
("spatial", True), ("auto", True),
("aggressive", False), ("default", True),
("temporal", False), ("spatial", True),
]) ("aggressive", False),
("temporal", False),
],
)
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool): def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes.""" """Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial") # The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
use_chunked_conv = tiling_mode in expected_modes use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \ assert (
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}" use_chunked_conv == should_enable
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask: class TestTrapezoidalMask:
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
# Right ramp should be decreasing # Right ramp should be decreasing
right_ramp = mask_np[-8:] right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing" assert np.all(
np.diff(right_ramp) <= 0
), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self): def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0.""" """Verify temporal mask (left_starts_from_0=True) starts from 0."""

View File

@@ -2,24 +2,25 @@
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# RoPE Tests # RoPE Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestRoPE: class TestRoPE:
"""Tests for 3-way factorized RoPE.""" """Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self): def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
freqs = rope_params(1024, 64) freqs = rope_params(1024, 64)
mx.eval(freqs) mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2] assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self): def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
for dim in [32, 64, 128]: for dim in [32, 64, 128]:
freqs = rope_params(512, dim) freqs = rope_params(512, dim)
mx.eval(freqs) mx.eval(freqs)
@@ -27,6 +28,7 @@ class TestRoPE:
def test_rope_params_cos_sin_range(self): def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
freqs = rope_params(256, 64) freqs = rope_params(256, 64)
mx.eval(freqs) mx.eval(freqs)
cos_vals = np.array(freqs[:, :, 0]) cos_vals = np.array(freqs[:, :, 0])
@@ -37,13 +39,15 @@ class TestRoPE:
def test_rope_params_position_zero(self): def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0.""" """At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
freqs = rope_params(10, 64) freqs = rope_params(10, 64)
mx.eval(freqs) mx.eval(freqs)
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6) np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6) np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self): def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_params, rope_apply from mlx_video.models.wan.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D)) x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D) freqs = rope_params(1024, D)
@@ -54,7 +58,8 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self): def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms.""" """RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_params, rope_apply from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 1, 2, 16 B, N, D = 1, 2, 16
F, H, W = 2, 3, 4 F, H, W = 2, 3, 4
L = F * H * W L = F * H * W
@@ -74,7 +79,8 @@ class TestRoPE:
def test_rope_apply_with_padding(self): def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged.""" """When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_params, rope_apply from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 1, 2, 16 B, N, D = 1, 2, 16
F, H, W = 2, 2, 2 F, H, W = 2, 2, 2
seq_len = F * H * W # 8 seq_len = F * H * W # 8
@@ -94,7 +100,8 @@ class TestRoPE:
def test_rope_apply_batch(self): def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes.""" """Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_params, rope_apply from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 2, 2, 16 B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)] grids = [(2, 3, 4), (2, 3, 4)]
L = 2 * 3 * 4 L = 2 * 3 * 4
@@ -122,9 +129,11 @@ class TestRoPE:
# Attention Tests # Attention Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWanRMSNorm: class TestWanRMSNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64) norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
out = norm(x) out = norm(x)
@@ -134,6 +143,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self): def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling.""" """RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64) norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0 x = mx.random.normal((1, 5, 64)) * 10.0
out = norm(x) out = norm(x)
@@ -147,6 +157,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self): def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32.""" """RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(32) norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16) x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
out = norm(x) out = norm(x)
@@ -158,6 +169,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm: class TestWanLayerNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64) norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
out = norm(x) out = norm(x)
@@ -166,6 +178,7 @@ class TestWanLayerNorm:
def test_without_affine(self): def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False) norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64)) x = mx.random.normal((1, 4, 64))
out = norm(x) out = norm(x)
@@ -178,6 +191,7 @@ class TestWanLayerNorm:
def test_with_affine(self): def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True) norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight") assert hasattr(norm, "weight")
assert hasattr(norm, "bias") assert hasattr(norm, "bias")
@@ -196,6 +210,7 @@ class TestWanSelfAttention:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24 B, L = 1, 24
F, H, W = 2, 3, 4 F, H, W = 2, 3, 4
@@ -207,12 +222,14 @@ class TestWanSelfAttention:
def test_with_qk_norm(self): def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None assert attn.norm_q is not None
assert attn.norm_k is not None assert attn.norm_k is not None
def test_without_qk_norm(self): def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None assert attn.norm_q is None
assert attn.norm_k is None assert attn.norm_k is None
@@ -221,6 +238,7 @@ class TestWanSelfAttention:
"""Test that masking works: shorter seq_lens should mask later tokens.""" """Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24 B, L = 1, 24
F, H, W = 2, 3, 4 F, H, W = 2, 3, 4
@@ -245,6 +263,7 @@ class TestWanCrossAttention:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16 B, L_q, L_kv = 1, 24, 16
x = mx.random.normal((B, L_q, self.dim)) x = mx.random.normal((B, L_q, self.dim))
@@ -255,6 +274,7 @@ class TestWanCrossAttention:
def test_with_context_mask(self): def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16 B, L_q, L_kv = 1, 12, 16
x = mx.random.normal((B, L_q, self.dim)) x = mx.random.normal((B, L_q, self.dim))
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
# bfloat16 Autocast Tests # bfloat16 Autocast Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestBFloat16Autocast: class TestBFloat16Autocast:
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16) """Tests that attention and FFN cast inputs to weight dtype (bfloat16)
for efficient matmul, matching official PyTorch autocast behavior.""" for efficient matmul, matching official PyTorch autocast behavior."""
@@ -292,6 +313,7 @@ class TestBFloat16Autocast:
"""Self-attention should cast input to weight dtype for QKV projections.""" """Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -305,6 +327,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self): def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype.""" """Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -318,6 +341,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self): def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype.""" """prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -330,6 +354,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self): def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers.""" """FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(self.dim, 128) ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters())) ffn.update(self._to_bf16(ffn.parameters()))
@@ -343,6 +368,7 @@ class TestBFloat16Autocast:
"""RoPE should be applied in float32 for precision, even with bf16 weights.""" """RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self): def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights.""" """Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True) block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters())) block.update(self._to_bf16(block.parameters()))

View File

@@ -1,17 +1,17 @@
"""Tests for Wan model configuration.""" """Tests for Wan model configuration."""
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Config Tests # Config Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWanModelConfig: class TestWanModelConfig:
"""Tests for WanModelConfig dataclass.""" """Tests for WanModelConfig dataclass."""
def test_default_values(self): def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.dim == 5120 assert config.dim == 5120
assert config.ffn_dim == 13824 assert config.ffn_dim == 13824
@@ -33,11 +33,13 @@ class TestWanModelConfig:
def test_head_dim_property(self): def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40 assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self): def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
d = config.to_dict() d = config.to_dict()
assert isinstance(d, dict) assert isinstance(d, dict)
@@ -47,6 +49,7 @@ class TestWanModelConfig:
def test_t5_config_values(self): def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.t5_vocab_size == 256384 assert config.t5_vocab_size == 256384
assert config.t5_dim == 4096 assert config.t5_dim == 4096
@@ -61,11 +64,13 @@ class TestWanModelConfig:
# Wan2.1 Config Tests # Wan2.1 Config Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWan21Config: class TestWan21Config:
"""Tests for Wan2.1 config presets.""" """Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self): def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1" assert config.model_version == "2.1"
assert config.dual_model is False assert config.dual_model is False
@@ -81,6 +86,7 @@ class TestWan21Config:
def test_wan21_1_3b_factory(self): def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1" assert config.model_version == "2.1"
assert config.dual_model is False assert config.dual_model is False
@@ -93,6 +99,7 @@ class TestWan21Config:
def test_wan22_14b_factory(self): def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b() config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2" assert config.model_version == "2.2"
assert config.dual_model is True assert config.dual_model is True
@@ -104,6 +111,7 @@ class TestWan21Config:
def test_wan21_config_to_dict(self): def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict() d = config.to_dict()
assert d["model_version"] == "2.1" assert d["model_version"] == "2.1"
@@ -112,6 +120,7 @@ class TestWan21Config:
def test_wan21_1_3b_config_to_dict(self): def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict() d = config.to_dict()
assert d["dim"] == 1536 assert d["dim"] == 1536
@@ -120,6 +129,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self): def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B.""" """Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.model_version == "2.2" assert config.model_version == "2.2"
assert config.dual_model is True assert config.dual_model is True

View File

@@ -3,17 +3,16 @@
import logging import logging
import mlx.core as mx import mlx.core as mx
import numpy as np
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Transformer Weight Conversion Tests # Transformer Weight Conversion Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights: class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self): def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)), "patch_embedding.bias": mx.random.normal((5120,)),
@@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights:
def test_text_embedding_rename(self): def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"text_embedding.0.weight": mx.zeros((64, 32)), "text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)), "text_embedding.0.bias": mx.zeros((64,)),
@@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights:
def test_time_embedding_rename(self): def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"time_embedding.0.weight": mx.zeros((64, 32)), "time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)), "time_embedding.2.weight": mx.zeros((64, 64)),
@@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights:
def test_time_projection_rename(self): def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"time_projection.1.weight": mx.zeros((384, 64)), "time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)), "time_projection.1.bias": mx.zeros((384,)),
@@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights:
def test_ffn_rename(self): def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)), "blocks.0.ffn.0.bias": mx.zeros((128,)),
@@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights:
def test_freqs_skipped(self): def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"freqs": mx.zeros((1024, 64, 2)), "freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)), "blocks.0.norm1.weight": mx.zeros((64,)),
@@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights:
def test_passthrough_keys(self): def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)), "blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)), "blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
@@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights:
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = { weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)), "patch_embedding.bias": mx.random.normal((5120,)),
@@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights:
class TestSanitizeT5Weights: class TestSanitizeT5Weights:
def test_gate_rename(self): def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = { weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)), "blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
@@ -133,6 +140,7 @@ class TestSanitizeT5Weights:
def test_passthrough(self): def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = { weights = {
"token_embedding.weight": mx.zeros((100, 64)), "token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)), "blocks.0.attn.q.weight": mx.zeros((64, 64)),
@@ -144,6 +152,7 @@ class TestSanitizeT5Weights:
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = { weights = {
"token_embedding.weight": mx.zeros((100, 64)), "token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -159,6 +168,7 @@ class TestSanitizeT5Weights:
class TestSanitizeVAEWeights: class TestSanitizeVAEWeights:
def test_conv3d_transpose(self): def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = { weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W] "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
} }
@@ -167,6 +177,7 @@ class TestSanitizeVAEWeights:
def test_conv2d_transpose(self): def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = { weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W] "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
} }
@@ -175,6 +186,7 @@ class TestSanitizeVAEWeights:
def test_non_conv_passthrough(self): def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = { weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose "decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)), "decoder.bias": mx.zeros((16,)),
@@ -185,6 +197,7 @@ class TestSanitizeVAEWeights:
def test_mixed_weights(self): def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = { weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D "conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D "conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
@@ -199,6 +212,7 @@ class TestSanitizeVAEWeights:
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = { weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), "decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
# Wan2.1 Conversion Tests # Wan2.1 Conversion Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWan21Convert: class TestWan21Convert:
"""Tests for Wan2.1 conversion support.""" """Tests for Wan2.1 conversion support."""
@@ -222,7 +237,7 @@ class TestWan21Convert:
# Create a Wan2.1-style directory (no low_noise_model subdir) # Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch() (tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1 # The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model" low = tmp_path / "low_noise_model"
assert not low.exists() assert not low.exists()
# Simulates auto detection # Simulates auto detection
@@ -233,7 +248,7 @@ class TestWan21Convert:
"""Auto-detect dual-model directory as Wan2.2.""" """Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir() (tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir() (tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model" low = tmp_path / "low_noise_model"
assert low.exists() assert low.exists()
version = "2.2" if low.exists() else "2.1" version = "2.2" if low.exists() else "2.1"
@@ -242,6 +257,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self): def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1.""" """Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict() d = config.to_dict()
assert d["model_version"] == "2.1" assert d["model_version"] == "2.1"
@@ -254,6 +270,7 @@ class TestWan21Convert:
# Encoder Weight Sanitization Tests # Encoder Weight Sanitization Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestSanitizeEncoderWeights: class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder.""" """Tests for sanitize_wan22_vae_weights with include_encoder."""

View File

@@ -2,15 +2,13 @@
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config from wan_test_helpers import _make_tiny_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass # Integration: end-to-end tiny model forward pass
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestEndToEnd: class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed).""" """End-to-end test with tiny model (no real weights needed)."""
@@ -78,6 +76,7 @@ class TestEndToEnd:
# I2V Mask Tests # I2V Mask Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestI2VMask: class TestI2VMask:
"""Tests for _build_i2v_mask.""" """Tests for _build_i2v_mask."""
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self): def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions.""" """Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask from mlx_video.generate_wan import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2) # TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames # 704x1280 → latent 44x80, t_latent=21 for 81 frames
z_shape = (48, 21, 44, 80) z_shape = (48, 21, 44, 80)
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self): def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma.""" """Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask from mlx_video.generate_wan import _build_i2v_mask
z_shape = (4, 3, 4, 4) z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2) patch_size = (1, 2, 2)
_, mask_tokens = _build_i2v_mask(z_shape, patch_size) _, mask_tokens = _build_i2v_mask(z_shape, patch_size)
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw) first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7) np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7) np.testing.assert_allclose(
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Dimension Alignment Tests # Dimension Alignment Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestDimensionAlignment: class TestDimensionAlignment:
"""Tests for automatic dimension alignment in generate_wan.""" """Tests for automatic dimension alignment in generate_wan."""
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self): def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors.""" """After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
patches, grid_size = model._patchify(vid) patches, grid_size = model._patchify(vid)
mx.eval(patches) mx.eval(patches)
assert patches.ndim == 3 # [1, L, dim] assert patches.ndim == 3 # [1, L, dim]
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2]) assert grid_size == (
t_latent,
h_latent // patch_size[1],
w_latent // patch_size[2],
)
def test_alignment_with_ti2v_config(self): def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32.""" """TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b() config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1] align_h = config.patch_size[1] * config.vae_stride[1]
align_w = config.patch_size[2] * config.vae_stride[2] align_w = config.patch_size[2] * config.vae_stride[2]

View File

@@ -1,9 +1,6 @@
"""Tests for Wan2.2 I2V-14B support.""" """Tests for Wan2.2 I2V-14B support."""
import mlx.core as mx import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config from wan_test_helpers import _make_tiny_config
@@ -145,7 +142,10 @@ class TestModelYParameter:
latents = mx.random.normal((C_noise, F, H, W)) latents = mx.random.normal((C_noise, F, H, W))
y = mx.random.normal((C_y, F, H, W)) y = mx.random.normal((C_y, F, H, W))
t = mx.array([500.0, 500.0]) t = mx.array([500.0, 500.0])
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))] ctx = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((6, config.text_dim)),
]
out = model([latents, latents], t, ctx, seq_len, y=[y, y]) out = model([latents, latents], t, ctx, seq_len, y=[y, y])
mx.eval(out[0], out[1]) mx.eval(out[0], out[1])
@@ -160,7 +160,9 @@ class TestVAEEncoder:
def test_encoder3d_instantiation(self): def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2) enc = Encoder3d(
dim=32, z_dim=8
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
assert enc.conv1 is not None assert enc.conv1 is not None
assert len(enc.downsamples) > 0 assert len(enc.downsamples) > 0
assert len(enc.middle) == 3 assert len(enc.middle) == 3
@@ -199,10 +201,10 @@ class TestVAEEncoder:
from mlx_video.models.wan.vae import WanVAE from mlx_video.models.wan.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False) vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, 'encoder') assert not hasattr(vae_no_enc, "encoder")
vae_enc = WanVAE(z_dim=4, encoder=True) vae_enc = WanVAE(z_dim=4, encoder=True)
assert hasattr(vae_enc, 'encoder') assert hasattr(vae_enc, "encoder")
class TestResampleDownsample: class TestResampleDownsample:
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
# Build mask following reference logic # Build mask following reference logic
msk = mx.ones((1, num_frames, h_latent, w_latent)) msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
t_latent = (num_frames - 1) // 4 + 1 # = 3 t_latent = (num_frames - 1) // 4 + 1 # = 3
msk = mx.ones((1, num_frames, h_latent, w_latent)) msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] msk = msk.transpose(0, 2, 1, 3, 4)[0]
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
config = _make_tiny_i2v_config() config = _make_tiny_i2v_config()
config.vae_z_dim = 16 config.vae_z_dim = 16
config.out_dim = 16 # must match VAE z_dim for decode config.out_dim = 16 # must match VAE z_dim for decode
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36 config.in_dim = (
16 + 4 + 16
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
model = WanModel(config) model = WanModel(config)
# --- Tiny VAE (with encoder) --- # --- Tiny VAE (with encoder) ---
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width)) img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W] # Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
video = mx.concatenate([ video = mx.concatenate(
img, [
mx.zeros((1, 3, num_frames - 1, height, width)), img,
], axis=2) mx.zeros((1, 3, num_frames - 1, height, width)),
],
axis=2,
)
# --- VAE encode --- # --- VAE encode ---
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat] z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
# --- Build I2V mask (4 channels) --- # --- Build I2V mask (4 channels) ---
msk = mx.ones((1, num_frames, h_latent, w_latent)) msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
noise_pred_cond, noise_pred_uncond = preds[0], preds[1] noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents) mx.eval(latents)
# With shift=5.0, early timesteps should be high (>=900), later ones low # With shift=5.0, early timesteps should be high (>=900), later ones low
@@ -461,9 +476,9 @@ class TestDualModelSwitching:
assert len(low_used_steps) > 0, "Low-noise model was never selected" assert len(low_used_steps) > 0, "Low-noise model was never selected"
# High-noise steps should come before low-noise steps (timesteps decrease) # High-noise steps should come before low-noise steps (timesteps decrease)
if high_used_steps and low_used_steps: if high_used_steps and low_used_steps:
assert max(high_used_steps) < min(low_used_steps) or \ assert max(high_used_steps) < min(low_used_steps) or min(
min(high_used_steps) < max(low_used_steps), \ high_used_steps
"Model switching should happen during the loop" ) < max(low_used_steps), "Model switching should happen during the loop"
assert latents.shape == (C_noise, F, H, W) assert latents.shape == (C_noise, F, H, W)
assert not mx.any(mx.isnan(latents)).item() assert not mx.any(mx.isnan(latents)).item()
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
y=[y_i2v, y_i2v], y=[y_i2v, y_i2v],
) )
noise_pred = pred[1] + gs * (pred[0] - pred[1]) noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents) mx.eval(latents)
# Verify both guide scales were used # Verify both guide scales were used

View File

@@ -4,7 +4,6 @@ import tempfile
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
import numpy as np
import pytest import pytest
@@ -40,7 +39,9 @@ class TestLoRATypes:
lora_a = mx.ones((2, 4)) lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2)) lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
applied = AppliedLoRA(weights=w, strength=0.5) applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta() delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones) # scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
@@ -51,7 +52,9 @@ class TestLoRATypes:
class TestLoRALoader: class TestLoRALoader:
"""Test LoRA weight loading from safetensors.""" """Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"): def _make_lora_file(
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
):
"""Helper to create a mock LoRA safetensors file.""" """Helper to create a mock LoRA safetensors file."""
weights = {} weights = {}
for name in module_names: for name in module_names:
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
"""Simulate typical Wan2.2 MLX model weight keys.""" """Simulate typical Wan2.2 MLX model weight keys."""
keys = set() keys = set()
for i in range(2): for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o", for layer in [
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]: "self_attn.q",
"self_attn.k",
"self_attn.v",
"self_attn.o",
"cross_attn.q",
"cross_attn.k",
"cross_attn.v",
"cross_attn.o",
]:
keys.add(f"blocks.{i}.{layer}.weight") keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight") keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight") keys.add(f"blocks.{i}.ffn.fc2.weight")
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys() keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q" assert (
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
== "blocks.0.self_attn.q"
)
def test_strip_diffusion_model_prefix(self): def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys() keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys) result = _normalize_wan_lora_key(
"model.diffusion_model.blocks.0.self_attn.k", keys
)
assert result == "blocks.0.self_attn.k" assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self): def test_ffn_key_mapping(self):
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys() keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj" assert (
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
)
def test_combined_prefix_and_ffn(self): def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -219,7 +237,9 @@ class TestApplyLoRA:
# LoRA weights in float32 (typical when loaded from safetensors) # LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)]) result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}" assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
@@ -230,7 +250,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4), dtype=mx.float16) original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)]) result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}" assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
@@ -241,7 +263,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4)) original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1 lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1 lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)]) result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4) # delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4)) expected = original + 0.02 * mx.ones((8, 4))
@@ -255,12 +279,16 @@ class TestApplyLoRA:
w1 = LoRAWeights( w1 = LoRAWeights(
lora_A=mx.ones((2, 4)), lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)), lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a", rank=2,
alpha=2.0,
module_name="a",
) )
w2 = LoRAWeights( w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2, lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2, lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b", rank=2,
alpha=4.0,
module_name="b",
) )
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)]) result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4) # w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
@@ -282,7 +310,9 @@ class TestApplyLoRA:
w = LoRAWeights( w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01, lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01, lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q", rank=4,
alpha=4.0,
module_name="blocks.0.self_attn.q",
) )
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]} module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras) result = apply_loras_to_weights(model_weights, module_to_loras)
@@ -319,9 +349,7 @@ class TestEndToEnd:
"blocks.0.self_attn.k.weight": mx.ones((128, 64)), "blocks.0.self_attn.k.weight": mx.ones((128, 64)),
} }
result = load_and_apply_loras( result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
model_weights, [(str(lora_path), 1.0)]
)
# q weight should be modified, k unchanged # q weight should be modified, k unchanged
assert not mx.array_equal( assert not mx.array_equal(

View File

@@ -3,18 +3,17 @@
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config from wan_test_helpers import _make_tiny_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Sinusoidal Embedding Tests # Sinusoidal Embedding Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestSinusoidalEmbedding: class TestSinusoidalEmbedding:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32) pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos) emb = sinusoidal_embedding_1d(256, pos)
mx.eval(emb) mx.eval(emb)
@@ -23,6 +22,7 @@ class TestSinusoidalEmbedding:
def test_position_zero(self): def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0.""" """Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0]) pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos) emb = sinusoidal_embedding_1d(64, pos)
mx.eval(emb) mx.eval(emb)
@@ -34,6 +34,7 @@ class TestSinusoidalEmbedding:
def test_different_positions_differ(self): def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0]) pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos) emb = sinusoidal_embedding_1d(128, pos)
mx.eval(emb) mx.eval(emb)
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
# Head Tests # Head Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestHead: class TestHead:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.model import Head from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24 B, L = 1, 24
x = mx.random.normal((B, L, 64)) x = mx.random.normal((B, L, 64))
@@ -60,6 +63,7 @@ class TestHead:
def test_modulation_shape(self): def test_modulation_shape(self):
from mlx_video.models.wan.model import Head from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64) assert head.modulation.shape == (1, 2, 64)
@@ -68,12 +72,14 @@ class TestHead:
# WanModel (Tiny) Tests # WanModel (Tiny) Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWanModel: class TestWanModel:
def setup_method(self): def setup_method(self):
mx.random.seed(42) mx.random.seed(42)
def test_instantiation(self): def test_instantiation(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters())) num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
@@ -81,6 +87,7 @@ class TestWanModel:
def test_patchify_shape(self): def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
# Input: [C=4, F=1, H=4, W=4] # Input: [C=4, F=1, H=4, W=4]
@@ -93,6 +100,7 @@ class TestWanModel:
def test_patchify_various_sizes(self): def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]: for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
@@ -108,6 +116,7 @@ class TestWanModel:
def test_unpatchify_inverse(self): def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims.""" """Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
C, F, H, W = config.in_dim, 2, 4, 6 C, F, H, W = config.in_dim, 2, 4, 6
@@ -123,6 +132,7 @@ class TestWanModel:
def test_forward_pass(self): def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4 C, F, H, W = config.in_dim, 1, 4, 4
@@ -140,6 +150,7 @@ class TestWanModel:
def test_forward_batch(self): def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4 C, F, H, W = config.in_dim, 1, 4, 4
@@ -148,7 +159,10 @@ class TestWanModel:
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))] x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
t = mx.array([500.0, 200.0]) t = mx.array([500.0, 200.0])
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))] context = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((4, config.text_dim)),
]
out = model(x_list, t, context, seq_len) out = model(x_list, t, context, seq_len)
mx.eval(out[0], out[1]) mx.eval(out[0], out[1])
@@ -158,12 +172,17 @@ class TestWanModel:
def test_output_is_float32(self): def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4 C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2) seq_len = (F // 1) * (H // 2) * (W // 2)
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]), out = model(
[mx.random.normal((4, config.text_dim))], seq_len) [mx.random.normal((C, F, H, W))],
mx.array([100.0]),
[mx.random.normal((4, config.text_dim))],
seq_len,
)
mx.eval(out[0]) mx.eval(out[0])
assert out[0].dtype == mx.float32 assert out[0].dtype == mx.float32
@@ -172,6 +191,7 @@ class TestWanModel:
# Wan2.1 Model Tests # Wan2.1 Model Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWan21Model: class TestWan21Model:
"""Test tiny Wan2.1-style model (single model mode).""" """Test tiny Wan2.1-style model (single model mode)."""
@@ -181,6 +201,7 @@ class TestWan21Model:
def _make_tiny_wan21_config(self): def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model).""" """Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values # Override to tiny values
config.dim = 64 config.dim = 64
@@ -197,6 +218,7 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self): def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B.""" """Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads) # Override to tiny values (preserve 1.3B head structure: 12 heads)
config.dim = 48 config.dim = 48
@@ -271,7 +293,9 @@ class TestWan21Model:
for i in range(3): for i in range(3):
t = sched.timesteps[i] t = sched.timesteps[i]
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0] pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0] pred_uncond = model(
[latents], mx.array([t.item()]), [context_null], seq_len
)[0]
pred = pred_uncond + gs * (pred_cond - pred_uncond) pred = pred_uncond + gs * (pred_cond - pred_uncond)
latents = sched.step(pred[None], t, latents[None]).squeeze(0) latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents) mx.eval(latents)
@@ -304,6 +328,7 @@ class TestWan21Model:
# Per-Token Timestep Tests # Per-Token Timestep Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestPerTokenTimestep: class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding.""" """Tests for per-token sinusoidal embedding."""

View File

@@ -1,22 +1,22 @@
"""Tests for Wan model quantization pipeline.""" """Tests for Wan model quantization pipeline."""
import json import json
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.utils import mlx.utils
import numpy as np import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config from wan_test_helpers import _make_tiny_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Quantize Predicate Tests # Quantize Predicate Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestQuantizePredicate: class TestQuantizePredicate:
def test_matches_self_attention_layers(self): def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]: for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}" path = f"blocks.0.self_attn.{suffix}"
@@ -24,6 +24,7 @@ class TestQuantizePredicate:
def test_matches_cross_attention_layers(self): def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]: for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}" path = f"blocks.0.cross_attn.{suffix}"
@@ -31,23 +32,31 @@ class TestQuantizePredicate:
def test_matches_ffn_layers(self): def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self): def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]: for path in [
"patch_embedding_proj",
"text_embedding_fc1",
"time_embedding.fc1",
]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self): def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64) mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self): def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64) mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized # Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm) assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
@@ -55,13 +64,19 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self): def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted.""" """Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
patterns = [ patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k", "blocks.0.self_attn.q",
"blocks.0.self_attn.v", "blocks.0.self_attn.o", "blocks.0.self_attn.k",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k", "blocks.0.self_attn.v",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o", "blocks.0.self_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2", "blocks.0.cross_attn.q",
"blocks.0.cross_attn.k",
"blocks.0.cross_attn.v",
"blocks.0.cross_attn.o",
"blocks.0.ffn.fc1",
"blocks.0.ffn.fc2",
] ]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)] matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10 assert len(matched) == 10
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
# Quantize Round-Trip Tests # Quantize Round-Trip Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestQuantizeRoundTrip: class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path.""" """Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config) model = WanModel(config)
nn.quantize( nn.quantize(
@@ -101,8 +117,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, config, model_path,
config,
quantization={"bits": 4, "group_size": 64}, quantization={"bits": 4, "group_size": 64},
) )
@@ -119,8 +137,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, config, model_path,
config,
quantization={"bits": 8, "group_size": 64}, quantization={"bits": 8, "group_size": 64},
) )
@@ -132,8 +152,10 @@ class TestQuantizeRoundTrip:
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, config, model_path,
config,
quantization={"bits": 4, "group_size": 64}, quantization={"bits": 4, "group_size": 64},
) )
@@ -151,6 +173,7 @@ class TestQuantizeRoundTrip:
mx.save_safetensors(str(model_path), weights_dict) mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None) loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear) assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
# Quantized Inference Tests # Quantized Inference Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestQuantizedInference: class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4): def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config) model = WanModel(config)
nn.quantize( nn.quantize(
@@ -214,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self): def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights.""" """Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
mx.random.seed(42) mx.random.seed(42)
@@ -243,11 +267,12 @@ class TestQuantizedInference:
# Config Metadata Tests # Config Metadata Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestQuantizationConfig: class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path): def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json.""" """Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64 assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path): def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path): def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files.""" """Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config() config = _make_tiny_config()

View File

@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
d = 128 # head_dim for all Wan models d = 128 # head_dim for all Wan models
# Reference: three separate calls # Reference: three separate calls
correct = mx.concatenate([ correct = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Wrong: single call # Wrong: single call
wrong = rope_params(1024, d) wrong = rope_params(1024, d)
mx.eval(correct, wrong) mx.eval(correct, wrong)
assert correct.shape == wrong.shape assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max() diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}" assert (
diff > 0.1
), f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self): def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0. """Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs) mx.eval(freqs)
f = np.array(freqs) f = np.array(freqs)
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1) # At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq # Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5, np.testing.assert_allclose(
err_msg="temporal[0] cos at pos 1") f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
)
# Height axis first freq (starts at index d_t) # Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5, np.testing.assert_allclose(
err_msg="height[0] cos at pos 1") f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
)
# Width axis first freq (starts at index d_t + d_h) # Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, np.testing.assert_allclose(
err_msg="width[0] cos at pos 1") f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
)
def test_height_width_frequencies_identical(self): def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables. """Height and width axes should have identical frequency tables.
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
d = 128 d = 128
d_h_dim = 2 * (d // 6) # 42 d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, d_h_dim), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim), rope_params(1024, d_h_dim),
], axis=1) rope_params(1024, d_h_dim),
],
axis=1,
)
mx.eval(freqs) mx.eval(freqs)
f = np.array(freqs) f = np.array(freqs)
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
d_t = half_d - 2 * (half_d // 3) d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3 d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h] height_freqs = f[:, d_t : d_t + d_h]
width_freqs = f[:, d_t + d_h:] width_freqs = f[:, d_t + d_h :]
np.testing.assert_array_equal(height_freqs, width_freqs) np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self): def test_frequency_range_per_axis(self):
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs) mx.eval(freqs)
f = np.array(freqs) f = np.array(freqs)
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
pos1_h = f[1, d_t, 0] # height first freq pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}" assert (
pos1_t > 0.5
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}" assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}" assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4) freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16 d = head_dim # 16
freqs_manual = mx.concatenate([ freqs_manual = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs_model, freqs_manual) mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal( np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual), np.array(freqs_model),
err_msg="WanModel.freqs should use three-call construction" np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction",
) )
def test_model_freqs_14b_dimensions(self): def test_model_freqs_14b_dimensions(self):
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs [
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1) rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
],
axis=1,
)
mx.eval(freqs) mx.eval(freqs)
assert freqs.shape == (1024, 64, 2) assert freqs.shape == (1024, 64, 2)
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
@pytest.fixture @pytest.fixture
def has_torch(self): def has_torch(self):
try: try:
import torch pass
return True return True
except ImportError: except ImportError:
pytest.skip("PyTorch not installed") pytest.skip("PyTorch not installed")
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
def test_freqs_match_pytorch_reference(self, has_torch): def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables.""" """Numerically compare MLX and PyTorch frequency tables."""
import torch import torch
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
d = 128 d = 128
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
def pt_rope_params(max_seq_len, dim, theta=10000): def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer( freqs = torch.outer(
torch.arange(max_seq_len), torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) 1.0
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs) freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs return freqs
ref = torch.cat([ ref = torch.cat(
pt_rope_params(1024, d - 4 * (d // 6)), [
pt_rope_params(1024, 2 * (d // 6)), pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)), pt_rope_params(1024, 2 * (d // 6)),
], dim=1) pt_rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# MLX # MLX
ours = mx.concatenate([ ours = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(ours) mx.eval(ours)
our_cos = np.array(ours[:, :, 0]) our_cos = np.array(ours[:, :, 0])
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
ref_cos = ref.real.float().numpy() ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy() ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6, np.testing.assert_allclose(
err_msg="cos mismatch vs PyTorch reference") our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6, )
err_msg="sin mismatch vs PyTorch reference") np.testing.assert_allclose(
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
)
class TestRoPEApplyWithCorrectFreqs: class TestRoPEApplyWithCorrectFreqs:
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug: This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions. height/width frequencies were too low to distinguish nearby positions.
""" """
from mlx_video.models.wan.rope import rope_params, rope_apply from mlx_video.models.wan.rope import rope_apply, rope_params
d = 128 d = 128
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 1, 4 B, N = 1, 4
F, H, W = 1, 4, 4 F, H, W = 1, 4, 4
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04 # Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3. # and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, ( assert (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}" height_diff > 0.5
) ), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
assert width_diff > 0.5, ( assert (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}" width_diff > 0.5
) ), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
# Height and width should have identical frequency tables → same diffs # Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5, np.testing.assert_allclose(
err_msg="Height and width should use identical frequency tables") height_diff,
width_diff,
rtol=1e-5,
err_msg="Height and width should use identical frequency tables",
)
def test_precomputed_matches_online(self): def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path.""" """rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
) )
d = 128 d = 128
freqs = mx.concatenate([ freqs = mx.concatenate(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], axis=1) rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 2, 4 B, N = 2, 4
F, H, W = 2, 3, 4 F, H, W = 2, 3, 4
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
mx.eval(out_online, out_precomp) mx.eval(out_online, out_precomp)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5, np.array(out_online),
err_msg="Precomputed and online RoPE should match" np.array(out_precomp),
atol=1e-5,
err_msg="Precomputed and online RoPE should match",
) )

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import numpy as np import numpy as np
import pytest import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Euler Scheduler Tests # Euler Scheduler Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler: class TestFlowMatchEulerScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
assert sched.timesteps is None assert sched.timesteps is None
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0) sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
def test_timesteps_decreasing(self): def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0) sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps) mx.eval(sched.timesteps)
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
def test_sigmas_decreasing(self): def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0) sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas) mx.eval(sched.sigmas)
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
def test_terminal_sigma_is_zero(self): def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas) mx.eval(sched.sigmas)
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self): def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values.""" """Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler() sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler() sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0) sched1.set_timesteps(20, shift=1.0)
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
def test_step_euler(self): def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas) mx.eval(sched.sigmas)
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
# Euler: x_next = x + (sigma_next - sigma) * v # Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5 expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4, np.array(result).flatten()[0],
expected,
rtol=1e-4,
) )
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0 assert sched._step_index == 0
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
def test_reset(self): def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) sample = mx.ones((1, 1, 1, 1, 1))
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50]) @pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0) sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self): def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged.""" """Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2)) sample = mx.ones((1, 2, 1, 2, 2))
@@ -141,22 +154,26 @@ class TestComputeSigmas:
def test_length(self): def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self): def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0) sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0 assert sigmas[-1] == 0.0
def test_starts_near_one(self): def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0) # Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3) np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self): def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0) assert np.all(np.diff(sigmas) <= 0)
@@ -169,6 +186,7 @@ class TestComputeSigmas:
shift is applied only once (single-shift). shift is applied only once (single-shift).
""" """
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000 steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N) sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once # Official single-shift: unshifted bounds, then shift once
@@ -183,6 +201,7 @@ class TestComputeSigmas:
def test_shift_one_is_near_linear(self): def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0) sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule) # With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0 # so schedule is nearly linear from ~0.999 to 0
@@ -196,6 +215,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
) )
scheds = [ scheds = [
FlowMatchEulerScheduler(1000), FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000), FlowDPMPP2MScheduler(1000),
@@ -214,6 +234,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
) )
scheds = [ scheds = [
FlowMatchEulerScheduler(1000), FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000), FlowDPMPP2MScheduler(1000),
@@ -235,12 +256,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler: class TestFlowDPMPP2MScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True assert sched.lower_order_final is True
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2)) sample = mx.ones((1, 4, 1, 2, 2))
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
def test_reset(self): def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) sample = mx.ones((1, 1, 1, 1, 1))
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self): def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output.""" """Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2)) sample = mx.ones((1, 2, 1, 2, 2))
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self): def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available).""" """First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4)) sample = mx.random.normal((1, 4, 2, 4, 4))
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self): def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction.""" """After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2)) sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
x0_after_second = sched._prev_x0 x0_after_second = sched._prev_x0
assert x0_after_second is not None assert x0_after_second is not None
# The stored x0 should differ from the first step's # The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6) assert not np.allclose(
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
)
def test_denoise_to_target(self): def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver.""" """Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4)) target = mx.zeros((1, 2, 1, 4, 4))
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50]) @pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0) sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self): def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly.""" """When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0 sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler: class TestFlowUniPCScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2 assert sched.solver_order == 2
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0) sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) sample = mx.ones((1, 1, 1, 1, 1))
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
def test_reset(self): def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) sample = mx.ones((1, 1, 1, 1, 1))
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
def test_full_loop_finite(self): def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2)) sample = mx.ones((1, 2, 1, 2, 2))
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self): def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history).""" """First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True) sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2)) sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self): def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled.""" """Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True) sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4)) sample = mx.random.normal((1, 2, 1, 4, 4))
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
def test_denoise_to_target(self): def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4)) target = mx.zeros((1, 2, 1, 4, 4))
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50]) @pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0) sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas) mx.eval(sched.timesteps, sched.sigmas)
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self): def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error.""" """Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2)) sample = mx.ones((1, 1, 1, 2, 2))
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self): def test_solver_order_3(self):
"""Order 3 should work without error.""" """Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2)) sample = mx.random.normal((1, 2, 1, 2, 2))
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
def test_corrector_rhos_c_not_hardcoded(self): def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5.""" """Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5: # For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5 # rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
rhos_c = np.linalg.solve(R, b) rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5 # History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large" assert (
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive" rhos_c[0] < 0.15
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert (
rhos_c[0] > 0.0
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5 # D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range" assert (
0.3 < rhos_c[1] < 0.5
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Scheduler Coherence Tests # Scheduler Coherence Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestSchedulerCoherence: class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results. """Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
results[name] = np.array(r) results[name] = np.array(r)
np.testing.assert_allclose( np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5, results["dpm++"],
results["euler"],
atol=1e-5,
err_msg="DPM++ step 0 should match Euler", err_msg="DPM++ step 0 should match Euler",
) )
np.testing.assert_allclose( np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5, results["unipc"],
results["euler"],
atol=1e-5,
err_msg="UniPC step 0 should match Euler", err_msg="UniPC step 0 should match Euler",
) )
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise) unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r) mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5, np.array(dpm_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}", err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
) )
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5, np.array(unipc_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}", err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
) )
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
latents = sched.step(v, sched.timesteps[i], latents) latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents) mx.eval(latents)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3, np.array(latents),
0.0,
atol=1e-3,
err_msg=f"{name} did not converge to target with oracle", err_msg=f"{name} did not converge to target with oracle",
) )
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
# Higher-order solvers should not be significantly worse than Euler # Higher-order solvers should not be significantly worse than Euler
# (add small epsilon to handle near-zero errors from floating point noise) # (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6 eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, ( assert (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}" errors["dpm++"] <= errors["euler"] * 1.5 + eps
) ), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, ( assert (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}" errors["unipc"] <= errors["euler"] * 1.5 + eps
) ), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
def test_multistep_trajectory_similar_magnitude(self): def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers """Over a full denoising loop with constant velocity, all solvers
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
# All solvers should produce results within the same order of magnitude # All solvers should produce results within the same order of magnitude
vals = list(final_means.values()) vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10) ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, ( assert (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}" ratio < 10.0
) ), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
def test_intermediate_values_finite(self): def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers.""" """Every intermediate latent value must be finite for all solvers."""
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
vel = mx.random.normal(shape) vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents) latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents) mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), ( assert np.isfinite(
f"{name} produced non-finite values at step {i}" np.array(latents)
) ).all(), f"{name} produced non-finite values at step {i}"
def test_lambda_boundary_values(self): def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
) )
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, ( assert (
f"{cls.__name__}._lambda(1.0) should be -inf" cls._lambda(1.0) == -math.inf
) ), f"{cls.__name__}._lambda(1.0) should be -inf"
assert cls._lambda(0.0) == math.inf, ( assert (
f"{cls.__name__}._lambda(0.0) should be +inf" cls._lambda(0.0) == math.inf
) ), f"{cls.__name__}._lambda(0.0) should be +inf"
# Interior values should be finite # Interior values should be finite
lam = cls._lambda(0.5) lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, ( assert (
f"{cls.__name__}._lambda(0.5) should be 0.0" math.isfinite(lam) and lam == 0.0
) ), f"{cls.__name__}._lambda(0.5) should be 0.0"
def test_lambda_monotonically_decreasing(self): def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
result = scheds[name].step(vel, scheds[name].timesteps[0], sample) result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result) mx.eval(result)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(result), np.array(expected), atol=5e-4, np.array(result),
np.array(expected),
atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})", err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
) )
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
results[name] = np.array(r) results[name] = np.array(r)
np.testing.assert_allclose( np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5, results["dpm++"],
results["euler"],
atol=1e-5,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5, results["unipc"],
results["euler"],
atol=1e-5,
) )
def test_dpmpp_unipc_agree_on_step1(self): def test_dpmpp_unipc_agree_on_step1(self):
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2) shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape) noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls() sched = cls()
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
mx.eval(latents) mx.eval(latents)
result2 = np.array(latents) result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5, np.testing.assert_allclose(
err_msg=f"{cls.__name__} not reproducible after reset()") result1,
result2,
atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# UniPC Corrector Default Tests # UniPC Corrector Default Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestUniPCCorrectorDefault: class TestUniPCCorrectorDefault:
"""Tests that the UniPC corrector is enabled by default, """Tests that the UniPC corrector is enabled by default,
matching official FlowUniPCMultistepScheduler behavior.""" matching official FlowUniPCMultistepScheduler behavior."""
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self): def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled.""" """Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
assert sched._use_corrector is True assert sched._use_corrector is True
def test_corrector_affects_output(self): def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1.""" """Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42) mx.random.seed(42)
shape = (1, 4, 1, 4, 4) shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape) noise = mx.random.normal(shape)
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self): def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting.""" """Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42) mx.random.seed(42)
shape = (1, 4, 1, 4, 4) shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape) noise = mx.random.normal(shape)

View File

@@ -3,16 +3,16 @@
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# T5 Encoder Tests # T5 Encoder Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestT5LayerNorm: class TestT5LayerNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(64) norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
out = norm(x) out = norm(x)
@@ -22,6 +22,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self): def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1.""" """After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(128) norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0 x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x) out = norm(x)
@@ -35,6 +36,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding: class TestT5RelativeEmbedding:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10) out = rel_emb(10, 10)
mx.eval(out) mx.eval(out)
@@ -42,6 +44,7 @@ class TestT5RelativeEmbedding:
def test_asymmetric_lengths(self): def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12) out = rel_emb(8, 12)
mx.eval(out) mx.eval(out)
@@ -50,6 +53,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self): def test_symmetry(self):
"""Position bias should have structure (not all zeros/random).""" """Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6) out = rel_emb(6, 6)
mx.eval(out) mx.eval(out)
@@ -64,6 +68,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention: class TestT5Attention:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
out = attn(x) out = attn(x)
@@ -73,12 +78,14 @@ class TestT5Attention:
def test_no_scaling(self): def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure.""" """T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention) # No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale") assert not hasattr(attn, "scale")
def test_with_position_bias(self): def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4) rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
@@ -89,6 +96,7 @@ class TestT5Attention:
def test_with_mask(self): def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10)) mask = mx.ones((1, 10))
@@ -101,6 +109,7 @@ class TestT5Attention:
class TestT5FeedForward: class TestT5FeedForward:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256) ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
out = ffn(x) out = ffn(x)
@@ -110,6 +119,7 @@ class TestT5FeedForward:
def test_gated_structure(self): def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x).""" """T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64) ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj") assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1") assert hasattr(ffn, "fc1")
@@ -122,9 +132,16 @@ class TestT5Encoder:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, vocab_size=100,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
) )
ids = mx.array([[1, 5, 10, 0, 0]]) ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]]) mask = mx.array([[1, 1, 1, 0, 0]])
@@ -134,9 +151,16 @@ class TestT5Encoder:
def test_shared_pos(self): def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, vocab_size=100,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True, dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=True,
) )
assert encoder.pos_embedding is not None assert encoder.pos_embedding is not None
for block in encoder.blocks: for block in encoder.blocks:
@@ -144,9 +168,16 @@ class TestT5Encoder:
def test_per_layer_pos(self): def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, vocab_size=100,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
) )
assert encoder.pos_embedding is None assert encoder.pos_embedding is None
for block in encoder.blocks: for block in encoder.blocks:
@@ -154,18 +185,32 @@ class TestT5Encoder:
def test_param_count(self): def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, vocab_size=100,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
) )
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters())) num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0 assert num_params > 0
def test_without_mask(self): def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128, vocab_size=100,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False, dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
) )
ids = mx.array([[1, 5, 10]]) ids = mx.array([[1, 5, 10]])
out = encoder(ids) out = encoder(ids)

View File

@@ -2,13 +2,11 @@
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
from mlx_video.models.ltx.video_vae.tiling import ( from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig, TilingConfig,
decode_with_tiling, decode_with_tiling,
split_in_spatial, split_in_spatial,
split_in_temporal,
) )
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
# Causal: 1 + (4-1)*4 = 13 # Causal: 1 + (4-1)*4 = 13
out_causal = decode_with_tiling( out_causal = decode_with_tiling(
dummy_decoder_causal, latents, config, dummy_decoder_causal,
spatial_scale=scale, temporal_scale=scale, causal_temporal=True, latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=True,
) )
mx.eval(out_causal) mx.eval(out_causal)
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13 assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
# Non-causal: 4*4 = 16 # Non-causal: 4*4 = 16
out_noncausal = decode_with_tiling( out_noncausal = decode_with_tiling(
dummy_decoder_noncausal, latents, config, dummy_decoder_noncausal,
spatial_scale=scale, temporal_scale=scale, causal_temporal=False, latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=False,
) )
mx.eval(out_noncausal) mx.eval(out_noncausal)
assert out_noncausal.shape[2] == t * scale # 16 assert out_noncausal.shape[2] == t * scale # 16
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled) mx.eval(out_tiled)
# Both should produce the same shape # Both should produce the same shape
assert out_regular.shape == out_tiled.shape, ( assert (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" out_regular.shape == out_tiled.shape
) ), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self): def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as __call__.""" """When input is smaller than tile size, decode_tiled should produce same output as __call__."""
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled) mx.eval(out_tiled)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled), np.array(out_regular),
rtol=1e-4, atol=1e-4, np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs", err_msg="Tiled decode should match regular decode for small inputs",
) )
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled) mx.eval(out_tiled)
assert out_regular.shape == out_tiled.shape, ( assert (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" out_regular.shape == out_tiled.shape
) ), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self): def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as decode.""" """When input is smaller than tile size, decode_tiled should produce same output as decode."""
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
mx.eval(out_tiled) mx.eval(out_tiled)
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled), np.array(out_regular),
rtol=1e-4, atol=1e-4, np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs", err_msg="Tiled decode should match regular decode for small inputs",
) )
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
from mlx_video.models.wan.vae import Decoder3d from mlx_video.models.wan.vae import Decoder3d
# Small decoder for fast test # Small decoder for fast test
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1, dec = Decoder3d(
temporal_upsample=[True, True, False]) dim=16,
z_dim=4,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temporal_upsample=[True, True, False],
)
mx.eval(dec.parameters()) mx.eval(dec.parameters())
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3 x = mx.random.normal((1, 4, 3, 4, 4)) # T=3

View File

@@ -2,16 +2,16 @@
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Transformer Block Tests # Transformer Block Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestWanFFN: class TestWanFFN:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(64, 256) ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
out = ffn(x) out = ffn(x)
@@ -21,6 +21,7 @@ class TestWanFFN:
def test_gelu_activation(self): def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity).""" """FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(32, 128) ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0 x = mx.ones((1, 1, 32)) * 2.0
out1 = ffn(x) out1 = ffn(x)
@@ -39,10 +40,13 @@ class TestWanAttentionBlock:
self.num_heads = 4 self.num_heads = 4
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads, self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=True, cross_attn_norm=True,
) )
B, L = 1, 24 B, L = 1, 24
@@ -53,37 +57,49 @@ class TestWanAttentionBlock:
freqs = rope_params(1024, self.dim // self.num_heads) freqs = rope_params(1024, self.dim // self.num_heads)
out = block( out = block(
x, e, seq_lens=[L], grid_sizes=[(F, H, W)], x,
freqs=freqs, context=context, e,
seq_lens=[L],
grid_sizes=[(F, H, W)],
freqs=freqs,
context=context,
) )
mx.eval(out) mx.eval(out)
assert out.shape == (B, L, self.dim) assert out.shape == (B, L, self.dim)
def test_modulation_shape(self): def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim) assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self): def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads, self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=True, cross_attn_norm=True,
) )
assert block.norm3 is not None assert block.norm3 is not None
def test_without_cross_attn_norm(self): def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads, self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=False, cross_attn_norm=False,
) )
assert block.norm3 is None assert block.norm3 is None
def test_residual_connection(self): def test_residual_connection(self):
"""Output should differ from zero even with small random init.""" """Output should differ from zero even with small random init."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8 B, L = 1, 8
F, H, W = 2, 2, 2 F, H, W = 2, 2, 2
@@ -102,6 +118,7 @@ class TestWanAttentionBlock:
# Float32 Modulation Precision Tests # Float32 Modulation Precision Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestFloat32Modulation: class TestFloat32Modulation:
"""Tests that modulation/gate operations are computed in float32, """Tests that modulation/gate operations are computed in float32,
matching official torch.amp.autocast('cuda', dtype=torch.float32).""" matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
@@ -113,13 +130,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self): def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32.""" """Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan.transformer import WanAttentionBlock from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True) block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32 assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self): def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32.""" """Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4) block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8 B, L = 1, 8
x = mx.random.normal((B, L, self.dim)) x = mx.random.normal((B, L, self.dim))
@@ -135,6 +154,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self): def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input.""" """Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan.model import Head from mlx_video.models.wan.model import Head
head = Head(self.dim, 4, (1, 2, 2)) head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim)) x = mx.random.normal((1, 8, self.dim))
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16) e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
@@ -145,6 +165,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self): def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32.""" """sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([500.0]) t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t) emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb) mx.eval(emb)
@@ -153,6 +174,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self): def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32.""" """Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t) emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb) mx.eval(emb)

View File

@@ -4,16 +4,16 @@ import math
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# VAE 2.1 Tests # VAE 2.1 Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestCausalConv3d: class TestCausalConv3d:
def test_output_shape_stride1(self): def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1) conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights # Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02 conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -29,6 +29,7 @@ class TestCausalConv3d:
def test_output_shape_kernel1(self): def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0) conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02 conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4)) x = mx.random.normal((1, 4, 2, 4, 4))
@@ -39,6 +40,7 @@ class TestCausalConv3d:
def test_causal_padding(self): def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future.""" """Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1) conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,)) conv.bias = mx.zeros((2,))
@@ -55,6 +57,7 @@ class TestCausalConv3d:
class TestResidualBlock: class TestResidualBlock:
def test_same_dim(self): def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x) out = block(x)
@@ -63,6 +66,7 @@ class TestResidualBlock:
def test_different_dim(self): def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x) out = block(x)
@@ -71,11 +75,13 @@ class TestResidualBlock:
def test_shortcut_exists_when_dims_differ(self): def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
assert block.shortcut is not None assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self): def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
assert block.shortcut is None assert block.shortcut is None
@@ -83,6 +89,7 @@ class TestResidualBlock:
class TestAttentionBlock: class TestAttentionBlock:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x) out = block(x)
@@ -91,6 +98,7 @@ class TestAttentionBlock:
def test_residual_connection(self): def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3)) x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x) out = block(x)
@@ -102,13 +110,15 @@ class TestAttentionBlock:
class TestWanVAE: class TestWanVAE:
def test_instantiation(self): def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
assert vae.z_dim == 16 assert vae.z_dim == 16
assert vae.mean.shape == (16,) assert vae.mean.shape == (16,)
assert vae.std.shape == (16,) assert vae.std.shape == (16,)
def test_normalization_stats(self): def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16 assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16 assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD) assert all(s > 0 for s in VAE_STD)
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k3(self): def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1) conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x) out = conv(x)
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k1(self): def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1) conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x) out = conv(x)
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self): def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0.""" """Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1) conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape) conv.bias = mx.zeros(conv.bias.shape)
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
t0_ref = np.array(out_zero[0, 0]) t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged # Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([ x_mod = mx.concatenate(
x[:, :2], [
mx.ones((1, 2, 4, 4, 2)), x[:, :2],
], axis=1) mx.ones((1, 2, 4, 4, 2)),
],
axis=1,
)
out_mod = conv(x_mod) out_mod = conv(x_mod)
mx.eval(out_mod) mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0]) t0_mod = np.array(out_mod[0, 0])
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self): def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C].""" """Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1) conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4)) x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x) out = conv(x)
@@ -175,6 +192,7 @@ class TestRMSNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16) norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16)) x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x) out = norm(x)
@@ -184,6 +202,7 @@ class TestRMSNorm:
def test_l2_normalization(self): def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim).""" """RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm from mlx_video.models.wan.vae22 import RMS_norm
dim = 32 dim = 32
norm = RMS_norm(dim) norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
@@ -197,6 +216,7 @@ class TestRMSNorm:
def test_scale_invariant(self): def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property).""" """Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8) norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8)) x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x) out1 = norm(x)
@@ -207,6 +227,7 @@ class TestRMSNorm:
def test_gamma_effect(self): def test_gamma_effect(self):
"""Non-unit gamma should scale output.""" """Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4) norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4)) x = mx.ones((1, 1, 1, 1, 4))
@@ -221,6 +242,7 @@ class TestDupUp3D:
def test_spatial_only(self): def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2) up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x) out = up(x)
@@ -229,6 +251,7 @@ class TestDupUp3D:
def test_temporal_and_spatial(self): def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2) up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16)) x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x) out = up(x)
@@ -237,6 +260,7 @@ class TestDupUp3D:
def test_first_chunk_trims(self): def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2) up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False) out_normal = up(x, first_chunk=False)
@@ -248,6 +272,7 @@ class TestDupUp3D:
def test_no_temporal_first_chunk_noop(self): def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2) up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False) out_normal = up(x, first_chunk=False)
@@ -262,6 +287,7 @@ class TestVAE22Resample:
def test_upsample2d_shape(self): def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d") r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -271,6 +297,7 @@ class TestVAE22Resample:
def test_upsample3d_shape(self): def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -280,6 +307,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk(self): def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -291,6 +319,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self): def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample.""" """Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8)) x = mx.random.normal((1, 1, 4, 4, 8))
@@ -308,6 +337,7 @@ class TestVAE22Resample:
the first input frame (not on time_conv parameters). the first input frame (not on time_conv parameters).
""" """
from mlx_video.models.wan.vae22 import Resample from mlx_video.models.wan.vae22 import Resample
C = 8 C = 8
r = Resample(C, "upsample3d") r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable # Set time_conv weights to large values so its effect is detectable
@@ -334,8 +364,9 @@ class TestVAE22Resample:
# Compare first output frame to reference # Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C) first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out) mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \ assert mx.allclose(
"First frame should bypass time_conv and match spatial-only upsample" first_out, ref, atol=1e-5
).item(), "First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock: class TestVAE22ResidualBlock:
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
def test_same_dim(self): def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
def test_different_dim(self): def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
def test_shortcut_when_dims_differ(self): def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
assert block.shortcut is not None assert block.shortcut is not None
def test_no_shortcut_same_dim(self): def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
assert block.shortcut is None assert block.shortcut is None
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self): def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them).""" """Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8) block = ResidualBlockLayers(8, 8)
params = dict(block.parameters()) params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N # All param keys should use layer_N, not _layer_N
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
def test_has_expected_layers(self): def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16) block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d assert hasattr(block, "layer_2") # first CausalConv3d
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
def test_forward_shape(self): def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16) block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16) block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01 block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
def test_residual_connection(self): def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape) block.proj_weight = mx.zeros(block.proj_weight.shape)
@@ -427,6 +467,7 @@ class TestHead22:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22 from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12) head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16)) x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x) out = head(x)
@@ -436,6 +477,7 @@ class TestHead22:
def test_layer_names_no_underscore(self): def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix.""" """Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22 from mlx_video.models.wan.vae22 import Head22
head = Head22(8) head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d assert hasattr(head, "layer_2") # CausalConv3d
@@ -449,6 +491,7 @@ class TestUnpatchify:
def test_basic_shape(self): def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2) out = _unpatchify(x, patch_size=2)
mx.eval(out) mx.eval(out)
@@ -456,6 +499,7 @@ class TestUnpatchify:
def test_patch_size_1_noop(self): def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3)) x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1) out = _unpatchify(x, patch_size=1)
mx.eval(out) mx.eval(out)
@@ -464,6 +508,7 @@ class TestUnpatchify:
def test_preserves_content(self): def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement.""" """Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2) out = _unpatchify(x, patch_size=2)
mx.eval(out) mx.eval(out)
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48)) z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z) out = denormalize_latents(z)
mx.eval(out) mx.eval(out)
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
def test_custom_mean_std(self): def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4)) z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0]) mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5]) std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std) out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out) mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5] # z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5) np.testing.assert_allclose(
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
)
def test_uses_default_constants(self): def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents from mlx_video.models.wan.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
# Should not raise with default constants # Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48)) z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z) out = denormalize_latents(z)
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
def test_dimensions(self): def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD) mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,) assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,) assert VAE22_STD.shape == (48,)
def test_std_positive(self): def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD) mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all() assert (np.array(VAE22_STD) > 0).all()
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self): def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output.""" """Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast # Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4] # Latent: [B=1, T=3, H=2, W=2, C=4]
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
def test_output_clipped(self): def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z) out = dec(z)
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
def test_skip_encoder(self): def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.layer.weight": mx.zeros((4,)), "encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)), "conv1.weight": mx.zeros((4,)),
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
def test_sequential_index_remapping(self): def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)), "decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
def test_resample_conv_remapping(self): def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)), "decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
def test_attention_remapping(self): def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)), "decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv3d_transpose(self): def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3)) w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w} weights = {"decoder.conv1.weight": w}
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv2d_transpose(self): def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I] # Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3)) w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w} weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
def test_gamma_squeeze(self): def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,) # gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1)) w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w} weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
def test_no_upsample(self): def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
def test_spatial_upsample(self): def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
def test_spatial_temporal_upsample(self): def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
def test_no_downsample(self): def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False) block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
)
x = mx.random.normal((1, 2, 8, 8, 8)) x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
def test_spatial_downsample(self): def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True) block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
)
x = mx.random.normal((1, 2, 8, 8, 8)) x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
def test_spatial_temporal_downsample(self): def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True) block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
)
x = mx.random.normal((1, 4, 8, 8, 8)) x = mx.random.normal((1, 4, 8, 8, 8))
out = block(x) out = block(x)
mx.eval(out) mx.eval(out)
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self): def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2.""" """Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12)) x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters()) mx.eval(enc.parameters())
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self): def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample.""" """Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16) enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples down_blocks = enc.encoder.downsamples
found_modes = [] found_modes = []
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self): def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern.""" """Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16) enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3)) img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters()) mx.eval(enc.parameters())
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self): def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs.""" """(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d from mlx_video.models.wan.vae22 import Encoder3d
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
)
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False)) enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12)) x = mx.random.normal((1, 5, 16, 16, 12))
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
z_dim = 4 z_dim = 4
dim = 8 dim = 8
# No temporal up/downsampling to keep the test simple # No temporal up/downsampling to keep the test simple
enc = Encoder3d( enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False] dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
mx.eval(enc.parameters(), dec.parameters()) mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8] # [B=1, C=3, T=1, H=8, W=8]
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
mx.eval(out) mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16 # 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB assert out.shape[-1] == 3 # RGB
out_np = np.array(out) out_np = np.array(out)
assert np.all(np.isfinite(out_np)) assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6 assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6 assert out_np.max() <= 1.0 + 1e-6

View File

@@ -4,6 +4,7 @@
def _make_tiny_config(): def _make_tiny_config():
"""Create a tiny WanModelConfig for testing.""" """Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
# Override to tiny values # Override to tiny values
config.dim = 64 config.dim = 64

27
uv.lock generated
View File

@@ -622,6 +622,18 @@ http = [
{ name = "aiohttp" }, { name = "aiohttp" },
] ]
[[package]]
name = "ftfy"
version = "6.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "wcwidth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
]
[[package]] [[package]]
name = "h11" name = "h11"
version = "0.16.0" version = "0.16.0"
@@ -996,7 +1008,10 @@ wheels = [
name = "mlx-video" name = "mlx-video"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "ftfy" },
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
{ name = "imageio" },
{ name = "imageio-ffmpeg" },
{ name = "librosa" }, { name = "librosa" },
{ name = "mlx" }, { name = "mlx" },
{ name = "mlx-vlm" }, { name = "mlx-vlm" },
@@ -1016,7 +1031,10 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "ftfy" },
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
{ name = "imageio", specifier = ">=2.37.2" },
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
{ name = "librosa", specifier = ">=0.10.0" }, { name = "librosa", specifier = ">=0.10.0" },
{ name = "mlx", specifier = ">=0.22.0" }, { name = "mlx", specifier = ">=0.22.0" },
{ name = "mlx-vlm" }, { name = "mlx-vlm" },
@@ -2509,6 +2527,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" }, { url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" },
] ]
[[package]]
name = "wcwidth"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" },
]
[[package]] [[package]]
name = "xxhash" name = "xxhash"
version = "3.6.0" version = "3.6.0"