This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -3,17 +3,16 @@
import logging
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Weight Conversion Tests
# ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights:
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)),
@@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights:
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
@@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights:
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)),
@@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights:
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)),
@@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights:
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)),
@@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights:
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
@@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights:
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
@@ -133,6 +140,7 @@ class TestSanitizeT5Weights:
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
@@ -144,6 +152,7 @@ class TestSanitizeT5Weights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -159,6 +168,7 @@ class TestSanitizeT5Weights:
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
}
@@ -167,6 +177,7 @@ class TestSanitizeVAEWeights:
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
}
@@ -175,6 +186,7 @@ class TestSanitizeVAEWeights:
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)),
@@ -185,6 +197,7 @@ class TestSanitizeVAEWeights:
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
@@ -199,6 +212,7 @@ class TestSanitizeVAEWeights:
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
# Wan2.1 Conversion Tests
# ---------------------------------------------------------------------------
class TestWan21Convert:
"""Tests for Wan2.1 conversion support."""
@@ -222,7 +237,7 @@ class TestWan21Convert:
# Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model"
assert not low.exists()
# Simulates auto detection
@@ -233,7 +248,7 @@ class TestWan21Convert:
"""Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model"
assert low.exists()
version = "2.2" if low.exists() else "2.1"
@@ -242,6 +257,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
@@ -254,6 +270,7 @@ class TestWan21Convert:
# Encoder Weight Sanitization Tests
# ---------------------------------------------------------------------------
class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""