diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index 985ac87..7c50343 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -22,7 +22,7 @@ from mlx_video.models.ltx_2.utils import ( load_safetensors, save_weights, ) -from mlx_video.models.wan import WanModel, WanModelConfig +from mlx_video.models.wan2 import WanModel, WanModelConfig __all__ = [ # Models diff --git a/mlx_video/models/__init__.py b/mlx_video/models/__init__.py index 4c49754..b54c40d 100644 --- a/mlx_video/models/__init__.py +++ b/mlx_video/models/__init__.py @@ -1,2 +1,2 @@ from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig -from mlx_video.models.wan import WanModel, WanModelConfig +from mlx_video.models.wan2 import WanModel, WanModelConfig diff --git a/mlx_video/models/wan2/__init__.py b/mlx_video/models/wan2/__init__.py index c0f37a8..b9c08ac 100644 --- a/mlx_video/models/wan2/__init__.py +++ b/mlx_video/models/wan2/__init__.py @@ -1,2 +1,2 @@ -from mlx_video.models.wan.config import WanModelConfig -from mlx_video.models.wan.model import WanModel +from mlx_video.models.wan2.config import WanModelConfig +from mlx_video.models.wan2.model import WanModel diff --git a/mlx_video/models/wan2/config.py b/mlx_video/models/wan2/config.py index deb0d78..b3b2019 100644 --- a/mlx_video/models/wan2/config.py +++ b/mlx_video/models/wan2/config.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Tuple, Union -from mlx_video.models.ltx.config import BaseModelConfig +from mlx_video.models.ltx_2.config import BaseModelConfig @dataclass diff --git a/mlx_video/models/wan2/convert.py b/mlx_video/models/wan2/convert.py index 657eee7..8ae510f 100644 --- a/mlx_video/models/wan2/convert.py +++ b/mlx_video/models/wan2/convert.py @@ -247,7 +247,7 @@ def _load_lora_configs( Shared between weight-merging and runtime-wrapping paths. """ - from mlx_video.generate_wan import Colors + from mlx_video.models.wan2.generate import Colors from mlx_video.lora import LoRAConfig, load_multiple_loras print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") @@ -282,7 +282,7 @@ def load_and_apply_loras( For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). """ - from mlx_video.generate_wan import Colors + from mlx_video.models.wan2.generate import Colors from mlx_video.lora import apply_loras_to_weights if not lora_configs: @@ -411,7 +411,7 @@ def convert_wan_checkpoint( print(" Warning: No transformer weights found!") # Save config — detect model size from source config.json or transformer weights - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig def _detect_config(): """Detect config from source config.json or transformer weight shapes.""" @@ -522,7 +522,7 @@ def convert_wan_checkpoint( print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...") weights = load_torch_weights(str(vae_path)) if is_wan22_vae: - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights include_encoder = config.model_type in ("ti2v", "i2v") weights = sanitize_wan22_vae_weights( @@ -594,7 +594,7 @@ def _quantize_saved_model( import mlx.nn as nn - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel if source_dir is None: source_dir = output_dir @@ -704,7 +704,7 @@ def quantize_mlx_model( ).exists() # Build model config - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config_dict = { k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__ diff --git a/mlx_video/models/wan2/docs/DIAGNOSTICS.md b/mlx_video/models/wan2/docs/DIAGNOSTICS.md deleted file mode 100644 index 3b6c456..0000000 --- a/mlx_video/models/wan2/docs/DIAGNOSTICS.md +++ /dev/null @@ -1,394 +0,0 @@ -# Wan2.2 I2V-14B Diagnostic Report - -This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix. - -## Table of Contents - -- [Overview](#overview) -- [Architecture Summary](#architecture-summary) -- [Diagnostic Methodology](#diagnostic-methodology) -- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination) -- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion) -- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original) -- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong) -- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order) -- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding) -- [Verified Correct Components](#verified-correct-components) -- [Performance Optimizations](#performance-optimizations) -- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation) -- [Reference Implementation](#reference-implementation) -- [Useful Diagnostic Commands](#useful-diagnostic-commands) - ---- - -## Overview - -The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey. - -Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level. - -### Timeline of Symptoms - -| Stage | Symptom | Root Cause | -|-------|---------|------------| -| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) | -| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) | -| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) | -| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) | - ---- - -## Architecture Summary - -``` -I2V-14B Pipeline: - Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat] - ↓ - Mask Construction → [4, T_lat, H_lat, W_lat] - ↓ - y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat] - ↓ - Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat] - ↓ - Dual DiT (40 layers, 5120 dim) × 40 denoising steps - ↓ - Denoised Latent [16, T_lat, H_lat, W_lat] - ↓ - VAE Decoder → Video [3, F, H, W] -``` - -**Key parameters:** -- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16` -- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900) -- 40 steps, shift=5.0, guide_scale=(3.5, 3.5) -- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8) - ---- - -## Diagnostic Methodology - -### 1. Component-Level Numerical Verification - -Each component was tested in isolation against the reference PyTorch implementation: - -1. **Load identical inputs** (same random seed, same image, same prompt) -2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy` -3. **Run through MLX** with the same inputs -4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics - -Components tested this way: -- RoPE frequency parameters and rotation output -- Time embedding (sinusoidal → MLP → projection) -- Patchify (reshape+Linear vs Conv3d) -- Unpatchify (transpose-based vs einsum) -- Scheduler (UniPC) timesteps and step formulas -- VAE encoder output (frame-by-frame comparison) -- Text embeddings (per-model MLP output) -- Cross-attention K/V cache shapes -- Mask construction values - -### 2. Artifact Analysis - -When visual artifacts appeared, quantitative metrics were used to characterize them: - -- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard. -- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts. -- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation. -- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content. - -### 3. Isolation Testing - -- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source. -- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness. -- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering. -- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence. -- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs. - -### 4. Weight Comparison - -Direct comparison of converted MLX weights against original PyTorch weights: -```python -# Load both weight sets -pt_weights = torch.load("model.safetensors") -mlx_weights = mx.load("model.safetensors") -# Compare each key -for key in pt_weights: - diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max() -``` -Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys. - ---- - -## Bug 1: Text Embedding Cross-Contamination - -**Symptom:** Model ignores text prompt, generated frames lack semantic content. - -**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output. - -**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices. - -**Fix:** Compute separate text embeddings per model: -```python -# Before (broken): -context_emb = low_noise_model.embed_text([context, context_null]) -cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models - -# After (correct): -context_emb_low = low_noise_model.embed_text([context, context_null]) -context_emb_high = high_noise_model.embed_text([context, context_null]) -cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low) -cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high) -``` - -**File:** `mlx_video/generate_wan.py` (lines 333–349) -**Commit:** `a85b1c21` - ---- - -## Bug 2: VAE Encoder Weights Excluded from Conversion - -**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion). - -**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization. - -**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning). - -**Fix:** -```python -# Before: -include_encoder = config.model_type == "ti2v" -# After: -include_encoder = config.model_type in ("ti2v", "i2v") -``` - -**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions. - -**File:** `mlx_video/convert_wan.py` (line 424) - ---- - -## Bug 3: RoPE Frequency Computation (original) - -**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame. - -**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below). - -**File:** `mlx_video/models/wan/model.py` -**Commit:** `3da4a637` - ---- - -## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong) - -**Symptom:** I2V generates input image in frames 0–3, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion. - -**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400–405) actually uses **three separate calls with different dimension normalizations**, concatenated: - -```python -# Reference (CORRECT): -d = dim // num_heads # 128 -self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44) - rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42) - rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42) -], dim=1) -``` - -Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave: -- Temporal: low frequencies only [1.0 → 0.049] -- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!) -- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!) - -The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference). - -**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes. - -**Fix:** -```python -d = dim // config.num_heads -self.freqs = mx.concatenate([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)), -], axis=1) -``` - -**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match). - -**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions. - -**File:** `mlx_video/models/wan/model.py` (line 155) - ---- - -## Bug 4: VAE Encoder Temporal Downsample Order - -**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 1–4 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34). - -**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters: - -``` -Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d -Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG -``` - -This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation. - -**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output: -- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct -- Frames 1–4 had massive differences — proved temporal processing was broken - -Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`. - -**Fix:** -```python -# In Encoder3d.__init__, change default: -temporal_downsample = [False, True, True] # was [True, True, False] -``` - -**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 60–80 to 0.1–7.7. - -**File:** `mlx_video/models/wan/vae.py` (line 370) -**Commit:** `3da4a637` - ---- - -## Bug 5: Non-Chunked VAE Encoding - -**Symptom:** First 4–5 frames grey, then blurred version of image appears. - -**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`): -1. Encode first frame alone (1 frame) -2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks -3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input - -Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because: -- Chunked: real features from previous chunks serve as causal context -- Non-chunked: zeros serve as causal context for the start - -**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values. - -**Fix:** Implemented full chunked encoding with temporal caching: -- Added `cache_x` parameter to `CausalConv3d.__call__` -- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d` -- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks) -- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head) - -**File:** `mlx_video/models/wan/vae.py` (multiple methods) -**Commit:** `b6a94c4c` - ---- - -## Verified Correct Components - -These components were numerically verified against the reference and are **not** sources of bugs: - -| Component | Method | Max Diff | Notes | -|-----------|--------|----------|-------| -| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only | -| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent | -| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative | -| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative | -| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation | -| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match | -| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 | -| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order | -| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output | -| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly | - ---- - -## Performance Optimizations - -Applied alongside bug fixes to improve inference speed: - -### Pre-Computation (Before Diffusion Loop) -- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once -- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat -- **Attention mask precomputation**: Build padding mask once, pass via kwargs -- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing -- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync - -### Per-Step Optimizations -- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection -- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element -- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks) -- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync - ---- - -## Resolved: CFG Effectiveness (was Open Investigation) - -**Symptom:** Generated video shows the input image in frames 0–3 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical. - -**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences. - ---- - -## Reference Implementation - -The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`: - -| File | Contents | -|------|----------| -| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) | -| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) | -| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding | -| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler | -| `wan/configs/wan_i2v_A14B.py` | Model configuration | - -Key structural differences between reference and our implementation: -- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2 -- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype -- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear` -- Reference casts timesteps to `int64`; we keep as float (diff < 1.0) - ---- - -## Useful Diagnostic Commands - -### Run I2V-14B generation -```bash -python -m mlx_video.generate_wan \ - --prompt "A woman smiles at camera" \ - --image start.png \ - --model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \ - --num-frames 17 --steps 40 \ - --height 384 --width 640 \ - --output output_i2v.mp4 -``` - -### Check VAE encoder output -```python -import mlx.core as mx, numpy as np -from mlx_video.models.wan.vae import WanVAE -# Load VAE and encode an image -latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat] -for t in range(latents.shape[2]): - frame = np.array(latents[0, :, t]) - print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}") -``` - -### Analyze video frame quality -```python -import cv2, numpy as np -cap = cv2.VideoCapture("output.mp4") -while True: - ret, frame = cap.read() - if not ret: break - # Checkerboard metric: high values indicate patch-boundary artifacts - checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean() - print(f"std={frame.std():.1f} checker={checker:.1f}") -``` - -### Compare weights between PyTorch and MLX -```python -import torch, mlx.core as mx, numpy as np -pt = torch.load("model.pt", map_location="cpu") -mlx_w = mx.load("model.safetensors") -for key in sorted(pt.keys()): - if key in mlx_w: - diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max() - if diff > 0.01: - print(f"LARGE DIFF {key}: {diff:.6f}") -``` diff --git a/mlx_video/models/wan2/docs/IMPLEMENTATION_NOTES.md b/mlx_video/models/wan2/docs/IMPLEMENTATION_NOTES.md deleted file mode 100644 index 186aabb..0000000 --- a/mlx_video/models/wan2/docs/IMPLEMENTATION_NOTES.md +++ /dev/null @@ -1,285 +0,0 @@ -# Wan2.2 MLX Implementation Notes - -> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX. - -## Architecture Overview - -Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep). - -### Key Parameters - -| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim | -|-------|-----|-------|--------|----------|-----------|------------|--------| -| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 | -| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 | -| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 | -| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 | - -### Codebase Structure (~3900 lines of Wan2.2 code) - -``` -mlx_video/ -├── generate_wan.py # 483L - Generation pipeline (T2V + I2V) -├── convert_wan.py # 564L - Weight conversion from HuggingFace -└── models/wan/ - ├── config.py # 113L - Model configs (dataclass presets) - ├── model.py # 320L - DiT model (time embed, patchify, unpatchify) - ├── transformer.py # 91L - Attention block + FFN - ├── attention.py # 211L - Self-attention + cross-attention - ├── rope.py # 100L - 3D Rotary Position Embeddings - ├── text_encoder.py # 240L - T5 encoder (UMT5-XXL) - ├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers - ├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8) - ├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16) - ├── loading.py # 154L - Model loading utilities - └── i2v_utils.py # 58L - I2V mask/preprocessing -``` - ---- - -## Critical Bugs & Fixes - -### 1. MLX Underscore Attribute Gotcha - -**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading. - -**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree. - -**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX. - -### 2. Patchify Channel Ordering - -**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output. - -**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout. - -**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered. - -### 3. VAE AttentionBlock Reshape - -**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output. - -**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention. - -### 4. RMS Norm vs L2 Norm - -**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion. - -**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`. - -**Lesson**: Don't trust class names in reference code — read the actual computation. - -### 5. Video Codec Green Output - -**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video. - -**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames. - ---- - -## Precision & Dtype Flow - -### The bfloat16 Autocast Pattern - -The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually: - -| Operation | Official (PyTorch) | MLX Implementation | -|---|---|---| -| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation | -| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` | -| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) | -| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE | -| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` | -| Residual stream | float32 | float32 (no cast) | - -**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression. - -**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.01–0.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations. - -### T5 Encoder Precision - -The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible. - -### VAE Decoder Precision - -VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames. - ---- - -## Scheduler Implementation Details - -### Three Schedulers: Euler, DPM++ 2M, UniPC - -All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean). - -**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging. - -**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1). - -**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**. - -### UniPC Corrector: Must Be Enabled - -**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`). - -**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage. - -### UniPC Corrector Coefficients - -The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~6–13% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients. - -### Sigma Schedule - -```python -# Flow-matching sigma schedule with shift -sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) -sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) -``` - -Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0. - ---- - -## Image-to-Video (I2V) Pipelines - -Wan2.2 supports two distinct I2V approaches: - -### TI2V-5B: Per-Token Timestep Masking - -I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep: - -```python -# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest -t_tokens = mask_tokens * current_timestep # first-frame → t=0 -``` - -The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels. - -#### Mask Re-application - -After each scheduler step, the first-frame latent is re-injected to prevent drift: - -```python -latents = (1.0 - mask) * z_img + mask * latents -``` - -#### VAE Encoder Temporal Downsample Order - -The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`: -- Stage 0: Spatial-only downsampling -- Stages 1–2: Spatial + temporal downsampling - -This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths. - -### I2V-14B: Channel Concatenation - -The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor: - -1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]` -2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]` -3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])` → `[20, T_lat, H_lat, W_lat]` -4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36` - -Key differences from TI2V-5B: -- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE -- Requires the **VAE encoder** (for encoding the reference image) -- Uses **scalar timesteps** (same as T2V) — no per-token masking -- **Dual model** pipeline with boundary=0.900 -- Both conditional and unconditional predictions receive the same `y` tensor - ---- - -## Dimension Constraints - -### Patchify Alignment - -Video dimensions must be divisible by `patch_size × vae_stride`: -- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels -- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels - -Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**. - -### Frame Count - -Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4. - ---- - -## Performance Optimizations - -### Batched CFG - -Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass: - -```python -preds = model([latents, latents], t=t_batch, context=context_cfg, ...) -noise_pred_cond, noise_pred_uncond = preds[0], preds[1] -``` - -**Result**: ~40% speedup by amortizing attention overhead. - -### Precomputed Text Embeddings & Cross-Attention KV Cache - -Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation. - -### Memory Management in Diffusion Loop - -```python -# Release temporaries before eval to free memory for graph execution -del noise_pred_cond, noise_pred_uncond, noise_pred, preds -mx.eval(latents) -``` - -MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution. - ---- - -## Weight Conversion - -### Key Mapping Patterns - -The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms: - -1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)` -2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`) -3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key - -### Dual-Model Splitting - -The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace. - ---- - -## Testing Strategy - -332 tests across 10 files, all running in ~5 seconds: - -| File | Focus | -|------|-------| -| test_wan_config.py | Config presets, field validation | -| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast | -| test_wan_transformer.py | FFN, attention block, float32 modulation | -| test_wan_model.py | Full DiT forward pass, per-token timesteps | -| test_wan_t5.py | T5 encoder layers and full encoding | -| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder | -| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence | -| test_wan_convert.py | Weight sanitization and conversion | -| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment | -| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction | - -Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise. - ---- - -## Known Issues - -### I2V Quality Degradation - -Frames 2–13 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes: -- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU) -- MLX-specific attention precision behavior -- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts - -### Chinese Negative Prompt - -The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches. diff --git a/mlx_video/models/wan2/generate.py b/mlx_video/models/wan2/generate.py index 789a78d..f173d9a 100644 --- a/mlx_video/models/wan2/generate.py +++ b/mlx_video/models/wan2/generate.py @@ -11,15 +11,15 @@ import mlx.core as mx import numpy as np from tqdm import tqdm -from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image -from mlx_video.models.wan.loading import ( +from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image +from mlx_video.models.wan2.utils import ( encode_text, load_t5_encoder, load_vae_decoder, load_vae_encoder, load_wan_model, ) -from mlx_video.models.wan.postprocess import save_video +from mlx_video.models.wan2.postprocess import save_video class Colors: @@ -121,8 +121,8 @@ def generate_video( """ import json - from mlx_video.models.wan.config import WanModelConfig - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -729,7 +729,7 @@ def generate_video( # the CausalConv3d zero-padding artifacts fall on the prefix (which we crop). # This gives the first real frame a full temporal receptive field of real data. # Select tiling configuration - from mlx_video.models.ltx.video_vae.tiling import TilingConfig + from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig if tiling == "none": tiling_config = None @@ -767,7 +767,7 @@ def generate_video( ) if is_wan22_vae: - from mlx_video.models.wan.vae22 import denormalize_latents + from mlx_video.models.wan2.vae22 import denormalize_latents # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) z = latents.transpose(1, 2, 3, 0)[None] diff --git a/mlx_video/models/wan2/tiling.py b/mlx_video/models/wan2/tiling.py index 9023c8d..1d144b7 100644 --- a/mlx_video/models/wan2/tiling.py +++ b/mlx_video/models/wan2/tiling.py @@ -6,7 +6,7 @@ for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale output frames rather than LTX's 1+(T-1)*scale mapping). # TODO: This function can be refactored to consolidate with -# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the +# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the # causal_temporal generalisation is accepted upstream. """ @@ -14,7 +14,7 @@ from typing import Callable, Optional import mlx.core as mx -from mlx_video.models.ltx.video_vae.tiling import ( +from mlx_video.models.ltx_2.video_vae.tiling import ( SpatialTilingConfig, TemporalTilingConfig, TilingConfig, diff --git a/mlx_video/models/wan2/loading.py b/mlx_video/models/wan2/utils.py similarity index 90% rename from mlx_video/models/wan2/loading.py rename to mlx_video/models/wan2/utils.py index e83b0de..6c9be4f 100644 --- a/mlx_video/models/wan2/loading.py +++ b/mlx_video/models/wan2/utils.py @@ -21,12 +21,12 @@ def load_wan_model( If provided, creates QuantizedLinear stubs before loading. loras: Optional list of (lora_path, strength) tuples to apply. """ - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel model = WanModel(config) if quantization: - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate nn.quantize( model, @@ -42,7 +42,7 @@ def load_wan_model( if quantization: # Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear. # Non-LoRA layers stay 4-bit. Zero per-step overhead. - from mlx_video.convert_wan import _load_lora_configs + from mlx_video.models.wan2.convert import _load_lora_configs from mlx_video.lora import apply_loras_to_model model.load_weights(list(weights.items()), strict=False) @@ -53,7 +53,7 @@ def load_wan_model( return model else: # Weight merging: fold LoRA into bf16 weights before loading - from mlx_video.convert_wan import load_and_apply_loras + from mlx_video.models.wan2.convert import load_and_apply_loras weights = load_and_apply_loras(dict(weights), loras) @@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config): only runs once per generation, so performance impact is negligible. This matches the official which computes softmax in float32 explicitly. """ - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=config.t5_vocab_size, @@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None): is_wan22 = config is not None and config.vae_z_dim == 48 if is_wan22: - from mlx_video.models.wan.vae22 import Wan22VAEDecoder + from mlx_video.models.wan2.vae22 import Wan22VAEDecoder vae = Wan22VAEDecoder(z_dim=48) else: - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae = WanVAE(z_dim=16) @@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None): For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True. """ if config is not None and config.vae_z_dim == 16: - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae = WanVAE(z_dim=16, encoder=True) else: - from mlx_video.models.wan.vae22 import Wan22VAEEncoder + from mlx_video.models.wan2.vae22 import Wan22VAEEncoder vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48) diff --git a/mlx_video/models/wan2/vae.py b/mlx_video/models/wan2/vae.py index ecc539a..b713ac7 100644 --- a/mlx_video/models/wan2/vae.py +++ b/mlx_video/models/wan2/vae.py @@ -589,7 +589,7 @@ class WanVAE(nn.Module): Returns: Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] """ - from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() diff --git a/mlx_video/models/wan2/vae22.py b/mlx_video/models/wan2/vae22.py index 4d26b95..0b99aef 100644 --- a/mlx_video/models/wan2/vae22.py +++ b/mlx_video/models/wan2/vae22.py @@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module): Returns: video: [B, T', H', W', 3] decoded RGB in [-1, 1] """ - from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() diff --git a/pyproject.toml b/pyproject.toml index 916f398..bf535c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,8 @@ Repository = "https://github.com/Blaizzy/mlx-video" Issues = "https://github.com/Blaizzy/mlx-video/issues" [project.scripts] -"mlx_video.generate" = "mlx_video.generate:main" -"mlx_video.generate_wan" = "mlx_video.generate_wan:main" +"mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main" +"mlx_video.wan2.generate" = "mlx_video.models.wan2.generate:main" [tool.setuptools.packages.find] include = ["mlx_video*"] diff --git a/tests/test_wan_attention.py b/tests/test_wan_attention.py index 700bb61..e94851e 100644 --- a/tests/test_wan_attention.py +++ b/tests/test_wan_attention.py @@ -12,14 +12,14 @@ class TestRoPE: """Tests for 3-way factorized RoPE.""" def test_rope_params_shape(self): - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params freqs = rope_params(1024, 64) mx.eval(freqs) assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2] def test_rope_params_different_dims(self): - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params for dim in [32, 64, 128]: freqs = rope_params(512, dim) @@ -27,7 +27,7 @@ class TestRoPE: assert freqs.shape == (512, dim // 2, 2) def test_rope_params_cos_sin_range(self): - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params freqs = rope_params(256, 64) mx.eval(freqs) @@ -38,7 +38,7 @@ class TestRoPE: def test_rope_params_position_zero(self): """At position 0, cos should be 1 and sin should be 0.""" - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params freqs = rope_params(10, 64) mx.eval(freqs) @@ -46,7 +46,7 @@ class TestRoPE: np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6) def test_rope_apply_output_shape(self): - from mlx_video.models.wan.rope import rope_apply, rope_params + from mlx_video.models.wan2.rope import rope_apply, rope_params B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim x = mx.random.normal((B, L, N, D)) @@ -58,7 +58,7 @@ class TestRoPE: def test_rope_apply_preserves_norm(self): """RoPE rotation should preserve vector norms.""" - from mlx_video.models.wan.rope import rope_apply, rope_params + from mlx_video.models.wan2.rope import rope_apply, rope_params B, N, D = 1, 2, 16 F, H, W = 2, 3, 4 @@ -79,7 +79,7 @@ class TestRoPE: def test_rope_apply_with_padding(self): """When seq_len < L, extra tokens should be preserved unchanged.""" - from mlx_video.models.wan.rope import rope_apply, rope_params + from mlx_video.models.wan2.rope import rope_apply, rope_params B, N, D = 1, 2, 16 F, H, W = 2, 2, 2 @@ -100,7 +100,7 @@ class TestRoPE: def test_rope_apply_batch(self): """Test with batch_size > 1 and different grid sizes.""" - from mlx_video.models.wan.rope import rope_apply, rope_params + from mlx_video.models.wan2.rope import rope_apply, rope_params B, N, D = 2, 2, 16 grids = [(2, 3, 4), (2, 3, 4)] @@ -132,7 +132,7 @@ class TestRoPE: class TestWanRMSNorm: def test_output_shape(self): - from mlx_video.models.wan.attention import WanRMSNorm + from mlx_video.models.wan2.attention import WanRMSNorm norm = WanRMSNorm(64) x = mx.random.normal((2, 10, 64)) @@ -142,7 +142,7 @@ class TestWanRMSNorm: def test_zero_mean_variance(self): """RMS norm should make RMS ≈ 1 before scaling.""" - from mlx_video.models.wan.attention import WanRMSNorm + from mlx_video.models.wan2.attention import WanRMSNorm norm = WanRMSNorm(64) x = mx.random.normal((1, 5, 64)) * 10.0 @@ -156,7 +156,7 @@ class TestWanRMSNorm: def test_dtype_preservation(self): """RMSNorm weight is float32, so output is promoted to float32.""" - from mlx_video.models.wan.attention import WanRMSNorm + from mlx_video.models.wan2.attention import WanRMSNorm norm = WanRMSNorm(32) x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16) @@ -168,7 +168,7 @@ class TestWanRMSNorm: class TestWanLayerNorm: def test_output_shape(self): - from mlx_video.models.wan.attention import WanLayerNorm + from mlx_video.models.wan2.attention import WanLayerNorm norm = WanLayerNorm(64) x = mx.random.normal((2, 10, 64)) @@ -177,7 +177,7 @@ class TestWanLayerNorm: assert out.shape == (2, 10, 64) def test_without_affine(self): - from mlx_video.models.wan.attention import WanLayerNorm + from mlx_video.models.wan2.attention import WanLayerNorm norm = WanLayerNorm(64, elementwise_affine=False) x = mx.random.normal((1, 4, 64)) @@ -190,7 +190,7 @@ class TestWanLayerNorm: np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1) def test_with_affine(self): - from mlx_video.models.wan.attention import WanLayerNorm + from mlx_video.models.wan2.attention import WanLayerNorm norm = WanLayerNorm(32, elementwise_affine=True) assert hasattr(norm, "weight") @@ -208,8 +208,8 @@ class TestWanSelfAttention: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan.attention import WanSelfAttention - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) B, L = 1, 24 @@ -221,14 +221,14 @@ class TestWanSelfAttention: assert out.shape == (B, L, self.dim) def test_with_qk_norm(self): - from mlx_video.models.wan.attention import WanSelfAttention + from mlx_video.models.wan2.attention import WanSelfAttention attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True) assert attn.norm_q is not None assert attn.norm_k is not None def test_without_qk_norm(self): - from mlx_video.models.wan.attention import WanSelfAttention + from mlx_video.models.wan2.attention import WanSelfAttention attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) assert attn.norm_q is None @@ -236,8 +236,8 @@ class TestWanSelfAttention: def test_masking(self): """Test that masking works: shorter seq_lens should mask later tokens.""" - from mlx_video.models.wan.attention import WanSelfAttention - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False) B, L = 1, 24 @@ -262,7 +262,7 @@ class TestWanCrossAttention: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan.attention import WanCrossAttention + from mlx_video.models.wan2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 24, 16 @@ -273,7 +273,7 @@ class TestWanCrossAttention: assert out.shape == (B, L_q, self.dim) def test_with_context_mask(self): - from mlx_video.models.wan.attention import WanCrossAttention + from mlx_video.models.wan2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) B, L_q, L_kv = 1, 12, 16 @@ -311,8 +311,8 @@ class TestBFloat16Autocast: def test_self_attn_casts_to_weight_dtype(self): """Self-attention should cast input to weight dtype for QKV projections.""" - from mlx_video.models.wan.attention import WanSelfAttention - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -326,7 +326,7 @@ class TestBFloat16Autocast: def test_cross_attn_casts_to_weight_dtype(self): """Cross-attention should cast input to weight dtype.""" - from mlx_video.models.wan.attention import WanCrossAttention + from mlx_video.models.wan2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -340,7 +340,7 @@ class TestBFloat16Autocast: def test_cross_attn_kv_cache_uses_weight_dtype(self): """prepare_kv should cast context to weight dtype.""" - from mlx_video.models.wan.attention import WanCrossAttention + from mlx_video.models.wan2.attention import WanCrossAttention attn = WanCrossAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -353,7 +353,7 @@ class TestBFloat16Autocast: def test_ffn_casts_to_weight_dtype(self): """FFN should cast input to weight dtype for linear layers.""" - from mlx_video.models.wan.transformer import WanFFN + from mlx_video.models.wan2.transformer import WanFFN ffn = WanFFN(self.dim, 128) ffn.update(self._to_bf16(ffn.parameters())) @@ -366,8 +366,8 @@ class TestBFloat16Autocast: def test_self_attn_rope_in_float32(self): """RoPE should be applied in float32 for precision, even with bf16 weights.""" - from mlx_video.models.wan.attention import WanSelfAttention - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.attention import WanSelfAttention + from mlx_video.models.wan2.rope import rope_params attn = WanSelfAttention(self.dim, self.num_heads) attn.update(self._to_bf16(attn.parameters())) @@ -381,8 +381,8 @@ class TestBFloat16Autocast: def test_block_float32_residual_with_bf16_weights(self): """Full block: residual stream stays float32, matmuls use bf16 weights.""" - from mlx_video.models.wan.rope import rope_params - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True) block.update(self._to_bf16(block.parameters())) diff --git a/tests/test_wan_config.py b/tests/test_wan_config.py index 2ffddcf..b37c722 100644 --- a/tests/test_wan_config.py +++ b/tests/test_wan_config.py @@ -10,7 +10,7 @@ class TestWanModelConfig: """Tests for WanModelConfig dataclass.""" def test_default_values(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() assert config.dim == 5120 @@ -32,13 +32,13 @@ class TestWanModelConfig: assert config.text_len == 512 def test_head_dim_property(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() assert config.head_dim == 128 # 5120 // 40 def test_to_dict_roundtrip(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() d = config.to_dict() @@ -48,7 +48,7 @@ class TestWanModelConfig: assert d["boundary"] == 0.875 def test_t5_config_values(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() assert config.t5_vocab_size == 256384 @@ -69,7 +69,7 @@ class TestWan21Config: """Tests for Wan2.1 config presets.""" def test_wan21_14b_factory(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() assert config.model_version == "2.1" @@ -85,7 +85,7 @@ class TestWan21Config: assert config.boundary == 0.0 def test_wan21_1_3b_factory(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() assert config.model_version == "2.1" @@ -98,7 +98,7 @@ class TestWan21Config: assert config.sample_guide_scale == 5.0 def test_wan22_14b_factory(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan22_t2v_14b() assert config.model_version == "2.2" @@ -110,7 +110,7 @@ class TestWan21Config: assert config.boundary == 0.875 def test_wan21_config_to_dict(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() @@ -119,7 +119,7 @@ class TestWan21Config: assert d["sample_guide_scale"] == 5.0 def test_wan21_1_3b_config_to_dict(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() d = config.to_dict() @@ -128,7 +128,7 @@ class TestWan21Config: def test_default_config_is_wan22(self): """Default WanModelConfig() should be Wan2.2 14B.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() assert config.model_version == "2.2" diff --git a/tests/test_wan_convert.py b/tests/test_wan_convert.py index 69a8dd3..0e5e48d 100644 --- a/tests/test_wan_convert.py +++ b/tests/test_wan_convert.py @@ -11,7 +11,7 @@ import mlx.core as mx class TestSanitizeTransformerWeights: def test_patch_embedding_reshape(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), @@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights: assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2) def test_text_embedding_rename(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "text_embedding.0.weight": mx.zeros((64, 32)), @@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights: assert "text_embedding_1.bias" in out def test_time_embedding_rename(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "time_embedding.0.weight": mx.zeros((64, 32)), @@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights: assert "time_embedding_1.weight" in out def test_time_projection_rename(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "time_projection.1.weight": mx.zeros((384, 64)), @@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights: assert "time_projection.bias" in out def test_ffn_rename(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "blocks.0.ffn.0.weight": mx.zeros((128, 64)), @@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights: assert "blocks.0.ffn.fc2.bias" in out def test_freqs_skipped(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "freqs": mx.zeros((1024, 64, 2)), @@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights: assert "blocks.0.norm1.weight" in out def test_passthrough_keys(self): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "blocks.0.self_attn.q.weight": mx.zeros((64, 64)), @@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights: assert key in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.convert_wan import sanitize_wan_transformer_weights + from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights weights = { "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), @@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights: "head.head.weight": mx.zeros((64, 64)), "freqs": mx.zeros((1024, 64, 2)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): sanitize_wan_transformer_weights(weights) assert "Unconsumed" not in caplog.text class TestSanitizeT5Weights: def test_gate_rename(self): - from mlx_video.convert_wan import sanitize_wan_t5_weights + from mlx_video.models.wan2.convert import sanitize_wan_t5_weights weights = { "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), @@ -139,7 +139,7 @@ class TestSanitizeT5Weights: assert "blocks.0.ffn.fc2.weight" in out def test_passthrough(self): - from mlx_video.convert_wan import sanitize_wan_t5_weights + from mlx_video.models.wan2.convert import sanitize_wan_t5_weights weights = { "token_embedding.weight": mx.zeros((100, 64)), @@ -151,7 +151,7 @@ class TestSanitizeT5Weights: assert key in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.convert_wan import sanitize_wan_t5_weights + from mlx_video.models.wan2.convert import sanitize_wan_t5_weights weights = { "token_embedding.weight": mx.zeros((100, 64)), @@ -160,14 +160,14 @@ class TestSanitizeT5Weights: "blocks.0.ffn.fc2.weight": mx.zeros((64, 128)), "norm.weight": mx.zeros((64,)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): sanitize_wan_t5_weights(weights) assert "Unconsumed" not in caplog.text class TestSanitizeVAEWeights: def test_conv3d_transpose(self): - from mlx_video.convert_wan import sanitize_wan_vae_weights + from mlx_video.models.wan2.convert import sanitize_wan_vae_weights weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W] @@ -176,7 +176,7 @@ class TestSanitizeVAEWeights: assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I] def test_conv2d_transpose(self): - from mlx_video.convert_wan import sanitize_wan_vae_weights + from mlx_video.models.wan2.convert import sanitize_wan_vae_weights weights = { "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W] @@ -185,7 +185,7 @@ class TestSanitizeVAEWeights: assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I] def test_non_conv_passthrough(self): - from mlx_video.convert_wan import sanitize_wan_vae_weights + from mlx_video.models.wan2.convert import sanitize_wan_vae_weights weights = { "decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose @@ -196,7 +196,7 @@ class TestSanitizeVAEWeights: assert out["decoder.bias"].shape == (16,) def test_mixed_weights(self): - from mlx_video.convert_wan import sanitize_wan_vae_weights + from mlx_video.models.wan2.convert import sanitize_wan_vae_weights weights = { "conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D @@ -211,7 +211,7 @@ class TestSanitizeVAEWeights: assert out["norm.weight"].shape == (8,) def test_no_unconsumed_keys(self, caplog): - from mlx_video.convert_wan import sanitize_wan_vae_weights + from mlx_video.models.wan2.convert import sanitize_wan_vae_weights weights = { "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), @@ -219,7 +219,7 @@ class TestSanitizeVAEWeights: "decoder.norm.weight": mx.zeros((64,)), "decoder.bias": mx.zeros((16,)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"): sanitize_wan_vae_weights(weights) assert "Unconsumed" not in caplog.text @@ -256,7 +256,7 @@ class TestWan21Convert: def test_wan21_config_saved_correctly(self): """Verify config dict has correct fields for Wan2.1.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() d = config.to_dict() @@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights: """Tests for sanitize_wan22_vae_weights with include_encoder.""" def test_exclude_encoder_by_default(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), @@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights: assert not any("encoder" in k or k.startswith("conv1") for k in out) def test_include_encoder(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), @@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights: assert "conv2.weight" in out def test_no_unconsumed_keys(self, caplog): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): sanitize_wan22_vae_weights(weights, include_encoder=True) assert "Unconsumed" not in caplog.text def test_no_unconsumed_keys_exclude_encoder(self, caplog): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), } - with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"): + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"): sanitize_wan22_vae_weights(weights, include_encoder=False) assert "Unconsumed" not in caplog.text diff --git a/tests/test_wan_generate.py b/tests/test_wan_generate.py index e42713c..f4d1682 100644 --- a/tests/test_wan_generate.py +++ b/tests/test_wan_generate.py @@ -14,8 +14,8 @@ class TestEndToEnd: def test_tiny_model_denoise_step(self): """Simulate one denoising step with tiny model.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(42) config = _make_tiny_config() @@ -43,8 +43,8 @@ class TestEndToEnd: def test_tiny_model_full_loop(self): """Run a complete (tiny) diffusion loop.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(123) config = _make_tiny_config() @@ -81,7 +81,7 @@ class TestI2VMask: """Tests for _build_i2v_mask.""" def test_mask_shapes(self): - from mlx_video.generate_wan import _build_i2v_mask + from mlx_video.models.wan2.generate import _build_i2v_mask z_shape = (48, 5, 4, 4) # C, T, H, W patch_size = (1, 2, 2) @@ -91,7 +91,7 @@ class TestI2VMask: assert mask_tokens.shape == (1, 20) def test_first_frame_zero(self): - from mlx_video.generate_wan import _build_i2v_mask + from mlx_video.models.wan2.generate import _build_i2v_mask z_shape = (48, 5, 4, 4) mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2)) @@ -111,7 +111,7 @@ class TestI2VMaskAlignment: def test_mask_with_ti2v_dimensions(self): """Mask should work with TI2V-5B typical dimensions.""" - from mlx_video.generate_wan import _build_i2v_mask + from mlx_video.models.wan2.generate import _build_i2v_mask # TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2) # 704x1280 → latent 44x80, t_latent=21 for 81 frames @@ -132,7 +132,7 @@ class TestI2VMaskAlignment: def test_mask_per_token_timestep(self): """Per-token timesteps: first-frame tokens get t=0, rest get t=sigma.""" - from mlx_video.generate_wan import _build_i2v_mask + from mlx_video.models.wan2.generate import _build_i2v_mask z_shape = (4, 3, 4, 4) patch_size = (1, 2, 2) @@ -201,7 +201,7 @@ class TestDimensionAlignment: def test_patchify_valid_after_alignment(self): """After alignment, patchify should succeed without reshape errors.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -235,7 +235,7 @@ class TestDimensionAlignment: def test_alignment_with_ti2v_config(self): """TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan22_ti2v_5b() align_h = config.patch_size[1] * config.vae_stride[1] diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index 112e7cc..b2a4bab 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -23,7 +23,7 @@ class TestI2VConfig: """Test I2V-14B config preset.""" def test_wan22_i2v_14b_preset(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan22_i2v_14b() assert config.model_type == "i2v" @@ -39,7 +39,7 @@ class TestI2VConfig: assert config.vae_z_dim == 16 def test_i2v_vs_t2v_differences(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig i2v = WanModelConfig.wan22_i2v_14b() t2v = WanModelConfig.wan22_t2v_14b() @@ -51,7 +51,7 @@ class TestI2VConfig: assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0 def test_i2v_serialization_roundtrip(self): - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan22_i2v_14b() d = config.to_dict() @@ -66,7 +66,7 @@ class TestModelYParameter: def test_forward_without_y(self): """Standard T2V forward pass (no y) still works.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -85,7 +85,7 @@ class TestModelYParameter: def test_forward_with_y(self): """I2V forward pass with y channel concatenation.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -108,7 +108,7 @@ class TestModelYParameter: def test_y_none_is_noop(self): """Passing y=None should be identical to not passing y.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -129,7 +129,7 @@ class TestModelYParameter: def test_batched_cfg_with_y(self): """Batched CFG (B=2) with y should work.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -158,7 +158,7 @@ class TestVAEEncoder: """Test Wan2.1 VAE encoder.""" def test_encoder3d_instantiation(self): - from mlx_video.models.wan.vae import Encoder3d + from mlx_video.models.wan2.vae import Encoder3d enc = Encoder3d( dim=32, z_dim=8 @@ -169,7 +169,7 @@ class TestVAEEncoder: def test_encoder3d_output_shape(self): """Encoder should downsample spatially by 8x and temporally by 4x.""" - from mlx_video.models.wan.vae import Encoder3d + from mlx_video.models.wan2.vae import Encoder3d enc = Encoder3d(dim=32, z_dim=8) # Random input: [B=1, 3, T=5, H=32, W=32] @@ -186,7 +186,7 @@ class TestVAEEncoder: def test_wan_vae_encode(self): """WanVAE with encoder=True should produce normalized latents.""" - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae = WanVAE(z_dim=16, encoder=True) # Input: [B=1, 3, T=5, H=32, W=32] @@ -198,7 +198,7 @@ class TestVAEEncoder: def test_wan_vae_encoder_flag(self): """WanVAE without encoder flag should not have encoder attribute.""" - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae_no_enc = WanVAE(z_dim=4, encoder=False) assert not hasattr(vae_no_enc, "encoder") @@ -211,7 +211,7 @@ class TestResampleDownsample: """Test downsample modes in Resample.""" def test_downsample2d(self): - from mlx_video.models.wan.vae import Resample + from mlx_video.models.wan2.vae import Resample r = Resample(dim=16, mode="downsample2d") x = mx.random.normal((1, 16, 2, 8, 8)) @@ -221,7 +221,7 @@ class TestResampleDownsample: assert out.shape == (1, 16, 2, 4, 4) def test_downsample3d(self): - from mlx_video.models.wan.vae import Resample + from mlx_video.models.wan2.vae import Resample r = Resample(dim=16, mode="downsample3d") x = mx.random.normal((1, 16, 4, 8, 8)) @@ -231,7 +231,7 @@ class TestResampleDownsample: assert out.shape == (1, 16, 2, 4, 4) def test_upsample2d_still_works(self): - from mlx_video.models.wan.vae import Resample + from mlx_video.models.wan2.vae import Resample r = Resample(dim=16, mode="upsample2d") x = mx.random.normal((1, 16, 2, 4, 4)) @@ -240,7 +240,7 @@ class TestResampleDownsample: assert out.shape == (1, 8, 2, 8, 8) def test_upsample3d_still_works(self): - from mlx_video.models.wan.vae import Resample + from mlx_video.models.wan2.vae import Resample r = Resample(dim=16, mode="upsample3d") x = mx.random.normal((1, 16, 2, 4, 4)) @@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline: def test_full_i2v_pipeline(self): """End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.vae import WanVAE mx.random.seed(0) @@ -410,8 +410,8 @@ class TestDualModelSwitching: def test_model_selection_by_timestep(self): """Verify high_noise model used for timesteps >= boundary, low_noise otherwise.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(1) config = _make_tiny_i2v_config() @@ -485,8 +485,8 @@ class TestDualModelSwitching: def test_guide_scale_tuple_applied_per_model(self): """Verify (low_gs, high_gs) tuple applies different scales per model.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(2) config = _make_tiny_i2v_config() @@ -545,8 +545,8 @@ class TestDualModelSwitching: def test_single_model_fallback_with_tuple_guide_scale(self): """When dual_model=False, guide_scale tuple should use first element.""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(3) config = _make_tiny_config() diff --git a/tests/test_wan_lora.py b/tests/test_wan_lora.py index 7dc8c4b..1c4b84c 100644 --- a/tests/test_wan_lora.py +++ b/tests/test_wan_lora.py @@ -331,7 +331,7 @@ class TestEndToEnd: """End-to-end LoRA loading and application.""" def test_load_and_apply_loras(self): - from mlx_video.convert_wan import load_and_apply_loras + from mlx_video.models.wan2.convert import load_and_apply_loras with tempfile.TemporaryDirectory() as tmp: # Create mock LoRA safetensors diff --git a/tests/test_wan_model.py b/tests/test_wan_model.py index 96c564a..650e0e5 100644 --- a/tests/test_wan_model.py +++ b/tests/test_wan_model.py @@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config class TestSinusoidalEmbedding: def test_output_shape(self): - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos = mx.arange(10).astype(mx.float32) emb = sinusoidal_embedding_1d(256, pos) @@ -21,7 +21,7 @@ class TestSinusoidalEmbedding: def test_position_zero(self): """Position 0 should have cos=1 for all dims and sin=0.""" - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos = mx.array([0.0]) emb = sinusoidal_embedding_1d(64, pos) @@ -33,7 +33,7 @@ class TestSinusoidalEmbedding: np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5) def test_different_positions_differ(self): - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 999.0]) emb = sinusoidal_embedding_1d(128, pos) @@ -50,7 +50,7 @@ class TestSinusoidalEmbedding: class TestHead: def test_output_shape(self): - from mlx_video.models.wan.model import Head + from mlx_video.models.wan2.model import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) B, L = 1, 24 @@ -62,7 +62,7 @@ class TestHead: assert out.shape == (B, L, expected_proj_dim) def test_modulation_shape(self): - from mlx_video.models.wan.model import Head + from mlx_video.models.wan2.model import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) assert head.modulation.shape == (1, 2, 64) @@ -78,7 +78,7 @@ class TestWanModel: mx.random.seed(42) def test_instantiation(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -86,7 +86,7 @@ class TestWanModel: assert num_params > 0 def test_patchify_shape(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -99,7 +99,7 @@ class TestWanModel: assert patches.shape == (1, 1 * 2 * 2, config.dim) def test_patchify_various_sizes(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -115,7 +115,7 @@ class TestWanModel: def test_unpatchify_inverse(self): """Patchify then unpatchify should reconstruct original spatial dims.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -131,7 +131,7 @@ class TestWanModel: assert out[0].shape == (config.out_dim, F, H, W) def test_forward_pass(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -149,7 +149,7 @@ class TestWanModel: assert out[0].shape == (C, F, H, W) def test_forward_batch(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -171,7 +171,7 @@ class TestWanModel: assert o.shape == (C, F, H, W) def test_output_is_float32(self): - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -200,7 +200,7 @@ class TestWan21Model: def _make_tiny_wan21_config(self): """Create a tiny config mimicking Wan2.1 (single model).""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() # Override to tiny values @@ -217,7 +217,7 @@ class TestWan21Model: def _make_tiny_wan21_1_3b_config(self): """Create a tiny config mimicking Wan2.1 1.3B.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() # Override to tiny values (preserve 1.3B head structure: 12 heads) @@ -234,7 +234,7 @@ class TestWan21Model: def test_wan21_tiny_model_forward(self): """Forward pass with Wan2.1 tiny config.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = self._make_tiny_wan21_config() model = WanModel(config) @@ -252,7 +252,7 @@ class TestWan21Model: def test_wan21_1_3b_tiny_model_forward(self): """Forward pass with Wan2.1 1.3B tiny config.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = self._make_tiny_wan21_1_3b_config() model = WanModel(config) @@ -270,8 +270,8 @@ class TestWan21Model: def test_wan21_single_model_loop(self): """Full diffusion loop with single model (Wan2.1 style).""" - from mlx_video.models.wan.model import WanModel - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler config = self._make_tiny_wan21_config() model = WanModel(config) @@ -305,7 +305,7 @@ class TestWan21Model: def test_wan21_vs_wan22_config_differences(self): """Verify key differences between Wan2.1 and Wan2.2 configs.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig c21 = WanModelConfig.wan21_t2v_14b() c22 = WanModelConfig.wan22_t2v_14b() @@ -333,21 +333,21 @@ class TestPerTokenTimestep: """Tests for per-token sinusoidal embedding.""" def test_1d_unchanged(self): - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 500.0]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (3, 256) def test_2d_per_token(self): - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (2, 3, 256) def test_consistency(self): - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d pos_1d = mx.array([0.0, 100.0]) emb_1d = sinusoidal_embedding_1d(256, pos_1d) diff --git a/tests/test_wan_quantization.py b/tests/test_wan_quantization.py index 5ec7355..1eb9622 100644 --- a/tests/test_wan_quantization.py +++ b/tests/test_wan_quantization.py @@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config class TestQuantizePredicate: def test_matches_self_attention_layers(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: @@ -23,7 +23,7 @@ class TestQuantizePredicate: assert _quantize_predicate(path, mock_linear), f"Should match {path}" def test_matches_cross_attention_layers(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for suffix in ["q", "k", "v", "o"]: @@ -31,14 +31,14 @@ class TestQuantizePredicate: assert _quantize_predicate(path, mock_linear), f"Should match {path}" def test_matches_ffn_layers(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) def test_rejects_embeddings(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) for path in [ @@ -49,13 +49,13 @@ class TestQuantizePredicate: assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" def test_rejects_norms(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_norm = nn.RMSNorm(64) assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) def test_rejects_non_quantizable_modules(self): - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_norm = nn.RMSNorm(64) # Even if path matches, module must have to_quantized @@ -63,7 +63,7 @@ class TestQuantizePredicate: def test_all_10_patterns_covered(self): """Verify exactly 10 layer patterns are targeted.""" - from mlx_video.convert_wan import _quantize_predicate + from mlx_video.models.wan2.convert import _quantize_predicate mock_linear = nn.Linear(64, 64) patterns = [ @@ -90,8 +90,8 @@ class TestQuantizePredicate: class TestQuantizeRoundTrip: def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): """Helper: create model, quantize, save to tmp_path.""" - from mlx_video.convert_wan import _quantize_predicate - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan2.model import WanModel model = WanModel(config) nn.quantize( @@ -116,7 +116,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) - from mlx_video.models.wan.loading import load_wan_model + from mlx_video.models.wan2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -136,7 +136,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) - from mlx_video.models.wan.loading import load_wan_model + from mlx_video.models.wan2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -151,7 +151,7 @@ class TestQuantizeRoundTrip: config = _make_tiny_config() model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) - from mlx_video.models.wan.loading import load_wan_model + from mlx_video.models.wan2.utils import load_wan_model loaded = load_wan_model( model_path, @@ -164,7 +164,7 @@ class TestQuantizeRoundTrip: def test_loading_without_quantization_flag(self, tmp_path): """Loading a non-quantized model should have standard Linear layers.""" - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -172,7 +172,7 @@ class TestQuantizeRoundTrip: model_path = tmp_path / "model.safetensors" mx.save_safetensors(str(model_path), weights_dict) - from mlx_video.models.wan.loading import load_wan_model + from mlx_video.models.wan2.utils import load_wan_model loaded = load_wan_model(model_path, config, quantization=None) @@ -187,8 +187,8 @@ class TestQuantizeRoundTrip: class TestQuantizedInference: def _make_quantized_model(self, config, bits=4): - from mlx_video.convert_wan import _quantize_predicate - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan2.model import WanModel model = WanModel(config) nn.quantize( @@ -238,8 +238,8 @@ class TestQuantizedInference: def test_quantized_output_differs_from_unquantized(self): """Sanity check: quantization should change the weights.""" - from mlx_video.convert_wan import _quantize_predicate - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_predicate + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() mx.random.seed(42) @@ -271,8 +271,8 @@ class TestQuantizedInference: class TestQuantizationConfig: def test_config_metadata_written(self, tmp_path): """Verify _quantize_saved_model writes quantization metadata to config.json.""" - from mlx_video.convert_wan import _quantize_saved_model - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_saved_model + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -295,8 +295,8 @@ class TestQuantizationConfig: assert cfg["quantization"]["group_size"] == 64 def test_config_metadata_8bit(self, tmp_path): - from mlx_video.convert_wan import _quantize_saved_model - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_saved_model + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() model = WanModel(config) @@ -316,8 +316,8 @@ class TestQuantizationConfig: def test_dual_model_quantization(self, tmp_path): """Verify dual-model quantization writes both model files.""" - from mlx_video.convert_wan import _quantize_saved_model - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.convert import _quantize_saved_model + from mlx_video.models.wan2.model import WanModel config = _make_tiny_config() diff --git a/tests/test_wan_rope_freqs.py b/tests/test_wan_rope_freqs.py index b37d7b0..5da2a5f 100644 --- a/tests/test_wan_rope_freqs.py +++ b/tests/test_wan_rope_freqs.py @@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction: def _get_model_freqs(self, dim=64, num_heads=4): """Instantiate a tiny WanModel and return its .freqs tensor.""" - from mlx_video.models.wan.config import WanModelConfig - from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan2.config import WanModelConfig + from mlx_video.models.wan2.model import WanModel config = WanModelConfig() config.dim = dim @@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction: def test_three_call_vs_single_call_differ(self): """Three separate rope_params calls must differ from single call.""" - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 # head_dim for all Wan models # Reference: three separate calls @@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction: This verifies each axis gets its own independent frequency range starting from theta^0 = 1.0 (i.e., exponent 0/dim). """ - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction: Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42). """ - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 d_h_dim = 2 * (d // 6) # 42 @@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction: axis should be 1.0 (theta^0). A single-call approach would give height starting at ~0.04 and width at ~0.002 instead of 1.0. """ - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction: def test_model_freqs_match_manual_construction(self): """WanModel.freqs should match manually constructed three-call freqs.""" - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4) d = head_dim # 16 @@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction: def test_model_freqs_14b_dimensions(self): """Verify freq dimensions for 14B-scale head_dim=128.""" - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 freqs = mx.concatenate( @@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference: """Numerically compare MLX and PyTorch frequency tables.""" import torch - from mlx_video.models.wan.rope import rope_params + from mlx_video.models.wan2.rope import rope_params d = 128 @@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs: This is the key property that was broken by the single-call bug: height/width frequencies were too low to distinguish nearby positions. """ - from mlx_video.models.wan.rope import rope_apply, rope_params + from mlx_video.models.wan2.rope import rope_apply, rope_params d = 128 freqs = mx.concatenate( @@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs: def test_precomputed_matches_online(self): """rope_precompute_cos_sin + rope_apply should match non-precomputed path.""" - from mlx_video.models.wan.rope import ( + from mlx_video.models.wan2.rope import ( rope_apply, rope_params, rope_precompute_cos_sin, diff --git a/tests/test_wan_scheduler.py b/tests/test_wan_scheduler.py index 19cdcd7..df5405c 100644 --- a/tests/test_wan_scheduler.py +++ b/tests/test_wan_scheduler.py @@ -13,7 +13,7 @@ import pytest class TestFlowMatchEulerScheduler: def test_initialization(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() assert sched.num_train_timesteps == 1000 @@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler: assert sched.sigmas is None def test_set_timesteps(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) @@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler: assert sched.sigmas.shape == (41,) # 40 steps + terminal def test_timesteps_decreasing(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(40, shift=12.0) @@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler: assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..." def test_sigmas_decreasing(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=1.0) @@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler: assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing" def test_terminal_sigma_is_zero(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(20, shift=5.0) @@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler: def test_shift_effect(self): """Larger shift should push sigmas toward higher values.""" - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched1 = FlowMatchEulerScheduler() sched2 = FlowMatchEulerScheduler() @@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler: assert mean2 > mean1, "Higher shift should push sigmas higher" def test_step_euler(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(10, shift=1.0) @@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler: ) def test_step_index_increments(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler: assert sched._step_index == 2 def test_reset(self): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler: @pytest.mark.parametrize("steps", [10, 20, 40, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(steps, shift=12.0) @@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler: def test_full_denoise_loop(self): """Run a complete denoise loop with zero velocity -> sample unchanged.""" - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler sched = FlowMatchEulerScheduler() sched.set_timesteps(5, shift=1.0) @@ -153,26 +153,26 @@ class TestComputeSigmas: """Tests for the shared _compute_sigmas helper.""" def test_length(self): - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) assert len(sigmas) == 21 # num_steps + terminal def test_terminal_zero(self): - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(10, shift=1.0) assert sigmas[-1] == 0.0 def test_starts_near_one(self): - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) # Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0) np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3) def test_decreasing(self): - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(20, shift=5.0) assert np.all(np.diff(sigmas) <= 0) @@ -185,7 +185,7 @@ class TestComputeSigmas: sigma_max/sigma_min come from the *unshifted* training schedule, and the shift is applied only once (single-shift). """ - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas steps, shift, N = 50, 5.0, 1000 sigmas = _compute_sigmas(steps, shift, N) @@ -200,7 +200,7 @@ class TestComputeSigmas: np.testing.assert_allclose(sigmas, official, atol=1e-6) def test_shift_one_is_near_linear(self): - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(10, shift=1.0) # With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule) @@ -210,7 +210,7 @@ class TestComputeSigmas: def test_all_schedulers_same_sigmas(self): """All three schedulers should produce identical sigma schedules.""" - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -229,7 +229,7 @@ class TestComputeSigmas: np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6) def test_all_schedulers_same_timesteps(self): - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -255,14 +255,14 @@ class TestComputeSigmas: class TestFlowDPMPP2MScheduler: def test_initialization(self): - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() assert sched.num_train_timesteps == 1000 assert sched.lower_order_final is True def test_set_timesteps(self): - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) @@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler: assert sched.sigmas.shape == (21,) def test_step_index_increments(self): - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler: assert sched._step_index == 2 def test_reset(self): - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler: def test_full_loop_finite(self): """Full loop with constant velocity should produce finite output.""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=1.0) @@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler: def test_first_step_is_first_order(self): """First step should use 1st-order (no prev_x0 available).""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) @@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler: def test_second_step_uses_correction(self): """After first step, DPM++ should have stored prev_x0 for correction.""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(10, shift=5.0) @@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler: def test_denoise_to_target(self): """Perfect oracle should denoise to target with any solver.""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(20, shift=5.0) @@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(steps, shift=5.0) @@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler: def test_terminal_sigma_produces_x0(self): """When sigma_next=0 the scheduler should return x0 directly.""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sched = FlowDPMPP2MScheduler() sched.set_timesteps(5, shift=1.0) @@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler: class TestFlowUniPCScheduler: def test_initialization(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() assert sched.num_train_timesteps == 1000 @@ -402,7 +402,7 @@ class TestFlowUniPCScheduler: assert sched.lower_order_final is True def test_set_timesteps(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(30, shift=12.0) @@ -411,7 +411,7 @@ class TestFlowUniPCScheduler: assert sched.sigmas.shape == (31,) def test_step_index_increments(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) @@ -422,7 +422,7 @@ class TestFlowUniPCScheduler: assert sched._step_index == 1 def test_reset(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(5, shift=1.0) @@ -435,7 +435,7 @@ class TestFlowUniPCScheduler: assert all(m is None for m in sched._model_outputs) def test_full_loop_finite(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(10, shift=1.0) @@ -448,7 +448,7 @@ class TestFlowUniPCScheduler: def test_corrector_not_applied_first_step(self): """First step should skip the corrector (no history).""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -462,7 +462,7 @@ class TestFlowUniPCScheduler: def test_corrector_applied_after_first_step(self): """Steps after the first should use the corrector when enabled.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -475,7 +475,7 @@ class TestFlowUniPCScheduler: assert sched._lower_order_nums >= 2 def test_denoise_to_target(self): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(20, shift=5.0) @@ -490,7 +490,7 @@ class TestFlowUniPCScheduler: @pytest.mark.parametrize("steps", [5, 10, 20, 50]) def test_various_step_counts(self, steps): - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() sched.set_timesteps(steps, shift=5.0) @@ -500,7 +500,7 @@ class TestFlowUniPCScheduler: def test_disable_corrector(self): """Disabling corrector on step 0 should still work without error.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) sched.set_timesteps(5, shift=1.0) @@ -513,7 +513,7 @@ class TestFlowUniPCScheduler: def test_solver_order_3(self): """Order 3 should work without error.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) sched.set_timesteps(10, shift=5.0) @@ -531,7 +531,7 @@ class TestFlowUniPCScheduler: # For 50-step schedule with shift=5.0, order 2 corrector at step 5: # rhos_c[0] (history) should be ~0.07, NOT 0.5 # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 - from mlx_video.models.wan.scheduler import _compute_sigmas + from mlx_video.models.wan2.scheduler import _compute_sigmas sigmas = _compute_sigmas(50, shift=5.0) @@ -597,7 +597,7 @@ class TestSchedulerCoherence: @staticmethod def _make_schedulers(steps=10, shift=5.0): - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowMatchEulerScheduler, FlowUniPCScheduler, @@ -780,7 +780,7 @@ class TestSchedulerCoherence: def test_lambda_boundary_values(self): """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowUniPCScheduler, ) @@ -800,7 +800,7 @@ class TestSchedulerCoherence: def test_lambda_monotonically_decreasing(self): """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" - from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas] @@ -902,7 +902,7 @@ class TestSchedulerCoherence: shape = (1, 2, 1, 2, 2) noise = mx.random.normal(shape) - from mlx_video.models.wan.scheduler import ( + from mlx_video.models.wan2.scheduler import ( FlowDPMPP2MScheduler, FlowUniPCScheduler, ) @@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault: def test_corrector_enabled_by_default(self): """Default construction should have corrector enabled.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler sched = FlowUniPCScheduler() assert sched._use_corrector is True def test_corrector_affects_output(self): """Corrector should produce different results than no corrector after step 1.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler mx.random.seed(42) shape = (1, 4, 1, 4, 4) @@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault: def test_corrector_does_not_affect_first_step(self): """Step 0 should be identical regardless of corrector setting.""" - from mlx_video.models.wan.scheduler import FlowUniPCScheduler + from mlx_video.models.wan2.scheduler import FlowUniPCScheduler mx.random.seed(42) shape = (1, 4, 1, 4, 4) diff --git a/tests/test_wan_t5.py b/tests/test_wan_t5.py index 7bf0c18..df103f7 100644 --- a/tests/test_wan_t5.py +++ b/tests/test_wan_t5.py @@ -11,7 +11,7 @@ import numpy as np class TestT5LayerNorm: def test_output_shape(self): - from mlx_video.models.wan.text_encoder import T5LayerNorm + from mlx_video.models.wan2.text_encoder import T5LayerNorm norm = T5LayerNorm(64) x = mx.random.normal((2, 10, 64)) @@ -21,7 +21,7 @@ class TestT5LayerNorm: def test_rms_normalization(self): """After T5LayerNorm with weight=1, RMS should be ~1.""" - from mlx_video.models.wan.text_encoder import T5LayerNorm + from mlx_video.models.wan2.text_encoder import T5LayerNorm norm = T5LayerNorm(128) x = mx.random.normal((1, 5, 128)) * 5.0 @@ -35,7 +35,7 @@ class TestT5LayerNorm: class TestT5RelativeEmbedding: def test_output_shape(self): - from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(10, 10) @@ -43,7 +43,7 @@ class TestT5RelativeEmbedding: assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk] def test_asymmetric_lengths(self): - from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4) out = rel_emb(8, 12) @@ -52,7 +52,7 @@ class TestT5RelativeEmbedding: def test_symmetry(self): """Position bias should have structure (not all zeros/random).""" - from mlx_video.models.wan.text_encoder import T5RelativeEmbedding + from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2) out = rel_emb(6, 6) @@ -67,7 +67,7 @@ class TestT5RelativeEmbedding: class TestT5Attention: def test_output_shape(self): - from mlx_video.models.wan.text_encoder import T5Attention + from mlx_video.models.wan2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) @@ -77,14 +77,14 @@ class TestT5Attention: def test_no_scaling(self): """T5 attention famously has no sqrt(d) scaling. Verify structure.""" - from mlx_video.models.wan.text_encoder import T5Attention + from mlx_video.models.wan2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) # No scale attribute (unlike standard attention) assert not hasattr(attn, "scale") def test_with_position_bias(self): - from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding + from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding attn = T5Attention(dim=64, dim_attn=64, num_heads=4) rel_emb = T5RelativeEmbedding(32, 4) @@ -95,7 +95,7 @@ class TestT5Attention: assert out.shape == (1, 10, 64) def test_with_mask(self): - from mlx_video.models.wan.text_encoder import T5Attention + from mlx_video.models.wan2.text_encoder import T5Attention attn = T5Attention(dim=64, dim_attn=64, num_heads=4) x = mx.random.normal((1, 10, 64)) @@ -108,7 +108,7 @@ class TestT5Attention: class TestT5FeedForward: def test_output_shape(self): - from mlx_video.models.wan.text_encoder import T5FeedForward + from mlx_video.models.wan2.text_encoder import T5FeedForward ffn = T5FeedForward(64, 256) x = mx.random.normal((1, 10, 64)) @@ -118,7 +118,7 @@ class TestT5FeedForward: def test_gated_structure(self): """T5 FFN is gated: gate(x) * fc1(x).""" - from mlx_video.models.wan.text_encoder import T5FeedForward + from mlx_video.models.wan2.text_encoder import T5FeedForward ffn = T5FeedForward(32, 64) assert hasattr(ffn, "gate_proj") @@ -131,7 +131,7 @@ class TestT5Encoder: mx.random.seed(42) def test_output_shape(self): - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -150,7 +150,7 @@ class TestT5Encoder: assert out.shape == (1, 5, 64) def test_shared_pos(self): - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -167,7 +167,7 @@ class TestT5Encoder: assert block.pos_embedding is None def test_per_layer_pos(self): - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -184,7 +184,7 @@ class TestT5Encoder: assert block.pos_embedding is not None def test_param_count(self): - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, @@ -200,7 +200,7 @@ class TestT5Encoder: assert num_params > 0 def test_without_mask(self): - from mlx_video.models.wan.text_encoder import T5Encoder + from mlx_video.models.wan2.text_encoder import T5Encoder encoder = T5Encoder( vocab_size=100, diff --git a/tests/test_wan_tiling.py b/tests/test_wan_tiling.py index 303f048..e55baac 100644 --- a/tests/test_wan_tiling.py +++ b/tests/test_wan_tiling.py @@ -3,7 +3,7 @@ import mlx.core as mx import numpy as np -from mlx_video.models.ltx.video_vae.tiling import ( +from mlx_video.models.ltx_2.video_vae.tiling import ( TilingConfig, decode_with_tiling, split_in_spatial, @@ -75,7 +75,7 @@ class TestWan22TiledDecoding: def _make_small_wan22_decoder(self): """Create a small Wan2.2 decoder for testing.""" - from mlx_video.models.wan.vae22 import Wan22VAEDecoder + from mlx_video.models.wan2.vae22 import Wan22VAEDecoder # Use very small dimensions for fast testing vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16) @@ -139,7 +139,7 @@ class TestWan21TiledDecoding: def _make_small_wan21_vae(self): """Create a small Wan2.1 VAE for testing.""" - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae = WanVAE(z_dim=16) mx.eval(vae.parameters()) @@ -192,7 +192,7 @@ class TestWan21TemporalScale: def test_wan21_decoder_temporal_output(self): """Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling).""" - from mlx_video.models.wan.vae import Decoder3d + from mlx_video.models.wan2.vae import Decoder3d # Small decoder for fast test dec = Decoder3d( diff --git a/tests/test_wan_transformer.py b/tests/test_wan_transformer.py index 8cbfb67..7d197c2 100644 --- a/tests/test_wan_transformer.py +++ b/tests/test_wan_transformer.py @@ -10,7 +10,7 @@ import numpy as np class TestWanFFN: def test_output_shape(self): - from mlx_video.models.wan.transformer import WanFFN + from mlx_video.models.wan2.transformer import WanFFN ffn = WanFFN(64, 256) x = mx.random.normal((2, 10, 64)) @@ -20,7 +20,7 @@ class TestWanFFN: def test_gelu_activation(self): """FFN should use GELU activation (non-linearity).""" - from mlx_video.models.wan.transformer import WanFFN + from mlx_video.models.wan2.transformer import WanFFN ffn = WanFFN(32, 128) x = mx.ones((1, 1, 32)) * 2.0 @@ -40,8 +40,8 @@ class TestWanAttentionBlock: self.num_heads = 4 def test_output_shape(self): - from mlx_video.models.wan.rope import rope_params - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -68,13 +68,13 @@ class TestWanAttentionBlock: assert out.shape == (B, L, self.dim) def test_modulation_shape(self): - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) assert block.modulation.shape == (1, 6, self.dim) def test_with_cross_attn_norm(self): - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -85,7 +85,7 @@ class TestWanAttentionBlock: assert block.norm3 is not None def test_without_cross_attn_norm(self): - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock( self.dim, @@ -97,8 +97,8 @@ class TestWanAttentionBlock: def test_residual_connection(self): """Output should differ from zero even with small random init.""" - from mlx_video.models.wan.rope import rope_params - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads) B, L = 1, 8 @@ -129,15 +129,15 @@ class TestFloat32Modulation: def test_block_modulation_in_float32(self): """Modulation param starts random but should be usable as float32.""" - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True) assert block.modulation.dtype == mx.float32 def test_block_output_float32_with_bf16_modulation_input(self): """Even if e (time embedding) arrives as bf16, modulation should cast to f32.""" - from mlx_video.models.wan.rope import rope_params - from mlx_video.models.wan.transformer import WanAttentionBlock + from mlx_video.models.wan2.rope import rope_params + from mlx_video.models.wan2.transformer import WanAttentionBlock block = WanAttentionBlock(self.dim, 128, 4) B, L = 1, 8 @@ -153,7 +153,7 @@ class TestFloat32Modulation: def test_head_modulation_float32(self): """Head modulation should be float32 even with bf16 e input.""" - from mlx_video.models.wan.model import Head + from mlx_video.models.wan2.model import Head head = Head(self.dim, 4, (1, 2, 2)) x = mx.random.normal((1, 8, self.dim)) @@ -164,7 +164,7 @@ class TestFloat32Modulation: def test_model_time_embedding_float32(self): """sinusoidal_embedding_1d output must be float32.""" - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d t = mx.array([500.0]) emb = sinusoidal_embedding_1d(256, t) @@ -173,7 +173,7 @@ class TestFloat32Modulation: def test_model_per_token_time_embedding_float32(self): """Per-token time embeddings (I2V) should also be float32.""" - from mlx_video.models.wan.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.model import sinusoidal_embedding_1d t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] emb = sinusoidal_embedding_1d(256, t) diff --git a/tests/test_wan_vae.py b/tests/test_wan_vae.py index c604e74..85c8381 100644 --- a/tests/test_wan_vae.py +++ b/tests/test_wan_vae.py @@ -12,7 +12,7 @@ import numpy as np class TestCausalConv3d: def test_output_shape_stride1(self): - from mlx_video.models.wan.vae import CausalConv3d + from mlx_video.models.wan2.vae import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1) # Initialize weights @@ -28,7 +28,7 @@ class TestCausalConv3d: assert out.shape[4] == 8 # W preserved def test_output_shape_kernel1(self): - from mlx_video.models.wan.vae import CausalConv3d + from mlx_video.models.wan2.vae import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0) conv.weight = mx.random.normal(conv.weight.shape) * 0.02 @@ -39,7 +39,7 @@ class TestCausalConv3d: def test_causal_padding(self): """Causal conv should only use past/current frames, not future.""" - from mlx_video.models.wan.vae import CausalConv3d + from mlx_video.models.wan2.vae import CausalConv3d conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 @@ -56,7 +56,7 @@ class TestCausalConv3d: class TestResidualBlock: def test_same_dim(self): - from mlx_video.models.wan.vae import ResidualBlock + from mlx_video.models.wan2.vae import ResidualBlock block = ResidualBlock(8, 8) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -65,7 +65,7 @@ class TestResidualBlock: assert out.shape == (1, 8, 2, 4, 4) def test_different_dim(self): - from mlx_video.models.wan.vae import ResidualBlock + from mlx_video.models.wan2.vae import ResidualBlock block = ResidualBlock(8, 16) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -74,13 +74,13 @@ class TestResidualBlock: assert out.shape == (1, 16, 2, 4, 4) def test_shortcut_exists_when_dims_differ(self): - from mlx_video.models.wan.vae import ResidualBlock + from mlx_video.models.wan2.vae import ResidualBlock block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_when_dims_same(self): - from mlx_video.models.wan.vae import ResidualBlock + from mlx_video.models.wan2.vae import ResidualBlock block = ResidualBlock(8, 8) assert block.shortcut is None @@ -88,7 +88,7 @@ class TestResidualBlock: class TestAttentionBlock: def test_output_shape(self): - from mlx_video.models.wan.vae import AttentionBlock + from mlx_video.models.wan2.vae import AttentionBlock block = AttentionBlock(8) x = mx.random.normal((1, 8, 2, 4, 4)) @@ -97,7 +97,7 @@ class TestAttentionBlock: assert out.shape == (1, 8, 2, 4, 4) def test_residual_connection(self): - from mlx_video.models.wan.vae import AttentionBlock + from mlx_video.models.wan2.vae import AttentionBlock block = AttentionBlock(8) x = mx.random.normal((1, 8, 1, 3, 3)) @@ -109,7 +109,7 @@ class TestAttentionBlock: class TestWanVAE: def test_instantiation(self): - from mlx_video.models.wan.vae import WanVAE + from mlx_video.models.wan2.vae import WanVAE vae = WanVAE(z_dim=16) assert vae.z_dim == 16 @@ -117,7 +117,7 @@ class TestWanVAE: assert vae.std.shape == (16,) def test_normalization_stats(self): - from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD + from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD assert len(VAE_MEAN) == 16 assert len(VAE_STD) == 16 @@ -133,7 +133,7 @@ class TestVAE22CausalConv3d: """Tests for vae22.CausalConv3d (channels-last).""" def test_output_shape_k3(self): - from mlx_video.models.wan.vae22 import CausalConv3d + from mlx_video.models.wan2.vae22 import CausalConv3d conv = CausalConv3d(8, 16, kernel_size=3, padding=1) x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] @@ -142,7 +142,7 @@ class TestVAE22CausalConv3d: assert out.shape == (1, 4, 8, 8, 16) def test_output_shape_k1(self): - from mlx_video.models.wan.vae22 import CausalConv3d + from mlx_video.models.wan2.vae22 import CausalConv3d conv = CausalConv3d(8, 16, kernel_size=1) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -152,7 +152,7 @@ class TestVAE22CausalConv3d: def test_temporal_causal(self): """Output at t=0 should not depend on t>0.""" - from mlx_video.models.wan.vae22 import CausalConv3d + from mlx_video.models.wan2.vae22 import CausalConv3d conv = CausalConv3d(2, 2, kernel_size=3, padding=1) conv.weight = mx.random.normal(conv.weight.shape) * 0.1 @@ -178,7 +178,7 @@ class TestVAE22CausalConv3d: def test_channels_last_format(self): """Verify input/output are channels-last [B, T, H, W, C].""" - from mlx_video.models.wan.vae22 import CausalConv3d + from mlx_video.models.wan2.vae22 import CausalConv3d conv = CausalConv3d(4, 8, kernel_size=3, padding=1) x = mx.random.normal((2, 3, 6, 6, 4)) @@ -191,7 +191,7 @@ class TestRMSNorm: """Tests for vae22.RMS_norm (actually L2 normalization).""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import RMS_norm + from mlx_video.models.wan2.vae22 import RMS_norm norm = RMS_norm(16) x = mx.random.normal((2, 4, 4, 4, 16)) @@ -201,7 +201,7 @@ class TestRMSNorm: def test_l2_normalization(self): """RMS_norm should normalize to unit L2 norm * sqrt(dim).""" - from mlx_video.models.wan.vae22 import RMS_norm + from mlx_video.models.wan2.vae22 import RMS_norm dim = 32 norm = RMS_norm(dim) @@ -215,7 +215,7 @@ class TestRMSNorm: def test_scale_invariant(self): """Scaling input by constant should not change output (L2 norm property).""" - from mlx_video.models.wan.vae22 import RMS_norm + from mlx_video.models.wan2.vae22 import RMS_norm norm = RMS_norm(8) x = mx.random.normal((1, 1, 1, 1, 8)) @@ -226,7 +226,7 @@ class TestRMSNorm: def test_gamma_effect(self): """Non-unit gamma should scale output.""" - from mlx_video.models.wan.vae22 import RMS_norm + from mlx_video.models.wan2.vae22 import RMS_norm norm = RMS_norm(4) norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) @@ -241,7 +241,7 @@ class TestDupUp3D: """Tests for vae22.DupUp3D spatial/temporal upsampling.""" def test_spatial_only(self): - from mlx_video.models.wan.vae22 import DupUp3D + from mlx_video.models.wan2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -250,7 +250,7 @@ class TestDupUp3D: assert out.shape == (1, 3, 8, 8, 4) def test_temporal_and_spatial(self): - from mlx_video.models.wan.vae22 import DupUp3D + from mlx_video.models.wan2.vae22 import DupUp3D up = DupUp3D(16, 8, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 16)) @@ -259,7 +259,7 @@ class TestDupUp3D: assert out.shape == (1, 6, 8, 8, 8) def test_first_chunk_trims(self): - from mlx_video.models.wan.vae22 import DupUp3D + from mlx_video.models.wan2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=2, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -271,7 +271,7 @@ class TestDupUp3D: assert out_trimmed.shape[1] == 5 def test_no_temporal_first_chunk_noop(self): - from mlx_video.models.wan.vae22 import DupUp3D + from mlx_video.models.wan2.vae22 import DupUp3D up = DupUp3D(8, 4, factor_t=1, factor_s=2) x = mx.random.normal((1, 3, 4, 4, 8)) @@ -286,7 +286,7 @@ class TestVAE22Resample: """Tests for vae22.Resample (spatial/temporal upsampling).""" def test_upsample2d_shape(self): - from mlx_video.models.wan.vae22 import Resample + from mlx_video.models.wan2.vae22 import Resample r = Resample(8, "upsample2d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -296,7 +296,7 @@ class TestVAE22Resample: assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal def test_upsample3d_shape(self): - from mlx_video.models.wan.vae22 import Resample + from mlx_video.models.wan2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -306,7 +306,7 @@ class TestVAE22Resample: assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal def test_upsample3d_first_chunk(self): - from mlx_video.models.wan.vae22 import Resample + from mlx_video.models.wan2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -318,7 +318,7 @@ class TestVAE22Resample: def test_upsample3d_first_chunk_single_frame(self): """Single-frame input with first_chunk: no temporal upsample.""" - from mlx_video.models.wan.vae22 import Resample + from mlx_video.models.wan2.vae22 import Resample r = Resample(8, "upsample3d") r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 @@ -336,7 +336,7 @@ class TestVAE22Resample: We verify this by checking that the first output frame depends only on the first input frame (not on time_conv parameters). """ - from mlx_video.models.wan.vae22 import Resample + from mlx_video.models.wan2.vae22 import Resample C = 8 r = Resample(C, "upsample3d") @@ -373,7 +373,7 @@ class TestVAE22ResidualBlock: """Tests for vae22.ResidualBlock.""" def test_same_dim(self): - from mlx_video.models.wan.vae22 import ResidualBlock + from mlx_video.models.wan2.vae22 import ResidualBlock block = ResidualBlock(8, 8) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -382,7 +382,7 @@ class TestVAE22ResidualBlock: assert out.shape == (1, 2, 4, 4, 8) def test_different_dim(self): - from mlx_video.models.wan.vae22 import ResidualBlock + from mlx_video.models.wan2.vae22 import ResidualBlock block = ResidualBlock(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -391,13 +391,13 @@ class TestVAE22ResidualBlock: assert out.shape == (1, 2, 4, 4, 16) def test_shortcut_when_dims_differ(self): - from mlx_video.models.wan.vae22 import ResidualBlock + from mlx_video.models.wan2.vae22 import ResidualBlock block = ResidualBlock(8, 16) assert block.shortcut is not None def test_no_shortcut_same_dim(self): - from mlx_video.models.wan.vae22 import ResidualBlock + from mlx_video.models.wan2.vae22 import ResidualBlock block = ResidualBlock(8, 8) assert block.shortcut is None @@ -408,7 +408,7 @@ class TestResidualBlockLayers: def test_layer_names_no_underscore_prefix(self): """Layer names must NOT start with underscore (MLX ignores them).""" - from mlx_video.models.wan.vae22 import ResidualBlockLayers + from mlx_video.models.wan2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 8) params = dict(block.parameters()) @@ -417,7 +417,7 @@ class TestResidualBlockLayers: assert not key.startswith("_"), f"Parameter {key} starts with underscore" def test_has_expected_layers(self): - from mlx_video.models.wan.vae22 import ResidualBlockLayers + from mlx_video.models.wan2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 16) assert hasattr(block, "layer_0") # first RMS_norm @@ -426,7 +426,7 @@ class TestResidualBlockLayers: assert hasattr(block, "layer_6") # second CausalConv3d def test_forward_shape(self): - from mlx_video.models.wan.vae22 import ResidualBlockLayers + from mlx_video.models.wan2.vae22 import ResidualBlockLayers block = ResidualBlockLayers(8, 16) x = mx.random.normal((1, 2, 4, 4, 8)) @@ -439,7 +439,7 @@ class TestVAE22AttentionBlock: """Tests for vae22.AttentionBlock (per-frame 2D self-attention).""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import AttentionBlock + from mlx_video.models.wan2.vae22 import AttentionBlock block = AttentionBlock(16) block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 @@ -450,7 +450,7 @@ class TestVAE22AttentionBlock: assert out.shape == (1, 2, 4, 4, 16) def test_residual_connection(self): - from mlx_video.models.wan.vae22 import AttentionBlock + from mlx_video.models.wan2.vae22 import AttentionBlock block = AttentionBlock(8) block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) @@ -466,7 +466,7 @@ class TestHead22: """Tests for vae22.Head22 output head.""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import Head22 + from mlx_video.models.wan2.vae22 import Head22 head = Head22(16, out_channels=12) x = mx.random.normal((1, 2, 4, 4, 16)) @@ -476,7 +476,7 @@ class TestHead22: def test_layer_names_no_underscore(self): """Head layers must not use underscore prefix.""" - from mlx_video.models.wan.vae22 import Head22 + from mlx_video.models.wan2.vae22 import Head22 head = Head22(8) assert hasattr(head, "layer_0") # RMS_norm @@ -490,7 +490,7 @@ class TestUnpatchify: """Tests for vae22._unpatchify.""" def test_basic_shape(self): - from mlx_video.models.wan.vae22 import _unpatchify + from mlx_video.models.wan2.vae22 import _unpatchify x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 out = _unpatchify(x, patch_size=2) @@ -498,7 +498,7 @@ class TestUnpatchify: assert out.shape == (1, 2, 8, 8, 3) def test_patch_size_1_noop(self): - from mlx_video.models.wan.vae22 import _unpatchify + from mlx_video.models.wan2.vae22 import _unpatchify x = mx.random.normal((1, 2, 4, 4, 3)) out = _unpatchify(x, patch_size=1) @@ -507,7 +507,7 @@ class TestUnpatchify: def test_preserves_content(self): """Unpatchify should be a lossless rearrangement.""" - from mlx_video.models.wan.vae22 import _unpatchify + from mlx_video.models.wan2.vae22 import _unpatchify x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) out = _unpatchify(x, patch_size=2) @@ -521,7 +521,7 @@ class TestDenormalizeLatents: """Tests for vae22.denormalize_latents.""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import denormalize_latents + from mlx_video.models.wan2.vae22 import denormalize_latents z = mx.random.normal((1, 2, 4, 4, 48)) out = denormalize_latents(z) @@ -529,7 +529,7 @@ class TestDenormalizeLatents: assert out.shape == (1, 2, 4, 4, 48) def test_custom_mean_std(self): - from mlx_video.models.wan.vae22 import denormalize_latents + from mlx_video.models.wan2.vae22 import denormalize_latents z = mx.ones((1, 1, 1, 1, 4)) mean = mx.array([1.0, 2.0, 3.0, 4.0]) @@ -542,7 +542,7 @@ class TestDenormalizeLatents: ) def test_uses_default_constants(self): - from mlx_video.models.wan.vae22 import ( + from mlx_video.models.wan2.vae22 import ( VAE22_MEAN, denormalize_latents, ) @@ -563,14 +563,14 @@ class TestVAE22NormConstants: """Tests for VAE22_MEAN and VAE22_STD constants.""" def test_dimensions(self): - from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD + from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD mx.eval(VAE22_MEAN, VAE22_STD) assert VAE22_MEAN.shape == (48,) assert VAE22_STD.shape == (48,) def test_std_positive(self): - from mlx_video.models.wan.vae22 import VAE22_STD + from mlx_video.models.wan2.vae22 import VAE22_STD mx.eval(VAE22_STD) assert (np.array(VAE22_STD) > 0).all() @@ -581,7 +581,7 @@ class TestWan22VAEDecoder: def test_output_shape_small(self): """Tiny decoder should produce correct spatial/temporal output.""" - from mlx_video.models.wan.vae22 import Wan22VAEDecoder + from mlx_video.models.wan2.vae22 import Wan22VAEDecoder # Use very small dims to keep test fast dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) @@ -597,7 +597,7 @@ class TestWan22VAEDecoder: assert np.array(out).max() <= 1.0 def test_output_clipped(self): - from mlx_video.models.wan.vae22 import Wan22VAEDecoder + from mlx_video.models.wan2.vae22 import Wan22VAEDecoder dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values @@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights: """Tests for vae22.sanitize_wan22_vae_weights.""" def test_skip_encoder(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "encoder.layer.weight": mx.zeros((4,)), @@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.conv1.bias" in out def test_sequential_index_remapping(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), @@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.head.layer_2.bias" in out def test_resample_conv_remapping(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), @@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.upsamples.1.upsamples.3.resample_bias" in out def test_attention_remapping(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights weights = { "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), @@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights: assert "decoder.middle.1.proj_bias" in out def test_conv3d_transpose(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] w = mx.zeros((16, 8, 3, 3, 3)) @@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights: assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8) def test_conv2d_transpose(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights # Conv2d weight: [O, I, H, W] → [O, H, W, I] w = mx.zeros((8, 8, 3, 3)) @@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights: assert out[key].shape == (8, 3, 3, 8) def test_gamma_squeeze(self): - from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights # gamma: (dim, 1, 1, 1) → (dim,) w = mx.ones((16, 1, 1, 1)) @@ -698,7 +698,7 @@ class TestUpResidualBlock: """Tests for vae22.Up_ResidualBlock.""" def test_no_upsample(self): - from mlx_video.models.wan.vae22 import Up_ResidualBlock + from mlx_video.models.wan2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False @@ -710,7 +710,7 @@ class TestUpResidualBlock: assert out.shape == (1, 2, 4, 4, 8) def test_spatial_upsample(self): - from mlx_video.models.wan.vae22 import Up_ResidualBlock + from mlx_video.models.wan2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True @@ -722,7 +722,7 @@ class TestUpResidualBlock: assert out.shape == (1, 2, 8, 8, 4) def test_spatial_temporal_upsample(self): - from mlx_video.models.wan.vae22 import Up_ResidualBlock + from mlx_video.models.wan2.vae22 import Up_ResidualBlock block = Up_ResidualBlock( 8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True @@ -738,7 +738,7 @@ class TestPatchify: """Tests for _patchify and _unpatchify round-trip.""" def test_roundtrip(self): - from mlx_video.models.wan.vae22 import _patchify, _unpatchify + from mlx_video.models.wan2.vae22 import _patchify, _unpatchify x = mx.random.normal((1, 1, 64, 64, 3)) p = _patchify(x, patch_size=2) @@ -748,7 +748,7 @@ class TestPatchify: assert float(mx.abs(x - back).max()) == 0.0 def test_identity_patch_1(self): - from mlx_video.models.wan.vae22 import _patchify, _unpatchify + from mlx_video.models.wan2.vae22 import _patchify, _unpatchify x = mx.random.normal((1, 2, 8, 8, 3)) assert _patchify(x, patch_size=1).shape == x.shape @@ -759,7 +759,7 @@ class TestAvgDown3D: """Tests for AvgDown3D downsampling.""" def test_spatial_only(self): - from mlx_video.models.wan.vae22 import AvgDown3D + from mlx_video.models.wan2.vae22 import AvgDown3D down = AvgDown3D(8, 16, factor_t=1, factor_s=2) x = mx.random.normal((1, 2, 8, 8, 8)) @@ -768,7 +768,7 @@ class TestAvgDown3D: assert out.shape == (1, 2, 4, 4, 16) def test_temporal_and_spatial(self): - from mlx_video.models.wan.vae22 import AvgDown3D + from mlx_video.models.wan2.vae22 import AvgDown3D down = AvgDown3D(8, 16, factor_t=2, factor_s=2) x = mx.random.normal((1, 4, 8, 8, 8)) @@ -777,7 +777,7 @@ class TestAvgDown3D: assert out.shape == (1, 2, 4, 4, 16) def test_single_frame(self): - from mlx_video.models.wan.vae22 import AvgDown3D + from mlx_video.models.wan2.vae22 import AvgDown3D down = AvgDown3D(8, 8, factor_t=2, factor_s=2) x = mx.random.normal((1, 1, 8, 8, 8)) @@ -791,7 +791,7 @@ class TestDownResidualBlock: """Tests for Down_ResidualBlock.""" def test_no_downsample(self): - from mlx_video.models.wan.vae22 import Down_ResidualBlock + from mlx_video.models.wan2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False @@ -802,7 +802,7 @@ class TestDownResidualBlock: assert out.shape == (1, 2, 8, 8, 8) def test_spatial_downsample(self): - from mlx_video.models.wan.vae22 import Down_ResidualBlock + from mlx_video.models.wan2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True @@ -813,7 +813,7 @@ class TestDownResidualBlock: assert out.shape == (1, 2, 4, 4, 16) def test_spatial_temporal_downsample(self): - from mlx_video.models.wan.vae22 import Down_ResidualBlock + from mlx_video.models.wan2.vae22 import Down_ResidualBlock block = Down_ResidualBlock( 8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True @@ -828,7 +828,7 @@ class TestEncoder3d: """Tests for Encoder3d.""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import Encoder3d + from mlx_video.models.wan2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8) x = mx.random.normal((1, 1, 16, 16, 12)) @@ -839,7 +839,7 @@ class TestEncoder3d: assert out.shape == (1, 1, 2, 2, 8) def test_multi_frame(self): - from mlx_video.models.wan.vae22 import Encoder3d + from mlx_video.models.wan2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False)) x = mx.random.normal((1, 5, 16, 16, 12)) @@ -854,7 +854,7 @@ class TestWan22VAEEncoder: """Tests for Wan22VAEEncoder wrapper.""" def test_output_shape(self): - from mlx_video.models.wan.vae22 import Wan22VAEEncoder + from mlx_video.models.wan2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) # Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2) @@ -865,7 +865,7 @@ class TestWan22VAEEncoder: assert z.shape == (1, 1, 2, 2, 48) def test_full_dim(self): - from mlx_video.models.wan.vae22 import Wan22VAEEncoder + from mlx_video.models.wan2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=160) img = mx.random.normal((1, 1, 64, 64, 3)) @@ -880,7 +880,7 @@ class TestNormalizeLatents: """Tests for normalize/denormalize latent roundtrip.""" def test_roundtrip(self): - from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents + from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents z = mx.random.normal((1, 2, 4, 4, 48)) z_norm = normalize_latents(z) @@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder: def test_encoder_temporal_downsample_pattern(self): """Encoder3d with (False, True, True): T=5→5→3→2.""" - from mlx_video.models.wan.vae22 import Encoder3d + from mlx_video.models.wan2.vae22 import Encoder3d enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True)) x = mx.random.normal((1, 5, 16, 16, 12)) @@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder: def test_wrapper_uses_correct_pattern(self): """Wan22VAEEncoder should use (False, True, True) temporal downsample.""" - from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder + from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) down_blocks = enc.encoder.downsamples @@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder: def test_single_frame_encoder(self): """Single frame (T=1) should work with (False, True, True) pattern.""" - from mlx_video.models.wan.vae22 import Wan22VAEEncoder + from mlx_video.models.wan2.vae22 import Wan22VAEEncoder enc = Wan22VAEEncoder(z_dim=48, dim=16) img = mx.random.normal((1, 1, 32, 32, 3)) @@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder: def test_wrong_order_gives_different_result(self): """(True, True, False) vs (False, True, True) produce different outputs.""" - from mlx_video.models.wan.vae22 import Encoder3d + from mlx_video.models.wan2.vae22 import Encoder3d enc_correct = Encoder3d( dim=16, z_dim=8, temperal_downsample=(False, True, True) @@ -963,7 +963,7 @@ class TestVAE21RoundTrip: def test_encode_decode_shape_and_values(self): """Encoder3d → Decoder3d: output shape matches input, values are finite.""" - from mlx_video.models.wan.vae import Decoder3d, Encoder3d + from mlx_video.models.wan2.vae import Decoder3d, Encoder3d z_dim = 4 dim = 8 @@ -995,7 +995,7 @@ class TestVAE22RoundTrip: def test_encode_decode_shape_and_values(self): """Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range.""" - from mlx_video.models.wan.vae22 import ( + from mlx_video.models.wan2.vae22 import ( Wan22VAEDecoder, Wan22VAEEncoder, denormalize_latents, diff --git a/tests/wan_test_helpers.py b/tests/wan_test_helpers.py index 0d1a2b1..2b67ada 100644 --- a/tests/wan_test_helpers.py +++ b/tests/wan_test_helpers.py @@ -3,7 +3,7 @@ def _make_tiny_config(): """Create a tiny WanModelConfig for testing.""" - from mlx_video.models.wan.config import WanModelConfig + from mlx_video.models.wan2.config import WanModelConfig config = WanModelConfig() # Override to tiny values