diff --git a/README.md b/README.md index dbcf7b9..5e8b1dc 100644 --- a/README.md +++ b/README.md @@ -93,18 +93,18 @@ Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.c ### Step 0: Download and Convert Weights -See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan/README.md) for details. +See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan_2/README.md) for details. ### Step 1: Generate Video ```bash # Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0) -python -m mlx_video.wan2.generate \ +python -m mlx_video.wan_2.generate \ --model-dir wan21_mlx \ --prompt "A cat playing piano in a cozy room" # Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0) -python -m mlx_video.wan2.generate \ +python -m mlx_video.wan_2.generate \ --model-dir wan22_mlx \ --prompt "A cat playing piano in a cozy room" ``` @@ -112,7 +112,7 @@ python -m mlx_video.wan2.generate \ With custom settings: ```bash -python -m mlx_video.wan2.generate \ +python -m mlx_video.wan_2.generate \ --model-dir wan21_mlx \ --prompt "Ocean waves at sunset, cinematic, 4K" \ --negative-prompt "blurry, low quality" \ @@ -131,7 +131,7 @@ The pipeline auto-detects the model version from `config.json` and selects the r ### Image-to-Video (I2V-14B) ```bash -python -m mlx_video.wan2.generate \ +python -m mlx_video.wan_2.generate \ --model-dir wan22_i2v_mlx \ --prompt "The camera slowly zooms in as the subject begins to move" \ --image start.png \ @@ -146,7 +146,7 @@ LoRAs can be used with the `--lora-high` and `--lora-low` command line switches. For example, using the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA for 4-step generation: ```bash -python -m mlx_vide.wan2.generate \ +python -m mlx_video.wan_2.generate \ --model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \ --width 480 \ --height 704 \ diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index 7c50343..a04ec64 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -22,7 +22,7 @@ from mlx_video.models.ltx_2.utils import ( load_safetensors, save_weights, ) -from mlx_video.models.wan2 import WanModel, WanModelConfig +from mlx_video.models.wan_2 import WanModel, WanModelConfig __all__ = [ # Models diff --git a/mlx_video/models/__init__.py b/mlx_video/models/__init__.py index b54c40d..c730f1d 100644 --- a/mlx_video/models/__init__.py +++ b/mlx_video/models/__init__.py @@ -1,2 +1,2 @@ from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig -from mlx_video.models.wan2 import WanModel, WanModelConfig +from mlx_video.models.wan_2 import WanModel, WanModelConfig diff --git a/mlx_video/models/wan2/__init__.py b/mlx_video/models/wan2/__init__.py deleted file mode 100644 index 90390e9..0000000 --- a/mlx_video/models/wan2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from mlx_video.models.wan2.config import WanModelConfig -from mlx_video.models.wan2.wan2 import WanModel diff --git a/mlx_video/models/wan2/README.md b/mlx_video/models/wan_2/README.md similarity index 100% rename from mlx_video/models/wan2/README.md rename to mlx_video/models/wan_2/README.md diff --git a/mlx_video/models/wan_2/__init__.py b/mlx_video/models/wan_2/__init__.py new file mode 100644 index 0000000..6b96519 --- /dev/null +++ b/mlx_video/models/wan_2/__init__.py @@ -0,0 +1,2 @@ +from mlx_video.models.wan_2.config import WanModelConfig +from mlx_video.models.wan_2.wan_2 import WanModel diff --git a/mlx_video/models/wan2/attention.py b/mlx_video/models/wan_2/attention.py similarity index 100% rename from mlx_video/models/wan2/attention.py rename to mlx_video/models/wan_2/attention.py diff --git a/mlx_video/models/wan2/config.py b/mlx_video/models/wan_2/config.py similarity index 100% rename from mlx_video/models/wan2/config.py rename to mlx_video/models/wan_2/config.py diff --git a/mlx_video/models/wan2/convert.py b/mlx_video/models/wan_2/convert.py similarity index 98% rename from mlx_video/models/wan2/convert.py rename to mlx_video/models/wan_2/convert.py index ba2b79a..1bd61cb 100644 --- a/mlx_video/models/wan2/convert.py +++ b/mlx_video/models/wan_2/convert.py @@ -247,7 +247,7 @@ def _load_lora_configs( Shared between weight-merging and runtime-wrapping paths. """ - from mlx_video.models.wan2.generate import Colors + from mlx_video.models.wan_2.generate import Colors from mlx_video.lora import LoRAConfig, load_multiple_loras print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") @@ -282,7 +282,7 @@ def load_and_apply_loras( For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). """ - from mlx_video.models.wan2.generate import Colors + from mlx_video.models.wan_2.generate import Colors from mlx_video.lora import apply_loras_to_weights if not lora_configs: @@ -411,7 +411,7 @@ def convert_wan_checkpoint( print(" Warning: No transformer weights found!") # Save config — detect model size from source config.json or transformer weights - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig def _detect_config(): """Detect config from source config.json or transformer weight shapes.""" @@ -522,7 +522,7 @@ def convert_wan_checkpoint( print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...") weights = load_torch_weights(str(vae_path)) if is_wan22_vae: - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights include_encoder = config.model_type in ("ti2v", "i2v") weights = sanitize_wan22_vae_weights( @@ -594,7 +594,7 @@ def _quantize_saved_model( import mlx.nn as nn - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel if source_dir is None: source_dir = output_dir @@ -704,7 +704,7 @@ def quantize_mlx_model( ).exists() # Build model config - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config_dict = { k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__ diff --git a/mlx_video/models/wan2/generate.py b/mlx_video/models/wan_2/generate.py similarity index 99% rename from mlx_video/models/wan2/generate.py rename to mlx_video/models/wan_2/generate.py index f173d9a..f455911 100644 --- a/mlx_video/models/wan2/generate.py +++ b/mlx_video/models/wan_2/generate.py @@ -11,15 +11,15 @@ import mlx.core as mx import numpy as np from tqdm import tqdm -from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image -from mlx_video.models.wan2.utils import ( +from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image +from mlx_video.models.wan_2.utils import ( encode_text, load_t5_encoder, load_vae_decoder, load_vae_encoder, load_wan_model, ) -from mlx_video.models.wan2.postprocess import save_video +from mlx_video.models.wan_2.postprocess import save_video class Colors: @@ -121,8 +121,8 @@ def generate_video( """ import json - from mlx_video.models.wan2.config import WanModelConfig - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.config import WanModelConfig + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -767,7 +767,7 @@ def generate_video( ) if is_wan22_vae: - from mlx_video.models.wan2.vae22 import denormalize_latents + from mlx_video.models.wan_2.vae22 import denormalize_latents # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) z = latents.transpose(1, 2, 3, 0)[None] diff --git a/mlx_video/models/wan2/i2v_utils.py b/mlx_video/models/wan_2/i2v_utils.py similarity index 100% rename from mlx_video/models/wan2/i2v_utils.py rename to mlx_video/models/wan_2/i2v_utils.py diff --git a/mlx_video/models/wan2/postprocess.py b/mlx_video/models/wan_2/postprocess.py similarity index 100% rename from mlx_video/models/wan2/postprocess.py rename to mlx_video/models/wan_2/postprocess.py diff --git a/mlx_video/models/wan2/rope.py b/mlx_video/models/wan_2/rope.py similarity index 100% rename from mlx_video/models/wan2/rope.py rename to mlx_video/models/wan_2/rope.py diff --git a/mlx_video/models/wan2/scheduler.py b/mlx_video/models/wan_2/scheduler.py similarity index 100% rename from mlx_video/models/wan2/scheduler.py rename to mlx_video/models/wan_2/scheduler.py diff --git a/mlx_video/models/wan2/text_encoder.py b/mlx_video/models/wan_2/text_encoder.py similarity index 100% rename from mlx_video/models/wan2/text_encoder.py rename to mlx_video/models/wan_2/text_encoder.py diff --git a/mlx_video/models/wan2/tiling.py b/mlx_video/models/wan_2/tiling.py similarity index 100% rename from mlx_video/models/wan2/tiling.py rename to mlx_video/models/wan_2/tiling.py diff --git a/mlx_video/models/wan2/transformer.py b/mlx_video/models/wan_2/transformer.py similarity index 100% rename from mlx_video/models/wan2/transformer.py rename to mlx_video/models/wan_2/transformer.py diff --git a/mlx_video/models/wan2/utils.py b/mlx_video/models/wan_2/utils.py similarity index 90% rename from mlx_video/models/wan2/utils.py rename to mlx_video/models/wan_2/utils.py index 45964fe..262e41d 100644 --- a/mlx_video/models/wan2/utils.py +++ b/mlx_video/models/wan_2/utils.py @@ -21,12 +21,12 @@ def load_wan_model( If provided, creates QuantizedLinear stubs before loading. loras: Optional list of (lora_path, strength) tuples to apply. """ - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel model = WanModel(config) if quantization: - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate nn.quantize( model, @@ -42,7 +42,7 @@ def load_wan_model( if quantization: # Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear. # Non-LoRA layers stay 4-bit. Zero per-step overhead. - from mlx_video.models.wan2.convert import _load_lora_configs + from mlx_video.models.wan_2.convert import _load_lora_configs from mlx_video.lora import apply_loras_to_model model.load_weights(list(weights.items()), strict=False) @@ -53,7 +53,7 @@ def load_wan_model( return model else: # Weight merging: fold LoRA into bf16 weights before loading - from mlx_video.models.wan2.convert import load_and_apply_loras + from mlx_video.models.wan_2.convert import load_and_apply_loras weights = load_and_apply_loras(dict(weights), loras) @@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config): only runs once per generation, so performance impact is negligible. This matches the official which computes softmax in float32 explicitly. """ - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=config.t5_vocab_size, @@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None): is_wan22 = config is not None and config.vae_z_dim == 48 if is_wan22: - from mlx_video.models.wan2.vae22 import Wan22VAEDecoder + from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder vae = Wan22VAEDecoder(z_dim=48) else: - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae = WanVAE(z_dim=16) @@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None): For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True. """ if config is not None and config.vae_z_dim == 16: - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae = WanVAE(z_dim=16, encoder=True) else: - from mlx_video.models.wan2.vae22 import Wan22VAEEncoder + from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48) diff --git a/mlx_video/models/wan2/vae.py b/mlx_video/models/wan_2/vae.py similarity index 99% rename from mlx_video/models/wan2/vae.py rename to mlx_video/models/wan_2/vae.py index b713ac7..379ec24 100644 --- a/mlx_video/models/wan2/vae.py +++ b/mlx_video/models/wan_2/vae.py @@ -589,7 +589,7 @@ class WanVAE(nn.Module): Returns: Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] """ - from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() diff --git a/mlx_video/models/wan2/vae22.py b/mlx_video/models/wan_2/vae22.py similarity index 99% rename from mlx_video/models/wan2/vae22.py rename to mlx_video/models/wan_2/vae22.py index 0b99aef..7063746 100644 --- a/mlx_video/models/wan2/vae22.py +++ b/mlx_video/models/wan_2/vae22.py @@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module): Returns: video: [B, T', H', W', 3] decoded RGB in [-1, 1] """ - from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() diff --git a/mlx_video/models/wan2/wan2.py b/mlx_video/models/wan_2/wan_2.py similarity index 100% rename from mlx_video/models/wan2/wan2.py rename to mlx_video/models/wan_2/wan_2.py diff --git a/pyproject.toml b/pyproject.toml index bf535c0..6a4d3ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues" [project.scripts] "mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main" -"mlx_video.wan2.generate" = "mlx_video.models.wan2.generate:main" +"mlx_video.wan_2.generate" = "mlx_video.models.wan_2.generate:main" [tool.setuptools.packages.find] include = ["mlx_video*"] diff --git a/tests/test_wan_attention.py b/tests/test_wan_attention.py index e94851e..0b48bf9 100644 --- a/tests/test_wan_attention.py +++ b/tests/test_wan_attention.py @@ -12,14 +12,14 @@ class TestRoPE: """Tests for 3-way factorized RoPE.""" def test_rope_params_shape(self): - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params freqs = rope_params(1024, 64) mx.eval(freqs) assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2] def test_rope_params_different_dims(self): - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params for dim in [32, 64, 128]: freqs = rope_params(512, dim) @@ -27,7 +27,7 @@ class TestRoPE: assert freqs.shape == (512, dim // 2, 2) def test_rope_params_cos_sin_range(self): - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params freqs = rope_params(256, 64) mx.eval(freqs) @@ -38,7 +38,7 @@ class TestRoPE: def test_rope_params_position_zero(self): """At position 0, cos should be 1 and sin should be 0.""" - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params freqs = rope_params(10, 64) mx.eval(freqs) @@ -46,7 +46,7 @@ class TestRoPE: np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6) def test_rope_apply_output_shape(self): - from mlx_video.models.wan2.rope import rope_apply, rope_params + from mlx_video.models.wan_2.rope import rope_apply, rope_params B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim x = mx.random.normal((B, L, N, D)) @@ -58,7 +58,7 @@ class TestRoPE: def test_rope_apply_preserves_norm(self): """RoPE rotation should preserve vector norms.""" - from mlx_video.models.wan2.rope import rope_apply, rope_params + from mlx_video.models.wan_2.rope import rope_apply, rope_params B, N, D = 1, 2, 16 F, H, W = 2, 3, 4 @@ -79,7 +79,7 @@ class TestRoPE: def test_rope_apply_with_padding(self): """When seq_len < L, extra tokens should be preserved unchanged.""" - from mlx_video.models.wan2.rope import rope_apply, rope_params + from mlx_video.models.wan_2.rope import rope_apply, rope_params B, N, D = 1, 2, 16 F, H, W = 2, 2, 2 @@ -100,7 +100,7 @@ class TestRoPE: def test_rope_apply_batch(self): """Test with batch_size > 1 and different grid sizes.""" - from mlx_video.models.wan2.rope import rope_apply, rope_params + from mlx_video.models.wan_2.rope import rope_apply, rope_params B, N, D = 2, 2, 16 grids = [(2, 3, 4), (2, 3, 4)] @@ -132,7 +132,7 @@ class TestRoPE: class TestWanRMSNorm: def test_output_shape(self): - from mlx_video.models.wan2.attention import WanRMSNorm + from mlx_video.models.wan_2.attention import WanRMSNorm norm = WanRMSNorm(64) x = mx.random.normal((2, 10, 64)) @@ -142,7 +142,7 @@ class TestWanRMSNorm: def test_zero_mean_variance(self): """RMS norm should make RMS ≈ 1 before scaling.""" - from mlx_video.models.wan2.attention import WanRMSNorm + from mlx_video.models.wan_2.attention import WanRMSNorm norm = WanRMSNorm(64) x = mx.random.normal((1, 5, 64)) * 10.0 @@ -156,7 +156,7 @@ class TestWanRMSNorm: def test_dtype_preservation(self): """RMSNorm weight is float32, so output is promoted to float32.""" - from mlx_video.models.wan2.attention import WanRMSNorm + from mlx_video.models.wan_2.attention import WanRMSNorm norm = WanRMSNorm(32) x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16) @@ -168,7 +168,7 @@ class TestWanRMSNorm: class TestWanLayerNorm: def test_output_shape(self): - from mlx_video.models.wan2.attention import WanLayerNorm + from mlx_video.models.wan_2.attention import WanLayerNorm norm = WanLayerNorm(64) x = mx.random.normal((2, 10, 64)) @@ -177,7 +177,7 @@ class TestWanLayerNorm: assert out.shape == (2, 10, 64) def test_without_affine(self): - from mlx_video.models.wan2.attention import WanLayerNorm + from mlx_video.models.wan_2.attention import WanLayerNorm norm = WanLayerNorm(64, elementwise_affine=False) x = mx.random.normal((1, 4, 64)) @@ -190,7 +190,7 @@ class TestWanLayerNorm: np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1) def test_with_affine(self): - from mlx_video.models.wan2.attention import WanLayerNorm + from mlx_video.models.wan_2.attention import WanLayerNorm norm = WanLayerNorm(32, elementwise_affine=True) assert hasattr(norm, "weight") @@ -208,8 +208,8 @@ class TestWanSelfAttention: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan2.attention import WanSelfAttention - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.attention import WanSelfAttention + from mlx_video.models.wan_2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) B, L = 1, 24 @@ -221,14 +221,14 @@ class TestWanSelfAttention: assert out.shape == (B, L, self.dim) def test_with_qk_norm(self): - from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan_2.attention import WanSelfAttention attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True) assert attn.norm_q is not None assert attn.norm_k is not None def test_without_qk_norm(self): - from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan_2.attention import WanSelfAttention attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) assert attn.norm_q is None @@ -236,8 +236,8 @@ class TestWanSelfAttention: def test_masking(self): """Test that masking works: shorter seq_lens should mask later tokens.""" - from mlx_video.models.wan2.attention import WanSelfAttention - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.attention import WanSelfAttention + from mlx_video.models.wan_2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) B, L = 1, 24 @@ -262,7 +262,7 @@ class TestWanCrossAttention: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan2.attention import WanCrossAttention + from mlx_video.models.wan_2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 24, 16 @@ -273,7 +273,7 @@ class TestWanCrossAttention: assert out.shape == (B, L_q, self.dim) def test_with_context_mask(self): - from mlx_video.models.wan2.attention import WanCrossAttention + from mlx_video.models.wan_2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 12, 16 @@ -311,8 +311,8 @@ class TestBFloat16Autocast: def test_self_attn_casts_to_weight_dtype(self): """Self-attention should cast input to weight dtype for QKV projections.""" - from mlx_video.models.wan2.attention import WanSelfAttention - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.attention import WanSelfAttention + from mlx_video.models.wan_2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -326,7 +326,7 @@ class TestBFloat16Autocast: def test_cross_attn_casts_to_weight_dtype(self): """Cross-attention should cast input to weight dtype.""" - from mlx_video.models.wan2.attention import WanCrossAttention + from mlx_video.models.wan_2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -340,7 +340,7 @@ class TestBFloat16Autocast: def test_cross_attn_kv_cache_uses_weight_dtype(self): """prepare_kv should cast context to weight dtype.""" - from mlx_video.models.wan2.attention import WanCrossAttention + from mlx_video.models.wan_2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -353,7 +353,7 @@ class TestBFloat16Autocast: def test_ffn_casts_to_weight_dtype(self): """FFN should cast input to weight dtype for linear layers.""" - from mlx_video.models.wan2.transformer import WanFFN + from mlx_video.models.wan_2.transformer import WanFFN ffn = WanFFN(self.dim, 128) ffn.update(self._to_bf16(ffn.parameters())) @@ -366,8 +366,8 @@ class TestBFloat16Autocast: def test_self_attn_rope_in_float32(self): """RoPE should be applied in float32 for precision, even with bf16 weights.""" - from mlx_video.models.wan2.attention import WanSelfAttention - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.attention import WanSelfAttention + from mlx_video.models.wan_2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -381,8 +381,8 @@ class TestBFloat16Autocast: def test_block_float32_residual_with_bf16_weights(self): """Full block: residual stream stays float32, matmuls use bf16 weights.""" - from mlx_video.models.wan2.rope import rope_params - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.rope import rope_params + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True) block.update(self._to_bf16(block.parameters())) diff --git a/tests/test_wan_config.py b/tests/test_wan_config.py index b37c722..a5f19ed 100644 --- a/tests/test_wan_config.py +++ b/tests/test_wan_config.py @@ -10,7 +10,7 @@ class TestWanModelConfig: """Tests for WanModelConfig dataclass.""" def test_default_values(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() assert config.dim == 5120 @@ -32,13 +32,13 @@ class TestWanModelConfig: assert config.text_len == 512 def test_head_dim_property(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() assert config.head_dim == 128 # 5120 // 40 def test_to_dict_roundtrip(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() d = config.to_dict() @@ -48,7 +48,7 @@ class TestWanModelConfig: assert d["boundary"] == 0.875 def test_t5_config_values(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() assert config.t5_vocab_size == 256384 @@ -69,7 +69,7 @@ class TestWan21Config: """Tests for Wan2.1 config presets.""" def test_wan21_14b_factory(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() assert config.model_version == "2.1" @@ -85,7 +85,7 @@ class TestWan21Config: assert config.boundary == 0.0 def test_wan21_1_3b_factory(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() assert config.model_version == "2.1" @@ -98,7 +98,7 @@ class TestWan21Config: assert config.sample_guide_scale == 5.0 def test_wan22_14b_factory(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan22_t2v_14b() assert config.model_version == "2.2" @@ -110,7 +110,7 @@ class TestWan21Config: assert config.boundary == 0.875 def test_wan21_config_to_dict(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() @@ -119,7 +119,7 @@ class TestWan21Config: assert d["sample_guide_scale"] == 5.0 def test_wan21_1_3b_config_to_dict(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() d = config.to_dict() @@ -128,7 +128,7 @@ class TestWan21Config: def test_default_config_is_wan22(self): """Default WanModelConfig() should be Wan2.2 14B.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() assert config.model_version == "2.2" diff --git a/tests/test_wan_convert.py b/tests/test_wan_convert.py index 0e5e48d..1483f9c 100644 --- a/tests/test_wan_convert.py +++ b/tests/test_wan_convert.py @@ -11,7 +11,7 @@ import mlx.core as mx class TestSanitizeTransformerWeights: def test_patch_embedding_reshape(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), @@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights: assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2) def test_text_embedding_rename(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "text_embedding.0.weight": mx.zeros((64, 32)), @@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights: assert "text_embedding_1.bias" in out def test_time_embedding_rename(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "time_embedding.0.weight": mx.zeros((64, 32)), @@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights: assert "time_embedding_1.weight" in out def test_time_projection_rename(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "time_projection.1.weight": mx.zeros((384, 64)), @@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights: assert "time_projection.bias" in out def test_ffn_rename(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "blocks.0.ffn.0.weight": mx.zeros((128, 64)), @@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights: assert "blocks.0.ffn.fc2.bias" in out def test_freqs_skipped(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "freqs": mx.zeros((1024, 64, 2)), @@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights: assert "blocks.0.norm1.weight" in out def test_passthrough_keys(self): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "blocks.0.self_attn.q.weight": mx.zeros((64, 64)), @@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights: assert key in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights + from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), @@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights: "head.head.weight": mx.zeros((64, 64)), "freqs": mx.zeros((1024, 64, 2)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"): sanitize_wan_transformer_weights(weights) assert "Unconsumed" not in caplog.text class TestSanitizeT5Weights: def test_gate_rename(self): - from mlx_video.models.wan2.convert import sanitize_wan_t5_weights + from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights weights = { "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), @@ -139,7 +139,7 @@ class TestSanitizeT5Weights: assert "blocks.0.ffn.fc2.weight" in out def test_passthrough(self): - from mlx_video.models.wan2.convert import sanitize_wan_t5_weights + from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights weights = { "token_embedding.weight": mx.zeros((100, 64)), @@ -151,7 +151,7 @@ class TestSanitizeT5Weights: assert key in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.models.wan2.convert import sanitize_wan_t5_weights + from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights weights = { "token_embedding.weight": mx.zeros((100, 64)), @@ -160,14 +160,14 @@ class TestSanitizeT5Weights: "blocks.0.ffn.fc2.weight": mx.zeros((64, 128)), "norm.weight": mx.zeros((64,)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"): sanitize_wan_t5_weights(weights) assert "Unconsumed" not in caplog.text class TestSanitizeVAEWeights: def test_conv3d_transpose(self): - from mlx_video.models.wan2.convert import sanitize_wan_vae_weights + from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W] @@ -176,7 +176,7 @@ class TestSanitizeVAEWeights: assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I] def test_conv2d_transpose(self): - from mlx_video.models.wan2.convert import sanitize_wan_vae_weights + from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights weights = { "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W] @@ -185,7 +185,7 @@ class TestSanitizeVAEWeights: assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I] def test_non_conv_passthrough(self): - from mlx_video.models.wan2.convert import sanitize_wan_vae_weights + from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights weights = { "decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose @@ -196,7 +196,7 @@ class TestSanitizeVAEWeights: assert out["decoder.bias"].shape == (16,) def test_mixed_weights(self): - from mlx_video.models.wan2.convert import sanitize_wan_vae_weights + from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights weights = { "conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D @@ -211,7 +211,7 @@ class TestSanitizeVAEWeights: assert out["norm.weight"].shape == (8,) def test_no_unconsumed_keys(self, caplog): - from mlx_video.models.wan2.convert import sanitize_wan_vae_weights + from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), @@ -219,7 +219,7 @@ class TestSanitizeVAEWeights: "decoder.norm.weight": mx.zeros((64,)), "decoder.bias": mx.zeros((16,)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"): sanitize_wan_vae_weights(weights) assert "Unconsumed" not in caplog.text @@ -256,7 +256,7 @@ class TestWan21Convert: def test_wan21_config_saved_correctly(self): """Verify config dict has correct fields for Wan2.1.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() @@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights: """Tests for sanitize_wan22_vae_weights with include_encoder.""" def test_exclude_encoder_by_default(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), @@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights: assert not any("encoder" in k or k.startswith("conv1") for k in out) def test_include_encoder(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), @@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights: assert "conv2.weight" in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"): sanitize_wan22_vae_weights(weights, include_encoder=True) assert "Unconsumed" not in caplog.text def test_no_unconsumed_keys_exclude_encoder(self, caplog): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"): sanitize_wan22_vae_weights(weights, include_encoder=False) assert "Unconsumed" not in caplog.text diff --git a/tests/test_wan_generate.py b/tests/test_wan_generate.py index e586cce..2972d1f 100644 --- a/tests/test_wan_generate.py +++ b/tests/test_wan_generate.py @@ -14,8 +14,8 @@ class TestEndToEnd: def test_tiny_model_denoise_step(self): """Simulate one denoising step with tiny model.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler mx.random.seed(42) config = _make_tiny_config() @@ -43,8 +43,8 @@ class TestEndToEnd: def test_tiny_model_full_loop(self): """Run a complete (tiny) diffusion loop.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler mx.random.seed(123) config = _make_tiny_config() @@ -81,7 +81,7 @@ class TestI2VMask: """Tests for _build_i2v_mask.""" def test_mask_shapes(self): - from mlx_video.models.wan2.generate import _build_i2v_mask + from mlx_video.models.wan_2.generate import _build_i2v_mask z_shape = (48, 5, 4, 4) # C, T, H, W patch_size = (1, 2, 2) @@ -91,7 +91,7 @@ class TestI2VMask: assert mask_tokens.shape == (1, 20) def test_first_frame_zero(self): - from mlx_video.models.wan2.generate import _build_i2v_mask + from mlx_video.models.wan_2.generate import _build_i2v_mask z_shape = (48, 5, 4, 4) mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2)) @@ -111,7 +111,7 @@ class TestI2VMaskAlignment: def test_mask_with_ti2v_dimensions(self): """Mask should work with TI2V-5B typical dimensions.""" - from mlx_video.models.wan2.generate import _build_i2v_mask + from mlx_video.models.wan_2.generate import _build_i2v_mask # TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2) # 704x1280 → latent 44x80, t_latent=21 for 81 frames @@ -132,7 +132,7 @@ class TestI2VMaskAlignment: def test_mask_per_token_timestep(self): """Per-token timesteps: first-frame tokens get t=0, rest get t=sigma.""" - from mlx_video.models.wan2.generate import _build_i2v_mask + from mlx_video.models.wan_2.generate import _build_i2v_mask z_shape = (4, 3, 4, 4) patch_size = (1, 2, 2) @@ -201,7 +201,7 @@ class TestDimensionAlignment: def test_patchify_valid_after_alignment(self): """After alignment, patchify should succeed without reshape errors.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -235,7 +235,7 @@ class TestDimensionAlignment: def test_alignment_with_ti2v_config(self): """TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan22_ti2v_5b() align_h = config.patch_size[1] * config.vae_stride[1] diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index 7c5e0cd..2b4789d 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -23,7 +23,7 @@ class TestI2VConfig: """Test I2V-14B config preset.""" def test_wan22_i2v_14b_preset(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan22_i2v_14b() assert config.model_type == "i2v" @@ -39,7 +39,7 @@ class TestI2VConfig: assert config.vae_z_dim == 16 def test_i2v_vs_t2v_differences(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig i2v = WanModelConfig.wan22_i2v_14b() t2v = WanModelConfig.wan22_t2v_14b() @@ -51,7 +51,7 @@ class TestI2VConfig: assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0 def test_i2v_serialization_roundtrip(self): - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan22_i2v_14b() d = config.to_dict() @@ -66,7 +66,7 @@ class TestModelYParameter: def test_forward_without_y(self): """Standard T2V forward pass (no y) still works.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -85,7 +85,7 @@ class TestModelYParameter: def test_forward_with_y(self): """I2V forward pass with y channel concatenation.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -108,7 +108,7 @@ class TestModelYParameter: def test_y_none_is_noop(self): """Passing y=None should be identical to not passing y.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -129,7 +129,7 @@ class TestModelYParameter: def test_batched_cfg_with_y(self): """Batched CFG (B=2) with y should work.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -158,7 +158,7 @@ class TestVAEEncoder: """Test Wan2.1 VAE encoder.""" def test_encoder3d_instantiation(self): - from mlx_video.models.wan2.vae import Encoder3d + from mlx_video.models.wan_2.vae import Encoder3d enc = Encoder3d( dim=32, z_dim=8 @@ -169,7 +169,7 @@ class TestVAEEncoder: def test_encoder3d_output_shape(self): """Encoder should downsample spatially by 8x and temporally by 4x.""" - from mlx_video.models.wan2.vae import Encoder3d + from mlx_video.models.wan_2.vae import Encoder3d enc = Encoder3d(dim=32, z_dim=8) # Random input: [B=1, 3, T=5, H=32, W=32] @@ -186,7 +186,7 @@ class TestVAEEncoder: def test_wan_vae_encode(self): """WanVAE with encoder=True should produce normalized latents.""" - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae = WanVAE(z_dim=16, encoder=True) # Input: [B=1, 3, T=5, H=32, W=32] @@ -198,7 +198,7 @@ class TestVAEEncoder: def test_wan_vae_encoder_flag(self): """WanVAE without encoder flag should not have encoder attribute.""" - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae_no_enc = WanVAE(z_dim=4, encoder=False) assert not hasattr(vae_no_enc, "encoder") @@ -211,7 +211,7 @@ class TestResampleDownsample: """Test downsample modes in Resample.""" def test_downsample2d(self): - from mlx_video.models.wan2.vae import Resample + from mlx_video.models.wan_2.vae import Resample r = Resample(dim=16, mode="downsample2d") x = mx.random.normal((1, 16, 2, 8, 8)) @@ -221,7 +221,7 @@ class TestResampleDownsample: assert out.shape == (1, 16, 2, 4, 4) def test_downsample3d(self): - from mlx_video.models.wan2.vae import Resample + from mlx_video.models.wan_2.vae import Resample r = Resample(dim=16, mode="downsample3d") x = mx.random.normal((1, 16, 4, 8, 8)) @@ -231,7 +231,7 @@ class TestResampleDownsample: assert out.shape == (1, 16, 2, 4, 4) def test_upsample2d_still_works(self): - from mlx_video.models.wan2.vae import Resample + from mlx_video.models.wan_2.vae import Resample r = Resample(dim=16, mode="upsample2d") x = mx.random.normal((1, 16, 2, 4, 4)) @@ -240,7 +240,7 @@ class TestResampleDownsample: assert out.shape == (1, 8, 2, 8, 8) def test_upsample3d_still_works(self): - from mlx_video.models.wan2.vae import Resample + from mlx_video.models.wan_2.vae import Resample r = Resample(dim=16, mode="upsample3d") x = mx.random.normal((1, 16, 2, 4, 4)) @@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline: def test_full_i2v_pipeline(self): """End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.vae import WanVAE mx.random.seed(0) @@ -410,8 +410,8 @@ class TestDualModelSwitching: def test_model_selection_by_timestep(self): """Verify high_noise model used for timesteps >= boundary, low_noise otherwise.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler mx.random.seed(1) config = _make_tiny_i2v_config() @@ -485,8 +485,8 @@ class TestDualModelSwitching: def test_guide_scale_tuple_applied_per_model(self): """Verify (low_gs, high_gs) tuple applies different scales per model.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler mx.random.seed(2) config = _make_tiny_i2v_config() @@ -545,8 +545,8 @@ class TestDualModelSwitching: def test_single_model_fallback_with_tuple_guide_scale(self): """When dual_model=False, guide_scale tuple should use first element.""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler mx.random.seed(3) config = _make_tiny_config() diff --git a/tests/test_wan_lora.py b/tests/test_wan_lora.py index 1c4b84c..9d5e57d 100644 --- a/tests/test_wan_lora.py +++ b/tests/test_wan_lora.py @@ -331,7 +331,7 @@ class TestEndToEnd: """End-to-end LoRA loading and application.""" def test_load_and_apply_loras(self): - from mlx_video.models.wan2.convert import load_and_apply_loras + from mlx_video.models.wan_2.convert import load_and_apply_loras with tempfile.TemporaryDirectory() as tmp: # Create mock LoRA safetensors diff --git a/tests/test_wan_model.py b/tests/test_wan_model.py index e415052..b386fb3 100644 --- a/tests/test_wan_model.py +++ b/tests/test_wan_model.py @@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config class TestSinusoidalEmbedding: def test_output_shape(self): - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.arange(10).astype(mx.float32) emb = sinusoidal_embedding_1d(256, pos) @@ -21,7 +21,7 @@ class TestSinusoidalEmbedding: def test_position_zero(self): """Position 0 should have cos=1 for all dims and sin=0.""" - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0]) emb = sinusoidal_embedding_1d(64, pos) @@ -33,7 +33,7 @@ class TestSinusoidalEmbedding: np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5) def test_different_positions_differ(self): - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 999.0]) emb = sinusoidal_embedding_1d(128, pos) @@ -50,7 +50,7 @@ class TestSinusoidalEmbedding: class TestHead: def test_output_shape(self): - from mlx_video.models.wan2.wan2 import Head + from mlx_video.models.wan_2.wan_2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) B, L = 1, 24 @@ -62,7 +62,7 @@ class TestHead: assert out.shape == (B, L, expected_proj_dim) def test_modulation_shape(self): - from mlx_video.models.wan2.wan2 import Head + from mlx_video.models.wan_2.wan_2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) assert head.modulation.shape == (1, 2, 64) @@ -78,7 +78,7 @@ class TestWanModel: mx.random.seed(42) def test_instantiation(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -86,7 +86,7 @@ class TestWanModel: assert num_params > 0 def test_patchify_shape(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -99,7 +99,7 @@ class TestWanModel: assert patches.shape == (1, 1 * 2 * 2, config.dim) def test_patchify_various_sizes(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -115,7 +115,7 @@ class TestWanModel: def test_unpatchify_inverse(self): """Patchify then unpatchify should reconstruct original spatial dims.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -131,7 +131,7 @@ class TestWanModel: assert out[0].shape == (config.out_dim, F, H, W) def test_forward_pass(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -149,7 +149,7 @@ class TestWanModel: assert out[0].shape == (C, F, H, W) def test_forward_batch(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -171,7 +171,7 @@ class TestWanModel: assert o.shape == (C, F, H, W) def test_output_is_float32(self): - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -200,7 +200,7 @@ class TestWan21Model: def _make_tiny_wan21_config(self): """Create a tiny config mimicking Wan2.1 (single model).""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() # Override to tiny values @@ -217,7 +217,7 @@ class TestWan21Model: def _make_tiny_wan21_1_3b_config(self): """Create a tiny config mimicking Wan2.1 1.3B.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() # Override to tiny values (preserve 1.3B head structure: 12 heads) @@ -234,7 +234,7 @@ class TestWan21Model: def test_wan21_tiny_model_forward(self): """Forward pass with Wan2.1 tiny config.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = self._make_tiny_wan21_config() model = WanModel(config) @@ -252,7 +252,7 @@ class TestWan21Model: def test_wan21_1_3b_tiny_model_forward(self): """Forward pass with Wan2.1 1.3B tiny config.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = self._make_tiny_wan21_1_3b_config() model = WanModel(config) @@ -270,8 +270,8 @@ class TestWan21Model: def test_wan21_single_model_loop(self): """Full diffusion loop with single model (Wan2.1 style).""" - from mlx_video.models.wan2.wan2 import WanModel - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.wan_2 import WanModel + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler config = self._make_tiny_wan21_config() model = WanModel(config) @@ -305,7 +305,7 @@ class TestWan21Model: def test_wan21_vs_wan22_config_differences(self): """Verify key differences between Wan2.1 and Wan2.2 configs.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig c21 = WanModelConfig.wan21_t2v_14b() c22 = WanModelConfig.wan22_t2v_14b() @@ -333,21 +333,21 @@ class TestPerTokenTimestep: """Tests for per-token sinusoidal embedding.""" def test_1d_unchanged(self): - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 500.0]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (3, 256) def test_2d_per_token(self): - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (2, 3, 256) def test_consistency(self): - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos_1d = mx.array([0.0, 100.0]) emb_1d = sinusoidal_embedding_1d(256, pos_1d) diff --git a/tests/test_wan_quantization.py b/tests/test_wan_quantization.py index 14fe3ca..fda29e3 100644 --- a/tests/test_wan_quantization.py +++ b/tests/test_wan_quantization.py @@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config class TestQuantizePredicate: def test_matches_self_attention_layers(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: @@ -23,7 +23,7 @@ class TestQuantizePredicate: assert _quantize_predicate(path, mock_linear), f"Should match {path}" def test_matches_cross_attention_layers(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: @@ -31,14 +31,14 @@ class TestQuantizePredicate: assert _quantize_predicate(path, mock_linear), f"Should match {path}" def test_matches_ffn_layers(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) def test_rejects_embeddings(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for path in [ @@ -49,13 +49,13 @@ class TestQuantizePredicate: assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" def test_rejects_norms(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_norm = nn.RMSNorm(64) assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) def test_rejects_non_quantizable_modules(self): - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_norm = nn.RMSNorm(64) # Even if path matches, module must have to_quantized @@ -63,7 +63,7 @@ class TestQuantizePredicate: def test_all_10_patterns_covered(self): """Verify exactly 10 layer patterns are targeted.""" - from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan_2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) patterns = [ @@ -90,8 +90,8 @@ class TestQuantizePredicate: class TestQuantizeRoundTrip: def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): """Helper: create model, quantize, save to tmp_path.""" - from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_predicate + from mlx_video.models.wan_2.wan_2 import WanModel model = WanModel(config) nn.quantize( @@ -116,7 +116,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) - from mlx_video.models.wan2.utils import load_wan_model + from mlx_video.models.wan_2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -136,7 +136,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) - from mlx_video.models.wan2.utils import load_wan_model + from mlx_video.models.wan_2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -151,7 +151,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) - from mlx_video.models.wan2.utils import load_wan_model + from mlx_video.models.wan_2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -164,7 +164,7 @@ class TestQuantizeRoundTrip: def test_loading_without_quantization_flag(self, tmp_path): """Loading a non-quantized model should have standard Linear layers.""" - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -172,7 +172,7 @@ class TestQuantizeRoundTrip: model_path = tmp_path / "model.safetensors" mx.save_safetensors(str(model_path), weights_dict) - from mlx_video.models.wan2.utils import load_wan_model + from mlx_video.models.wan_2.utils import load_wan_model loaded = load_wan_model(model_path, config, quantization=None) @@ -187,8 +187,8 @@ class TestQuantizeRoundTrip: class TestQuantizedInference: def _make_quantized_model(self, config, bits=4): - from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_predicate + from mlx_video.models.wan_2.wan_2 import WanModel model = WanModel(config) nn.quantize( @@ -238,8 +238,8 @@ class TestQuantizedInference: def test_quantized_output_differs_from_unquantized(self): """Sanity check: quantization should change the weights.""" - from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_predicate + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() mx.random.seed(42) @@ -271,8 +271,8 @@ class TestQuantizedInference: class TestQuantizationConfig: def test_config_metadata_written(self, tmp_path): """Verify _quantize_saved_model writes quantization metadata to config.json.""" - from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_saved_model + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -295,8 +295,8 @@ class TestQuantizationConfig: assert cfg["quantization"]["group_size"] == 64 def test_config_metadata_8bit(self, tmp_path): - from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_saved_model + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -316,8 +316,8 @@ class TestQuantizationConfig: def test_dual_model_quantization(self, tmp_path): """Verify dual-model quantization writes both model files.""" - from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.convert import _quantize_saved_model + from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() diff --git a/tests/test_wan_rope_freqs.py b/tests/test_wan_rope_freqs.py index 93324a5..0b64bdb 100644 --- a/tests/test_wan_rope_freqs.py +++ b/tests/test_wan_rope_freqs.py @@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction: def _get_model_freqs(self, dim=64, num_heads=4): """Instantiate a tiny WanModel and return its .freqs tensor.""" - from mlx_video.models.wan2.config import WanModelConfig - from mlx_video.models.wan2.wan2 import WanModel + from mlx_video.models.wan_2.config import WanModelConfig + from mlx_video.models.wan_2.wan_2 import WanModel config = WanModelConfig() config.dim = dim @@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction: def test_three_call_vs_single_call_differ(self): """Three separate rope_params calls must differ from single call.""" - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 # head_dim for all Wan models # Reference: three separate calls @@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction: This verifies each axis gets its own independent frequency range starting from theta^0 = 1.0 (i.e., exponent 0/dim). """ - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction: Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42). """ - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 d_h_dim = 2 * (d // 6) # 42 @@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction: axis should be 1.0 (theta^0). A single-call approach would give height starting at ~0.04 and width at ~0.002 instead of 1.0. """ - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction: def test_model_freqs_match_manual_construction(self): """WanModel.freqs should match manually constructed three-call freqs.""" - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4) d = head_dim # 16 @@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction: def test_model_freqs_14b_dimensions(self): """Verify freq dimensions for 14B-scale head_dim=128.""" - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference: """Numerically compare MLX and PyTorch frequency tables.""" import torch - from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan_2.rope import rope_params d = 128 @@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs: This is the key property that was broken by the single-call bug: height/width frequencies were too low to distinguish nearby positions. """ - from mlx_video.models.wan2.rope import rope_apply, rope_params + from mlx_video.models.wan_2.rope import rope_apply, rope_params d = 128 freqs = mx.concatenate( @@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs: def test_precomputed_matches_online(self): """rope_precompute_cos_sin + rope_apply should match non-precomputed path.""" - from mlx_video.models.wan2.rope import ( + from mlx_video.models.wan_2.rope import ( rope_apply, rope_params, rope_precompute_cos_sin, diff --git a/tests/test_wan_scheduler.py b/tests/test_wan_scheduler.py index df5405c..3789a8d 100644 --- a/tests/test_wan_scheduler.py +++ b/tests/test_wan_scheduler.py @@ -13,7 +13,7 @@ import pytest class TestFlowMatchEulerScheduler: def test_initialization(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() assert sched.num_train_timesteps == 1000 @@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler: assert sched.sigmas is None def test_set_timesteps(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) @@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler: assert sched.sigmas.shape == (41,) # 40 steps + terminal def test_timesteps_decreasing(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) @@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler: assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..." def test_sigmas_decreasing(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=1.0) @@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler: assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing" def test_terminal_sigma_is_zero(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=5.0) @@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler: def test_shift_effect(self): """Larger shift should push sigmas toward higher values.""" - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched1 = FlowMatchEulerScheduler() sched2 = FlowMatchEulerScheduler() @@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler: assert mean2 > mean1, "Higher shift should push sigmas higher" def test_step_euler(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(10, shift=1.0) @@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler: ) def test_step_index_increments(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler: assert sched._step_index == 2 def test_reset(self): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler: @pytest.mark.parametrize("steps", [10, 20, 40, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(steps, shift=12.0) @@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler: def test_full_denoise_loop(self): """Run a complete denoise loop with zero velocity -> sample unchanged.""" - from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -153,26 +153,26 @@ class TestComputeSigmas: """Tests for the shared _compute_sigmas helper.""" def test_length(self): - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) assert len(sigmas) == 21 # num_steps + terminal def test_terminal_zero(self): - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(10, shift=1.0) assert sigmas[-1] == 0.0 def test_starts_near_one(self): - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) # Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0) np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3) def test_decreasing(self): - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) assert np.all(np.diff(sigmas) <= 0) @@ -185,7 +185,7 @@ class TestComputeSigmas: sigma_max/sigma_min come from the *unshifted* training schedule, and the shift is applied only once (single-shift). """ - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas steps, shift, N = 50, 5.0, 1000 sigmas = _compute_sigmas(steps, shift, N) @@ -200,7 +200,7 @@ class TestComputeSigmas: np.testing.assert_allclose(sigmas, official, atol=1e-6) def test_shift_one_is_near_linear(self): - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(10, shift=1.0) # With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule) @@ -210,7 +210,7 @@ class TestComputeSigmas: def test_all_schedulers_same_sigmas(self): """All three schedulers should produce identical sigma schedules.""" - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -229,7 +229,7 @@ class TestComputeSigmas: np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6) def test_all_schedulers_same_timesteps(self): - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -255,14 +255,14 @@ class TestComputeSigmas: class TestFlowDPMPP2MScheduler: def test_initialization(self): - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() assert sched.num_train_timesteps == 1000 assert sched.lower_order_final is True def test_set_timesteps(self): - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) @@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler: assert sched.sigmas.shape == (21,) def test_step_index_increments(self): - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler: assert sched._step_index == 2 def test_reset(self): - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler: def test_full_loop_finite(self): """Full loop with constant velocity should produce finite output.""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=1.0) @@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler: def test_first_step_is_first_order(self): """First step should use 1st-order (no prev_x0 available).""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) @@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler: def test_second_step_uses_correction(self): """After first step, DPM++ should have stored prev_x0 for correction.""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) @@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler: def test_denoise_to_target(self): """Perfect oracle should denoise to target with any solver.""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) @@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(steps, shift=5.0) @@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler: def test_terminal_sigma_produces_x0(self): """When sigma_next=0 the scheduler should return x0 directly.""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler: class TestFlowUniPCScheduler: def test_initialization(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() assert sched.num_train_timesteps == 1000 @@ -402,7 +402,7 @@ class TestFlowUniPCScheduler: assert sched.lower_order_final is True def test_set_timesteps(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(30, shift=12.0) @@ -411,7 +411,7 @@ class TestFlowUniPCScheduler: assert sched.sigmas.shape == (31,) def test_step_index_increments(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) @@ -422,7 +422,7 @@ class TestFlowUniPCScheduler: assert sched._step_index == 1 def test_reset(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) @@ -435,7 +435,7 @@ class TestFlowUniPCScheduler: assert all(m is None for m in sched._model_outputs) def test_full_loop_finite(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(10, shift=1.0) @@ -448,7 +448,7 @@ class TestFlowUniPCScheduler: def test_corrector_not_applied_first_step(self): """First step should skip the corrector (no history).""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -462,7 +462,7 @@ class TestFlowUniPCScheduler: def test_corrector_applied_after_first_step(self): """Steps after the first should use the corrector when enabled.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -475,7 +475,7 @@ class TestFlowUniPCScheduler: assert sched._lower_order_nums >= 2 def test_denoise_to_target(self): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(20, shift=5.0) @@ -490,7 +490,7 @@ class TestFlowUniPCScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(steps, shift=5.0) @@ -500,7 +500,7 @@ class TestFlowUniPCScheduler: def test_disable_corrector(self): """Disabling corrector on step 0 should still work without error.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) sched.set_timesteps(5, shift=1.0) @@ -513,7 +513,7 @@ class TestFlowUniPCScheduler: def test_solver_order_3(self): """Order 3 should work without error.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -531,7 +531,7 @@ class TestFlowUniPCScheduler: # For 50-step schedule with shift=5.0, order 2 corrector at step 5: # rhos_c[0] (history) should be ~0.07, NOT 0.5 # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 - from mlx_video.models.wan2.scheduler import _compute_sigmas + from mlx_video.models.wan_2.scheduler import _compute_sigmas sigmas = _compute_sigmas(50, shift=5.0) @@ -597,7 +597,7 @@ class TestSchedulerCoherence: @staticmethod def _make_schedulers(steps=10, shift=5.0): - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -780,7 +780,7 @@ class TestSchedulerCoherence: def test_lambda_boundary_values(self): """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowUniPCScheduler, ) @@ -800,7 +800,7 @@ class TestSchedulerCoherence: def test_lambda_monotonically_decreasing(self): """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" - from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas] @@ -902,7 +902,7 @@ class TestSchedulerCoherence: shape = (1, 2, 1, 2, 2) noise = mx.random.normal(shape) - from mlx_video.models.wan2.scheduler import ( + from mlx_video.models.wan_2.scheduler import ( FlowDPMPP2MScheduler, FlowUniPCScheduler, ) @@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault: def test_corrector_enabled_by_default(self): """Default construction should have corrector enabled.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() assert sched._use_corrector is True def test_corrector_affects_output(self): """Corrector should produce different results than no corrector after step 1.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler mx.random.seed(42) shape = (1, 4, 1, 4, 4) @@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault: def test_corrector_does_not_affect_first_step(self): """Step 0 should be identical regardless of corrector setting.""" - from mlx_video.models.wan2.scheduler import FlowUniPCScheduler + from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler mx.random.seed(42) shape = (1, 4, 1, 4, 4) diff --git a/tests/test_wan_t5.py b/tests/test_wan_t5.py index df103f7..0c606d8 100644 --- a/tests/test_wan_t5.py +++ b/tests/test_wan_t5.py @@ -11,7 +11,7 @@ import numpy as np class TestT5LayerNorm: def test_output_shape(self): - from mlx_video.models.wan2.text_encoder import T5LayerNorm + from mlx_video.models.wan_2.text_encoder import T5LayerNorm norm = T5LayerNorm(64) x = mx.random.normal((2, 10, 64)) @@ -21,7 +21,7 @@ class TestT5LayerNorm: def test_rms_normalization(self): """After T5LayerNorm with weight=1, RMS should be ~1.""" - from mlx_video.models.wan2.text_encoder import T5LayerNorm + from mlx_video.models.wan_2.text_encoder import T5LayerNorm norm = T5LayerNorm(128) x = mx.random.normal((1, 5, 128)) * 5.0 @@ -35,7 +35,7 @@ class TestT5LayerNorm: class TestT5RelativeEmbedding: def test_output_shape(self): - from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(10, 10) @@ -43,7 +43,7 @@ class TestT5RelativeEmbedding: assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk] def test_asymmetric_lengths(self): - from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(8, 12) @@ -52,7 +52,7 @@ class TestT5RelativeEmbedding: def test_symmetry(self): """Position bias should have structure (not all zeros/random).""" - from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2) out = rel_emb(6, 6) @@ -67,7 +67,7 @@ class TestT5RelativeEmbedding: class TestT5Attention: def test_output_shape(self): - from mlx_video.models.wan2.text_encoder import T5Attention + from mlx_video.models.wan_2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) @@ -77,14 +77,14 @@ class TestT5Attention: def test_no_scaling(self): """T5 attention famously has no sqrt(d) scaling. Verify structure.""" - from mlx_video.models.wan2.text_encoder import T5Attention + from mlx_video.models.wan_2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) # No scale attribute (unlike standard attention) assert not hasattr(attn, "scale") def test_with_position_bias(self): - from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding + from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding attn = T5Attention(dim=64, dim_attn=64, num_heads=4) rel_emb = T5RelativeEmbedding(32, 4) @@ -95,7 +95,7 @@ class TestT5Attention: assert out.shape == (1, 10, 64) def test_with_mask(self): - from mlx_video.models.wan2.text_encoder import T5Attention + from mlx_video.models.wan_2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) @@ -108,7 +108,7 @@ class TestT5Attention: class TestT5FeedForward: def test_output_shape(self): - from mlx_video.models.wan2.text_encoder import T5FeedForward + from mlx_video.models.wan_2.text_encoder import T5FeedForward ffn = T5FeedForward(64, 256) x = mx.random.normal((1, 10, 64)) @@ -118,7 +118,7 @@ class TestT5FeedForward: def test_gated_structure(self): """T5 FFN is gated: gate(x) * fc1(x).""" - from mlx_video.models.wan2.text_encoder import T5FeedForward + from mlx_video.models.wan_2.text_encoder import T5FeedForward ffn = T5FeedForward(32, 64) assert hasattr(ffn, "gate_proj") @@ -131,7 +131,7 @@ class TestT5Encoder: mx.random.seed(42) def test_output_shape(self): - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -150,7 +150,7 @@ class TestT5Encoder: assert out.shape == (1, 5, 64) def test_shared_pos(self): - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -167,7 +167,7 @@ class TestT5Encoder: assert block.pos_embedding is None def test_per_layer_pos(self): - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -184,7 +184,7 @@ class TestT5Encoder: assert block.pos_embedding is not None def test_param_count(self): - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -200,7 +200,7 @@ class TestT5Encoder: assert num_params > 0 def test_without_mask(self): - from mlx_video.models.wan2.text_encoder import T5Encoder + from mlx_video.models.wan_2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, diff --git a/tests/test_wan_tiling.py b/tests/test_wan_tiling.py index e55baac..b90eab9 100644 --- a/tests/test_wan_tiling.py +++ b/tests/test_wan_tiling.py @@ -75,7 +75,7 @@ class TestWan22TiledDecoding: def _make_small_wan22_decoder(self): """Create a small Wan2.2 decoder for testing.""" - from mlx_video.models.wan2.vae22 import Wan22VAEDecoder + from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder # Use very small dimensions for fast testing vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16) @@ -139,7 +139,7 @@ class TestWan21TiledDecoding: def _make_small_wan21_vae(self): """Create a small Wan2.1 VAE for testing.""" - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae = WanVAE(z_dim=16) mx.eval(vae.parameters()) @@ -192,7 +192,7 @@ class TestWan21TemporalScale: def test_wan21_decoder_temporal_output(self): """Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling).""" - from mlx_video.models.wan2.vae import Decoder3d + from mlx_video.models.wan_2.vae import Decoder3d # Small decoder for fast test dec = Decoder3d( diff --git a/tests/test_wan_transformer.py b/tests/test_wan_transformer.py index 66df8c5..0722958 100644 --- a/tests/test_wan_transformer.py +++ b/tests/test_wan_transformer.py @@ -10,7 +10,7 @@ import numpy as np class TestWanFFN: def test_output_shape(self): - from mlx_video.models.wan2.transformer import WanFFN + from mlx_video.models.wan_2.transformer import WanFFN ffn = WanFFN(64, 256) x = mx.random.normal((2, 10, 64)) @@ -20,7 +20,7 @@ class TestWanFFN: def test_gelu_activation(self): """FFN should use GELU activation (non-linearity).""" - from mlx_video.models.wan2.transformer import WanFFN + from mlx_video.models.wan_2.transformer import WanFFN ffn = WanFFN(32, 128) x = mx.ones((1, 1, 32)) * 2.0 @@ -40,8 +40,8 @@ class TestWanAttentionBlock: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan2.rope import rope_params - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.rope import rope_params + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -68,13 +68,13 @@ class TestWanAttentionBlock: assert out.shape == (B, L, self.dim) def test_modulation_shape(self): - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) assert block.modulation.shape == (1, 6, self.dim) def test_with_cross_attn_norm(self): - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -85,7 +85,7 @@ class TestWanAttentionBlock: assert block.norm3 is not None def test_without_cross_attn_norm(self): - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -97,8 +97,8 @@ class TestWanAttentionBlock: def test_residual_connection(self): """Output should differ from zero even with small random init.""" - from mlx_video.models.wan2.rope import rope_params - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.rope import rope_params + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) B, L = 1, 8 @@ -129,15 +129,15 @@ class TestFloat32Modulation: def test_block_modulation_in_float32(self): """Modulation param starts random but should be usable as float32.""" - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True) assert block.modulation.dtype == mx.float32 def test_block_output_float32_with_bf16_modulation_input(self): """Even if e (time embedding) arrives as bf16, modulation should cast to f32.""" - from mlx_video.models.wan2.rope import rope_params - from mlx_video.models.wan2.transformer import WanAttentionBlock + from mlx_video.models.wan_2.rope import rope_params + from mlx_video.models.wan_2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, 4) B, L = 1, 8 @@ -153,7 +153,7 @@ class TestFloat32Modulation: def test_head_modulation_float32(self): """Head modulation should be float32 even with bf16 e input.""" - from mlx_video.models.wan2.wan2 import Head + from mlx_video.models.wan_2.wan_2 import Head head = Head(self.dim, 4, (1, 2, 2)) x = mx.random.normal((1, 8, self.dim)) @@ -164,7 +164,7 @@ class TestFloat32Modulation: def test_model_time_embedding_float32(self): """sinusoidal_embedding_1d output must be float32.""" - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d t = mx.array([500.0]) emb = sinusoidal_embedding_1d(256, t) @@ -173,7 +173,7 @@ class TestFloat32Modulation: def test_model_per_token_time_embedding_float32(self): """Per-token time embeddings (I2V) should also be float32.""" - from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d + from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] emb = sinusoidal_embedding_1d(256, t) diff --git a/tests/test_wan_vae.py b/tests/test_wan_vae.py index 85c8381..255ef71 100644 --- a/tests/test_wan_vae.py +++ b/tests/test_wan_vae.py @@ -12,7 +12,7 @@ import numpy as np class TestCausalConv3d: def test_output_shape_stride1(self): - from mlx_video.models.wan2.vae import CausalConv3d + from mlx_video.models.wan_2.vae import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1) # Initialize weights @@ -28,7 +28,7 @@ class TestCausalConv3d: assert out.shape[4] == 8 # W preserved def test_output_shape_kernel1(self): - from mlx_video.models.wan2.vae import CausalConv3d + from mlx_video.models.wan_2.vae import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0) conv.weight = mx.random.normal(conv.weight.shape) * 0.02 @@ -39,7 +39,7 @@ class TestCausalConv3d: def test_causal_padding(self): """Causal conv should only use past/current frames, not future.""" - from mlx_video.models.wan2.vae import CausalConv3d + from mlx_video.models.wan_2.vae import CausalConv3d conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 @@ -56,7 +56,7 @@ class TestCausalConv3d: class TestResidualBlock: def test_same_dim(self): - from mlx_video.models.wan2.vae import ResidualBlock + from mlx_video.models.wan_2.vae import ResidualBlock block = ResidualBlock(8, 8) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -65,7 +65,7 @@ class TestResidualBlock: assert out.shape == (1, 8, 2, 4, 4) def test_different_dim(self): - from mlx_video.models.wan2.vae import ResidualBlock + from mlx_video.models.wan_2.vae import ResidualBlock block = ResidualBlock(8, 16) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -74,13 +74,13 @@ class TestResidualBlock: assert out.shape == (1, 16, 2, 4, 4) def test_shortcut_exists_when_dims_differ(self): - from mlx_video.models.wan2.vae import ResidualBlock + from mlx_video.models.wan_2.vae import ResidualBlock block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_when_dims_same(self): - from mlx_video.models.wan2.vae import ResidualBlock + from mlx_video.models.wan_2.vae import ResidualBlock block = ResidualBlock(8, 8) assert block.shortcut is None @@ -88,7 +88,7 @@ class TestResidualBlock: class TestAttentionBlock: def test_output_shape(self): - from mlx_video.models.wan2.vae import AttentionBlock + from mlx_video.models.wan_2.vae import AttentionBlock block = AttentionBlock(8) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -97,7 +97,7 @@ class TestAttentionBlock: assert out.shape == (1, 8, 2, 4, 4) def test_residual_connection(self): - from mlx_video.models.wan2.vae import AttentionBlock + from mlx_video.models.wan_2.vae import AttentionBlock block = AttentionBlock(8) x = mx.random.normal((1, 8, 1, 3, 3)) @@ -109,7 +109,7 @@ class TestAttentionBlock: class TestWanVAE: def test_instantiation(self): - from mlx_video.models.wan2.vae import WanVAE + from mlx_video.models.wan_2.vae import WanVAE vae = WanVAE(z_dim=16) assert vae.z_dim == 16 @@ -117,7 +117,7 @@ class TestWanVAE: assert vae.std.shape == (16,) def test_normalization_stats(self): - from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD + from mlx_video.models.wan_2.vae import VAE_MEAN, VAE_STD assert len(VAE_MEAN) == 16 assert len(VAE_STD) == 16 @@ -133,7 +133,7 @@ class TestVAE22CausalConv3d: """Tests for vae22.CausalConv3d (channels-last).""" def test_output_shape_k3(self): - from mlx_video.models.wan2.vae22 import CausalConv3d + from mlx_video.models.wan_2.vae22 import CausalConv3d conv = CausalConv3d(8, 16, kernel_size=3, padding=1) x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] @@ -142,7 +142,7 @@ class TestVAE22CausalConv3d: assert out.shape == (1, 4, 8, 8, 16) def test_output_shape_k1(self): - from mlx_video.models.wan2.vae22 import CausalConv3d + from mlx_video.models.wan_2.vae22 import CausalConv3d conv = CausalConv3d(8, 16, kernel_size=1) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -152,7 +152,7 @@ class TestVAE22CausalConv3d: def test_temporal_causal(self): """Output at t=0 should not depend on t>0.""" - from mlx_video.models.wan2.vae22 import CausalConv3d + from mlx_video.models.wan_2.vae22 import CausalConv3d conv = CausalConv3d(2, 2, kernel_size=3, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 @@ -178,7 +178,7 @@ class TestVAE22CausalConv3d: def test_channels_last_format(self): """Verify input/output are channels-last [B, T, H, W, C].""" - from mlx_video.models.wan2.vae22 import CausalConv3d + from mlx_video.models.wan_2.vae22 import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=3, padding=1) x = mx.random.normal((2, 3, 6, 6, 4)) @@ -191,7 +191,7 @@ class TestRMSNorm: """Tests for vae22.RMS_norm (actually L2 normalization).""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import RMS_norm + from mlx_video.models.wan_2.vae22 import RMS_norm norm = RMS_norm(16) x = mx.random.normal((2, 4, 4, 4, 16)) @@ -201,7 +201,7 @@ class TestRMSNorm: def test_l2_normalization(self): """RMS_norm should normalize to unit L2 norm * sqrt(dim).""" - from mlx_video.models.wan2.vae22 import RMS_norm + from mlx_video.models.wan_2.vae22 import RMS_norm dim = 32 norm = RMS_norm(dim) @@ -215,7 +215,7 @@ class TestRMSNorm: def test_scale_invariant(self): """Scaling input by constant should not change output (L2 norm property).""" - from mlx_video.models.wan2.vae22 import RMS_norm + from mlx_video.models.wan_2.vae22 import RMS_norm norm = RMS_norm(8) x = mx.random.normal((1, 1, 1, 1, 8)) @@ -226,7 +226,7 @@ class TestRMSNorm: def test_gamma_effect(self): """Non-unit gamma should scale output.""" - from mlx_video.models.wan2.vae22 import RMS_norm + from mlx_video.models.wan_2.vae22 import RMS_norm norm = RMS_norm(4) norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) @@ -241,7 +241,7 @@ class TestDupUp3D: """Tests for vae22.DupUp3D spatial/temporal upsampling.""" def test_spatial_only(self): - from mlx_video.models.wan2.vae22 import DupUp3D + from mlx_video.models.wan_2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -250,7 +250,7 @@ class TestDupUp3D: assert out.shape == (1, 3, 8, 8, 4) def test_temporal_and_spatial(self): - from mlx_video.models.wan2.vae22 import DupUp3D + from mlx_video.models.wan_2.vae22 import DupUp3D up = DupUp3D(16, 8, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 16)) @@ -259,7 +259,7 @@ class TestDupUp3D: assert out.shape == (1, 6, 8, 8, 8) def test_first_chunk_trims(self): - from mlx_video.models.wan2.vae22 import DupUp3D + from mlx_video.models.wan_2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -271,7 +271,7 @@ class TestDupUp3D: assert out_trimmed.shape[1] == 5 def test_no_temporal_first_chunk_noop(self): - from mlx_video.models.wan2.vae22 import DupUp3D + from mlx_video.models.wan_2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -286,7 +286,7 @@ class TestVAE22Resample: """Tests for vae22.Resample (spatial/temporal upsampling).""" def test_upsample2d_shape(self): - from mlx_video.models.wan2.vae22 import Resample + from mlx_video.models.wan_2.vae22 import Resample r = Resample(8, "upsample2d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -296,7 +296,7 @@ class TestVAE22Resample: assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal def test_upsample3d_shape(self): - from mlx_video.models.wan2.vae22 import Resample + from mlx_video.models.wan_2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -306,7 +306,7 @@ class TestVAE22Resample: assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal def test_upsample3d_first_chunk(self): - from mlx_video.models.wan2.vae22 import Resample + from mlx_video.models.wan_2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -318,7 +318,7 @@ class TestVAE22Resample: def test_upsample3d_first_chunk_single_frame(self): """Single-frame input with first_chunk: no temporal upsample.""" - from mlx_video.models.wan2.vae22 import Resample + from mlx_video.models.wan_2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -336,7 +336,7 @@ class TestVAE22Resample: We verify this by checking that the first output frame depends only on the first input frame (not on time_conv parameters). """ - from mlx_video.models.wan2.vae22 import Resample + from mlx_video.models.wan_2.vae22 import Resample C = 8 r = Resample(C, "upsample3d") @@ -373,7 +373,7 @@ class TestVAE22ResidualBlock: """Tests for vae22.ResidualBlock.""" def test_same_dim(self): - from mlx_video.models.wan2.vae22 import ResidualBlock + from mlx_video.models.wan_2.vae22 import ResidualBlock block = ResidualBlock(8, 8) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -382,7 +382,7 @@ class TestVAE22ResidualBlock: assert out.shape == (1, 2, 4, 4, 8) def test_different_dim(self): - from mlx_video.models.wan2.vae22 import ResidualBlock + from mlx_video.models.wan_2.vae22 import ResidualBlock block = ResidualBlock(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -391,13 +391,13 @@ class TestVAE22ResidualBlock: assert out.shape == (1, 2, 4, 4, 16) def test_shortcut_when_dims_differ(self): - from mlx_video.models.wan2.vae22 import ResidualBlock + from mlx_video.models.wan_2.vae22 import ResidualBlock block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_same_dim(self): - from mlx_video.models.wan2.vae22 import ResidualBlock + from mlx_video.models.wan_2.vae22 import ResidualBlock block = ResidualBlock(8, 8) assert block.shortcut is None @@ -408,7 +408,7 @@ class TestResidualBlockLayers: def test_layer_names_no_underscore_prefix(self): """Layer names must NOT start with underscore (MLX ignores them).""" - from mlx_video.models.wan2.vae22 import ResidualBlockLayers + from mlx_video.models.wan_2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 8) params = dict(block.parameters()) @@ -417,7 +417,7 @@ class TestResidualBlockLayers: assert not key.startswith("_"), f"Parameter {key} starts with underscore" def test_has_expected_layers(self): - from mlx_video.models.wan2.vae22 import ResidualBlockLayers + from mlx_video.models.wan_2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 16) assert hasattr(block, "layer_0") # first RMS_norm @@ -426,7 +426,7 @@ class TestResidualBlockLayers: assert hasattr(block, "layer_6") # second CausalConv3d def test_forward_shape(self): - from mlx_video.models.wan2.vae22 import ResidualBlockLayers + from mlx_video.models.wan_2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -439,7 +439,7 @@ class TestVAE22AttentionBlock: """Tests for vae22.AttentionBlock (per-frame 2D self-attention).""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import AttentionBlock + from mlx_video.models.wan_2.vae22 import AttentionBlock block = AttentionBlock(16) block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 @@ -450,7 +450,7 @@ class TestVAE22AttentionBlock: assert out.shape == (1, 2, 4, 4, 16) def test_residual_connection(self): - from mlx_video.models.wan2.vae22 import AttentionBlock + from mlx_video.models.wan_2.vae22 import AttentionBlock block = AttentionBlock(8) block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) @@ -466,7 +466,7 @@ class TestHead22: """Tests for vae22.Head22 output head.""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import Head22 + from mlx_video.models.wan_2.vae22 import Head22 head = Head22(16, out_channels=12) x = mx.random.normal((1, 2, 4, 4, 16)) @@ -476,7 +476,7 @@ class TestHead22: def test_layer_names_no_underscore(self): """Head layers must not use underscore prefix.""" - from mlx_video.models.wan2.vae22 import Head22 + from mlx_video.models.wan_2.vae22 import Head22 head = Head22(8) assert hasattr(head, "layer_0") # RMS_norm @@ -490,7 +490,7 @@ class TestUnpatchify: """Tests for vae22._unpatchify.""" def test_basic_shape(self): - from mlx_video.models.wan2.vae22 import _unpatchify + from mlx_video.models.wan_2.vae22 import _unpatchify x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 out = _unpatchify(x, patch_size=2) @@ -498,7 +498,7 @@ class TestUnpatchify: assert out.shape == (1, 2, 8, 8, 3) def test_patch_size_1_noop(self): - from mlx_video.models.wan2.vae22 import _unpatchify + from mlx_video.models.wan_2.vae22 import _unpatchify x = mx.random.normal((1, 2, 4, 4, 3)) out = _unpatchify(x, patch_size=1) @@ -507,7 +507,7 @@ class TestUnpatchify: def test_preserves_content(self): """Unpatchify should be a lossless rearrangement.""" - from mlx_video.models.wan2.vae22 import _unpatchify + from mlx_video.models.wan_2.vae22 import _unpatchify x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) out = _unpatchify(x, patch_size=2) @@ -521,7 +521,7 @@ class TestDenormalizeLatents: """Tests for vae22.denormalize_latents.""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import denormalize_latents + from mlx_video.models.wan_2.vae22 import denormalize_latents z = mx.random.normal((1, 2, 4, 4, 48)) out = denormalize_latents(z) @@ -529,7 +529,7 @@ class TestDenormalizeLatents: assert out.shape == (1, 2, 4, 4, 48) def test_custom_mean_std(self): - from mlx_video.models.wan2.vae22 import denormalize_latents + from mlx_video.models.wan_2.vae22 import denormalize_latents z = mx.ones((1, 1, 1, 1, 4)) mean = mx.array([1.0, 2.0, 3.0, 4.0]) @@ -542,7 +542,7 @@ class TestDenormalizeLatents: ) def test_uses_default_constants(self): - from mlx_video.models.wan2.vae22 import ( + from mlx_video.models.wan_2.vae22 import ( VAE22_MEAN, denormalize_latents, ) @@ -563,14 +563,14 @@ class TestVAE22NormConstants: """Tests for VAE22_MEAN and VAE22_STD constants.""" def test_dimensions(self): - from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD + from mlx_video.models.wan_2.vae22 import VAE22_MEAN, VAE22_STD mx.eval(VAE22_MEAN, VAE22_STD) assert VAE22_MEAN.shape == (48,) assert VAE22_STD.shape == (48,) def test_std_positive(self): - from mlx_video.models.wan2.vae22 import VAE22_STD + from mlx_video.models.wan_2.vae22 import VAE22_STD mx.eval(VAE22_STD) assert (np.array(VAE22_STD) > 0).all() @@ -581,7 +581,7 @@ class TestWan22VAEDecoder: def test_output_shape_small(self): """Tiny decoder should produce correct spatial/temporal output.""" - from mlx_video.models.wan2.vae22 import Wan22VAEDecoder + from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder # Use very small dims to keep test fast dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) @@ -597,7 +597,7 @@ class TestWan22VAEDecoder: assert np.array(out).max() <= 1.0 def test_output_clipped(self): - from mlx_video.models.wan2.vae22 import Wan22VAEDecoder + from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values @@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights: """Tests for vae22.sanitize_wan22_vae_weights.""" def test_skip_encoder(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.layer.weight": mx.zeros((4,)), @@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.conv1.bias" in out def test_sequential_index_remapping(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), @@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.head.layer_2.bias" in out def test_resample_conv_remapping(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), @@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.upsamples.1.upsamples.3.resample_bias" in out def test_attention_remapping(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), @@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.middle.1.proj_bias" in out def test_conv3d_transpose(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] w = mx.zeros((16, 8, 3, 3, 3)) @@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights: assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8) def test_conv2d_transpose(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights # Conv2d weight: [O, I, H, W] → [O, H, W, I] w = mx.zeros((8, 8, 3, 3)) @@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights: assert out[key].shape == (8, 3, 3, 8) def test_gamma_squeeze(self): - from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights # gamma: (dim, 1, 1, 1) → (dim,) w = mx.ones((16, 1, 1, 1)) @@ -698,7 +698,7 @@ class TestUpResidualBlock: """Tests for vae22.Up_ResidualBlock.""" def test_no_upsample(self): - from mlx_video.models.wan2.vae22 import Up_ResidualBlock + from mlx_video.models.wan_2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False @@ -710,7 +710,7 @@ class TestUpResidualBlock: assert out.shape == (1, 2, 4, 4, 8) def test_spatial_upsample(self): - from mlx_video.models.wan2.vae22 import Up_ResidualBlock + from mlx_video.models.wan_2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True @@ -722,7 +722,7 @@ class TestUpResidualBlock: assert out.shape == (1, 2, 8, 8, 4) def test_spatial_temporal_upsample(self): - from mlx_video.models.wan2.vae22 import Up_ResidualBlock + from mlx_video.models.wan_2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True @@ -738,7 +738,7 @@ class TestPatchify: """Tests for _patchify and _unpatchify round-trip.""" def test_roundtrip(self): - from mlx_video.models.wan2.vae22 import _patchify, _unpatchify + from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify x = mx.random.normal((1, 1, 64, 64, 3)) p = _patchify(x, patch_size=2) @@ -748,7 +748,7 @@ class TestPatchify: assert float(mx.abs(x - back).max()) == 0.0 def test_identity_patch_1(self): - from mlx_video.models.wan2.vae22 import _patchify, _unpatchify + from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify x = mx.random.normal((1, 2, 8, 8, 3)) assert _patchify(x, patch_size=1).shape == x.shape @@ -759,7 +759,7 @@ class TestAvgDown3D: """Tests for AvgDown3D downsampling.""" def test_spatial_only(self): - from mlx_video.models.wan2.vae22 import AvgDown3D + from mlx_video.models.wan_2.vae22 import AvgDown3D down = AvgDown3D(8, 16, factor_t=1, factor_s=2) x = mx.random.normal((1, 2, 8, 8, 8)) @@ -768,7 +768,7 @@ class TestAvgDown3D: assert out.shape == (1, 2, 4, 4, 16) def test_temporal_and_spatial(self): - from mlx_video.models.wan2.vae22 import AvgDown3D + from mlx_video.models.wan_2.vae22 import AvgDown3D down = AvgDown3D(8, 16, factor_t=2, factor_s=2) x = mx.random.normal((1, 4, 8, 8, 8)) @@ -777,7 +777,7 @@ class TestAvgDown3D: assert out.shape == (1, 2, 4, 4, 16) def test_single_frame(self): - from mlx_video.models.wan2.vae22 import AvgDown3D + from mlx_video.models.wan_2.vae22 import AvgDown3D down = AvgDown3D(8, 8, factor_t=2, factor_s=2) x = mx.random.normal((1, 1, 8, 8, 8)) @@ -791,7 +791,7 @@ class TestDownResidualBlock: """Tests for Down_ResidualBlock.""" def test_no_downsample(self): - from mlx_video.models.wan2.vae22 import Down_ResidualBlock + from mlx_video.models.wan_2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False @@ -802,7 +802,7 @@ class TestDownResidualBlock: assert out.shape == (1, 2, 8, 8, 8) def test_spatial_downsample(self): - from mlx_video.models.wan2.vae22 import Down_ResidualBlock + from mlx_video.models.wan_2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True @@ -813,7 +813,7 @@ class TestDownResidualBlock: assert out.shape == (1, 2, 4, 4, 16) def test_spatial_temporal_downsample(self): - from mlx_video.models.wan2.vae22 import Down_ResidualBlock + from mlx_video.models.wan_2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True @@ -828,7 +828,7 @@ class TestEncoder3d: """Tests for Encoder3d.""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import Encoder3d + from mlx_video.models.wan_2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8) x = mx.random.normal((1, 1, 16, 16, 12)) @@ -839,7 +839,7 @@ class TestEncoder3d: assert out.shape == (1, 1, 2, 2, 8) def test_multi_frame(self): - from mlx_video.models.wan2.vae22 import Encoder3d + from mlx_video.models.wan_2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False)) x = mx.random.normal((1, 5, 16, 16, 12)) @@ -854,7 +854,7 @@ class TestWan22VAEEncoder: """Tests for Wan22VAEEncoder wrapper.""" def test_output_shape(self): - from mlx_video.models.wan2.vae22 import Wan22VAEEncoder + from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) # Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2) @@ -865,7 +865,7 @@ class TestWan22VAEEncoder: assert z.shape == (1, 1, 2, 2, 48) def test_full_dim(self): - from mlx_video.models.wan2.vae22 import Wan22VAEEncoder + from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=160) img = mx.random.normal((1, 1, 64, 64, 3)) @@ -880,7 +880,7 @@ class TestNormalizeLatents: """Tests for normalize/denormalize latent roundtrip.""" def test_roundtrip(self): - from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents + from mlx_video.models.wan_2.vae22 import denormalize_latents, normalize_latents z = mx.random.normal((1, 2, 4, 4, 48)) z_norm = normalize_latents(z) @@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder: def test_encoder_temporal_downsample_pattern(self): """Encoder3d with (False, True, True): T=5→5→3→2.""" - from mlx_video.models.wan2.vae22 import Encoder3d + from mlx_video.models.wan_2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) x = mx.random.normal((1, 5, 16, 16, 12)) @@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder: def test_wrapper_uses_correct_pattern(self): """Wan22VAEEncoder should use (False, True, True) temporal downsample.""" - from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder + from mlx_video.models.wan_2.vae22 import Resample, Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) down_blocks = enc.encoder.downsamples @@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder: def test_single_frame_encoder(self): """Single frame (T=1) should work with (False, True, True) pattern.""" - from mlx_video.models.wan2.vae22 import Wan22VAEEncoder + from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) img = mx.random.normal((1, 1, 32, 32, 3)) @@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder: def test_wrong_order_gives_different_result(self): """(True, True, False) vs (False, True, True) produce different outputs.""" - from mlx_video.models.wan2.vae22 import Encoder3d + from mlx_video.models.wan_2.vae22 import Encoder3d enc_correct = Encoder3d( dim=16, z_dim=8, temperal_downsample=(False, True, True) @@ -963,7 +963,7 @@ class TestVAE21RoundTrip: def test_encode_decode_shape_and_values(self): """Encoder3d → Decoder3d: output shape matches input, values are finite.""" - from mlx_video.models.wan2.vae import Decoder3d, Encoder3d + from mlx_video.models.wan_2.vae import Decoder3d, Encoder3d z_dim = 4 dim = 8 @@ -995,7 +995,7 @@ class TestVAE22RoundTrip: def test_encode_decode_shape_and_values(self): """Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range.""" - from mlx_video.models.wan2.vae22 import ( + from mlx_video.models.wan_2.vae22 import ( Wan22VAEDecoder, Wan22VAEEncoder, denormalize_latents, diff --git a/tests/wan_test_helpers.py b/tests/wan_test_helpers.py index 2b67ada..cdaab1e 100644 --- a/tests/wan_test_helpers.py +++ b/tests/wan_test_helpers.py @@ -3,7 +3,7 @@ def _make_tiny_config(): """Create a tiny WanModelConfig for testing.""" - from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig() # Override to tiny values