339 lines
12 KiB
Python
339 lines
12 KiB
Python
"""Tests for Wan model quantization pipeline."""
|
|
|
|
import json
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.utils
|
|
import numpy as np
|
|
from wan_test_helpers import _make_tiny_config
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quantize Predicate Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuantizePredicate:
|
|
def test_matches_self_attention_layers(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_linear = nn.Linear(64, 64)
|
|
for suffix in ["q", "k", "v", "o"]:
|
|
path = f"blocks.0.self_attn.{suffix}"
|
|
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
|
|
|
def test_matches_cross_attention_layers(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_linear = nn.Linear(64, 64)
|
|
for suffix in ["q", "k", "v", "o"]:
|
|
path = f"blocks.0.cross_attn.{suffix}"
|
|
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
|
|
|
def test_matches_ffn_layers(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_linear = nn.Linear(64, 64)
|
|
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
|
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
|
|
|
def test_rejects_embeddings(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_linear = nn.Linear(64, 64)
|
|
for path in [
|
|
"patch_embedding_proj",
|
|
"text_embedding_fc1",
|
|
"time_embedding.fc1",
|
|
]:
|
|
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
|
|
|
def test_rejects_norms(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_norm = nn.RMSNorm(64)
|
|
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
|
|
|
def test_rejects_non_quantizable_modules(self):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_norm = nn.RMSNorm(64)
|
|
# Even if path matches, module must have to_quantized
|
|
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
|
|
|
def test_all_10_patterns_covered(self):
|
|
"""Verify exactly 10 layer patterns are targeted."""
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
|
|
mock_linear = nn.Linear(64, 64)
|
|
patterns = [
|
|
"blocks.0.self_attn.q",
|
|
"blocks.0.self_attn.k",
|
|
"blocks.0.self_attn.v",
|
|
"blocks.0.self_attn.o",
|
|
"blocks.0.cross_attn.q",
|
|
"blocks.0.cross_attn.k",
|
|
"blocks.0.cross_attn.v",
|
|
"blocks.0.cross_attn.o",
|
|
"blocks.0.ffn.fc1",
|
|
"blocks.0.ffn.fc2",
|
|
]
|
|
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
|
assert len(matched) == 10
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quantize Round-Trip Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuantizeRoundTrip:
|
|
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
|
"""Helper: create model, quantize, save to tmp_path."""
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
model = WanModel(config)
|
|
nn.quantize(
|
|
model,
|
|
group_size=group_size,
|
|
bits=bits,
|
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
|
)
|
|
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
model_path = tmp_path / "model.safetensors"
|
|
mx.save_safetensors(str(model_path), weights_dict)
|
|
|
|
# Write config.json
|
|
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
|
|
with open(tmp_path / "config.json", "w") as f:
|
|
json.dump(cfg, f)
|
|
|
|
return model_path, weights_dict
|
|
|
|
def test_4bit_roundtrip(self, tmp_path):
|
|
config = _make_tiny_config()
|
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
|
|
|
from mlx_video.models.wan.loading import load_wan_model
|
|
|
|
loaded = load_wan_model(
|
|
model_path,
|
|
config,
|
|
quantization={"bits": 4, "group_size": 64},
|
|
)
|
|
|
|
# Verify quantized layers have scales
|
|
has_scales = any("scales" in k for k in saved_weights)
|
|
assert has_scales, "Quantized model should have .scales tensors"
|
|
|
|
# Verify a self-attention layer is QuantizedLinear
|
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
|
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
|
|
|
|
def test_8bit_roundtrip(self, tmp_path):
|
|
config = _make_tiny_config()
|
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
|
|
|
from mlx_video.models.wan.loading import load_wan_model
|
|
|
|
loaded = load_wan_model(
|
|
model_path,
|
|
config,
|
|
quantization={"bits": 8, "group_size": 64},
|
|
)
|
|
|
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
|
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
|
|
|
|
def test_non_quantized_layers_remain_linear(self, tmp_path):
|
|
config = _make_tiny_config()
|
|
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
|
|
|
from mlx_video.models.wan.loading import load_wan_model
|
|
|
|
loaded = load_wan_model(
|
|
model_path,
|
|
config,
|
|
quantization={"bits": 4, "group_size": 64},
|
|
)
|
|
|
|
# Head should NOT be quantized (it's not in the predicate patterns)
|
|
assert not isinstance(loaded.head, nn.QuantizedLinear)
|
|
|
|
def test_loading_without_quantization_flag(self, tmp_path):
|
|
"""Loading a non-quantized model should have standard Linear layers."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
model_path = tmp_path / "model.safetensors"
|
|
mx.save_safetensors(str(model_path), weights_dict)
|
|
|
|
from mlx_video.models.wan.loading import load_wan_model
|
|
|
|
loaded = load_wan_model(model_path, config, quantization=None)
|
|
|
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
|
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quantized Inference Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuantizedInference:
|
|
def _make_quantized_model(self, config, bits=4):
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
model = WanModel(config)
|
|
nn.quantize(
|
|
model,
|
|
group_size=64,
|
|
bits=bits,
|
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
|
)
|
|
mx.eval(model.parameters())
|
|
return model
|
|
|
|
def test_forward_pass_4bit(self):
|
|
config = _make_tiny_config()
|
|
model = self._make_quantized_model(config, bits=4)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x = [mx.random.normal((C, F, H, W))]
|
|
t = mx.array([500.0])
|
|
context = [mx.random.normal((4, config.text_dim))]
|
|
|
|
out = model(x, t, context, seq_len)
|
|
mx.eval(out[0])
|
|
|
|
assert len(out) == 1
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_forward_pass_8bit(self):
|
|
config = _make_tiny_config()
|
|
model = self._make_quantized_model(config, bits=8)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x = [mx.random.normal((C, F, H, W))]
|
|
t = mx.array([500.0])
|
|
context = [mx.random.normal((4, config.text_dim))]
|
|
|
|
out = model(x, t, context, seq_len)
|
|
mx.eval(out[0])
|
|
|
|
assert len(out) == 1
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_quantized_output_differs_from_unquantized(self):
|
|
"""Sanity check: quantization should change the weights."""
|
|
from mlx_video.convert_wan import _quantize_predicate
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
mx.random.seed(42)
|
|
|
|
# Get unquantized weights
|
|
model = WanModel(config)
|
|
mx.eval(model.parameters())
|
|
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
|
|
|
|
# Quantize
|
|
nn.quantize(
|
|
model,
|
|
group_size=64,
|
|
bits=4,
|
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
|
)
|
|
mx.eval(model.parameters())
|
|
|
|
# QuantizedLinear stores weight differently (uint32 packed)
|
|
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
|
|
assert hasattr(model.blocks[0].self_attn.q, "scales")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config Metadata Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuantizationConfig:
|
|
def test_config_metadata_written(self, tmp_path):
|
|
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
|
from mlx_video.convert_wan import _quantize_saved_model
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
|
|
# Save unquantized model + config
|
|
model_path = tmp_path / "model.safetensors"
|
|
mx.save_safetensors(str(model_path), weights_dict)
|
|
with open(tmp_path / "config.json", "w") as f:
|
|
json.dump({"dim": config.dim}, f)
|
|
|
|
# Run quantization
|
|
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
|
|
|
|
# Verify metadata
|
|
with open(tmp_path / "config.json") as f:
|
|
cfg = json.load(f)
|
|
assert "quantization" in cfg
|
|
assert cfg["quantization"]["bits"] == 4
|
|
assert cfg["quantization"]["group_size"] == 64
|
|
|
|
def test_config_metadata_8bit(self, tmp_path):
|
|
from mlx_video.convert_wan import _quantize_saved_model
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
|
|
model_path = tmp_path / "model.safetensors"
|
|
mx.save_safetensors(str(model_path), weights_dict)
|
|
with open(tmp_path / "config.json", "w") as f:
|
|
json.dump({}, f)
|
|
|
|
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
|
|
|
|
with open(tmp_path / "config.json") as f:
|
|
cfg = json.load(f)
|
|
assert cfg["quantization"]["bits"] == 8
|
|
assert cfg["quantization"]["group_size"] == 32
|
|
|
|
def test_dual_model_quantization(self, tmp_path):
|
|
"""Verify dual-model quantization writes both model files."""
|
|
from mlx_video.convert_wan import _quantize_saved_model
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
|
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
|
model = WanModel(config)
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
mx.save_safetensors(str(tmp_path / name), weights_dict)
|
|
|
|
with open(tmp_path / "config.json", "w") as f:
|
|
json.dump({}, f)
|
|
|
|
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
|
|
|
|
# Both files should now contain quantized weights (have .scales keys)
|
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
|
weights = mx.load(str(tmp_path / name))
|
|
has_scales = any("scales" in k for k in weights)
|
|
assert has_scales, f"{name} should have quantized layers"
|