This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,17 +1,17 @@
"""Tests for LTX-2 dev model generation pipeline."""
import pytest
import mlx.core as mx
import pytest
from mlx_video.generate_dev import (
ltx2_scheduler,
create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta,
DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND,
AUDIO_SAMPLE_RATE,
DEFAULT_NEGATIVE_PROMPT,
cfg_delta,
compute_audio_frames,
create_audio_position_grid,
create_position_grid,
ltx2_scheduler,
)
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
"""Scheduler should return steps+1 sigma values."""
steps = 20
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
assert sigmas.shape == (
steps + 1,
), f"Expected ({steps + 1},), got {sigmas.shape}"
def test_scheduler_starts_at_one(self):
"""Sigma schedule should start at 1.0."""
sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
assert (
abs(sigmas[0].item() - 1.0) < 1e-5
), f"Expected 1.0, got {sigmas[0].item()}"
def test_scheduler_ends_at_zero(self):
"""Sigma schedule should end at 0.0."""
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
sigmas = ltx2_scheduler(steps=20)
sigmas_list = sigmas.tolist()
for i in range(len(sigmas_list) - 1):
assert sigmas_list[i] >= sigmas_list[i + 1], \
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
assert (
sigmas_list[i] >= sigmas_list[i + 1]
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
def test_scheduler_dtype(self):
"""Scheduler should return float32 array."""
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
num_patches = num_frames * height * width
expected_shape = (batch_size, 3, num_patches, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_position_grid_dtype(self):
"""Position grid should be float32 for RoPE precision."""
positions = create_position_grid(1, 5, 16, 24)
assert positions.dtype == mx.float32, \
f"Expected float32 for RoPE precision, got {positions.dtype}"
assert (
positions.dtype == mx.float32
), f"Expected float32 for RoPE precision, got {positions.dtype}"
def test_position_grid_batch_size(self):
"""Position grid should respect batch size."""
@@ -165,7 +172,9 @@ class TestCFGDelta:
mx.eval(delta)
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
assert (
mx.max(mx.abs(delta)).item() < 1e-6
), "CFG delta with scale=1.0 should be zero"
def test_cfg_delta_formula(self):
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
# Check for common negative quality terms
assert "blurry" in prompt_lower, "Should contain 'blurry'"
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
"Should contain quality-related terms"
assert (
"low quality" in prompt_lower or "low contrast" in prompt_lower
), "Should contain quality-related terms"
class TestInputValidation:
@@ -248,15 +258,16 @@ class TestInputValidation:
(30, 33), # 30 -> nearest valid is 33
(35, 33), # 35 -> nearest valid is 33
(40, 41), # 40 -> nearest valid is 41
(1, 1), # 1 is already valid
(1, 1), # 1 is already valid
(33, 33), # 33 is already valid
]
for input_frames, expected in test_cases:
if input_frames % 8 != 1:
adjusted = round((input_frames - 1) / 8) * 8 + 1
assert adjusted == expected, \
f"Expected {expected} for input {input_frames}, got {adjusted}"
assert (
adjusted == expected
), f"Expected {expected} for input {input_frames}, got {adjusted}"
class TestDenoiseWithCFGMocked:
@@ -277,14 +288,16 @@ class TestTilingDefault:
def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance."""
import inspect
from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling')
tiling_param = sig.parameters.get("tiling")
assert tiling_param is not None
assert tiling_param.default == "none", \
f"Expected default tiling='none', got '{tiling_param.default}'"
assert (
tiling_param.default == "none"
), f"Expected default tiling='none', got '{tiling_param.default}'"
class TestLatentDimensions:
@@ -296,8 +309,9 @@ class TestLatentDimensions:
for height, expected_latent_h in test_cases:
latent_h = height // 32
assert latent_h == expected_latent_h, \
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
assert (
latent_h == expected_latent_h
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
def test_latent_width_calculation(self):
"""Latent width should be width // 32."""
@@ -305,8 +319,9 @@ class TestLatentDimensions:
for width, expected_latent_w in test_cases:
latent_w = width // 32
assert latent_w == expected_latent_w, \
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
assert (
latent_w == expected_latent_w
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
def test_latent_frames_calculation(self):
"""Latent frames should be 1 + (num_frames - 1) // 8."""
@@ -314,8 +329,9 @@ class TestLatentDimensions:
for num_frames, expected_latent_f in test_cases:
latent_f = 1 + (num_frames - 1) // 8
assert latent_f == expected_latent_f, \
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
assert (
latent_f == expected_latent_f
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
def test_num_tokens_calculation(self):
"""Number of tokens should be latent_f * latent_h * latent_w."""
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \
f"Expected float32, got {positions.dtype}"
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
assert not mx.any(
mx.isnan(positions)
).item(), "Audio position grid contains NaN"
assert not mx.any(
mx.isinf(positions)
).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
assert (
audio_65 > audio_33
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""

View File

@@ -1,11 +1,9 @@
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx_2.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
def create_video_position_grid(
@@ -20,7 +18,7 @@ def create_video_position_grid(
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
freqs = (
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
)
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
freqs = freqs.reshape(
freqs.shape[0], freqs.shape[1], -1
) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs)
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
pad_size = expected - cos_ref.shape[-1]
if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
cos_ref = np.concatenate(
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
)
sin_ref = np.concatenate(
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
)
B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
assert (
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
), "cos_freq values out of [-1, 1] range"
assert (
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
), "sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
has_precision_loss = (
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
)
# Document the precision loss (this is expected behavior)
if has_precision_loss:
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
assert (
has_precision_loss
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
assert (
positions.dtype == mx.float32
), "Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
assert (
mx.max(temporal_positions).item() < 100
), "Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
assert (
mx.max(spatial_h).item() > 0
), "Spatial height positions should be positive"
assert (
mx.max(spatial_w).item() > 0
), "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
)
# NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx)
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
# Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
cosine_sim = np.dot(cos_flat, ref_flat) / (
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
)
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert max_cos_diff < 0.01, \
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert max_sin_diff < 0.01, \
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert cosine_sim > 0.9999, \
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
assert (
max_cos_diff < 0.01
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
max_sin_diff < 0.01
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
cosine_sim > 0.9999
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification.
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
double_precision=False,
)
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert max_cos_diff < 0.01, \
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert max_sin_diff < 0.01, \
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
assert (
max_cos_diff < 0.01
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert (
max_sin_diff < 0.01
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved:
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=True,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16,
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert config.double_precision_rope is False, \
"LTX-2 should force double_precision_rope=False regardless of input"
assert (
config.double_precision_rope is False
), "LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert config.double_precision_rope is True, \
"LTX-2.3 should preserve double_precision_rope=True"
assert (
config.double_precision_rope is True
), "LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert config.double_precision_rope is False, \
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
assert (
config.double_precision_rope is False
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False."""
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
})
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
})
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is True
@@ -496,10 +543,12 @@ class TestRoPESplit:
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
assert (
cos_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert (
sin_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":

View File

@@ -1,8 +1,8 @@
"""Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import (
@@ -50,7 +50,7 @@ class TestChunkedConv:
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
err_msg="Chunked conv output differs from regular output"
err_msg="Chunked conv output differs from regular output",
)
def test_chunked_conv_small_input_passthrough(self):
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({
'shape': frames.shape,
'start_idx': start_idx,
})
frames_received.append(
{
"shape": frames.shape,
"start_idx": start_idx,
}
)
# Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
# Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
# All received frames should have correct channel count
for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
assert (
received["shape"][1] == 3
), f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks."""
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
for i in range(num_frames):
all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(range(expected_frames)), \
"Not all frame indices were covered"
assert (
len(all_frame_indices) == expected_frames
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(
range(expected_frames)
), "Not all frame indices were covered"
class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
])
@pytest.mark.parametrize(
"tiling_mode,should_enable",
[
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
],
)
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
assert (
use_chunked_conv == should_enable
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask:
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
# Right ramp should be decreasing
right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
assert np.all(
np.diff(right_ramp) <= 0
), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""

View File

@@ -2,24 +2,25 @@
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# RoPE Tests
# ---------------------------------------------------------------------------
class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
mx.eval(freqs)
@@ -27,6 +28,7 @@ class TestRoPE:
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
cos_vals = np.array(freqs[:, :, 0])
@@ -37,13 +39,15 @@ class TestRoPE:
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
@@ -54,7 +58,8 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
L = F * H * W
@@ -74,7 +79,8 @@ class TestRoPE:
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
seq_len = F * H * W # 8
@@ -94,7 +100,8 @@ class TestRoPE:
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
L = 2 * 3 * 4
@@ -122,9 +129,11 @@ class TestRoPE:
# Attention Tests
# ---------------------------------------------------------------------------
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -134,6 +143,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
out = norm(x)
@@ -147,6 +157,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
out = norm(x)
@@ -158,6 +169,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -166,6 +178,7 @@ class TestWanLayerNorm:
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
out = norm(x)
@@ -178,6 +191,7 @@ class TestWanLayerNorm:
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
assert hasattr(norm, "bias")
@@ -196,6 +210,7 @@ class TestWanSelfAttention:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
F, H, W = 2, 3, 4
@@ -207,12 +222,14 @@ class TestWanSelfAttention:
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
assert attn.norm_k is None
@@ -221,6 +238,7 @@ class TestWanSelfAttention:
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
F, H, W = 2, 3, 4
@@ -245,6 +263,7 @@ class TestWanCrossAttention:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
x = mx.random.normal((B, L_q, self.dim))
@@ -255,6 +274,7 @@ class TestWanCrossAttention:
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
x = mx.random.normal((B, L_q, self.dim))
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
# bfloat16 Autocast Tests
# ---------------------------------------------------------------------------
class TestBFloat16Autocast:
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
for efficient matmul, matching official PyTorch autocast behavior."""
@@ -292,6 +313,7 @@ class TestBFloat16Autocast:
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -305,6 +327,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -318,6 +341,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -330,6 +354,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
@@ -343,6 +368,7 @@ class TestBFloat16Autocast:
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))

View File

@@ -1,17 +1,17 @@
"""Tests for Wan model configuration."""
import pytest
# ---------------------------------------------------------------------------
# Config Tests
# ---------------------------------------------------------------------------
class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
assert config.ffn_dim == 13824
@@ -33,11 +33,13 @@ class TestWanModelConfig:
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
assert isinstance(d, dict)
@@ -47,6 +49,7 @@ class TestWanModelConfig:
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
assert config.t5_dim == 4096
@@ -61,11 +64,13 @@ class TestWanModelConfig:
# Wan2.1 Config Tests
# ---------------------------------------------------------------------------
class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
assert config.dual_model is False
@@ -81,6 +86,7 @@ class TestWan21Config:
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
assert config.dual_model is False
@@ -93,6 +99,7 @@ class TestWan21Config:
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
assert config.dual_model is True
@@ -104,6 +111,7 @@ class TestWan21Config:
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
@@ -112,6 +120,7 @@ class TestWan21Config:
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
assert d["dim"] == 1536
@@ -120,6 +129,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"
assert config.dual_model is True

View File

@@ -3,17 +3,16 @@
import logging
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Weight Conversion Tests
# ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights:
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)),
@@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights:
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
@@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights:
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)),
@@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights:
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)),
@@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights:
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)),
@@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights:
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
@@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights:
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
@@ -133,6 +140,7 @@ class TestSanitizeT5Weights:
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
@@ -144,6 +152,7 @@ class TestSanitizeT5Weights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -159,6 +168,7 @@ class TestSanitizeT5Weights:
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
}
@@ -167,6 +177,7 @@ class TestSanitizeVAEWeights:
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
}
@@ -175,6 +186,7 @@ class TestSanitizeVAEWeights:
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)),
@@ -185,6 +197,7 @@ class TestSanitizeVAEWeights:
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
@@ -199,6 +212,7 @@ class TestSanitizeVAEWeights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
# Wan2.1 Conversion Tests
# ---------------------------------------------------------------------------
class TestWan21Convert:
"""Tests for Wan2.1 conversion support."""
@@ -222,7 +237,7 @@ class TestWan21Convert:
# Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model"
assert not low.exists()
# Simulates auto detection
@@ -233,7 +248,7 @@ class TestWan21Convert:
"""Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model"
assert low.exists()
version = "2.2" if low.exists() else "2.1"
@@ -242,6 +257,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
@@ -254,6 +270,7 @@ class TestWan21Convert:
# Encoder Weight Sanitization Tests
# ---------------------------------------------------------------------------
class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""

View File

@@ -2,15 +2,13 @@
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed)."""
@@ -78,6 +76,7 @@ class TestEndToEnd:
# I2V Mask Tests
# ---------------------------------------------------------------------------
class TestI2VMask:
"""Tests for _build_i2v_mask."""
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
z_shape = (48, 21, 44, 80)
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
np.testing.assert_allclose(
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
)
# ---------------------------------------------------------------------------
# Dimension Alignment Tests
# ---------------------------------------------------------------------------
class TestDimensionAlignment:
"""Tests for automatic dimension alignment in generate_wan."""
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
patches, grid_size = model._patchify(vid)
mx.eval(patches)
assert patches.ndim == 3 # [1, L, dim]
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
assert grid_size == (
t_latent,
h_latent // patch_size[1],
w_latent // patch_size[2],
)
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]
align_w = config.patch_size[2] * config.vae_stride[2]

View File

@@ -1,9 +1,6 @@
"""Tests for Wan2.2 I2V-14B support."""
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
@@ -145,7 +142,10 @@ class TestModelYParameter:
latents = mx.random.normal((C_noise, F, H, W))
y = mx.random.normal((C_y, F, H, W))
t = mx.array([500.0, 500.0])
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
ctx = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((6, config.text_dim)),
]
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
mx.eval(out[0], out[1])
@@ -160,7 +160,9 @@ class TestVAEEncoder:
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
enc = Encoder3d(
dim=32, z_dim=8
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
assert enc.conv1 is not None
assert len(enc.downsamples) > 0
assert len(enc.middle) == 3
@@ -199,10 +201,10 @@ class TestVAEEncoder:
from mlx_video.models.wan.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, 'encoder')
assert not hasattr(vae_no_enc, "encoder")
vae_enc = WanVAE(z_dim=4, encoder=True)
assert hasattr(vae_enc, 'encoder')
assert hasattr(vae_enc, "encoder")
class TestResampleDownsample:
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
# Build mask following reference logic
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
t_latent = (num_frames - 1) // 4 + 1 # = 3
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0]
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
config = _make_tiny_i2v_config()
config.vae_z_dim = 16
config.out_dim = 16 # must match VAE z_dim for decode
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
config.in_dim = (
16 + 4 + 16
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
model = WanModel(config)
# --- Tiny VAE (with encoder) ---
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
video = mx.concatenate([
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
], axis=2)
video = mx.concatenate(
[
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
],
axis=2,
)
# --- VAE encode ---
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
# --- Build I2V mask (4 channels) ---
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents)
# With shift=5.0, early timesteps should be high (>=900), later ones low
@@ -461,9 +476,9 @@ class TestDualModelSwitching:
assert len(low_used_steps) > 0, "Low-noise model was never selected"
# High-noise steps should come before low-noise steps (timesteps decrease)
if high_used_steps and low_used_steps:
assert max(high_used_steps) < min(low_used_steps) or \
min(high_used_steps) < max(low_used_steps), \
"Model switching should happen during the loop"
assert max(high_used_steps) < min(low_used_steps) or min(
high_used_steps
) < max(low_used_steps), "Model switching should happen during the loop"
assert latents.shape == (C_noise, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
y=[y_i2v, y_i2v],
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents)
# Verify both guide scales were used

View File

@@ -4,7 +4,6 @@ import tempfile
from pathlib import Path
import mlx.core as mx
import numpy as np
import pytest
@@ -40,7 +39,9 @@ class TestLoRATypes:
lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
@@ -51,7 +52,9 @@ class TestLoRATypes:
class TestLoRALoader:
"""Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
def _make_lora_file(
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
):
"""Helper to create a mock LoRA safetensors file."""
weights = {}
for name in module_names:
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
"""Simulate typical Wan2.2 MLX model weight keys."""
keys = set()
for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
for layer in [
"self_attn.q",
"self_attn.k",
"self_attn.v",
"self_attn.o",
"cross_attn.q",
"cross_attn.k",
"cross_attn.v",
"cross_attn.o",
]:
keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight")
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
assert (
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
== "blocks.0.self_attn.q"
)
def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
result = _normalize_wan_lora_key(
"model.diffusion_model.blocks.0.self_attn.k", keys
)
assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self):
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
assert (
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
)
def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -219,7 +237,9 @@ class TestApplyLoRA:
# LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
@@ -230,7 +250,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
@@ -241,7 +263,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4))
@@ -255,12 +279,16 @@ class TestApplyLoRA:
w1 = LoRAWeights(
lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a",
rank=2,
alpha=2.0,
module_name="a",
)
w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b",
rank=2,
alpha=4.0,
module_name="b",
)
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
@@ -282,7 +310,9 @@ class TestApplyLoRA:
w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
rank=4,
alpha=4.0,
module_name="blocks.0.self_attn.q",
)
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras)
@@ -319,9 +349,7 @@ class TestEndToEnd:
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
}
result = load_and_apply_loras(
model_weights, [(str(lora_path), 1.0)]
)
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
# q weight should be modified, k unchanged
assert not mx.array_equal(

View File

@@ -3,18 +3,17 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Sinusoidal Embedding Tests
# ---------------------------------------------------------------------------
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
mx.eval(emb)
@@ -23,6 +22,7 @@ class TestSinusoidalEmbedding:
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
mx.eval(emb)
@@ -34,6 +34,7 @@ class TestSinusoidalEmbedding:
def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
mx.eval(emb)
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
# Head Tests
# ---------------------------------------------------------------------------
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
x = mx.random.normal((B, L, 64))
@@ -60,6 +63,7 @@ class TestHead:
def test_modulation_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
@@ -68,12 +72,14 @@ class TestHead:
# WanModel (Tiny) Tests
# ---------------------------------------------------------------------------
class TestWanModel:
def setup_method(self):
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
@@ -81,6 +87,7 @@ class TestWanModel:
def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
# Input: [C=4, F=1, H=4, W=4]
@@ -93,6 +100,7 @@ class TestWanModel:
def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
@@ -108,6 +116,7 @@ class TestWanModel:
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 2, 4, 6
@@ -123,6 +132,7 @@ class TestWanModel:
def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
@@ -140,6 +150,7 @@ class TestWanModel:
def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
@@ -148,7 +159,10 @@ class TestWanModel:
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
t = mx.array([500.0, 200.0])
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
context = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((4, config.text_dim)),
]
out = model(x_list, t, context, seq_len)
mx.eval(out[0], out[1])
@@ -158,12 +172,17 @@ class TestWanModel:
def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
[mx.random.normal((4, config.text_dim))], seq_len)
out = model(
[mx.random.normal((C, F, H, W))],
mx.array([100.0]),
[mx.random.normal((4, config.text_dim))],
seq_len,
)
mx.eval(out[0])
assert out[0].dtype == mx.float32
@@ -172,6 +191,7 @@ class TestWanModel:
# Wan2.1 Model Tests
# ---------------------------------------------------------------------------
class TestWan21Model:
"""Test tiny Wan2.1-style model (single model mode)."""
@@ -181,6 +201,7 @@ class TestWan21Model:
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
config.dim = 64
@@ -197,6 +218,7 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
config.dim = 48
@@ -271,7 +293,9 @@ class TestWan21Model:
for i in range(3):
t = sched.timesteps[i]
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
pred_uncond = model(
[latents], mx.array([t.item()]), [context_null], seq_len
)[0]
pred = pred_uncond + gs * (pred_cond - pred_uncond)
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
@@ -304,6 +328,7 @@ class TestWan21Model:
# Per-Token Timestep Tests
# ---------------------------------------------------------------------------
class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding."""

View File

@@ -1,22 +1,22 @@
"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
@@ -24,6 +24,7 @@ class TestQuantizePredicate:
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
@@ -31,23 +32,31 @@ class TestQuantizePredicate:
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
for path in [
"patch_embedding_proj",
"text_embedding_fc1",
"time_embedding.fc1",
]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
@@ -55,13 +64,19 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
"blocks.0.self_attn.q",
"blocks.0.self_attn.k",
"blocks.0.self_attn.v",
"blocks.0.self_attn.o",
"blocks.0.cross_attn.q",
"blocks.0.cross_attn.k",
"blocks.0.cross_attn.v",
"blocks.0.cross_attn.o",
"blocks.0.ffn.fc1",
"blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -101,8 +117,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -119,8 +137,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 8, "group_size": 64},
)
@@ -132,8 +152,10 @@ class TestQuantizeRoundTrip:
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -151,6 +173,7 @@ class TestQuantizeRoundTrip:
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -214,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
mx.random.seed(42)
@@ -243,11 +267,12 @@ class TestQuantizedInference:
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()

View File

@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
d = 128 # head_dim for all Wan models
# Reference: three separate calls
correct = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
correct = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Wrong: single call
wrong = rope_params(1024, d)
mx.eval(correct, wrong)
assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
assert (
diff > 0.1
), f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
err_msg="temporal[0] cos at pos 1")
np.testing.assert_allclose(
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
)
# Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
err_msg="height[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
)
# Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
err_msg="width[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
)
def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables.
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
d = 128
d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h]
width_freqs = f[:, d_t + d_h:]
height_freqs = f[:, d_t : d_t + d_h]
width_freqs = f[:, d_t + d_h :]
np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self):
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert (
pos1_t > 0.5
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
freqs_manual = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs_manual = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction"
np.array(freqs_model),
np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction",
)
def test_model_freqs_14b_dimensions(self):
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
],
axis=1,
)
mx.eval(freqs)
assert freqs.shape == (1024, 64, 2)
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
@pytest.fixture
def has_torch(self):
try:
import torch
pass
return True
except ImportError:
pytest.skip("PyTorch not installed")
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
d = 128
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
1.0
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
ref = torch.cat([
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
], dim=1)
ref = torch.cat(
[
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# MLX
ours = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
ours = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(ours)
our_cos = np.array(ours[:, :, 0])
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
err_msg="cos mismatch vs PyTorch reference")
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
err_msg="sin mismatch vs PyTorch reference")
np.testing.assert_allclose(
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
)
np.testing.assert_allclose(
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
)
class TestRoPEApplyWithCorrectFreqs:
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 1, 4
F, H, W = 1, 4, 4
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
)
assert width_diff > 0.5, (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
)
assert (
height_diff > 0.5
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
assert (
width_diff > 0.5
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
# Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
err_msg="Height and width should use identical frequency tables")
np.testing.assert_allclose(
height_diff,
width_diff,
rtol=1e-5,
err_msg="Height and width should use identical frequency tables",
)
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
)
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 2, 4
F, H, W = 2, 3, 4
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
mx.eval(out_online, out_precomp)
np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5,
err_msg="Precomputed and online RoPE should match"
np.array(out_online),
np.array(out_precomp),
atol=1e-5,
err_msg="Precomputed and online RoPE should match",
)

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Euler Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
assert sched.timesteps is None
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps)
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas)
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas)
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0)
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas)
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
# Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4,
np.array(result).flatten()[0],
expected,
rtol=1e-4,
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -141,22 +154,26 @@ class TestComputeSigmas:
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -169,6 +186,7 @@ class TestComputeSigmas:
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once
@@ -183,6 +201,7 @@ class TestComputeSigmas:
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0
@@ -196,6 +215,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -214,6 +234,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -235,12 +256,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2))
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4))
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
x0_after_second = sched._prev_x0
assert x0_after_second is not None
# The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
assert not np.allclose(
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
)
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4))
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2))
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2))
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
assert (
rhos_c[0] < 0.15
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert (
rhos_c[0] > 0.0
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
assert (
0.3 < rhos_c[1] < 0.5
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
# ---------------------------------------------------------------------------
# Scheduler Coherence Tests
# ---------------------------------------------------------------------------
class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
err_msg="DPM++ step 0 should match Euler",
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
err_msg="UniPC step 0 should match Euler",
)
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5,
np.array(dpm_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
)
np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5,
np.array(unipc_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
)
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3,
np.array(latents),
0.0,
atol=1e-3,
err_msg=f"{name} did not converge to target with oracle",
)
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
# Higher-order solvers should not be significantly worse than Euler
# (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert (
errors["dpm++"] <= errors["euler"] * 1.5 + eps
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
assert (
errors["unipc"] <= errors["euler"] * 1.5 + eps
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
# All solvers should produce results within the same order of magnitude
vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
)
assert (
ratio < 10.0
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers."""
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), (
f"{name} produced non-finite values at step {i}"
)
assert np.isfinite(
np.array(latents)
).all(), f"{name} produced non-finite values at step {i}"
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, (
f"{cls.__name__}._lambda(1.0) should be -inf"
)
assert cls._lambda(0.0) == math.inf, (
f"{cls.__name__}._lambda(0.0) should be +inf"
)
assert (
cls._lambda(1.0) == -math.inf
), f"{cls.__name__}._lambda(1.0) should be -inf"
assert (
cls._lambda(0.0) == math.inf
), f"{cls.__name__}._lambda(0.0) should be +inf"
# Interior values should be finite
lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, (
f"{cls.__name__}._lambda(0.5) should be 0.0"
)
assert (
math.isfinite(lam) and lam == 0.0
), f"{cls.__name__}._lambda(0.5) should be 0.0"
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=5e-4,
np.array(result),
np.array(expected),
atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
)
def test_dpmpp_unipc_agree_on_step1(self):
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls()
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
mx.eval(latents)
result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()")
np.testing.assert_allclose(
result1,
result2,
atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()",
)
# ---------------------------------------------------------------------------
# UniPC Corrector Default Tests
# ---------------------------------------------------------------------------
class TestUniPCCorrectorDefault:
"""Tests that the UniPC corrector is enabled by default,
matching official FlowUniPCMultistepScheduler behavior."""
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)

View File

@@ -3,16 +3,16 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# T5 Encoder Tests
# ---------------------------------------------------------------------------
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -22,6 +22,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x)
@@ -35,6 +36,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
mx.eval(out)
@@ -42,6 +44,7 @@ class TestT5RelativeEmbedding:
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
mx.eval(out)
@@ -50,6 +53,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
mx.eval(out)
@@ -64,6 +68,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
out = attn(x)
@@ -73,12 +78,14 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64))
@@ -89,6 +96,7 @@ class TestT5Attention:
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10))
@@ -101,6 +109,7 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
out = ffn(x)
@@ -110,6 +119,7 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1")
@@ -122,9 +132,16 @@ class TestT5Encoder:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]])
@@ -134,9 +151,16 @@ class TestT5Encoder:
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=True,
)
assert encoder.pos_embedding is not None
for block in encoder.blocks:
@@ -144,9 +168,16 @@ class TestT5Encoder:
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
assert encoder.pos_embedding is None
for block in encoder.blocks:
@@ -154,18 +185,32 @@ class TestT5Encoder:
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10]])
out = encoder(ids)

View File

@@ -2,13 +2,11 @@
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
split_in_temporal,
)
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
# Causal: 1 + (4-1)*4 = 13
out_causal = decode_with_tiling(
dummy_decoder_causal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
dummy_decoder_causal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=True,
)
mx.eval(out_causal)
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
# Non-causal: 4*4 = 16
out_noncausal = decode_with_tiling(
dummy_decoder_noncausal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
dummy_decoder_noncausal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=False,
)
mx.eval(out_noncausal)
assert out_noncausal.shape[2] == t * scale # 16
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled)
# Both should produce the same shape
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
from mlx_video.models.wan.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
temporal_upsample=[True, True, False])
dec = Decoder3d(
dim=16,
z_dim=4,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temporal_upsample=[True, True, False],
)
mx.eval(dec.parameters())
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3

View File

@@ -2,16 +2,16 @@
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Block Tests
# ---------------------------------------------------------------------------
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
out = ffn(x)
@@ -21,6 +21,7 @@ class TestWanFFN:
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
out1 = ffn(x)
@@ -39,10 +40,13 @@ class TestWanAttentionBlock:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=True,
)
B, L = 1, 24
@@ -53,37 +57,49 @@ class TestWanAttentionBlock:
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
freqs=freqs, context=context,
x,
e,
seq_lens=[L],
grid_sizes=[(F, H, W)],
freqs=freqs,
context=context,
)
mx.eval(out)
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=True,
)
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
self.dim,
self.ffn_dim,
self.num_heads,
cross_attn_norm=False,
)
assert block.norm3 is None
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
F, H, W = 2, 2, 2
@@ -102,6 +118,7 @@ class TestWanAttentionBlock:
# Float32 Modulation Precision Tests
# ---------------------------------------------------------------------------
class TestFloat32Modulation:
"""Tests that modulation/gate operations are computed in float32,
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
@@ -113,13 +130,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8
x = mx.random.normal((B, L, self.dim))
@@ -135,6 +154,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan.model import Head
head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim))
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
@@ -145,6 +165,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb)
@@ -153,6 +174,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb)

View File

@@ -4,16 +4,16 @@ import math
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# VAE 2.1 Tests
# ---------------------------------------------------------------------------
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -29,6 +29,7 @@ class TestCausalConv3d:
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4))
@@ -39,6 +40,7 @@ class TestCausalConv3d:
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,))
@@ -55,6 +57,7 @@ class TestCausalConv3d:
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -63,6 +66,7 @@ class TestResidualBlock:
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -71,11 +75,13 @@ class TestResidualBlock:
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -83,6 +89,7 @@ class TestResidualBlock:
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -91,6 +98,7 @@ class TestAttentionBlock:
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x)
@@ -102,13 +110,15 @@ class TestAttentionBlock:
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
assert vae.mean.shape == (16,)
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD)
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x)
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x)
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape)
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
], axis=1)
x_mod = mx.concatenate(
[
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
],
axis=1,
)
out_mod = conv(x_mod)
mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0])
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x)
@@ -175,6 +192,7 @@ class TestRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x)
@@ -184,6 +202,7 @@ class TestRMSNorm:
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
@@ -197,6 +216,7 @@ class TestRMSNorm:
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x)
@@ -207,6 +227,7 @@ class TestRMSNorm:
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4))
@@ -221,6 +242,7 @@ class TestDupUp3D:
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x)
@@ -229,6 +251,7 @@ class TestDupUp3D:
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x)
@@ -237,6 +260,7 @@ class TestDupUp3D:
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
@@ -248,6 +272,7 @@ class TestDupUp3D:
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
@@ -262,6 +287,7 @@ class TestVAE22Resample:
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -271,6 +297,7 @@ class TestVAE22Resample:
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -280,6 +307,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -291,6 +319,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8))
@@ -308,6 +337,7 @@ class TestVAE22Resample:
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable
@@ -334,8 +364,9 @@ class TestVAE22Resample:
# Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
"First frame should bypass time_conv and match spatial-only upsample"
assert mx.allclose(
first_out, ref, atol=1e-5
).item(), "First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock:
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape)
@@ -427,6 +467,7 @@ class TestHead22:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x)
@@ -436,6 +477,7 @@ class TestHead22:
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d
@@ -449,6 +491,7 @@ class TestUnpatchify:
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
mx.eval(out)
@@ -456,6 +499,7 @@ class TestUnpatchify:
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
mx.eval(out)
@@ -464,6 +508,7 @@ class TestUnpatchify:
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
mx.eval(out)
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
mx.eval(out)
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
np.testing.assert_allclose(
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
from mlx_video.models.wan.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
# Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z)
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4]
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z)
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)),
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w}
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
)
x = mx.random.normal((1, 4, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
found_modes = []
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
)
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6

View File

@@ -4,6 +4,7 @@
def _make_tiny_config():
"""Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
# Override to tiny values
config.dim = 64