Files
mlx-video/tests/test_wan_quantization.py
2026-03-11 09:24:06 +01:00

314 lines
12 KiB
Python

"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
# ---------------------------------------------------------------------------
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
# Write config.json
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
with open(tmp_path / "config.json", "w") as f:
json.dump(cfg, f)
return model_path, weights_dict
def test_4bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Verify quantized layers have scales
has_scales = any("scales" in k for k in saved_weights)
assert has_scales, "Quantized model should have .scales tensors"
# Verify a self-attention layer is QuantizedLinear
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
def test_8bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 8, "group_size": 64},
)
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
def test_non_quantized_layers_remain_linear(self, tmp_path):
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Head should NOT be quantized (it's not in the predicate patterns)
assert not isinstance(loaded.head, nn.QuantizedLinear)
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
# ---------------------------------------------------------------------------
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=64,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
return model
def test_forward_pass_4bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=4)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_forward_pass_8bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=8)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
config = _make_tiny_config()
mx.random.seed(42)
# Get unquantized weights
model = WanModel(config)
mx.eval(model.parameters())
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
# Quantize
nn.quantize(
model,
group_size=64,
bits=4,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
# QuantizedLinear stores weight differently (uint32 packed)
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
assert hasattr(model.blocks[0].self_attn.q, "scales")
# ---------------------------------------------------------------------------
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Save unquantized model + config
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({"dim": config.dim}, f)
# Run quantization
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
# Verify metadata
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert "quantization" in cfg
assert cfg["quantization"]["bits"] == 4
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert cfg["quantization"]["bits"] == 8
assert cfg["quantization"]["group_size"] == 32
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(tmp_path / name), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
# Both files should now contain quantized weights (have .scales keys)
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
weights = mx.load(str(tmp_path / name))
has_scales = any("scales" in k for k in weights)
assert has_scales, f"{name} should have quantized layers"