More poodles
This commit is contained in:
22
README.md
22
README.md
@@ -150,7 +150,29 @@ The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and use
|
|||||||
| `--seed` | -1 (random) | Random seed for reproducibility |
|
| `--seed` | -1 (random) | Random seed for reproducibility |
|
||||||
| `--output-path` | `output.mp4` | Output video path |
|
| `--output-path` | `output.mp4` | Output video path |
|
||||||
|
|
||||||
|
## LoRA Support
|
||||||
|
|
||||||
|
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
|
||||||
|
|
||||||
|
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m mlx_video.generate_wan \
|
||||||
|
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||||
|
--width 480 \
|
||||||
|
--height 704 \
|
||||||
|
--num-frames 41 \
|
||||||
|
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
|
||||||
|
--steps 4 \
|
||||||
|
--guide-scale 1 \
|
||||||
|
--trim-first-frames 1 \
|
||||||
|
--seed 2391784614 \
|
||||||
|
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||||
|
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Which results in
|
||||||
|

|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
|
|||||||
BIN
examples/poodles-wan.gif
Normal file
BIN
examples/poodles-wan.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 MiB |
@@ -75,7 +75,6 @@ def generate_video(
|
|||||||
trim_first_frames: int = 0,
|
trim_first_frames: int = 0,
|
||||||
debug_latents: bool = False,
|
debug_latents: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -108,7 +107,6 @@ def generate_video(
|
|||||||
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
||||||
debug_latents: If True, print per-temporal-position latent statistics
|
debug_latents: If True, print per-temporal-position latent statistics
|
||||||
after denoising for diagnosing first-frame artifacts.
|
after denoising for diagnosing first-frame artifacts.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -494,6 +492,7 @@ def generate_video(
|
|||||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
|
|
||||||
|
# Compile model forward for faster denoising
|
||||||
if not no_compile:
|
if not no_compile:
|
||||||
models_to_compile = (
|
models_to_compile = (
|
||||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||||
@@ -501,9 +500,6 @@ def generate_video(
|
|||||||
for m in models_to_compile:
|
for m in models_to_compile:
|
||||||
m._compiled = mx.compile(m)
|
m._compiled = mx.compile(m)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||||
timestep_list = sched.timesteps.tolist()
|
timestep_list = sched.timesteps.tolist()
|
||||||
|
|
||||||
@@ -773,7 +769,6 @@ def main():
|
|||||||
"--debug-latents", action="store_true",
|
"--debug-latents", action="store_true",
|
||||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Parse guide scale
|
# Parse guide scale
|
||||||
@@ -814,7 +809,6 @@ def main():
|
|||||||
no_compile=args.no_compile,
|
no_compile=args.no_compile,
|
||||||
trim_first_frames=args.trim_first_frames,
|
trim_first_frames=args.trim_first_frames,
|
||||||
debug_latents=args.debug_latents,
|
debug_latents=args.debug_latents,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -146,12 +146,16 @@ For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.
|
|||||||
python -m mlx_video.generate_wan \
|
python -m mlx_video.generate_wan \
|
||||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||||
--width 480 \
|
--width 480 \
|
||||||
--height 480 \
|
--height 704 \
|
||||||
--num-frames 121 \
|
--num-frames 41 \
|
||||||
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, close up, cinematic, sunset" \
|
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
|
||||||
--steps 4 \
|
--steps 4 \
|
||||||
--guide-scale 1 \
|
--guide-scale 1 \
|
||||||
--trim-first-frames 1 \
|
--trim-first-frames 1 \
|
||||||
|
--seed 2391784614 \
|
||||||
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||||
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Which results in
|
||||||
|

|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import BaseModelConfig
|
from mlx_video.models.ltx.config import BaseModelConfig
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
sample_shift=5.0,
|
sample_shift=5.0,
|
||||||
sample_guide_scale=(3.5, 3.5),
|
sample_guide_scale=(3.5, 3.5),
|
||||||
max_area=704 * 1280,
|
max_area=704 * 1280,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||||
@@ -126,4 +126,4 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
sample_guide_scale=5.0,
|
sample_guide_scale=5.0,
|
||||||
sample_fps=24,
|
sample_fps=24,
|
||||||
max_area=704 * 1280,
|
max_area=704 * 1280,
|
||||||
|
)
|
||||||
|
|||||||
@@ -315,11 +315,6 @@ Applied alongside bug fixes to improve inference speed:
|
|||||||
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
|
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
|
||||||
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
|
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
|
||||||
|
|
||||||
### TeaCache Integration
|
|
||||||
- Polynomial rescaling stays in MLX lazy graph (Horner's method)
|
|
||||||
- Single `.item()` call on the accumulated distance for the skip/compute decision
|
|
||||||
- Configurable threshold, retention steps, and cutoff steps
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Resolved: CFG Effectiveness (was Open Investigation)
|
## Resolved: CFG Effectiveness (was Open Investigation)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -354,7 +353,6 @@ class WanModel(nn.Module):
|
|||||||
for i, sl in enumerate(seq_lens_list):
|
for i, sl in enumerate(seq_lens_list):
|
||||||
attn_mask[i, :, :, sl:] = -1e9
|
attn_mask[i, :, :, sl:] = -1e9
|
||||||
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
e=e0,
|
e=e0,
|
||||||
seq_lens=seq_lens_list,
|
seq_lens=seq_lens_list,
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ def _make_tiny_i2v_config():
|
|||||||
config.boundary = 0.900
|
config.boundary = 0.900
|
||||||
config.sample_shift = 5.0
|
config.sample_shift = 5.0
|
||||||
config.sample_guide_scale = (3.5, 3.5)
|
config.sample_guide_scale = (3.5, 3.5)
|
||||||
config.teacache_coefficients = None
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -41,7 +40,6 @@ class TestI2VConfig:
|
|||||||
assert config.sample_guide_scale == (3.5, 3.5)
|
assert config.sample_guide_scale == (3.5, 3.5)
|
||||||
assert config.vae_stride == (4, 8, 8)
|
assert config.vae_stride == (4, 8, 8)
|
||||||
assert config.vae_z_dim == 16
|
assert config.vae_z_dim == 16
|
||||||
assert config.teacache_coefficients is None
|
|
||||||
|
|
||||||
def test_i2v_vs_t2v_differences(self):
|
def test_i2v_vs_t2v_differences(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|||||||
313
tests/test_wan_quantization.py
Normal file
313
tests/test_wan_quantization.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
"""Tests for Wan model quantization pipeline."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.utils
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quantize Predicate Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQuantizePredicate:
|
||||||
|
def test_matches_self_attention_layers(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_linear = nn.Linear(64, 64)
|
||||||
|
for suffix in ["q", "k", "v", "o"]:
|
||||||
|
path = f"blocks.0.self_attn.{suffix}"
|
||||||
|
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||||
|
|
||||||
|
def test_matches_cross_attention_layers(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_linear = nn.Linear(64, 64)
|
||||||
|
for suffix in ["q", "k", "v", "o"]:
|
||||||
|
path = f"blocks.0.cross_attn.{suffix}"
|
||||||
|
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||||
|
|
||||||
|
def test_matches_ffn_layers(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_linear = nn.Linear(64, 64)
|
||||||
|
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||||
|
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||||
|
|
||||||
|
def test_rejects_embeddings(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_linear = nn.Linear(64, 64)
|
||||||
|
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
|
||||||
|
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||||
|
|
||||||
|
def test_rejects_norms(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_norm = nn.RMSNorm(64)
|
||||||
|
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||||
|
|
||||||
|
def test_rejects_non_quantizable_modules(self):
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_norm = nn.RMSNorm(64)
|
||||||
|
# Even if path matches, module must have to_quantized
|
||||||
|
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
||||||
|
|
||||||
|
def test_all_10_patterns_covered(self):
|
||||||
|
"""Verify exactly 10 layer patterns are targeted."""
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
mock_linear = nn.Linear(64, 64)
|
||||||
|
patterns = [
|
||||||
|
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
|
||||||
|
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
|
||||||
|
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
|
||||||
|
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
|
||||||
|
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
|
||||||
|
]
|
||||||
|
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
||||||
|
assert len(matched) == 10
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quantize Round-Trip Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQuantizeRoundTrip:
|
||||||
|
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||||
|
"""Helper: create model, quantize, save to tmp_path."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
|
||||||
|
# Write config.json
|
||||||
|
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
|
||||||
|
with open(tmp_path / "config.json", "w") as f:
|
||||||
|
json.dump(cfg, f)
|
||||||
|
|
||||||
|
return model_path, weights_dict
|
||||||
|
|
||||||
|
def test_4bit_roundtrip(self, tmp_path):
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||||
|
|
||||||
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
loaded = load_wan_model(
|
||||||
|
model_path, config,
|
||||||
|
quantization={"bits": 4, "group_size": 64},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify quantized layers have scales
|
||||||
|
has_scales = any("scales" in k for k in saved_weights)
|
||||||
|
assert has_scales, "Quantized model should have .scales tensors"
|
||||||
|
|
||||||
|
# Verify a self-attention layer is QuantizedLinear
|
||||||
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||||
|
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
|
||||||
|
|
||||||
|
def test_8bit_roundtrip(self, tmp_path):
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||||
|
|
||||||
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
loaded = load_wan_model(
|
||||||
|
model_path, config,
|
||||||
|
quantization={"bits": 8, "group_size": 64},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||||
|
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
|
||||||
|
|
||||||
|
def test_non_quantized_layers_remain_linear(self, tmp_path):
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||||
|
|
||||||
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
loaded = load_wan_model(
|
||||||
|
model_path, config,
|
||||||
|
quantization={"bits": 4, "group_size": 64},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Head should NOT be quantized (it's not in the predicate patterns)
|
||||||
|
assert not isinstance(loaded.head, nn.QuantizedLinear)
|
||||||
|
|
||||||
|
def test_loading_without_quantization_flag(self, tmp_path):
|
||||||
|
"""Loading a non-quantized model should have standard Linear layers."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
|
||||||
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
loaded = load_wan_model(model_path, config, quantization=None)
|
||||||
|
|
||||||
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
||||||
|
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quantized Inference Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQuantizedInference:
|
||||||
|
def _make_quantized_model(self, config, bits=4):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=64,
|
||||||
|
bits=bits,
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_forward_pass_4bit(self):
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = self._make_quantized_model(config, bits=4)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
x = [mx.random.normal((C, F, H, W))]
|
||||||
|
t = mx.array([500.0])
|
||||||
|
context = [mx.random.normal((4, config.text_dim))]
|
||||||
|
|
||||||
|
out = model(x, t, context, seq_len)
|
||||||
|
mx.eval(out[0])
|
||||||
|
|
||||||
|
assert len(out) == 1
|
||||||
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_forward_pass_8bit(self):
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = self._make_quantized_model(config, bits=8)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
x = [mx.random.normal((C, F, H, W))]
|
||||||
|
t = mx.array([500.0])
|
||||||
|
context = [mx.random.normal((4, config.text_dim))]
|
||||||
|
|
||||||
|
out = model(x, t, context, seq_len)
|
||||||
|
mx.eval(out[0])
|
||||||
|
|
||||||
|
assert len(out) == 1
|
||||||
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_quantized_output_differs_from_unquantized(self):
|
||||||
|
"""Sanity check: quantization should change the weights."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
|
config = _make_tiny_config()
|
||||||
|
mx.random.seed(42)
|
||||||
|
|
||||||
|
# Get unquantized weights
|
||||||
|
model = WanModel(config)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=64,
|
||||||
|
bits=4,
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# QuantizedLinear stores weight differently (uint32 packed)
|
||||||
|
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||||
|
assert hasattr(model.blocks[0].self_attn.q, "scales")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config Metadata Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQuantizationConfig:
|
||||||
|
def test_config_metadata_written(self, tmp_path):
|
||||||
|
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
# Save unquantized model + config
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
with open(tmp_path / "config.json", "w") as f:
|
||||||
|
json.dump({"dim": config.dim}, f)
|
||||||
|
|
||||||
|
# Run quantization
|
||||||
|
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
with open(tmp_path / "config.json") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
assert "quantization" in cfg
|
||||||
|
assert cfg["quantization"]["bits"] == 4
|
||||||
|
assert cfg["quantization"]["group_size"] == 64
|
||||||
|
|
||||||
|
def test_config_metadata_8bit(self, tmp_path):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
with open(tmp_path / "config.json", "w") as f:
|
||||||
|
json.dump({}, f)
|
||||||
|
|
||||||
|
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
|
||||||
|
|
||||||
|
with open(tmp_path / "config.json") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
assert cfg["quantization"]["bits"] == 8
|
||||||
|
assert cfg["quantization"]["group_size"] == 32
|
||||||
|
|
||||||
|
def test_dual_model_quantization(self, tmp_path):
|
||||||
|
"""Verify dual-model quantization writes both model files."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
|
||||||
|
config = _make_tiny_config()
|
||||||
|
|
||||||
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||||
|
model = WanModel(config)
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
mx.save_safetensors(str(tmp_path / name), weights_dict)
|
||||||
|
|
||||||
|
with open(tmp_path / "config.json", "w") as f:
|
||||||
|
json.dump({}, f)
|
||||||
|
|
||||||
|
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
|
||||||
|
|
||||||
|
# Both files should now contain quantized weights (have .scales keys)
|
||||||
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||||
|
weights = mx.load(str(tmp_path / name))
|
||||||
|
has_scales = any("scales" in k for k in weights)
|
||||||
|
assert has_scales, f"{name} should have quantized layers"
|
||||||
Reference in New Issue
Block a user