This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,22 +1,22 @@
"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
@@ -24,6 +24,7 @@ class TestQuantizePredicate:
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
@@ -31,23 +32,31 @@ class TestQuantizePredicate:
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
for path in [
"patch_embedding_proj",
"text_embedding_fc1",
"time_embedding.fc1",
]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
@@ -55,13 +64,19 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
"blocks.0.self_attn.q",
"blocks.0.self_attn.k",
"blocks.0.self_attn.v",
"blocks.0.self_attn.o",
"blocks.0.cross_attn.q",
"blocks.0.cross_attn.k",
"blocks.0.cross_attn.v",
"blocks.0.cross_attn.o",
"blocks.0.ffn.fc1",
"blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -101,8 +117,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -119,8 +137,10 @@ class TestQuantizeRoundTrip:
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 8, "group_size": 64},
)
@@ -132,8 +152,10 @@ class TestQuantizeRoundTrip:
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -151,6 +173,7 @@ class TestQuantizeRoundTrip:
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -214,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
mx.random.seed(42)
@@ -243,11 +267,12 @@ class TestQuantizedInference:
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()