Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -93,18 +93,18 @@ Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.c
### Step 0: Download and Convert Weights ### Step 0: Download and Convert Weights
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan/README.md) for details. See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan_2/README.md) for details.
### Step 1: Generate Video ### Step 1: Generate Video
```bash ```bash
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0) # Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
python -m mlx_video.wan2.generate \ python -m mlx_video.wan_2.generate \
--model-dir wan21_mlx \ --model-dir wan21_mlx \
--prompt "A cat playing piano in a cozy room" --prompt "A cat playing piano in a cozy room"
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0) # Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
python -m mlx_video.wan2.generate \ python -m mlx_video.wan_2.generate \
--model-dir wan22_mlx \ --model-dir wan22_mlx \
--prompt "A cat playing piano in a cozy room" --prompt "A cat playing piano in a cozy room"
``` ```
@@ -112,7 +112,7 @@ python -m mlx_video.wan2.generate \
With custom settings: With custom settings:
```bash ```bash
python -m mlx_video.wan2.generate \ python -m mlx_video.wan_2.generate \
--model-dir wan21_mlx \ --model-dir wan21_mlx \
--prompt "Ocean waves at sunset, cinematic, 4K" \ --prompt "Ocean waves at sunset, cinematic, 4K" \
--negative-prompt "blurry, low quality" \ --negative-prompt "blurry, low quality" \
@@ -131,7 +131,7 @@ The pipeline auto-detects the model version from `config.json` and selects the r
### Image-to-Video (I2V-14B) ### Image-to-Video (I2V-14B)
```bash ```bash
python -m mlx_video.wan2.generate \ python -m mlx_video.wan_2.generate \
--model-dir wan22_i2v_mlx \ --model-dir wan22_i2v_mlx \
--prompt "The camera slowly zooms in as the subject begins to move" \ --prompt "The camera slowly zooms in as the subject begins to move" \
--image start.png \ --image start.png \
@@ -146,7 +146,7 @@ LoRAs can be used with the `--lora-high` and `--lora-low` command line switches.
For example, using the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA for 4-step generation: For example, using the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA for 4-step generation:
```bash ```bash
python -m mlx_vide.wan2.generate \ python -m mlx_video.wan_2.generate \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \ --model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \ --width 480 \
--height 704 \ --height 704 \

View File

@@ -22,7 +22,7 @@ from mlx_video.models.ltx_2.utils import (
load_safetensors, load_safetensors,
save_weights, save_weights,
) )
from mlx_video.models.wan2 import WanModel, WanModelConfig from mlx_video.models.wan_2 import WanModel, WanModelConfig
__all__ = [ __all__ = [
# Models # Models

View File

@@ -1,2 +1,2 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan2 import WanModel, WanModelConfig from mlx_video.models.wan_2 import WanModel, WanModelConfig

View File

@@ -1,2 +0,0 @@
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.wan2 import WanModel

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.wan_2 import WanModel

View File

@@ -247,7 +247,7 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths. Shared between weight-merging and runtime-wrapping paths.
""" """
from mlx_video.models.wan2.generate import Colors from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -282,7 +282,7 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
""" """
from mlx_video.models.wan2.generate import Colors from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import apply_loras_to_weights from mlx_video.lora import apply_loras_to_weights
if not lora_configs: if not lora_configs:
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
print(" Warning: No transformer weights found!") print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights # Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
def _detect_config(): def _detect_config():
"""Detect config from source config.json or transformer weight shapes.""" """Detect config from source config.json or transformer weight shapes."""
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...") print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path)) weights = load_torch_weights(str(vae_path))
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v") include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights( weights = sanitize_wan22_vae_weights(
@@ -594,7 +594,7 @@ def _quantize_saved_model(
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
if source_dir is None: if source_dir is None:
source_dir = output_dir source_dir = output_dir
@@ -704,7 +704,7 @@ def quantize_mlx_model(
).exists() ).exists()
# Build model config # Build model config
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config_dict = { config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__ k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__

View File

@@ -11,15 +11,15 @@ import mlx.core as mx
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan2.utils import ( from mlx_video.models.wan_2.utils import (
encode_text, encode_text,
load_t5_encoder, load_t5_encoder,
load_vae_decoder, load_vae_decoder,
load_vae_encoder, load_vae_encoder,
load_wan_model, load_wan_model,
) )
from mlx_video.models.wan2.postprocess import save_video from mlx_video.models.wan_2.postprocess import save_video
class Colors: class Colors:
@@ -121,8 +121,8 @@ def generate_video(
""" """
import json import json
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
@@ -767,7 +767,7 @@ def generate_video(
) )
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan2.vae22 import denormalize_latents from mlx_video.models.wan_2.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None] z = latents.transpose(1, 2, 3, 0)[None]

View File

@@ -21,12 +21,12 @@ def load_wan_model(
If provided, creates QuantizedLinear stubs before loading. If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply. loras: Optional list of (lora_path, strength) tuples to apply.
""" """
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config) model = WanModel(config)
if quantization: if quantization:
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
nn.quantize( nn.quantize(
model, model,
@@ -42,7 +42,7 @@ def load_wan_model(
if quantization: if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear. # Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead. # Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.models.wan2.convert import _load_lora_configs from mlx_video.models.wan_2.convert import _load_lora_configs
from mlx_video.lora import apply_loras_to_model from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False) model.load_weights(list(weights.items()), strict=False)
@@ -53,7 +53,7 @@ def load_wan_model(
return model return model
else: else:
# Weight merging: fold LoRA into bf16 weights before loading # Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.models.wan2.convert import load_and_apply_loras from mlx_video.models.wan_2.convert import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras) weights = load_and_apply_loras(dict(weights), loras)
@@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config):
only runs once per generation, so performance impact is negligible. only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly. This matches the official which computes softmax in float32 explicitly.
""" """
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=config.t5_vocab_size, vocab_size=config.t5_vocab_size,
@@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None):
is_wan22 = config is not None and config.vae_z_dim == 48 is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22: if is_wan22:
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48) vae = Wan22VAEDecoder(z_dim=48)
else: else:
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
@@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None):
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True. For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
""" """
if config is not None and config.vae_z_dim == 16: if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True) vae = WanVAE(z_dim=16, encoder=True)
else: else:
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48) vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)

View File

@@ -589,7 +589,7 @@ class WanVAE(nn.Module):
Returns: Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
""" """
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None: if tiling_config is None:
tiling_config = TilingConfig.default() tiling_config = TilingConfig.default()

View File

@@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module):
Returns: Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1] video: [B, T', H', W', 3] decoded RGB in [-1, 1]
""" """
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None: if tiling_config is None:
tiling_config = TilingConfig.default() tiling_config = TilingConfig.default()

View File

@@ -47,7 +47,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues"
[project.scripts] [project.scripts]
"mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main" "mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main"
"mlx_video.wan2.generate" = "mlx_video.models.wan2.generate:main" "mlx_video.wan_2.generate" = "mlx_video.models.wan_2.generate:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["mlx_video*"] include = ["mlx_video*"]

View File

@@ -12,14 +12,14 @@ class TestRoPE:
"""Tests for 3-way factorized RoPE.""" """Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self): def test_rope_params_shape(self):
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(1024, 64) freqs = rope_params(1024, 64)
mx.eval(freqs) mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2] assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self): def test_rope_params_different_dims(self):
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
for dim in [32, 64, 128]: for dim in [32, 64, 128]:
freqs = rope_params(512, dim) freqs = rope_params(512, dim)
@@ -27,7 +27,7 @@ class TestRoPE:
assert freqs.shape == (512, dim // 2, 2) assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self): def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(256, 64) freqs = rope_params(256, 64)
mx.eval(freqs) mx.eval(freqs)
@@ -38,7 +38,7 @@ class TestRoPE:
def test_rope_params_position_zero(self): def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0.""" """At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(10, 64) freqs = rope_params(10, 64)
mx.eval(freqs) mx.eval(freqs)
@@ -46,7 +46,7 @@ class TestRoPE:
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6) np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self): def test_rope_apply_output_shape(self):
from mlx_video.models.wan2.rope import rope_apply, rope_params from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D)) x = mx.random.normal((B, L, N, D))
@@ -58,7 +58,7 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self): def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms.""" """RoPE rotation should preserve vector norms."""
from mlx_video.models.wan2.rope import rope_apply, rope_params from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16 B, N, D = 1, 2, 16
F, H, W = 2, 3, 4 F, H, W = 2, 3, 4
@@ -79,7 +79,7 @@ class TestRoPE:
def test_rope_apply_with_padding(self): def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged.""" """When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan2.rope import rope_apply, rope_params from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16 B, N, D = 1, 2, 16
F, H, W = 2, 2, 2 F, H, W = 2, 2, 2
@@ -100,7 +100,7 @@ class TestRoPE:
def test_rope_apply_batch(self): def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes.""" """Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan2.rope import rope_apply, rope_params from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 2, 2, 16 B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)] grids = [(2, 3, 4), (2, 3, 4)]
@@ -132,7 +132,7 @@ class TestRoPE:
class TestWanRMSNorm: class TestWanRMSNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.attention import WanRMSNorm from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64) norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
@@ -142,7 +142,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self): def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling.""" """RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan2.attention import WanRMSNorm from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64) norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0 x = mx.random.normal((1, 5, 64)) * 10.0
@@ -156,7 +156,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self): def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32.""" """RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan2.attention import WanRMSNorm from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(32) norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16) x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
@@ -168,7 +168,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm: class TestWanLayerNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.attention import WanLayerNorm from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64) norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
@@ -177,7 +177,7 @@ class TestWanLayerNorm:
assert out.shape == (2, 10, 64) assert out.shape == (2, 10, 64)
def test_without_affine(self): def test_without_affine(self):
from mlx_video.models.wan2.attention import WanLayerNorm from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False) norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64)) x = mx.random.normal((1, 4, 64))
@@ -190,7 +190,7 @@ class TestWanLayerNorm:
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1) np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self): def test_with_affine(self):
from mlx_video.models.wan2.attention import WanLayerNorm from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True) norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight") assert hasattr(norm, "weight")
@@ -208,8 +208,8 @@ class TestWanSelfAttention:
self.num_heads = 4 self.num_heads = 4
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24 B, L = 1, 24
@@ -221,14 +221,14 @@ class TestWanSelfAttention:
assert out.shape == (B, L, self.dim) assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self): def test_with_qk_norm(self):
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None assert attn.norm_q is not None
assert attn.norm_k is not None assert attn.norm_k is not None
def test_without_qk_norm(self): def test_without_qk_norm(self):
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None assert attn.norm_q is None
@@ -236,8 +236,8 @@ class TestWanSelfAttention:
def test_masking(self): def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens.""" """Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24 B, L = 1, 24
@@ -262,7 +262,7 @@ class TestWanCrossAttention:
self.num_heads = 4 self.num_heads = 4
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.attention import WanCrossAttention from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16 B, L_q, L_kv = 1, 24, 16
@@ -273,7 +273,7 @@ class TestWanCrossAttention:
assert out.shape == (B, L_q, self.dim) assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self): def test_with_context_mask(self):
from mlx_video.models.wan2.attention import WanCrossAttention from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16 B, L_q, L_kv = 1, 12, 16
@@ -311,8 +311,8 @@ class TestBFloat16Autocast:
def test_self_attn_casts_to_weight_dtype(self): def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections.""" """Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -326,7 +326,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self): def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype.""" """Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan2.attention import WanCrossAttention from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -340,7 +340,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self): def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype.""" """prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan2.attention import WanCrossAttention from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads) attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -353,7 +353,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self): def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers.""" """FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan2.transformer import WanFFN from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(self.dim, 128) ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters())) ffn.update(self._to_bf16(ffn.parameters()))
@@ -366,8 +366,8 @@ class TestBFloat16Autocast:
def test_self_attn_rope_in_float32(self): def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights.""" """RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan2.attention import WanSelfAttention from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads) attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters())) attn.update(self._to_bf16(attn.parameters()))
@@ -381,8 +381,8 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self): def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights.""" """Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True) block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters())) block.update(self._to_bf16(block.parameters()))

View File

@@ -10,7 +10,7 @@ class TestWanModelConfig:
"""Tests for WanModelConfig dataclass.""" """Tests for WanModelConfig dataclass."""
def test_default_values(self): def test_default_values(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.dim == 5120 assert config.dim == 5120
@@ -32,13 +32,13 @@ class TestWanModelConfig:
assert config.text_len == 512 assert config.text_len == 512
def test_head_dim_property(self): def test_head_dim_property(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40 assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self): def test_to_dict_roundtrip(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
d = config.to_dict() d = config.to_dict()
@@ -48,7 +48,7 @@ class TestWanModelConfig:
assert d["boundary"] == 0.875 assert d["boundary"] == 0.875
def test_t5_config_values(self): def test_t5_config_values(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.t5_vocab_size == 256384 assert config.t5_vocab_size == 256384
@@ -69,7 +69,7 @@ class TestWan21Config:
"""Tests for Wan2.1 config presets.""" """Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self): def test_wan21_14b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1" assert config.model_version == "2.1"
@@ -85,7 +85,7 @@ class TestWan21Config:
assert config.boundary == 0.0 assert config.boundary == 0.0
def test_wan21_1_3b_factory(self): def test_wan21_1_3b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1" assert config.model_version == "2.1"
@@ -98,7 +98,7 @@ class TestWan21Config:
assert config.sample_guide_scale == 5.0 assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self): def test_wan22_14b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b() config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2" assert config.model_version == "2.2"
@@ -110,7 +110,7 @@ class TestWan21Config:
assert config.boundary == 0.875 assert config.boundary == 0.875
def test_wan21_config_to_dict(self): def test_wan21_config_to_dict(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict() d = config.to_dict()
@@ -119,7 +119,7 @@ class TestWan21Config:
assert d["sample_guide_scale"] == 5.0 assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self): def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict() d = config.to_dict()
@@ -128,7 +128,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self): def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B.""" """Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
assert config.model_version == "2.2" assert config.model_version == "2.2"

View File

@@ -11,7 +11,7 @@ import mlx.core as mx
class TestSanitizeTransformerWeights: class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self): def test_patch_embedding_reshape(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2) assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self): def test_text_embedding_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"text_embedding.0.weight": mx.zeros((64, 32)), "text_embedding.0.weight": mx.zeros((64, 32)),
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
assert "text_embedding_1.bias" in out assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self): def test_time_embedding_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"time_embedding.0.weight": mx.zeros((64, 32)), "time_embedding.0.weight": mx.zeros((64, 32)),
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
assert "time_embedding_1.weight" in out assert "time_embedding_1.weight" in out
def test_time_projection_rename(self): def test_time_projection_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"time_projection.1.weight": mx.zeros((384, 64)), "time_projection.1.weight": mx.zeros((384, 64)),
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
assert "time_projection.bias" in out assert "time_projection.bias" in out
def test_ffn_rename(self): def test_ffn_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.0.weight": mx.zeros((128, 64)),
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.ffn.fc2.bias" in out assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self): def test_freqs_skipped(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"freqs": mx.zeros((1024, 64, 2)), "freqs": mx.zeros((1024, 64, 2)),
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.norm1.weight" in out assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self): def test_passthrough_keys(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)), "blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
assert key in out assert key in out
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = { weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
"head.head.weight": mx.zeros((64, 64)), "head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)), "freqs": mx.zeros((1024, 64, 2)),
} }
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_transformer_weights(weights) sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights: class TestSanitizeT5Weights:
def test_gate_rename(self): def test_gate_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = { weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
assert "blocks.0.ffn.fc2.weight" in out assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self): def test_passthrough(self):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = { weights = {
"token_embedding.weight": mx.zeros((100, 64)), "token_embedding.weight": mx.zeros((100, 64)),
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
assert key in out assert key in out
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = { weights = {
"token_embedding.weight": mx.zeros((100, 64)), "token_embedding.weight": mx.zeros((100, 64)),
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)), "blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)), "norm.weight": mx.zeros((64,)),
} }
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_t5_weights(weights) sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights: class TestSanitizeVAEWeights:
def test_conv3d_transpose(self): def test_conv3d_transpose(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = { weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W] "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I] assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self): def test_conv2d_transpose(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = { weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W] "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I] assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self): def test_non_conv_passthrough(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = { weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose "decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.bias"].shape == (16,) assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self): def test_mixed_weights(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = { weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D "conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
assert out["norm.weight"].shape == (8,) assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = { weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
"decoder.norm.weight": mx.zeros((64,)), "decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)), "decoder.bias": mx.zeros((16,)),
} }
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_vae_weights(weights) sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text assert "Unconsumed" not in caplog.text
@@ -256,7 +256,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self): def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1.""" """Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict() d = config.to_dict()
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder.""" """Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self): def test_exclude_encoder_by_default(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
assert not any("encoder" in k or k.startswith("conv1") for k in out) assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self): def test_include_encoder(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
assert "conv2.weight" in out assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog): def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
} }
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True) sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog): def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
} }
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False) sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text assert "Unconsumed" not in caplog.text

View File

@@ -14,8 +14,8 @@ class TestEndToEnd:
def test_tiny_model_denoise_step(self): def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model.""" """Simulate one denoising step with tiny model."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(42) mx.random.seed(42)
config = _make_tiny_config() config = _make_tiny_config()
@@ -43,8 +43,8 @@ class TestEndToEnd:
def test_tiny_model_full_loop(self): def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop.""" """Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(123) mx.random.seed(123)
config = _make_tiny_config() config = _make_tiny_config()
@@ -81,7 +81,7 @@ class TestI2VMask:
"""Tests for _build_i2v_mask.""" """Tests for _build_i2v_mask."""
def test_mask_shapes(self): def test_mask_shapes(self):
from mlx_video.models.wan2.generate import _build_i2v_mask from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2) patch_size = (1, 2, 2)
@@ -91,7 +91,7 @@ class TestI2VMask:
assert mask_tokens.shape == (1, 20) assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self): def test_first_frame_zero(self):
from mlx_video.models.wan2.generate import _build_i2v_mask from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2)) mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self): def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions.""" """Mask should work with TI2V-5B typical dimensions."""
from mlx_video.models.wan2.generate import _build_i2v_mask from mlx_video.models.wan_2.generate import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2) # TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames # 704x1280 → latent 44x80, t_latent=21 for 81 frames
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self): def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma.""" """Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.models.wan2.generate import _build_i2v_mask from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (4, 3, 4, 4) z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2) patch_size = (1, 2, 2)
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self): def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors.""" """After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
def test_alignment_with_ti2v_config(self): def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32.""" """TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b() config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1] align_h = config.patch_size[1] * config.vae_stride[1]

View File

@@ -23,7 +23,7 @@ class TestI2VConfig:
"""Test I2V-14B config preset.""" """Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self): def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b() config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v" assert config.model_type == "i2v"
@@ -39,7 +39,7 @@ class TestI2VConfig:
assert config.vae_z_dim == 16 assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self): def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b() i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b() t2v = WanModelConfig.wan22_t2v_14b()
@@ -51,7 +51,7 @@ class TestI2VConfig:
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0 assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self): def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b() config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict() d = config.to_dict()
@@ -66,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self): def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works.""" """Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -85,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self): def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation.""" """I2V forward pass with y channel concatenation."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config() config = _make_tiny_i2v_config()
model = WanModel(config) model = WanModel(config)
@@ -108,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self): def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y.""" """Passing y=None should be identical to not passing y."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -129,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self): def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work.""" """Batched CFG (B=2) with y should work."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config() config = _make_tiny_i2v_config()
model = WanModel(config) model = WanModel(config)
@@ -158,7 +158,7 @@ class TestVAEEncoder:
"""Test Wan2.1 VAE encoder.""" """Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self): def test_encoder3d_instantiation(self):
from mlx_video.models.wan2.vae import Encoder3d from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d( enc = Encoder3d(
dim=32, z_dim=8 dim=32, z_dim=8
@@ -169,7 +169,7 @@ class TestVAEEncoder:
def test_encoder3d_output_shape(self): def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x.""" """Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan2.vae import Encoder3d from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32] # Random input: [B=1, 3, T=5, H=32, W=32]
@@ -186,7 +186,7 @@ class TestVAEEncoder:
def test_wan_vae_encode(self): def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents.""" """WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True) vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32] # Input: [B=1, 3, T=5, H=32, W=32]
@@ -198,7 +198,7 @@ class TestVAEEncoder:
def test_wan_vae_encoder_flag(self): def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute.""" """WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False) vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, "encoder") assert not hasattr(vae_no_enc, "encoder")
@@ -211,7 +211,7 @@ class TestResampleDownsample:
"""Test downsample modes in Resample.""" """Test downsample modes in Resample."""
def test_downsample2d(self): def test_downsample2d(self):
from mlx_video.models.wan2.vae import Resample from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample2d") r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8)) x = mx.random.normal((1, 16, 2, 8, 8))
@@ -221,7 +221,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4) assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self): def test_downsample3d(self):
from mlx_video.models.wan2.vae import Resample from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample3d") r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8)) x = mx.random.normal((1, 16, 4, 8, 8))
@@ -231,7 +231,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4) assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self): def test_upsample2d_still_works(self):
from mlx_video.models.wan2.vae import Resample from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample2d") r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4)) x = mx.random.normal((1, 16, 2, 4, 4))
@@ -240,7 +240,7 @@ class TestResampleDownsample:
assert out.shape == (1, 8, 2, 8, 8) assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self): def test_upsample3d_still_works(self):
from mlx_video.models.wan2.vae import Resample from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample3d") r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4)) x = mx.random.normal((1, 16, 2, 4, 4))
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self): def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode.""" """End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
mx.random.seed(0) mx.random.seed(0)
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self): def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise.""" """Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1) mx.random.seed(1)
config = _make_tiny_i2v_config() config = _make_tiny_i2v_config()
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
def test_guide_scale_tuple_applied_per_model(self): def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model.""" """Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2) mx.random.seed(2)
config = _make_tiny_i2v_config() config = _make_tiny_i2v_config()
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self): def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element.""" """When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3) mx.random.seed(3)
config = _make_tiny_config() config = _make_tiny_config()

View File

@@ -331,7 +331,7 @@ class TestEndToEnd:
"""End-to-end LoRA loading and application.""" """End-to-end LoRA loading and application."""
def test_load_and_apply_loras(self): def test_load_and_apply_loras(self):
from mlx_video.models.wan2.convert import load_and_apply_loras from mlx_video.models.wan_2.convert import load_and_apply_loras
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
# Create mock LoRA safetensors # Create mock LoRA safetensors

View File

@@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config
class TestSinusoidalEmbedding: class TestSinusoidalEmbedding:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32) pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos) emb = sinusoidal_embedding_1d(256, pos)
@@ -21,7 +21,7 @@ class TestSinusoidalEmbedding:
def test_position_zero(self): def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0.""" """Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0]) pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos) emb = sinusoidal_embedding_1d(64, pos)
@@ -33,7 +33,7 @@ class TestSinusoidalEmbedding:
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5) np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self): def test_different_positions_differ(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0]) pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos) emb = sinusoidal_embedding_1d(128, pos)
@@ -50,7 +50,7 @@ class TestSinusoidalEmbedding:
class TestHead: class TestHead:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.wan2 import Head from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24 B, L = 1, 24
@@ -62,7 +62,7 @@ class TestHead:
assert out.shape == (B, L, expected_proj_dim) assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self): def test_modulation_shape(self):
from mlx_video.models.wan2.wan2 import Head from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64) assert head.modulation.shape == (1, 2, 64)
@@ -78,7 +78,7 @@ class TestWanModel:
mx.random.seed(42) mx.random.seed(42)
def test_instantiation(self): def test_instantiation(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -86,7 +86,7 @@ class TestWanModel:
assert num_params > 0 assert num_params > 0
def test_patchify_shape(self): def test_patchify_shape(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -99,7 +99,7 @@ class TestWanModel:
assert patches.shape == (1, 1 * 2 * 2, config.dim) assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self): def test_patchify_various_sizes(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -115,7 +115,7 @@ class TestWanModel:
def test_unpatchify_inverse(self): def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims.""" """Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -131,7 +131,7 @@ class TestWanModel:
assert out[0].shape == (config.out_dim, F, H, W) assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self): def test_forward_pass(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -149,7 +149,7 @@ class TestWanModel:
assert out[0].shape == (C, F, H, W) assert out[0].shape == (C, F, H, W)
def test_forward_batch(self): def test_forward_batch(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -171,7 +171,7 @@ class TestWanModel:
assert o.shape == (C, F, H, W) assert o.shape == (C, F, H, W)
def test_output_is_float32(self): def test_output_is_float32(self):
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -200,7 +200,7 @@ class TestWan21Model:
def _make_tiny_wan21_config(self): def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model).""" """Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b() config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values # Override to tiny values
@@ -217,7 +217,7 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self): def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B.""" """Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b() config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads) # Override to tiny values (preserve 1.3B head structure: 12 heads)
@@ -234,7 +234,7 @@ class TestWan21Model:
def test_wan21_tiny_model_forward(self): def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config.""" """Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_config() config = self._make_tiny_wan21_config()
model = WanModel(config) model = WanModel(config)
@@ -252,7 +252,7 @@ class TestWan21Model:
def test_wan21_1_3b_tiny_model_forward(self): def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config.""" """Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_1_3b_config() config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config) model = WanModel(config)
@@ -270,8 +270,8 @@ class TestWan21Model:
def test_wan21_single_model_loop(self): def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style).""" """Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config() config = self._make_tiny_wan21_config()
model = WanModel(config) model = WanModel(config)
@@ -305,7 +305,7 @@ class TestWan21Model:
def test_wan21_vs_wan22_config_differences(self): def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs.""" """Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b() c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b() c22 = WanModelConfig.wan22_t2v_14b()
@@ -333,21 +333,21 @@ class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding.""" """Tests for per-token sinusoidal embedding."""
def test_1d_unchanged(self): def test_1d_unchanged(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 500.0]) pos = mx.array([0.0, 100.0, 500.0])
emb = sinusoidal_embedding_1d(256, pos) emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (3, 256) assert emb.shape == (3, 256)
def test_2d_per_token(self): def test_2d_per_token(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]]) pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
emb = sinusoidal_embedding_1d(256, pos) emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (2, 3, 256) assert emb.shape == (2, 3, 256)
def test_consistency(self): def test_consistency(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos_1d = mx.array([0.0, 100.0]) pos_1d = mx.array([0.0, 100.0])
emb_1d = sinusoidal_embedding_1d(256, pos_1d) emb_1d = sinusoidal_embedding_1d(256, pos_1d)

View File

@@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config
class TestQuantizePredicate: class TestQuantizePredicate:
def test_matches_self_attention_layers(self): def test_matches_self_attention_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]: for suffix in ["q", "k", "v", "o"]:
@@ -23,7 +23,7 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}" assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self): def test_matches_cross_attention_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]: for suffix in ["q", "k", "v", "o"]:
@@ -31,14 +31,14 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}" assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self): def test_matches_ffn_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self): def test_rejects_embeddings(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
for path in [ for path in [
@@ -49,13 +49,13 @@ class TestQuantizePredicate:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self): def test_rejects_norms(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64) mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self): def test_rejects_non_quantizable_modules(self):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64) mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized # Even if path matches, module must have to_quantized
@@ -63,7 +63,7 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self): def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted.""" """Verify exactly 10 layer patterns are targeted."""
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64) mock_linear = nn.Linear(64, 64)
patterns = [ patterns = [
@@ -90,8 +90,8 @@ class TestQuantizePredicate:
class TestQuantizeRoundTrip: class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path.""" """Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config) model = WanModel(config)
nn.quantize( nn.quantize(
@@ -116,7 +116,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config() config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan2.utils import load_wan_model from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, model_path,
@@ -136,7 +136,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config() config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan2.utils import load_wan_model from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, model_path,
@@ -151,7 +151,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config() config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan2.utils import load_wan_model from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model( loaded = load_wan_model(
model_path, model_path,
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
def test_loading_without_quantization_flag(self, tmp_path): def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers.""" """Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -172,7 +172,7 @@ class TestQuantizeRoundTrip:
model_path = tmp_path / "model.safetensors" model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict) mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan2.utils import load_wan_model from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None) loaded = load_wan_model(model_path, config, quantization=None)
@@ -187,8 +187,8 @@ class TestQuantizeRoundTrip:
class TestQuantizedInference: class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4): def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config) model = WanModel(config)
nn.quantize( nn.quantize(
@@ -238,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self): def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights.""" """Sanity check: quantization should change the weights."""
from mlx_video.models.wan2.convert import _quantize_predicate from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
mx.random.seed(42) mx.random.seed(42)
@@ -271,8 +271,8 @@ class TestQuantizedInference:
class TestQuantizationConfig: class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path): def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json.""" """Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan2.convert import _quantize_saved_model from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -295,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64 assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path): def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan2.convert import _quantize_saved_model from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()
model = WanModel(config) model = WanModel(config)
@@ -316,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path): def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files.""" """Verify dual-model quantization writes both model files."""
from mlx_video.models.wan2.convert import _quantize_saved_model from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config() config = _make_tiny_config()

View File

@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
def _get_model_freqs(self, dim=64, num_heads=4): def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor.""" """Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan_2.wan_2 import WanModel
config = WanModelConfig() config = WanModelConfig()
config.dim = dim config.dim = dim
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
def test_three_call_vs_single_call_differ(self): def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call.""" """Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 # head_dim for all Wan models d = 128 # head_dim for all Wan models
# Reference: three separate calls # Reference: three separate calls
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
This verifies each axis gets its own independent frequency range This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim). starting from theta^0 = 1.0 (i.e., exponent 0/dim).
""" """
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate( freqs = mx.concatenate(
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42). Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
""" """
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 d = 128
d_h_dim = 2 * (d // 6) # 42 d_h_dim = 2 * (d // 6) # 42
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
axis should be 1.0 (theta^0). A single-call approach would give height axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0. starting at ~0.04 and width at ~0.002 instead of 1.0.
""" """
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate( freqs = mx.concatenate(
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_match_manual_construction(self): def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs.""" """WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4) freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16 d = head_dim # 16
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_14b_dimensions(self): def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128.""" """Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 d = 128
freqs = mx.concatenate( freqs = mx.concatenate(
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
"""Numerically compare MLX and PyTorch frequency tables.""" """Numerically compare MLX and PyTorch frequency tables."""
import torch import torch
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
d = 128 d = 128
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug: This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions. height/width frequencies were too low to distinguish nearby positions.
""" """
from mlx_video.models.wan2.rope import rope_apply, rope_params from mlx_video.models.wan_2.rope import rope_apply, rope_params
d = 128 d = 128
freqs = mx.concatenate( freqs = mx.concatenate(
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
def test_precomputed_matches_online(self): def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path.""" """rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan2.rope import ( from mlx_video.models.wan_2.rope import (
rope_apply, rope_apply,
rope_params, rope_params,
rope_precompute_cos_sin, rope_precompute_cos_sin,

View File

@@ -13,7 +13,7 @@ import pytest
class TestFlowMatchEulerScheduler: class TestFlowMatchEulerScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
@@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas is None assert sched.sigmas is None
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0) sched.set_timesteps(40, shift=12.0)
@@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas.shape == (41,) # 40 steps + terminal assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self): def test_timesteps_decreasing(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0) sched.set_timesteps(40, shift=12.0)
@@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..." assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self): def test_sigmas_decreasing(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0) sched.set_timesteps(20, shift=1.0)
@@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing" assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self): def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
@@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self): def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values.""" """Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler() sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler() sched2 = FlowMatchEulerScheduler()
@@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler:
assert mean2 > mean1, "Higher shift should push sigmas higher" assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self): def test_step_euler(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
@@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler:
) )
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler:
assert sched._step_index == 2 assert sched._step_index == 2
def test_reset(self): def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50]) @pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0) sched.set_timesteps(steps, shift=12.0)
@@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self): def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged.""" """Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler() sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -153,26 +153,26 @@ class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper.""" """Tests for the shared _compute_sigmas helper."""
def test_length(self): def test_length(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self): def test_terminal_zero(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0) sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0 assert sigmas[-1] == 0.0
def test_starts_near_one(self): def test_starts_near_one(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0) # Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3) np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self): def test_decreasing(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0) sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0) assert np.all(np.diff(sigmas) <= 0)
@@ -185,7 +185,7 @@ class TestComputeSigmas:
sigma_max/sigma_min come from the *unshifted* training schedule, and the sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift). shift is applied only once (single-shift).
""" """
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000 steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N) sigmas = _compute_sigmas(steps, shift, N)
@@ -200,7 +200,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(sigmas, official, atol=1e-6) np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self): def test_shift_one_is_near_linear(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0) sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule) # With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
@@ -210,7 +210,7 @@ class TestComputeSigmas:
def test_all_schedulers_same_sigmas(self): def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules.""" """All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
@@ -229,7 +229,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6) np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self): def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
@@ -255,14 +255,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler: class TestFlowDPMPP2MScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True assert sched.lower_order_final is True
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
@@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler:
assert sched.sigmas.shape == (21,) assert sched.sigmas.shape == (21,)
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler:
assert sched._step_index == 2 assert sched._step_index == 2
def test_reset(self): def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self): def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output.""" """Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
@@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self): def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available).""" """First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
@@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self): def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction.""" """After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
@@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler:
def test_denoise_to_target(self): def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver.""" """Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
@@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50]) @pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0) sched.set_timesteps(steps, shift=5.0)
@@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self): def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly.""" """When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler() sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler: class TestFlowUniPCScheduler:
def test_initialization(self): def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000 assert sched.num_train_timesteps == 1000
@@ -402,7 +402,7 @@ class TestFlowUniPCScheduler:
assert sched.lower_order_final is True assert sched.lower_order_final is True
def test_set_timesteps(self): def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0) sched.set_timesteps(30, shift=12.0)
@@ -411,7 +411,7 @@ class TestFlowUniPCScheduler:
assert sched.sigmas.shape == (31,) assert sched.sigmas.shape == (31,)
def test_step_index_increments(self): def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -422,7 +422,7 @@ class TestFlowUniPCScheduler:
assert sched._step_index == 1 assert sched._step_index == 1
def test_reset(self): def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -435,7 +435,7 @@ class TestFlowUniPCScheduler:
assert all(m is None for m in sched._model_outputs) assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self): def test_full_loop_finite(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0) sched.set_timesteps(10, shift=1.0)
@@ -448,7 +448,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self): def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history).""" """First step should skip the corrector (no history)."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True) sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
@@ -462,7 +462,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self): def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled.""" """Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True) sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
@@ -475,7 +475,7 @@ class TestFlowUniPCScheduler:
assert sched._lower_order_nums >= 2 assert sched._lower_order_nums >= 2
def test_denoise_to_target(self): def test_denoise_to_target(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0) sched.set_timesteps(20, shift=5.0)
@@ -490,7 +490,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50]) @pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps): def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0) sched.set_timesteps(steps, shift=5.0)
@@ -500,7 +500,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self): def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error.""" """Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0) sched.set_timesteps(5, shift=1.0)
@@ -513,7 +513,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self): def test_solver_order_3(self):
"""Order 3 should work without error.""" """Order 3 should work without error."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0) sched.set_timesteps(10, shift=5.0)
@@ -531,7 +531,7 @@ class TestFlowUniPCScheduler:
# For 50-step schedule with shift=5.0, order 2 corrector at step 5: # For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5 # rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan2.scheduler import _compute_sigmas from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0) sigmas = _compute_sigmas(50, shift=5.0)
@@ -597,7 +597,7 @@ class TestSchedulerCoherence:
@staticmethod @staticmethod
def _make_schedulers(steps=10, shift=5.0): def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowMatchEulerScheduler, FlowMatchEulerScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
@@ -780,7 +780,7 @@ class TestSchedulerCoherence:
def test_lambda_boundary_values(self): def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
) )
@@ -800,7 +800,7 @@ class TestSchedulerCoherence:
def test_lambda_monotonically_decreasing(self): def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas] lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
@@ -902,7 +902,7 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2) shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape) noise = mx.random.normal(shape)
from mlx_video.models.wan2.scheduler import ( from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler, FlowDPMPP2MScheduler,
FlowUniPCScheduler, FlowUniPCScheduler,
) )
@@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self): def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled.""" """Default construction should have corrector enabled."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler() sched = FlowUniPCScheduler()
assert sched._use_corrector is True assert sched._use_corrector is True
def test_corrector_affects_output(self): def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1.""" """Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42) mx.random.seed(42)
shape = (1, 4, 1, 4, 4) shape = (1, 4, 1, 4, 4)
@@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self): def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting.""" """Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42) mx.random.seed(42)
shape = (1, 4, 1, 4, 4) shape = (1, 4, 1, 4, 4)

View File

@@ -11,7 +11,7 @@ import numpy as np
class TestT5LayerNorm: class TestT5LayerNorm:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5LayerNorm from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(64) norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
@@ -21,7 +21,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self): def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1.""" """After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan2.text_encoder import T5LayerNorm from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(128) norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0 x = mx.random.normal((1, 5, 128)) * 5.0
@@ -35,7 +35,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding: class TestT5RelativeEmbedding:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10) out = rel_emb(10, 10)
@@ -43,7 +43,7 @@ class TestT5RelativeEmbedding:
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk] assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self): def test_asymmetric_lengths(self):
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12) out = rel_emb(8, 12)
@@ -52,7 +52,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self): def test_symmetry(self):
"""Position bias should have structure (not all zeros/random).""" """Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2) rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6) out = rel_emb(6, 6)
@@ -67,7 +67,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention: class TestT5Attention:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5Attention from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
@@ -77,14 +77,14 @@ class TestT5Attention:
def test_no_scaling(self): def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure.""" """T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan2.text_encoder import T5Attention from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention) # No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale") assert not hasattr(attn, "scale")
def test_with_position_bias(self): def test_with_position_bias(self):
from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4) rel_emb = T5RelativeEmbedding(32, 4)
@@ -95,7 +95,7 @@ class TestT5Attention:
assert out.shape == (1, 10, 64) assert out.shape == (1, 10, 64)
def test_with_mask(self): def test_with_mask(self):
from mlx_video.models.wan2.text_encoder import T5Attention from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4) attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
@@ -108,7 +108,7 @@ class TestT5Attention:
class TestT5FeedForward: class TestT5FeedForward:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5FeedForward from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256) ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64)) x = mx.random.normal((1, 10, 64))
@@ -118,7 +118,7 @@ class TestT5FeedForward:
def test_gated_structure(self): def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x).""" """T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan2.text_encoder import T5FeedForward from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64) ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj") assert hasattr(ffn, "gate_proj")
@@ -131,7 +131,7 @@ class TestT5Encoder:
mx.random.seed(42) mx.random.seed(42)
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, vocab_size=100,
@@ -150,7 +150,7 @@ class TestT5Encoder:
assert out.shape == (1, 5, 64) assert out.shape == (1, 5, 64)
def test_shared_pos(self): def test_shared_pos(self):
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, vocab_size=100,
@@ -167,7 +167,7 @@ class TestT5Encoder:
assert block.pos_embedding is None assert block.pos_embedding is None
def test_per_layer_pos(self): def test_per_layer_pos(self):
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, vocab_size=100,
@@ -184,7 +184,7 @@ class TestT5Encoder:
assert block.pos_embedding is not None assert block.pos_embedding is not None
def test_param_count(self): def test_param_count(self):
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, vocab_size=100,
@@ -200,7 +200,7 @@ class TestT5Encoder:
assert num_params > 0 assert num_params > 0
def test_without_mask(self): def test_without_mask(self):
from mlx_video.models.wan2.text_encoder import T5Encoder from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
vocab_size=100, vocab_size=100,

View File

@@ -75,7 +75,7 @@ class TestWan22TiledDecoding:
def _make_small_wan22_decoder(self): def _make_small_wan22_decoder(self):
"""Create a small Wan2.2 decoder for testing.""" """Create a small Wan2.2 decoder for testing."""
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
# Use very small dimensions for fast testing # Use very small dimensions for fast testing
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16) vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
@@ -139,7 +139,7 @@ class TestWan21TiledDecoding:
def _make_small_wan21_vae(self): def _make_small_wan21_vae(self):
"""Create a small Wan2.1 VAE for testing.""" """Create a small Wan2.1 VAE for testing."""
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
mx.eval(vae.parameters()) mx.eval(vae.parameters())
@@ -192,7 +192,7 @@ class TestWan21TemporalScale:
def test_wan21_decoder_temporal_output(self): def test_wan21_decoder_temporal_output(self):
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling).""" """Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
from mlx_video.models.wan2.vae import Decoder3d from mlx_video.models.wan_2.vae import Decoder3d
# Small decoder for fast test # Small decoder for fast test
dec = Decoder3d( dec = Decoder3d(

View File

@@ -10,7 +10,7 @@ import numpy as np
class TestWanFFN: class TestWanFFN:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.transformer import WanFFN from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(64, 256) ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64)) x = mx.random.normal((2, 10, 64))
@@ -20,7 +20,7 @@ class TestWanFFN:
def test_gelu_activation(self): def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity).""" """FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan2.transformer import WanFFN from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(32, 128) ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0 x = mx.ones((1, 1, 32)) * 2.0
@@ -40,8 +40,8 @@ class TestWanAttentionBlock:
self.num_heads = 4 self.num_heads = 4
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.dim,
@@ -68,13 +68,13 @@ class TestWanAttentionBlock:
assert out.shape == (B, L, self.dim) assert out.shape == (B, L, self.dim)
def test_modulation_shape(self): def test_modulation_shape(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim) assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self): def test_with_cross_attn_norm(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.dim,
@@ -85,7 +85,7 @@ class TestWanAttentionBlock:
assert block.norm3 is not None assert block.norm3 is not None
def test_without_cross_attn_norm(self): def test_without_cross_attn_norm(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock( block = WanAttentionBlock(
self.dim, self.dim,
@@ -97,8 +97,8 @@ class TestWanAttentionBlock:
def test_residual_connection(self): def test_residual_connection(self):
"""Output should differ from zero even with small random init.""" """Output should differ from zero even with small random init."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8 B, L = 1, 8
@@ -129,15 +129,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self): def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32.""" """Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True) block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32 assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self): def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32.""" """Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan2.rope import rope_params from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4) block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8 B, L = 1, 8
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self): def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input.""" """Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan2.wan2 import Head from mlx_video.models.wan_2.wan_2 import Head
head = Head(self.dim, 4, (1, 2, 2)) head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim)) x = mx.random.normal((1, 8, self.dim))
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self): def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32.""" """sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
t = mx.array([500.0]) t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t) emb = sinusoidal_embedding_1d(256, t)
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self): def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32.""" """Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t) emb = sinusoidal_embedding_1d(256, t)

View File

@@ -12,7 +12,7 @@ import numpy as np
class TestCausalConv3d: class TestCausalConv3d:
def test_output_shape_stride1(self): def test_output_shape_stride1(self):
from mlx_video.models.wan2.vae import CausalConv3d from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1) conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights # Initialize weights
@@ -28,7 +28,7 @@ class TestCausalConv3d:
assert out.shape[4] == 8 # W preserved assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self): def test_output_shape_kernel1(self):
from mlx_video.models.wan2.vae import CausalConv3d from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0) conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02 conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -39,7 +39,7 @@ class TestCausalConv3d:
def test_causal_padding(self): def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future.""" """Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan2.vae import CausalConv3d from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1) conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -56,7 +56,7 @@ class TestCausalConv3d:
class TestResidualBlock: class TestResidualBlock:
def test_same_dim(self): def test_same_dim(self):
from mlx_video.models.wan2.vae import ResidualBlock from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
@@ -65,7 +65,7 @@ class TestResidualBlock:
assert out.shape == (1, 8, 2, 4, 4) assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self): def test_different_dim(self):
from mlx_video.models.wan2.vae import ResidualBlock from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
@@ -74,13 +74,13 @@ class TestResidualBlock:
assert out.shape == (1, 16, 2, 4, 4) assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self): def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan2.vae import ResidualBlock from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
assert block.shortcut is not None assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self): def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan2.vae import ResidualBlock from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
assert block.shortcut is None assert block.shortcut is None
@@ -88,7 +88,7 @@ class TestResidualBlock:
class TestAttentionBlock: class TestAttentionBlock:
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae import AttentionBlock from mlx_video.models.wan_2.vae import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4)) x = mx.random.normal((1, 8, 2, 4, 4))
@@ -97,7 +97,7 @@ class TestAttentionBlock:
assert out.shape == (1, 8, 2, 4, 4) assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self): def test_residual_connection(self):
from mlx_video.models.wan2.vae import AttentionBlock from mlx_video.models.wan_2.vae import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3)) x = mx.random.normal((1, 8, 1, 3, 3))
@@ -109,7 +109,7 @@ class TestAttentionBlock:
class TestWanVAE: class TestWanVAE:
def test_instantiation(self): def test_instantiation(self):
from mlx_video.models.wan2.vae import WanVAE from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
assert vae.z_dim == 16 assert vae.z_dim == 16
@@ -117,7 +117,7 @@ class TestWanVAE:
assert vae.std.shape == (16,) assert vae.std.shape == (16,)
def test_normalization_stats(self): def test_normalization_stats(self):
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD from mlx_video.models.wan_2.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16 assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16 assert len(VAE_STD) == 16
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last).""" """Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self): def test_output_shape_k3(self):
from mlx_video.models.wan2.vae22 import CausalConv3d from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1) conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
assert out.shape == (1, 4, 8, 8, 16) assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self): def test_output_shape_k1(self):
from mlx_video.models.wan2.vae22 import CausalConv3d from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1) conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self): def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0.""" """Output at t=0 should not depend on t>0."""
from mlx_video.models.wan2.vae22 import CausalConv3d from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1) conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1 conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self): def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C].""" """Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan2.vae22 import CausalConv3d from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1) conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4)) x = mx.random.normal((2, 3, 6, 6, 4))
@@ -191,7 +191,7 @@ class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization).""" """Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import RMS_norm from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(16) norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16)) x = mx.random.normal((2, 4, 4, 4, 16))
@@ -201,7 +201,7 @@ class TestRMSNorm:
def test_l2_normalization(self): def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim).""" """RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan2.vae22 import RMS_norm from mlx_video.models.wan_2.vae22 import RMS_norm
dim = 32 dim = 32
norm = RMS_norm(dim) norm = RMS_norm(dim)
@@ -215,7 +215,7 @@ class TestRMSNorm:
def test_scale_invariant(self): def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property).""" """Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan2.vae22 import RMS_norm from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(8) norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8)) x = mx.random.normal((1, 1, 1, 1, 8))
@@ -226,7 +226,7 @@ class TestRMSNorm:
def test_gamma_effect(self): def test_gamma_effect(self):
"""Non-unit gamma should scale output.""" """Non-unit gamma should scale output."""
from mlx_video.models.wan2.vae22 import RMS_norm from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(4) norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
@@ -241,7 +241,7 @@ class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling.""" """Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self): def test_spatial_only(self):
from mlx_video.models.wan2.vae22 import DupUp3D from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2) up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
@@ -250,7 +250,7 @@ class TestDupUp3D:
assert out.shape == (1, 3, 8, 8, 4) assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self): def test_temporal_and_spatial(self):
from mlx_video.models.wan2.vae22 import DupUp3D from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2) up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16)) x = mx.random.normal((1, 3, 4, 4, 16))
@@ -259,7 +259,7 @@ class TestDupUp3D:
assert out.shape == (1, 6, 8, 8, 8) assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self): def test_first_chunk_trims(self):
from mlx_video.models.wan2.vae22 import DupUp3D from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2) up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
@@ -271,7 +271,7 @@ class TestDupUp3D:
assert out_trimmed.shape[1] == 5 assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self): def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan2.vae22 import DupUp3D from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2) up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8)) x = mx.random.normal((1, 3, 4, 4, 8))
@@ -286,7 +286,7 @@ class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling).""" """Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self): def test_upsample2d_shape(self):
from mlx_video.models.wan2.vae22 import Resample from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample2d") r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -296,7 +296,7 @@ class TestVAE22Resample:
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self): def test_upsample3d_shape(self):
from mlx_video.models.wan2.vae22 import Resample from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -306,7 +306,7 @@ class TestVAE22Resample:
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self): def test_upsample3d_first_chunk(self):
from mlx_video.models.wan2.vae22 import Resample from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -318,7 +318,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self): def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample.""" """Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan2.vae22 import Resample from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d") r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -336,7 +336,7 @@ class TestVAE22Resample:
We verify this by checking that the first output frame depends only on We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters). the first input frame (not on time_conv parameters).
""" """
from mlx_video.models.wan2.vae22 import Resample from mlx_video.models.wan_2.vae22 import Resample
C = 8 C = 8
r = Resample(C, "upsample3d") r = Resample(C, "upsample3d")
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock.""" """Tests for vae22.ResidualBlock."""
def test_same_dim(self): def test_same_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 8) assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self): def test_different_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 16) assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self): def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan2.vae22 import ResidualBlock from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 16) block = ResidualBlock(8, 16)
assert block.shortcut is not None assert block.shortcut is not None
def test_no_shortcut_same_dim(self): def test_no_shortcut_same_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 8) block = ResidualBlock(8, 8)
assert block.shortcut is None assert block.shortcut is None
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self): def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them).""" """Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan2.vae22 import ResidualBlockLayers from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8) block = ResidualBlockLayers(8, 8)
params = dict(block.parameters()) params = dict(block.parameters())
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
assert not key.startswith("_"), f"Parameter {key} starts with underscore" assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self): def test_has_expected_layers(self):
from mlx_video.models.wan2.vae22 import ResidualBlockLayers from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16) block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm assert hasattr(block, "layer_0") # first RMS_norm
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
assert hasattr(block, "layer_6") # second CausalConv3d assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self): def test_forward_shape(self):
from mlx_video.models.wan2.vae22 import ResidualBlockLayers from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16) block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8)) x = mx.random.normal((1, 2, 4, 4, 8))
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention).""" """Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import AttentionBlock from mlx_video.models.wan_2.vae22 import AttentionBlock
block = AttentionBlock(16) block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
assert out.shape == (1, 2, 4, 4, 16) assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self): def test_residual_connection(self):
from mlx_video.models.wan2.vae22 import AttentionBlock from mlx_video.models.wan_2.vae22 import AttentionBlock
block = AttentionBlock(8) block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
@@ -466,7 +466,7 @@ class TestHead22:
"""Tests for vae22.Head22 output head.""" """Tests for vae22.Head22 output head."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Head22 from mlx_video.models.wan_2.vae22 import Head22
head = Head22(16, out_channels=12) head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16)) x = mx.random.normal((1, 2, 4, 4, 16))
@@ -476,7 +476,7 @@ class TestHead22:
def test_layer_names_no_underscore(self): def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix.""" """Head layers must not use underscore prefix."""
from mlx_video.models.wan2.vae22 import Head22 from mlx_video.models.wan_2.vae22 import Head22
head = Head22(8) head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm assert hasattr(head, "layer_0") # RMS_norm
@@ -490,7 +490,7 @@ class TestUnpatchify:
"""Tests for vae22._unpatchify.""" """Tests for vae22._unpatchify."""
def test_basic_shape(self): def test_basic_shape(self):
from mlx_video.models.wan2.vae22 import _unpatchify from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2) out = _unpatchify(x, patch_size=2)
@@ -498,7 +498,7 @@ class TestUnpatchify:
assert out.shape == (1, 2, 8, 8, 3) assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self): def test_patch_size_1_noop(self):
from mlx_video.models.wan2.vae22 import _unpatchify from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3)) x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1) out = _unpatchify(x, patch_size=1)
@@ -507,7 +507,7 @@ class TestUnpatchify:
def test_preserves_content(self): def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement.""" """Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan2.vae22 import _unpatchify from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2) out = _unpatchify(x, patch_size=2)
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents.""" """Tests for vae22.denormalize_latents."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import denormalize_latents from mlx_video.models.wan_2.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48)) z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z) out = denormalize_latents(z)
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
assert out.shape == (1, 2, 4, 4, 48) assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self): def test_custom_mean_std(self):
from mlx_video.models.wan2.vae22 import denormalize_latents from mlx_video.models.wan_2.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4)) z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0]) mean = mx.array([1.0, 2.0, 3.0, 4.0])
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
) )
def test_uses_default_constants(self): def test_uses_default_constants(self):
from mlx_video.models.wan2.vae22 import ( from mlx_video.models.wan_2.vae22 import (
VAE22_MEAN, VAE22_MEAN,
denormalize_latents, denormalize_latents,
) )
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants.""" """Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self): def test_dimensions(self):
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD from mlx_video.models.wan_2.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD) mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,) assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,) assert VAE22_STD.shape == (48,)
def test_std_positive(self): def test_std_positive(self):
from mlx_video.models.wan2.vae22 import VAE22_STD from mlx_video.models.wan_2.vae22 import VAE22_STD
mx.eval(VAE22_STD) mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all() assert (np.array(VAE22_STD) > 0).all()
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self): def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output.""" """Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast # Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
assert np.array(out).max() <= 1.0 assert np.array(out).max() <= 1.0
def test_output_clipped(self): def test_output_clipped(self):
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights.""" """Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self): def test_skip_encoder(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"encoder.layer.weight": mx.zeros((4,)), "encoder.layer.weight": mx.zeros((4,)),
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.conv1.bias" in out assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self): def test_sequential_index_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.head.layer_2.bias" in out assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self): def test_resample_conv_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self): def test_attention_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = { weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.middle.1.proj_bias" in out assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self): def test_conv3d_transpose(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3)) w = mx.zeros((16, 8, 3, 3, 3))
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8) assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self): def test_conv2d_transpose(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I] # Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3)) w = mx.zeros((8, 8, 3, 3))
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
assert out[key].shape == (8, 3, 3, 8) assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self): def test_gamma_squeeze(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,) # gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1)) w = mx.ones((16, 1, 1, 1))
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock.""" """Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self): def test_no_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock( block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False 8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 4, 4, 8) assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self): def test_spatial_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock( block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True 8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 8, 8, 4) assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self): def test_spatial_temporal_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock( block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True 8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
@@ -738,7 +738,7 @@ class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip.""" """Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self): def test_roundtrip(self):
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3)) x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2) p = _patchify(x, patch_size=2)
@@ -748,7 +748,7 @@ class TestPatchify:
assert float(mx.abs(x - back).max()) == 0.0 assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self): def test_identity_patch_1(self):
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3)) x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape assert _patchify(x, patch_size=1).shape == x.shape
@@ -759,7 +759,7 @@ class TestAvgDown3D:
"""Tests for AvgDown3D downsampling.""" """Tests for AvgDown3D downsampling."""
def test_spatial_only(self): def test_spatial_only(self):
from mlx_video.models.wan2.vae22 import AvgDown3D from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2) down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8)) x = mx.random.normal((1, 2, 8, 8, 8))
@@ -768,7 +768,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16) assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self): def test_temporal_and_spatial(self):
from mlx_video.models.wan2.vae22 import AvgDown3D from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2) down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8)) x = mx.random.normal((1, 4, 8, 8, 8))
@@ -777,7 +777,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16) assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self): def test_single_frame(self):
from mlx_video.models.wan2.vae22 import AvgDown3D from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2) down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8)) x = mx.random.normal((1, 1, 8, 8, 8))
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
"""Tests for Down_ResidualBlock.""" """Tests for Down_ResidualBlock."""
def test_no_downsample(self): def test_no_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock( block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False 8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 8, 8, 8) assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self): def test_spatial_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock( block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True 8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 4, 4, 16) assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self): def test_spatial_temporal_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock( block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True 8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
@@ -828,7 +828,7 @@ class TestEncoder3d:
"""Tests for Encoder3d.""" """Tests for Encoder3d."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Encoder3d from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8) enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12)) x = mx.random.normal((1, 1, 16, 16, 12))
@@ -839,7 +839,7 @@ class TestEncoder3d:
assert out.shape == (1, 1, 2, 2, 8) assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self): def test_multi_frame(self):
from mlx_video.models.wan2.vae22 import Encoder3d from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False)) enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12)) x = mx.random.normal((1, 5, 16, 16, 12))
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper.""" """Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self): def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16) enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2) # Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
assert z.shape == (1, 1, 2, 2, 48) assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self): def test_full_dim(self):
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160) enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3)) img = mx.random.normal((1, 1, 64, 64, 3))
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip.""" """Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self): def test_roundtrip(self):
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents from mlx_video.models.wan_2.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48)) z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z) z_norm = normalize_latents(z)
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self): def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2.""" """Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan2.vae22 import Encoder3d from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12)) x = mx.random.normal((1, 5, 16, 16, 12))
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self): def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample.""" """Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder from mlx_video.models.wan_2.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16) enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples down_blocks = enc.encoder.downsamples
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self): def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern.""" """Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16) enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3)) img = mx.random.normal((1, 1, 32, 32, 3))
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self): def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs.""" """(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan2.vae22 import Encoder3d from mlx_video.models.wan_2.vae22 import Encoder3d
enc_correct = Encoder3d( enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True) dim=16, z_dim=8, temperal_downsample=(False, True, True)
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
def test_encode_decode_shape_and_values(self): def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite.""" """Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d from mlx_video.models.wan_2.vae import Decoder3d, Encoder3d
z_dim = 4 z_dim = 4
dim = 8 dim = 8
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
def test_encode_decode_shape_and_values(self): def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range.""" """Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan2.vae22 import ( from mlx_video.models.wan_2.vae22 import (
Wan22VAEDecoder, Wan22VAEDecoder,
Wan22VAEEncoder, Wan22VAEEncoder,
denormalize_latents, denormalize_latents,

View File

@@ -3,7 +3,7 @@
def _make_tiny_config(): def _make_tiny_config():
"""Create a tiny WanModelConfig for testing.""" """Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan2.config import WanModelConfig from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig() config = WanModelConfig()
# Override to tiny values # Override to tiny values