format
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""Tests for LTX-2 dev model generation pipeline."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from mlx_video.generate_dev import (
|
||||
ltx2_scheduler,
|
||||
create_position_grid,
|
||||
create_audio_position_grid,
|
||||
compute_audio_frames,
|
||||
cfg_delta,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
AUDIO_LATENTS_PER_SECOND,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
cfg_delta,
|
||||
compute_audio_frames,
|
||||
create_audio_position_grid,
|
||||
create_position_grid,
|
||||
ltx2_scheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
|
||||
"""Scheduler should return steps+1 sigma values."""
|
||||
steps = 20
|
||||
sigmas = ltx2_scheduler(steps=steps)
|
||||
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
assert sigmas.shape == (
|
||||
steps + 1,
|
||||
), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
|
||||
def test_scheduler_starts_at_one(self):
|
||||
"""Sigma schedule should start at 1.0."""
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
|
||||
assert (
|
||||
abs(sigmas[0].item() - 1.0) < 1e-5
|
||||
), f"Expected 1.0, got {sigmas[0].item()}"
|
||||
|
||||
def test_scheduler_ends_at_zero(self):
|
||||
"""Sigma schedule should end at 0.0."""
|
||||
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
sigmas_list = sigmas.tolist()
|
||||
for i in range(len(sigmas_list) - 1):
|
||||
assert sigmas_list[i] >= sigmas_list[i + 1], \
|
||||
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
assert (
|
||||
sigmas_list[i] >= sigmas_list[i + 1]
|
||||
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
|
||||
def test_scheduler_dtype(self):
|
||||
"""Scheduler should return float32 array."""
|
||||
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
|
||||
num_patches = num_frames * height * width
|
||||
|
||||
expected_shape = (batch_size, 3, num_patches, 2)
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_position_grid_dtype(self):
|
||||
"""Position grid should be float32 for RoPE precision."""
|
||||
positions = create_position_grid(1, 5, 16, 24)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
|
||||
def test_position_grid_batch_size(self):
|
||||
"""Position grid should respect batch size."""
|
||||
@@ -165,7 +172,9 @@ class TestCFGDelta:
|
||||
mx.eval(delta)
|
||||
|
||||
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
||||
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
|
||||
assert (
|
||||
mx.max(mx.abs(delta)).item() < 1e-6
|
||||
), "CFG delta with scale=1.0 should be zero"
|
||||
|
||||
def test_cfg_delta_formula(self):
|
||||
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
||||
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
|
||||
|
||||
# Check for common negative quality terms
|
||||
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
||||
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
|
||||
"Should contain quality-related terms"
|
||||
assert (
|
||||
"low quality" in prompt_lower or "low contrast" in prompt_lower
|
||||
), "Should contain quality-related terms"
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
@@ -248,15 +258,16 @@ class TestInputValidation:
|
||||
(30, 33), # 30 -> nearest valid is 33
|
||||
(35, 33), # 35 -> nearest valid is 33
|
||||
(40, 41), # 40 -> nearest valid is 41
|
||||
(1, 1), # 1 is already valid
|
||||
(1, 1), # 1 is already valid
|
||||
(33, 33), # 33 is already valid
|
||||
]
|
||||
|
||||
for input_frames, expected in test_cases:
|
||||
if input_frames % 8 != 1:
|
||||
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
||||
assert adjusted == expected, \
|
||||
f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
assert (
|
||||
adjusted == expected
|
||||
), f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
|
||||
|
||||
class TestDenoiseWithCFGMocked:
|
||||
@@ -277,14 +288,16 @@ class TestTilingDefault:
|
||||
def test_tiling_default_is_none(self):
|
||||
"""Default tiling should be 'none' for performance."""
|
||||
import inspect
|
||||
|
||||
from mlx_video.generate_dev import generate_video_dev
|
||||
|
||||
sig = inspect.signature(generate_video_dev)
|
||||
|
||||
tiling_param = sig.parameters.get('tiling')
|
||||
tiling_param = sig.parameters.get("tiling")
|
||||
assert tiling_param is not None
|
||||
assert tiling_param.default == "none", \
|
||||
f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
assert (
|
||||
tiling_param.default == "none"
|
||||
), f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
|
||||
|
||||
class TestLatentDimensions:
|
||||
@@ -296,8 +309,9 @@ class TestLatentDimensions:
|
||||
|
||||
for height, expected_latent_h in test_cases:
|
||||
latent_h = height // 32
|
||||
assert latent_h == expected_latent_h, \
|
||||
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
assert (
|
||||
latent_h == expected_latent_h
|
||||
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
|
||||
def test_latent_width_calculation(self):
|
||||
"""Latent width should be width // 32."""
|
||||
@@ -305,8 +319,9 @@ class TestLatentDimensions:
|
||||
|
||||
for width, expected_latent_w in test_cases:
|
||||
latent_w = width // 32
|
||||
assert latent_w == expected_latent_w, \
|
||||
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
assert (
|
||||
latent_w == expected_latent_w
|
||||
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
|
||||
def test_latent_frames_calculation(self):
|
||||
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
||||
@@ -314,8 +329,9 @@ class TestLatentDimensions:
|
||||
|
||||
for num_frames, expected_latent_f in test_cases:
|
||||
latent_f = 1 + (num_frames - 1) // 8
|
||||
assert latent_f == expected_latent_f, \
|
||||
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
assert (
|
||||
latent_f == expected_latent_f
|
||||
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
|
||||
def test_num_tokens_calculation(self):
|
||||
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
||||
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
|
||||
positions = create_audio_position_grid(batch_size, audio_frames)
|
||||
expected_shape = (batch_size, 1, audio_frames, 2)
|
||||
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_audio_position_grid_dtype(self):
|
||||
"""Audio position grid should be float32."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32, got {positions.dtype}"
|
||||
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
|
||||
|
||||
def test_audio_position_grid_batch_size(self):
|
||||
"""Audio position grid should respect batch size."""
|
||||
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
|
||||
"""Audio position grid should not contain NaN or Inf."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
|
||||
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
|
||||
assert not mx.any(
|
||||
mx.isnan(positions)
|
||||
).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(
|
||||
mx.isinf(positions)
|
||||
).item(), "Audio position grid contains Inf"
|
||||
|
||||
|
||||
class TestComputeAudioFrames:
|
||||
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
|
||||
audio_33 = compute_audio_frames(33, 24.0)
|
||||
audio_65 = compute_audio_frames(65, 24.0)
|
||||
|
||||
assert audio_65 > audio_33, \
|
||||
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
assert (
|
||||
audio_65 > audio_33
|
||||
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
|
||||
def test_audio_frames_formula(self):
|
||||
"""Audio frames should match expected formula."""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
@@ -20,7 +18,7 @@ def create_video_position_grid(
|
||||
h_coords = np.arange(0, height)
|
||||
w_coords = np.arange(0, width)
|
||||
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
patch_ends = patch_starts + 1
|
||||
|
||||
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
scaled = fractional * 2 - 1 # [-1, 1]
|
||||
|
||||
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
||||
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
freqs = (
|
||||
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
)
|
||||
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
|
||||
freqs = freqs.reshape(
|
||||
freqs.shape[0], freqs.shape[1], -1
|
||||
) # (B, T, num_indices * n_dims)
|
||||
|
||||
cos_ref = np.cos(freqs)
|
||||
sin_ref = np.sin(freqs)
|
||||
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
pad_size = expected - cos_ref.shape[-1]
|
||||
if pad_size > 0:
|
||||
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
||||
cos_ref = np.concatenate(
|
||||
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
|
||||
)
|
||||
sin_ref = np.concatenate(
|
||||
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
|
||||
)
|
||||
|
||||
B, T, _ = cos_ref.shape
|
||||
dim_per_head = dim // num_heads
|
||||
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
|
||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||
|
||||
# Verify cos/sin are in valid range [-1, 1]
|
||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||
"cos_freq values out of [-1, 1] range"
|
||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||
"sin_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
|
||||
), "cos_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
|
||||
), "sin_freq values out of [-1, 1] range"
|
||||
|
||||
def test_bfloat16_positions_cause_precision_loss(self):
|
||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
|
||||
# The threshold here is intentionally low to catch the issue
|
||||
precision_threshold = 1e-6
|
||||
|
||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
has_precision_loss = (
|
||||
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
)
|
||||
|
||||
# Document the precision loss (this is expected behavior)
|
||||
if has_precision_loss:
|
||||
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
|
||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||
|
||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||
assert has_precision_loss, \
|
||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
assert (
|
||||
has_precision_loss
|
||||
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
|
||||
def test_double_precision_converts_to_float32_internally(self):
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
|
||||
# Recommended: create positions in float32
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
assert positions.dtype == mx.float32, \
|
||||
"Position grids should be created in float32 for RoPE precision"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), "Position grids should be created in float32 for RoPE precision"
|
||||
|
||||
# Verify the position values are reasonable
|
||||
# Temporal positions should be small (seconds)
|
||||
temporal_positions = positions[:, 0, :, :]
|
||||
assert mx.max(temporal_positions).item() < 100, \
|
||||
"Temporal positions should be in seconds (small values)"
|
||||
assert (
|
||||
mx.max(temporal_positions).item() < 100
|
||||
), "Temporal positions should be in seconds (small values)"
|
||||
|
||||
# Spatial positions should be larger (pixels)
|
||||
spatial_h = positions[:, 1, :, :]
|
||||
spatial_w = positions[:, 2, :, :]
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_h).item() > 0
|
||||
), "Spatial height positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_w).item() > 0
|
||||
), "Spatial width positions should be positive"
|
||||
|
||||
def test_float32_positions_match_numpy_float64_reference(self):
|
||||
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
||||
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
|
||||
)
|
||||
|
||||
# NumPy float64 reference
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
cos_mlx_np = np.array(cos_mlx)
|
||||
sin_mlx_np = np.array(sin_mlx)
|
||||
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
|
||||
# Cosine similarity (flatten for single scalar)
|
||||
cos_flat = cos_mlx_np.flatten()
|
||||
ref_flat = cos_ref.flatten()
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (
|
||||
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
|
||||
)
|
||||
|
||||
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
||||
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert cosine_sim > 0.9999, \
|
||||
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
cosine_sim > 0.9999
|
||||
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
|
||||
def test_high_frequency_amplification_regression(self):
|
||||
"""Regression test for the specific failure mode: high-frequency index amplification.
|
||||
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
||||
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
||||
|
||||
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=True,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_f16,
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
|
||||
def test_ltx2_forces_double_precision_rope_false(self):
|
||||
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2 should force double_precision_rope=False regardless of input"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2 should force double_precision_rope=False regardless of input"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_true(self):
|
||||
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
||||
assert config.double_precision_rope is True, \
|
||||
"LTX-2.3 should preserve double_precision_rope=True"
|
||||
assert (
|
||||
config.double_precision_rope is True
|
||||
), "LTX-2.3 should preserve double_precision_rope=True"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_false(self):
|
||||
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
|
||||
def test_ltx2_default_double_precision_rope(self):
|
||||
"""LTX-2 default (double_precision_rope not set) should be False."""
|
||||
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
|
||||
|
||||
def test_config_from_dict_ltx2(self):
|
||||
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_config_from_dict_ltx23(self):
|
||||
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is True
|
||||
|
||||
|
||||
@@ -496,10 +543,12 @@ class TestRoPESplit:
|
||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||
dim_per_head = dim // num_heads
|
||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||
assert cos_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert sin_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
assert (
|
||||
cos_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert (
|
||||
sin_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tests for VAE streaming and chunked conv features."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
@@ -50,7 +50,7 @@ class TestChunkedConv:
|
||||
np.array(out_chunked),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg="Chunked conv output differs from regular output"
|
||||
err_msg="Chunked conv output differs from regular output",
|
||||
)
|
||||
|
||||
def test_chunked_conv_small_input_passthrough(self):
|
||||
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
|
||||
frames_received = []
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
frames_received.append({
|
||||
'shape': frames.shape,
|
||||
'start_idx': start_idx,
|
||||
})
|
||||
frames_received.append(
|
||||
{
|
||||
"shape": frames.shape,
|
||||
"start_idx": start_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock decoder that just returns scaled input
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
|
||||
|
||||
# All received frames should have correct channel count
|
||||
for received in frames_received:
|
||||
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
|
||||
assert (
|
||||
received["shape"][1] == 3
|
||||
), f"Expected 3 channels, got {received['shape'][1]}"
|
||||
|
||||
def test_on_frames_ready_covers_all_frames(self):
|
||||
"""Verify all frames are emitted via callbacks."""
|
||||
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
|
||||
for i in range(num_frames):
|
||||
all_frame_indices.add(start_idx + i)
|
||||
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
out_h = h * 32
|
||||
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
|
||||
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
||||
|
||||
# All frames should have been emitted
|
||||
assert len(all_frame_indices) == expected_frames, \
|
||||
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(range(expected_frames)), \
|
||||
"Not all frame indices were covered"
|
||||
assert (
|
||||
len(all_frame_indices) == expected_frames
|
||||
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(
|
||||
range(expected_frames)
|
||||
), "Not all frame indices were covered"
|
||||
|
||||
|
||||
class TestAutoChunkedConv:
|
||||
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
||||
|
||||
@pytest.mark.parametrize("tiling_mode,should_enable", [
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tiling_mode,should_enable",
|
||||
[
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
],
|
||||
)
|
||||
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
||||
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
||||
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
|
||||
|
||||
use_chunked_conv = tiling_mode in expected_modes
|
||||
|
||||
assert use_chunked_conv == should_enable, \
|
||||
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
assert (
|
||||
use_chunked_conv == should_enable
|
||||
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
|
||||
|
||||
class TestTrapezoidalMask:
|
||||
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
|
||||
|
||||
# Right ramp should be decreasing
|
||||
right_ramp = mask_np[-8:]
|
||||
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
|
||||
assert np.all(
|
||||
np.diff(right_ramp) <= 0
|
||||
), "Right ramp not monotonically decreasing"
|
||||
|
||||
def test_temporal_mask_starts_from_zero(self):
|
||||
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
||||
|
||||
@@ -2,24 +2,25 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoPE:
|
||||
"""Tests for 3-way factorized RoPE."""
|
||||
|
||||
def test_rope_params_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(1024, 64)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||
|
||||
def test_rope_params_different_dims(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
for dim in [32, 64, 128]:
|
||||
freqs = rope_params(512, dim)
|
||||
mx.eval(freqs)
|
||||
@@ -27,6 +28,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_params_cos_sin_range(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(256, 64)
|
||||
mx.eval(freqs)
|
||||
cos_vals = np.array(freqs[:, :, 0])
|
||||
@@ -37,13 +39,15 @@ class TestRoPE:
|
||||
def test_rope_params_position_zero(self):
|
||||
"""At position 0, cos should be 1 and sin should be 0."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(10, 64)
|
||||
mx.eval(freqs)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_rope_apply_output_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
@@ -54,7 +58,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_preserves_norm(self):
|
||||
"""RoPE rotation should preserve vector norms."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
@@ -74,7 +79,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_with_padding(self):
|
||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 2, 2
|
||||
seq_len = F * H * W # 8
|
||||
@@ -94,7 +100,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_batch(self):
|
||||
"""Test with batch_size > 1 and different grid sizes."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 2, 2, 16
|
||||
grids = [(2, 3, 4), (2, 3, 4)]
|
||||
L = 2 * 3 * 4
|
||||
@@ -122,9 +129,11 @@ class TestRoPE:
|
||||
# Attention Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanRMSNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
@@ -134,6 +143,7 @@ class TestWanRMSNorm:
|
||||
def test_zero_mean_variance(self):
|
||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||
out = norm(x)
|
||||
@@ -147,6 +157,7 @@ class TestWanRMSNorm:
|
||||
def test_dtype_preservation(self):
|
||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(32)
|
||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||
out = norm(x)
|
||||
@@ -158,6 +169,7 @@ class TestWanRMSNorm:
|
||||
class TestWanLayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
@@ -166,6 +178,7 @@ class TestWanLayerNorm:
|
||||
|
||||
def test_without_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||
x = mx.random.normal((1, 4, 64))
|
||||
out = norm(x)
|
||||
@@ -178,6 +191,7 @@ class TestWanLayerNorm:
|
||||
|
||||
def test_with_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||
assert hasattr(norm, "weight")
|
||||
assert hasattr(norm, "bias")
|
||||
@@ -196,6 +210,7 @@ class TestWanSelfAttention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -207,12 +222,14 @@ class TestWanSelfAttention:
|
||||
|
||||
def test_with_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||
assert attn.norm_q is not None
|
||||
assert attn.norm_k is not None
|
||||
|
||||
def test_without_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
assert attn.norm_q is None
|
||||
assert attn.norm_k is None
|
||||
@@ -221,6 +238,7 @@ class TestWanSelfAttention:
|
||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -245,6 +263,7 @@ class TestWanCrossAttention:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 24, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
@@ -255,6 +274,7 @@ class TestWanCrossAttention:
|
||||
|
||||
def test_with_context_mask(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 12, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
|
||||
# bfloat16 Autocast Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBFloat16Autocast:
|
||||
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||
@@ -292,6 +313,7 @@ class TestBFloat16Autocast:
|
||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -305,6 +327,7 @@ class TestBFloat16Autocast:
|
||||
def test_cross_attn_casts_to_weight_dtype(self):
|
||||
"""Cross-attention should cast input to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -318,6 +341,7 @@ class TestBFloat16Autocast:
|
||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||
"""prepare_kv should cast context to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -330,6 +354,7 @@ class TestBFloat16Autocast:
|
||||
def test_ffn_casts_to_weight_dtype(self):
|
||||
"""FFN should cast input to weight dtype for linear layers."""
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(self.dim, 128)
|
||||
ffn.update(self._to_bf16(ffn.parameters()))
|
||||
|
||||
@@ -343,6 +368,7 @@ class TestBFloat16Autocast:
|
||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_block_float32_residual_with_bf16_weights(self):
|
||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||
block.update(self._to_bf16(block.parameters()))
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Tests for Wan model configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModelConfig:
|
||||
"""Tests for WanModelConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
@@ -33,11 +33,13 @@ class TestWanModelConfig:
|
||||
|
||||
def test_head_dim_property(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.head_dim == 128 # 5120 // 40
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
d = config.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
@@ -47,6 +49,7 @@ class TestWanModelConfig:
|
||||
|
||||
def test_t5_config_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.t5_vocab_size == 256384
|
||||
assert config.t5_dim == 4096
|
||||
@@ -61,11 +64,13 @@ class TestWanModelConfig:
|
||||
# Wan2.1 Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Config:
|
||||
"""Tests for Wan2.1 config presets."""
|
||||
|
||||
def test_wan21_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
@@ -81,6 +86,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_1_3b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
@@ -93,6 +99,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan22_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
@@ -104,6 +111,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
@@ -112,6 +120,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_1_3b_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
d = config.to_dict()
|
||||
assert d["dim"] == 1536
|
||||
@@ -120,6 +129,7 @@ class TestWan21Config:
|
||||
def test_default_config_is_wan22(self):
|
||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
|
||||
@@ -3,17 +3,16 @@
|
||||
import logging
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Weight Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeTransformerWeights:
|
||||
def test_patch_embedding_reshape(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
@@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_text_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.0.bias": mx.zeros((64,)),
|
||||
@@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_time_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
@@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_time_projection_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"time_projection.1.bias": mx.zeros((384,)),
|
||||
@@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_ffn_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
||||
@@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_freqs_skipped(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
"blocks.0.norm1.weight": mx.zeros((64,)),
|
||||
@@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_passthrough_keys(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
||||
@@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights:
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
@@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights:
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
@@ -133,6 +140,7 @@ class TestSanitizeT5Weights:
|
||||
|
||||
def test_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
||||
@@ -144,6 +152,7 @@ class TestSanitizeT5Weights:
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
@@ -159,6 +168,7 @@ class TestSanitizeT5Weights:
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||
}
|
||||
@@ -167,6 +177,7 @@ class TestSanitizeVAEWeights:
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||
}
|
||||
@@ -175,6 +186,7 @@ class TestSanitizeVAEWeights:
|
||||
|
||||
def test_non_conv_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
@@ -185,6 +197,7 @@ class TestSanitizeVAEWeights:
|
||||
|
||||
def test_mixed_weights(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
||||
@@ -199,6 +212,7 @@ class TestSanitizeVAEWeights:
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
|
||||
# Wan2.1 Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Convert:
|
||||
"""Tests for Wan2.1 conversion support."""
|
||||
|
||||
@@ -222,7 +237,7 @@ class TestWan21Convert:
|
||||
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
||||
(tmp_path / "dummy.safetensors").touch()
|
||||
# The auto-detect logic: no low_noise_model dir → 2.1
|
||||
from pathlib import Path
|
||||
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert not low.exists()
|
||||
# Simulates auto detection
|
||||
@@ -233,7 +248,7 @@ class TestWan21Convert:
|
||||
"""Auto-detect dual-model directory as Wan2.2."""
|
||||
(tmp_path / "low_noise_model").mkdir()
|
||||
(tmp_path / "high_noise_model").mkdir()
|
||||
from pathlib import Path
|
||||
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert low.exists()
|
||||
version = "2.2" if low.exists() else "2.1"
|
||||
@@ -242,6 +257,7 @@ class TestWan21Convert:
|
||||
def test_wan21_config_saved_correctly(self):
|
||||
"""Verify config dict has correct fields for Wan2.1."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
@@ -254,6 +270,7 @@ class TestWan21Convert:
|
||||
# Encoder Weight Sanitization Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeEncoderWeights:
|
||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end tiny model forward pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end test with tiny model (no real weights needed)."""
|
||||
|
||||
@@ -78,6 +76,7 @@ class TestEndToEnd:
|
||||
# I2V Mask Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
z_shape = (48, 21, 44, 80)
|
||||
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
|
||||
|
||||
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension Alignment Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDimensionAlignment:
|
||||
"""Tests for automatic dimension alignment in generate_wan."""
|
||||
|
||||
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
|
||||
patches, grid_size = model._patchify(vid)
|
||||
mx.eval(patches)
|
||||
assert patches.ndim == 3 # [1, L, dim]
|
||||
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
|
||||
assert grid_size == (
|
||||
t_latent,
|
||||
h_latent // patch_size[1],
|
||||
w_latent // patch_size[2],
|
||||
)
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
"""Tests for Wan2.2 I2V-14B support."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
@@ -145,7 +142,10 @@ class TestModelYParameter:
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y = mx.random.normal((C_y, F, H, W))
|
||||
t = mx.array([500.0, 500.0])
|
||||
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
|
||||
ctx = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
||||
mx.eval(out[0], out[1])
|
||||
@@ -160,7 +160,9 @@ class TestVAEEncoder:
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
enc = Encoder3d(
|
||||
dim=32, z_dim=8
|
||||
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
assert enc.conv1 is not None
|
||||
assert len(enc.downsamples) > 0
|
||||
assert len(enc.middle) == 3
|
||||
@@ -199,10 +201,10 @@ class TestVAEEncoder:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, 'encoder')
|
||||
assert not hasattr(vae_no_enc, "encoder")
|
||||
|
||||
vae_enc = WanVAE(z_dim=4, encoder=True)
|
||||
assert hasattr(vae_enc, 'encoder')
|
||||
assert hasattr(vae_enc, "encoder")
|
||||
|
||||
|
||||
class TestResampleDownsample:
|
||||
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
|
||||
|
||||
# Build mask following reference logic
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
||||
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
||||
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
|
||||
config = _make_tiny_i2v_config()
|
||||
config.vae_z_dim = 16
|
||||
config.out_dim = 16 # must match VAE z_dim for decode
|
||||
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
config.in_dim = (
|
||||
16 + 4 + 16
|
||||
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
model = WanModel(config)
|
||||
|
||||
# --- Tiny VAE (with encoder) ---
|
||||
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
|
||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||
video = mx.concatenate([
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
], axis=2)
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
|
||||
# --- VAE encode ---
|
||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
|
||||
|
||||
# --- Build I2V mask (4 channels) ---
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||
@@ -461,9 +476,9 @@ class TestDualModelSwitching:
|
||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||
if high_used_steps and low_used_steps:
|
||||
assert max(high_used_steps) < min(low_used_steps) or \
|
||||
min(high_used_steps) < max(low_used_steps), \
|
||||
"Model switching should happen during the loop"
|
||||
assert max(high_used_steps) < min(low_used_steps) or min(
|
||||
high_used_steps
|
||||
) < max(low_used_steps), "Model switching should happen during the loop"
|
||||
|
||||
assert latents.shape == (C_noise, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# Verify both guide scales were used
|
||||
|
||||
@@ -4,7 +4,6 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -40,7 +39,9 @@ class TestLoRATypes:
|
||||
|
||||
lora_a = mx.ones((2, 4))
|
||||
lora_b = mx.ones((8, 2))
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||
delta = applied.compute_delta()
|
||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||
@@ -51,7 +52,9 @@ class TestLoRATypes:
|
||||
class TestLoRALoader:
|
||||
"""Test LoRA weight loading from safetensors."""
|
||||
|
||||
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
||||
def _make_lora_file(
|
||||
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
|
||||
):
|
||||
"""Helper to create a mock LoRA safetensors file."""
|
||||
weights = {}
|
||||
for name in module_names:
|
||||
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
|
||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||
keys = set()
|
||||
for i in range(2):
|
||||
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
||||
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
||||
for layer in [
|
||||
"self_attn.q",
|
||||
"self_attn.k",
|
||||
"self_attn.v",
|
||||
"self_attn.o",
|
||||
"cross_attn.q",
|
||||
"cross_attn.k",
|
||||
"cross_attn.v",
|
||||
"cross_attn.o",
|
||||
]:
|
||||
keys.add(f"blocks.{i}.{layer}.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
||||
assert (
|
||||
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
|
||||
== "blocks.0.self_attn.q"
|
||||
)
|
||||
|
||||
def test_strip_diffusion_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
||||
result = _normalize_wan_lora_key(
|
||||
"model.diffusion_model.blocks.0.self_attn.k", keys
|
||||
)
|
||||
assert result == "blocks.0.self_attn.k"
|
||||
|
||||
def test_ffn_key_mapping(self):
|
||||
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
assert (
|
||||
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
)
|
||||
|
||||
def test_combined_prefix_and_ffn(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
@@ -219,7 +237,9 @@ class TestApplyLoRA:
|
||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||
|
||||
@@ -230,7 +250,9 @@ class TestApplyLoRA:
|
||||
original = mx.ones((8, 4), dtype=mx.float16)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||
|
||||
@@ -241,7 +263,9 @@ class TestApplyLoRA:
|
||||
original = mx.ones((8, 4))
|
||||
lora_a = mx.ones((2, 4)) * 0.1
|
||||
lora_b = mx.ones((8, 2)) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||
expected = original + 0.02 * mx.ones((8, 4))
|
||||
@@ -255,12 +279,16 @@ class TestApplyLoRA:
|
||||
w1 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)),
|
||||
lora_B=mx.ones((8, 2)),
|
||||
rank=2, alpha=2.0, module_name="a",
|
||||
rank=2,
|
||||
alpha=2.0,
|
||||
module_name="a",
|
||||
)
|
||||
w2 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)) * 2,
|
||||
lora_B=mx.ones((8, 2)) * 2,
|
||||
rank=2, alpha=4.0, module_name="b",
|
||||
rank=2,
|
||||
alpha=4.0,
|
||||
module_name="b",
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||
@@ -282,7 +310,9 @@ class TestApplyLoRA:
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.ones((4, 64)) * 0.01,
|
||||
lora_B=mx.ones((128, 4)) * 0.01,
|
||||
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
||||
rank=4,
|
||||
alpha=4.0,
|
||||
module_name="blocks.0.self_attn.q",
|
||||
)
|
||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||
@@ -319,9 +349,7 @@ class TestEndToEnd:
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
}
|
||||
|
||||
result = load_and_apply_loras(
|
||||
model_weights, [(str(lora_path), 1.0)]
|
||||
)
|
||||
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
|
||||
|
||||
# q weight should be modified, k unchanged
|
||||
assert not mx.array_equal(
|
||||
|
||||
@@ -3,18 +3,17 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sinusoidal Embedding Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSinusoidalEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.arange(10).astype(mx.float32)
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
mx.eval(emb)
|
||||
@@ -23,6 +22,7 @@ class TestSinusoidalEmbedding:
|
||||
def test_position_zero(self):
|
||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0])
|
||||
emb = sinusoidal_embedding_1d(64, pos)
|
||||
mx.eval(emb)
|
||||
@@ -34,6 +34,7 @@ class TestSinusoidalEmbedding:
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 999.0])
|
||||
emb = sinusoidal_embedding_1d(128, pos)
|
||||
mx.eval(emb)
|
||||
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
|
||||
# Head Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHead:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
B, L = 1, 24
|
||||
x = mx.random.normal((B, L, 64))
|
||||
@@ -60,6 +63,7 @@ class TestHead:
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
assert head.modulation.shape == (1, 2, 64)
|
||||
|
||||
@@ -68,12 +72,14 @@ class TestHead:
|
||||
# WanModel (Tiny) Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModel:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||
@@ -81,6 +87,7 @@ class TestWanModel:
|
||||
|
||||
def test_patchify_shape(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
# Input: [C=4, F=1, H=4, W=4]
|
||||
@@ -93,6 +100,7 @@ class TestWanModel:
|
||||
|
||||
def test_patchify_various_sizes(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||
@@ -108,6 +116,7 @@ class TestWanModel:
|
||||
def test_unpatchify_inverse(self):
|
||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 2, 4, 6
|
||||
@@ -123,6 +132,7 @@ class TestWanModel:
|
||||
|
||||
def test_forward_pass(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
@@ -140,6 +150,7 @@ class TestWanModel:
|
||||
|
||||
def test_forward_batch(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
@@ -148,7 +159,10 @@ class TestWanModel:
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0, 200.0])
|
||||
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
||||
context = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((4, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0], out[1])
|
||||
@@ -158,12 +172,17 @@ class TestWanModel:
|
||||
|
||||
def test_output_is_float32(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))], seq_len)
|
||||
out = model(
|
||||
[mx.random.normal((C, F, H, W))],
|
||||
mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))],
|
||||
seq_len,
|
||||
)
|
||||
mx.eval(out[0])
|
||||
assert out[0].dtype == mx.float32
|
||||
|
||||
@@ -172,6 +191,7 @@ class TestWanModel:
|
||||
# Wan2.1 Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Model:
|
||||
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||
|
||||
@@ -181,6 +201,7 @@ class TestWan21Model:
|
||||
def _make_tiny_wan21_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
@@ -197,6 +218,7 @@ class TestWan21Model:
|
||||
def _make_tiny_wan21_1_3b_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||
config.dim = 48
|
||||
@@ -271,7 +293,9 @@ class TestWan21Model:
|
||||
for i in range(3):
|
||||
t = sched.timesteps[i]
|
||||
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
||||
pred_uncond = model(
|
||||
[latents], mx.array([t.item()]), [context_null], seq_len
|
||||
)[0]
|
||||
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
@@ -304,6 +328,7 @@ class TestWan21Model:
|
||||
# Per-Token Timestep Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerTokenTimestep:
|
||||
"""Tests for per-token sinusoidal embedding."""
|
||||
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
"""Tests for Wan model quantization pipeline."""
|
||||
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantize Predicate Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizePredicate:
|
||||
def test_matches_self_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.self_attn.{suffix}"
|
||||
@@ -24,6 +24,7 @@ class TestQuantizePredicate:
|
||||
|
||||
def test_matches_cross_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.cross_attn.{suffix}"
|
||||
@@ -31,23 +32,31 @@ class TestQuantizePredicate:
|
||||
|
||||
def test_matches_ffn_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||
|
||||
def test_rejects_embeddings(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
|
||||
for path in [
|
||||
"patch_embedding_proj",
|
||||
"text_embedding_fc1",
|
||||
"time_embedding.fc1",
|
||||
]:
|
||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||
|
||||
def test_rejects_norms(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||
|
||||
def test_rejects_non_quantizable_modules(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
# Even if path matches, module must have to_quantized
|
||||
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
||||
@@ -55,13 +64,19 @@ class TestQuantizePredicate:
|
||||
def test_all_10_patterns_covered(self):
|
||||
"""Verify exactly 10 layer patterns are targeted."""
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
patterns = [
|
||||
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
|
||||
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
|
||||
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
|
||||
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
|
||||
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
|
||||
"blocks.0.self_attn.q",
|
||||
"blocks.0.self_attn.k",
|
||||
"blocks.0.self_attn.v",
|
||||
"blocks.0.self_attn.o",
|
||||
"blocks.0.cross_attn.q",
|
||||
"blocks.0.cross_attn.k",
|
||||
"blocks.0.cross_attn.v",
|
||||
"blocks.0.cross_attn.o",
|
||||
"blocks.0.ffn.fc1",
|
||||
"blocks.0.ffn.fc2",
|
||||
]
|
||||
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
||||
assert len(matched) == 10
|
||||
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
|
||||
# Quantize Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -101,8 +117,10 @@ class TestQuantizeRoundTrip:
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
@@ -119,8 +137,10 @@ class TestQuantizeRoundTrip:
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 8, "group_size": 64},
|
||||
)
|
||||
|
||||
@@ -132,8 +152,10 @@ class TestQuantizeRoundTrip:
|
||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
@@ -151,6 +173,7 @@ class TestQuantizeRoundTrip:
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
|
||||
loaded = load_wan_model(model_path, config, quantization=None)
|
||||
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
||||
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
|
||||
# Quantized Inference Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -214,8 +238,8 @@ class TestQuantizedInference:
|
||||
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
@@ -243,11 +267,12 @@ class TestQuantizedInference:
|
||||
# Config Metadata Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
|
||||
assert cfg["quantization"]["group_size"] == 64
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
|
||||
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
|
||||
@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
correct = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
correct = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Wrong: single call
|
||||
wrong = rope_params(1024, d)
|
||||
mx.eval(correct, wrong)
|
||||
|
||||
assert correct.shape == wrong.shape
|
||||
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
||||
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
assert (
|
||||
diff > 0.1
|
||||
), f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
|
||||
def test_each_axis_starts_at_frequency_one(self):
|
||||
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
||||
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
||||
# Temporal axis first freq
|
||||
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="temporal[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
|
||||
)
|
||||
# Height axis first freq (starts at index d_t)
|
||||
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="height[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
|
||||
)
|
||||
# Width axis first freq (starts at index d_t + d_h)
|
||||
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="width[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
|
||||
)
|
||||
|
||||
def test_height_width_frequencies_identical(self):
|
||||
"""Height and width axes should have identical frequency tables.
|
||||
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
height_freqs = f[:, d_t:d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h:]
|
||||
height_freqs = f[:, d_t : d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h :]
|
||||
np.testing.assert_array_equal(height_freqs, width_freqs)
|
||||
|
||||
def test_frequency_range_per_axis(self):
|
||||
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
|
||||
pos1_h = f[1, d_t, 0] # height first freq
|
||||
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
||||
|
||||
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert (
|
||||
pos1_t > 0.5
|
||||
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
||||
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
||||
|
||||
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
freqs_manual = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs_manual = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs_model, freqs_manual)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(freqs_model), np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction"
|
||||
np.array(freqs_model),
|
||||
np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction",
|
||||
)
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
|
||||
assert freqs.shape == (1024, 64, 2)
|
||||
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
|
||||
@pytest.fixture
|
||||
def has_torch(self):
|
||||
try:
|
||||
import torch
|
||||
pass
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pytest.skip("PyTorch not installed")
|
||||
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
|
||||
def test_freqs_match_pytorch_reference(self, has_torch):
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
|
||||
def pt_rope_params(max_seq_len, dim, theta=10000):
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
1.0
|
||||
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
ref = torch.cat([
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
], dim=1)
|
||||
ref = torch.cat(
|
||||
[
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# MLX
|
||||
ours = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
ours = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(ours)
|
||||
|
||||
our_cos = np.array(ours[:, :, 0])
|
||||
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
|
||||
ref_cos = ref.real.float().numpy()
|
||||
ref_sin = ref.imag.float().numpy()
|
||||
|
||||
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
|
||||
err_msg="cos mismatch vs PyTorch reference")
|
||||
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
|
||||
err_msg="sin mismatch vs PyTorch reference")
|
||||
np.testing.assert_allclose(
|
||||
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
|
||||
)
|
||||
|
||||
|
||||
class TestRoPEApplyWithCorrectFreqs:
|
||||
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 1, 4
|
||||
F, H, W = 1, 4, 4
|
||||
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
|
||||
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
||||
# and width was ~0.002. With correct freqs, both are ~1.3.
|
||||
assert height_diff > 0.5, (
|
||||
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
)
|
||||
assert width_diff > 0.5, (
|
||||
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
)
|
||||
assert (
|
||||
height_diff > 0.5
|
||||
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
assert (
|
||||
width_diff > 0.5
|
||||
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
# Height and width should have identical frequency tables → same diffs
|
||||
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables")
|
||||
np.testing.assert_allclose(
|
||||
height_diff,
|
||||
width_diff,
|
||||
rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables",
|
||||
)
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
)
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 2, 4
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
mx.eval(out_online, out_precomp)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_online), np.array(out_precomp), atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match"
|
||||
np.array(out_online),
|
||||
np.array(out_precomp),
|
||||
atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match",
|
||||
)
|
||||
|
||||
@@ -6,14 +6,15 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Euler Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowMatchEulerScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.timesteps is None
|
||||
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_timesteps_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps)
|
||||
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_sigmas_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_terminal_sigma_is_zero(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
|
||||
def test_shift_effect(self):
|
||||
"""Larger shift should push sigmas toward higher values."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched1 = FlowMatchEulerScheduler()
|
||||
sched2 = FlowMatchEulerScheduler()
|
||||
sched1.set_timesteps(20, shift=1.0)
|
||||
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_step_euler(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
|
||||
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||
np.testing.assert_allclose(
|
||||
np.array(result).flatten()[0], expected, rtol=1e-4,
|
||||
np.array(result).flatten()[0],
|
||||
expected,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
assert sched._step_index == 0
|
||||
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
|
||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(steps, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
|
||||
def test_full_denoise_loop(self):
|
||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -141,22 +154,26 @@ class TestComputeSigmas:
|
||||
|
||||
def test_length(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert len(sigmas) == 21 # num_steps + terminal
|
||||
|
||||
def test_terminal_zero(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
|
||||
@@ -169,6 +186,7 @@ class TestComputeSigmas:
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
# Official single-shift: unshifted bounds, then shift once
|
||||
@@ -183,6 +201,7 @@ class TestComputeSigmas:
|
||||
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
# so schedule is nearly linear from ~0.999 to 0
|
||||
@@ -196,6 +215,7 @@ class TestComputeSigmas:
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
@@ -214,6 +234,7 @@ class TestComputeSigmas:
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
@@ -235,12 +256,14 @@ class TestComputeSigmas:
|
||||
class TestFlowDPMPP2MScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 4, 1, 2, 2))
|
||||
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_full_loop_finite(self):
|
||||
"""Full loop with constant velocity should produce finite output."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_first_step_is_first_order(self):
|
||||
"""First step should use 1st-order (no prev_x0 available)."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_second_step_uses_correction(self):
|
||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
|
||||
x0_after_second = sched._prev_x0
|
||||
assert x0_after_second is not None
|
||||
# The stored x0 should differ from the first step's
|
||||
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
|
||||
assert not np.allclose(
|
||||
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
|
||||
)
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
"""Perfect oracle should denoise to target with any solver."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_terminal_sigma_produces_x0(self):
|
||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
class TestFlowUniPCScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.solver_order == 2
|
||||
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(30, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_not_applied_first_step(self):
|
||||
"""First step should skip the corrector (no history)."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_applied_after_first_step(self):
|
||||
"""Steps after the first should use the corrector when enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_disable_corrector(self):
|
||||
"""Disabling corrector on step 0 should still work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 2, 2))
|
||||
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_solver_order_3(self):
|
||||
"""Order 3 should work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_rhos_c_not_hardcoded(self):
|
||||
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||
import math
|
||||
|
||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
|
||||
rhos_c = np.linalg.solve(R, b)
|
||||
|
||||
# History weight should be small (~0.07-0.09), not 0.5
|
||||
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
assert (
|
||||
rhos_c[0] < 0.15
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert (
|
||||
rhos_c[0] > 0.0
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
assert (
|
||||
0.3 < rhos_c[1] < 0.5
|
||||
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler Coherence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedulerCoherence:
|
||||
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||
|
||||
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="DPM++ step 0 should match Euler",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="UniPC step 0 should match Euler",
|
||||
)
|
||||
|
||||
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
|
||||
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||
mx.eval(euler_r, dpm_r, unipc_r)
|
||||
np.testing.assert_allclose(
|
||||
np.array(dpm_r), np.array(euler_r), atol=1e-5,
|
||||
np.array(dpm_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
np.array(unipc_r), np.array(euler_r), atol=1e-5,
|
||||
np.array(unipc_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
|
||||
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(
|
||||
np.array(latents), 0.0, atol=1e-3,
|
||||
np.array(latents),
|
||||
0.0,
|
||||
atol=1e-3,
|
||||
err_msg=f"{name} did not converge to target with oracle",
|
||||
)
|
||||
|
||||
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
|
||||
# Higher-order solvers should not be significantly worse than Euler
|
||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||
eps = 1e-6
|
||||
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert (
|
||||
errors["dpm++"] <= errors["euler"] * 1.5 + eps
|
||||
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
assert (
|
||||
errors["unipc"] <= errors["euler"] * 1.5 + eps
|
||||
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
|
||||
def test_multistep_trajectory_similar_magnitude(self):
|
||||
"""Over a full denoising loop with constant velocity, all solvers
|
||||
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
|
||||
# All solvers should produce results within the same order of magnitude
|
||||
vals = list(final_means.values())
|
||||
ratio = max(vals) / max(min(vals), 1e-10)
|
||||
assert ratio < 10.0, (
|
||||
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
)
|
||||
assert (
|
||||
ratio < 10.0
|
||||
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
|
||||
def test_intermediate_values_finite(self):
|
||||
"""Every intermediate latent value must be finite for all solvers."""
|
||||
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
|
||||
vel = mx.random.normal(shape)
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
assert np.isfinite(np.array(latents)).all(), (
|
||||
f"{name} produced non-finite values at step {i}"
|
||||
)
|
||||
assert np.isfinite(
|
||||
np.array(latents)
|
||||
).all(), f"{name} produced non-finite values at step {i}"
|
||||
|
||||
def test_lambda_boundary_values(self):
|
||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
assert cls._lambda(1.0) == -math.inf, (
|
||||
f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
)
|
||||
assert cls._lambda(0.0) == math.inf, (
|
||||
f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
)
|
||||
assert (
|
||||
cls._lambda(1.0) == -math.inf
|
||||
), f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
assert (
|
||||
cls._lambda(0.0) == math.inf
|
||||
), f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
# Interior values should be finite
|
||||
lam = cls._lambda(0.5)
|
||||
assert math.isfinite(lam) and lam == 0.0, (
|
||||
f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
)
|
||||
assert (
|
||||
math.isfinite(lam) and lam == 0.0
|
||||
), f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
|
||||
def test_lambda_monotonically_decreasing(self):
|
||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
|
||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
np.testing.assert_allclose(
|
||||
np.array(result), np.array(expected), atol=5e-4,
|
||||
np.array(result),
|
||||
np.array(expected),
|
||||
atol=5e-4,
|
||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||
)
|
||||
|
||||
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_dpmpp_unipc_agree_on_step1(self):
|
||||
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
sched = cls()
|
||||
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
|
||||
mx.eval(latents)
|
||||
result2 = np.array(latents)
|
||||
|
||||
np.testing.assert_allclose(result1, result2, atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()")
|
||||
np.testing.assert_allclose(
|
||||
result1,
|
||||
result2,
|
||||
atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Corrector Default Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUniPCCorrectorDefault:
|
||||
"""Tests that the UniPC corrector is enabled by default,
|
||||
matching official FlowUniPCMultistepScheduler behavior."""
|
||||
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
|
||||
def test_corrector_enabled_by_default(self):
|
||||
"""Default construction should have corrector enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched._use_corrector is True
|
||||
|
||||
def test_corrector_affects_output(self):
|
||||
"""Corrector should produce different results than no corrector after step 1."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
|
||||
def test_corrector_does_not_affect_first_step(self):
|
||||
"""Step 0 should be identical regardless of corrector setting."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T5 Encoder Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestT5LayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
@@ -22,6 +22,7 @@ class TestT5LayerNorm:
|
||||
def test_rms_normalization(self):
|
||||
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(128)
|
||||
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||
out = norm(x)
|
||||
@@ -35,6 +36,7 @@ class TestT5LayerNorm:
|
||||
class TestT5RelativeEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(10, 10)
|
||||
mx.eval(out)
|
||||
@@ -42,6 +44,7 @@ class TestT5RelativeEmbedding:
|
||||
|
||||
def test_asymmetric_lengths(self):
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(8, 12)
|
||||
mx.eval(out)
|
||||
@@ -50,6 +53,7 @@ class TestT5RelativeEmbedding:
|
||||
def test_symmetry(self):
|
||||
"""Position bias should have structure (not all zeros/random)."""
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||
out = rel_emb(6, 6)
|
||||
mx.eval(out)
|
||||
@@ -64,6 +68,7 @@ class TestT5RelativeEmbedding:
|
||||
class TestT5Attention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = attn(x)
|
||||
@@ -73,12 +78,14 @@ class TestT5Attention:
|
||||
def test_no_scaling(self):
|
||||
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
# No scale attribute (unlike standard attention)
|
||||
assert not hasattr(attn, "scale")
|
||||
|
||||
def test_with_position_bias(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
rel_emb = T5RelativeEmbedding(32, 4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
@@ -89,6 +96,7 @@ class TestT5Attention:
|
||||
|
||||
def test_with_mask(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
mask = mx.ones((1, 10))
|
||||
@@ -101,6 +109,7 @@ class TestT5Attention:
|
||||
class TestT5FeedForward:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(64, 256)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = ffn(x)
|
||||
@@ -110,6 +119,7 @@ class TestT5FeedForward:
|
||||
def test_gated_structure(self):
|
||||
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(32, 64)
|
||||
assert hasattr(ffn, "gate_proj")
|
||||
assert hasattr(ffn, "fc1")
|
||||
@@ -122,9 +132,16 @@ class TestT5Encoder:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10, 0, 0]])
|
||||
mask = mx.array([[1, 1, 1, 0, 0]])
|
||||
@@ -134,9 +151,16 @@ class TestT5Encoder:
|
||||
|
||||
def test_shared_pos(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=True,
|
||||
)
|
||||
assert encoder.pos_embedding is not None
|
||||
for block in encoder.blocks:
|
||||
@@ -144,9 +168,16 @@ class TestT5Encoder:
|
||||
|
||||
def test_per_layer_pos(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
assert encoder.pos_embedding is None
|
||||
for block in encoder.blocks:
|
||||
@@ -154,18 +185,32 @@ class TestT5Encoder:
|
||||
|
||||
def test_param_count(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
||||
assert num_params > 0
|
||||
|
||||
def test_without_mask(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10]])
|
||||
out = encoder(ids)
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
decode_with_tiling,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
|
||||
|
||||
# Causal: 1 + (4-1)*4 = 13
|
||||
out_causal = decode_with_tiling(
|
||||
dummy_decoder_causal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
|
||||
dummy_decoder_causal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=True,
|
||||
)
|
||||
mx.eval(out_causal)
|
||||
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||
|
||||
# Non-causal: 4*4 = 16
|
||||
out_noncausal = decode_with_tiling(
|
||||
dummy_decoder_noncausal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
|
||||
dummy_decoder_noncausal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=False,
|
||||
)
|
||||
mx.eval(out_noncausal)
|
||||
assert out_noncausal.shape[2] == t * scale # 16
|
||||
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
# Both should produce the same shape
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
|
||||
from mlx_video.models.wan.vae import Decoder3d
|
||||
|
||||
# Small decoder for fast test
|
||||
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False])
|
||||
dec = Decoder3d(
|
||||
dim=16,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False],
|
||||
)
|
||||
mx.eval(dec.parameters())
|
||||
|
||||
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Block Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanFFN:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(64, 256)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = ffn(x)
|
||||
@@ -21,6 +21,7 @@ class TestWanFFN:
|
||||
def test_gelu_activation(self):
|
||||
"""FFN should use GELU activation (non-linearity)."""
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(32, 128)
|
||||
x = mx.ones((1, 1, 32)) * 2.0
|
||||
out1 = ffn(x)
|
||||
@@ -39,10 +40,13 @@ class TestWanAttentionBlock:
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
B, L = 1, 24
|
||||
@@ -53,37 +57,49 @@ class TestWanAttentionBlock:
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(
|
||||
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
|
||||
freqs=freqs, context=context,
|
||||
x,
|
||||
e,
|
||||
seq_lens=[L],
|
||||
grid_sizes=[(F, H, W)],
|
||||
freqs=freqs,
|
||||
context=context,
|
||||
)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
assert block.modulation.shape == (1, 6, self.dim)
|
||||
|
||||
def test_with_cross_attn_norm(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
assert block.norm3 is not None
|
||||
|
||||
def test_without_cross_attn_norm(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=False,
|
||||
)
|
||||
assert block.norm3 is None
|
||||
|
||||
def test_residual_connection(self):
|
||||
"""Output should differ from zero even with small random init."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
B, L = 1, 8
|
||||
F, H, W = 2, 2, 2
|
||||
@@ -102,6 +118,7 @@ class TestWanAttentionBlock:
|
||||
# Float32 Modulation Precision Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFloat32Modulation:
|
||||
"""Tests that modulation/gate operations are computed in float32,
|
||||
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
||||
@@ -113,13 +130,15 @@ class TestFloat32Modulation:
|
||||
def test_block_modulation_in_float32(self):
|
||||
"""Modulation param starts random but should be usable as float32."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||
assert block.modulation.dtype == mx.float32
|
||||
|
||||
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4)
|
||||
B, L = 1, 8
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
@@ -135,6 +154,7 @@ class TestFloat32Modulation:
|
||||
def test_head_modulation_float32(self):
|
||||
"""Head modulation should be float32 even with bf16 e input."""
|
||||
from mlx_video.models.wan.model import Head
|
||||
|
||||
head = Head(self.dim, 4, (1, 2, 2))
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
||||
@@ -145,6 +165,7 @@ class TestFloat32Modulation:
|
||||
def test_model_time_embedding_float32(self):
|
||||
"""sinusoidal_embedding_1d output must be float32."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([500.0])
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
@@ -153,6 +174,7 @@ class TestFloat32Modulation:
|
||||
def test_model_per_token_time_embedding_float32(self):
|
||||
"""Per-token time embeddings (I2V) should also be float32."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
|
||||
@@ -4,16 +4,16 @@ import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE 2.1 Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
@@ -29,6 +29,7 @@ class TestCausalConv3d:
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||
@@ -39,6 +40,7 @@ class TestCausalConv3d:
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros((2,))
|
||||
@@ -55,6 +57,7 @@ class TestCausalConv3d:
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -63,6 +66,7 @@ class TestResidualBlock:
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -71,11 +75,13 @@ class TestResidualBlock:
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
@@ -83,6 +89,7 @@ class TestResidualBlock:
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -91,6 +98,7 @@ class TestAttentionBlock:
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
out = block(x)
|
||||
@@ -102,13 +110,15 @@ class TestAttentionBlock:
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
assert vae.mean.shape == (16,)
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
||||
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
|
||||
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
assert all(s > 0 for s in VAE_STD)
|
||||
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
out = conv(x)
|
||||
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = conv(x)
|
||||
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros(conv.bias.shape)
|
||||
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
|
||||
t0_ref = np.array(out_zero[0, 0])
|
||||
|
||||
# Modify t=2..3; output at t=0 should be unchanged
|
||||
x_mod = mx.concatenate([
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
], axis=1)
|
||||
x_mod = mx.concatenate(
|
||||
[
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
out_mod = conv(x_mod)
|
||||
mx.eval(out_mod)
|
||||
t0_mod = np.array(out_mod[0, 0])
|
||||
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
out = conv(x)
|
||||
@@ -175,6 +192,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
out = norm(x)
|
||||
@@ -184,6 +202,7 @@ class TestRMSNorm:
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||
@@ -197,6 +216,7 @@ class TestRMSNorm:
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
out1 = norm(x)
|
||||
@@ -207,6 +227,7 @@ class TestRMSNorm:
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
x = mx.ones((1, 1, 1, 1, 4))
|
||||
@@ -221,6 +242,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out = up(x)
|
||||
@@ -229,6 +251,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
out = up(x)
|
||||
@@ -237,6 +260,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
@@ -248,6 +272,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
@@ -262,6 +287,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -271,6 +297,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -280,6 +307,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -291,6 +319,7 @@ class TestVAE22Resample:
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||
@@ -308,6 +337,7 @@ class TestVAE22Resample:
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
# Set time_conv weights to large values so its effect is detectable
|
||||
@@ -334,8 +364,9 @@ class TestVAE22Resample:
|
||||
# Compare first output frame to reference
|
||||
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||
mx.eval(first_out)
|
||||
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
||||
"First frame should bypass time_conv and match spatial-only upsample"
|
||||
assert mx.allclose(
|
||||
first_out, ref, atol=1e-5
|
||||
).item(), "First frame should bypass time_conv and match spatial-only upsample"
|
||||
|
||||
|
||||
class TestVAE22ResidualBlock:
|
||||
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
# All param keys should use layer_N, not _layer_N
|
||||
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||
@@ -427,6 +467,7 @@ class TestHead22:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = head(x)
|
||||
@@ -436,6 +477,7 @@ class TestHead22:
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
assert hasattr(head, "layer_2") # CausalConv3d
|
||||
@@ -449,6 +491,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
@@ -456,6 +499,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
mx.eval(out)
|
||||
@@ -464,6 +508,7 @@ class TestUnpatchify:
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||
out = denormalize_latents(z, mean=mean, std=std)
|
||||
mx.eval(out)
|
||||
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
||||
np.testing.assert_allclose(
|
||||
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
|
||||
)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
VAE22_MEAN,
|
||||
denormalize_latents,
|
||||
)
|
||||
|
||||
# Should not raise with default constants
|
||||
z = mx.zeros((1, 1, 1, 1, 48))
|
||||
out = denormalize_latents(z)
|
||||
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
|
||||
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
out = dec(z)
|
||||
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
"conv1.weight": mx.zeros((4,)),
|
||||
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
weights = {"decoder.conv1.weight": w}
|
||||
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
||||
block = Down_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
|
||||
)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
||||
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
found_modes = []
|
||||
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
|
||||
enc_correct = Encoder3d(
|
||||
dim=16, z_dim=8, temperal_downsample=(False, True, True)
|
||||
)
|
||||
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
# No temporal up/downsampling to keep the test simple
|
||||
enc = Encoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||
)
|
||||
dec = Decoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||
)
|
||||
enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
|
||||
dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, C=3, T=1, H=8, W=8]
|
||||
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
|
||||
mx.eval(out)
|
||||
|
||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
|
||||
out_np = np.array(out)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert out_np.min() >= -1.0 - 1e-6
|
||||
assert out_np.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
def _make_tiny_config():
|
||||
"""Create a tiny WanModelConfig for testing."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
|
||||
Reference in New Issue
Block a user