diff --git a/README.md b/README.md index 313cd29..6d4fe11 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,29 @@ The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and use | `--seed` | -1 (random) | Random seed for reproducibility | | `--output-path` | `output.mp4` | Output video path | +## LoRA Support +LoRA's can be used with the `--lora-high` and `--lora-low` command line switches. + +For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1. + +```bash +python -m mlx_video.generate_wan \ + --model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \ + --width 480 \ + --height 704 \ + --num-frames 41 \ + --prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \ + --steps 4 \ + --guide-scale 1 \ + --trim-first-frames 1 \ + --seed 2391784614 \ + --lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \ + --lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1 + ``` + +Which results in +![Poodles](examples/poodles-wan.gif) ## Requirements diff --git a/examples/poodles-wan.gif b/examples/poodles-wan.gif new file mode 100644 index 0000000..d4f36d2 Binary files /dev/null and b/examples/poodles-wan.gif differ diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 6df2e77..697ce50 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -75,7 +75,6 @@ def generate_video( trim_first_frames: int = 0, debug_latents: bool = False, ): - """Generate video using Wan pipeline (supports T2V and I2V). Args: @@ -108,7 +107,6 @@ def generate_video( discards first 4). Use 2 for more aggressive trimming. Default: 0. debug_latents: If True, print per-temporal-position latent statistics after denoising for diagnosing first-frame artifacts. - """ import json @@ -494,6 +492,7 @@ def generate_video( print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") t3 = time.time() + # Compile model forward for faster denoising if not no_compile: models_to_compile = ( [high_noise_model, low_noise_model] if is_dual else [single_model] @@ -501,9 +500,6 @@ def generate_video( for m in models_to_compile: m._compiled = mx.compile(m) - - - # Pre-convert timesteps to Python list to avoid .item() sync each step timestep_list = sched.timesteps.tolist() @@ -773,7 +769,6 @@ def main(): "--debug-latents", action="store_true", help="Print per-temporal-position latent statistics after denoising (diagnostic)", ) - args = parser.parse_args() # Parse guide scale @@ -814,7 +809,6 @@ def main(): no_compile=args.no_compile, trim_first_frames=args.trim_first_frames, debug_latents=args.debug_latents, - ) diff --git a/mlx_video/models/wan/README.md b/mlx_video/models/wan/README.md index 369e56f..b5830a4 100644 --- a/mlx_video/models/wan/README.md +++ b/mlx_video/models/wan/README.md @@ -146,12 +146,16 @@ For example, for using the the distilled [Wan2.2-Lightning](https://huggingface. python -m mlx_video.generate_wan \ --model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \ --width 480 \ - --height 480 \ - --num-frames 121 \ - --prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, close up, cinematic, sunset" \ + --height 704 \ + --num-frames 41 \ + --prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \ --steps 4 \ --guide-scale 1 \ --trim-first-frames 1 \ + --seed 2391784614 \ --lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \ --lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1 ``` + +Which results in +![Poodles](../../../examples/poodles-wan.gif) \ No newline at end of file diff --git a/mlx_video/models/wan/config.py b/mlx_video/models/wan/config.py index 5e51f3b..deb0d78 100644 --- a/mlx_video/models/wan/config.py +++ b/mlx_video/models/wan/config.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Tuple, Union from mlx_video.models.ltx.config import BaseModelConfig @@ -104,7 +104,7 @@ class WanModelConfig(BaseModelConfig): sample_shift=5.0, sample_guide_scale=(3.5, 3.5), max_area=704 * 1280, - + ) @classmethod def wan22_ti2v_5b(cls) -> "WanModelConfig": @@ -126,4 +126,4 @@ class WanModelConfig(BaseModelConfig): sample_guide_scale=5.0, sample_fps=24, max_area=704 * 1280, - + ) diff --git a/mlx_video/models/wan/docs/DIAGNOSTICS.md b/mlx_video/models/wan/docs/DIAGNOSTICS.md index 246ee0d..3b6c456 100644 --- a/mlx_video/models/wan/docs/DIAGNOSTICS.md +++ b/mlx_video/models/wan/docs/DIAGNOSTICS.md @@ -315,11 +315,6 @@ Applied alongside bug fixes to improve inference speed: - **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks) - **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync -### TeaCache Integration -- Polynomial rescaling stays in MLX lazy graph (Horner's method) -- Single `.item()` call on the accumulated distance for the skip/compute decision -- Configurable threshold, retention steps, and cutoff steps - --- ## Resolved: CFG Effectiveness (was Open Investigation) diff --git a/mlx_video/models/wan/model.py b/mlx_video/models/wan/model.py index 620c7d2..989e712 100644 --- a/mlx_video/models/wan/model.py +++ b/mlx_video/models/wan/model.py @@ -1,5 +1,4 @@ import math - import mlx.core as mx import mlx.nn as nn import numpy as np @@ -354,7 +353,6 @@ class WanModel(nn.Module): for i, sl in enumerate(seq_lens_list): attn_mask[i, :, :, sl:] = -1e9 - kwargs = dict( e=e0, seq_lens=seq_lens_list, diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index 1843715..067d6c3 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -19,7 +19,6 @@ def _make_tiny_i2v_config(): config.boundary = 0.900 config.sample_shift = 5.0 config.sample_guide_scale = (3.5, 3.5) - config.teacache_coefficients = None return config @@ -41,7 +40,6 @@ class TestI2VConfig: assert config.sample_guide_scale == (3.5, 3.5) assert config.vae_stride == (4, 8, 8) assert config.vae_z_dim == 16 - assert config.teacache_coefficients is None def test_i2v_vs_t2v_differences(self): from mlx_video.models.wan.config import WanModelConfig diff --git a/tests/test_wan_quantization.py b/tests/test_wan_quantization.py new file mode 100644 index 0000000..a219eb7 --- /dev/null +++ b/tests/test_wan_quantization.py @@ -0,0 +1,313 @@ +"""Tests for Wan model quantization pipeline.""" + +import json +import mlx.core as mx +import mlx.nn as nn +import mlx.utils +import numpy as np +import pytest + +from wan_test_helpers import _make_tiny_config + + +# --------------------------------------------------------------------------- +# Quantize Predicate Tests +# --------------------------------------------------------------------------- + +class TestQuantizePredicate: + def test_matches_self_attention_layers(self): + from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) + for suffix in ["q", "k", "v", "o"]: + path = f"blocks.0.self_attn.{suffix}" + assert _quantize_predicate(path, mock_linear), f"Should match {path}" + + def test_matches_cross_attention_layers(self): + from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) + for suffix in ["q", "k", "v", "o"]: + path = f"blocks.0.cross_attn.{suffix}" + assert _quantize_predicate(path, mock_linear), f"Should match {path}" + + def test_matches_ffn_layers(self): + from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) + assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear) + assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear) + + def test_rejects_embeddings(self): + from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) + for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]: + assert not _quantize_predicate(path, mock_linear), f"Should reject {path}" + + def test_rejects_norms(self): + from mlx_video.convert_wan import _quantize_predicate + mock_norm = nn.RMSNorm(64) + assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm) + + def test_rejects_non_quantizable_modules(self): + from mlx_video.convert_wan import _quantize_predicate + mock_norm = nn.RMSNorm(64) + # Even if path matches, module must have to_quantized + assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm) + + def test_all_10_patterns_covered(self): + """Verify exactly 10 layer patterns are targeted.""" + from mlx_video.convert_wan import _quantize_predicate + mock_linear = nn.Linear(64, 64) + patterns = [ + "blocks.0.self_attn.q", "blocks.0.self_attn.k", + "blocks.0.self_attn.v", "blocks.0.self_attn.o", + "blocks.0.cross_attn.q", "blocks.0.cross_attn.k", + "blocks.0.cross_attn.v", "blocks.0.cross_attn.o", + "blocks.0.ffn.fc1", "blocks.0.ffn.fc2", + ] + matched = [p for p in patterns if _quantize_predicate(p, mock_linear)] + assert len(matched) == 10 + + +# --------------------------------------------------------------------------- +# Quantize Round-Trip Tests +# --------------------------------------------------------------------------- + +class TestQuantizeRoundTrip: + def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): + """Helper: create model, quantize, save to tmp_path.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_predicate + + model = WanModel(config) + nn.quantize( + model, + group_size=group_size, + bits=bits, + class_predicate=lambda path, m: _quantize_predicate(path, m), + ) + + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + model_path = tmp_path / "model.safetensors" + mx.save_safetensors(str(model_path), weights_dict) + + # Write config.json + cfg = {"quantization": {"bits": bits, "group_size": group_size}} + with open(tmp_path / "config.json", "w") as f: + json.dump(cfg, f) + + return model_path, weights_dict + + def test_4bit_roundtrip(self, tmp_path): + config = _make_tiny_config() + model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4) + + from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( + model_path, config, + quantization={"bits": 4, "group_size": 64}, + ) + + # Verify quantized layers have scales + has_scales = any("scales" in k for k in saved_weights) + assert has_scales, "Quantized model should have .scales tensors" + + # Verify a self-attention layer is QuantizedLinear + assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear) + assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear) + + def test_8bit_roundtrip(self, tmp_path): + config = _make_tiny_config() + model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8) + + from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( + model_path, config, + quantization={"bits": 8, "group_size": 64}, + ) + + assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear) + assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear) + + def test_non_quantized_layers_remain_linear(self, tmp_path): + config = _make_tiny_config() + model_path, _ = self._quantize_and_save(config, tmp_path, bits=4) + + from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model( + model_path, config, + quantization={"bits": 4, "group_size": 64}, + ) + + # Head should NOT be quantized (it's not in the predicate patterns) + assert not isinstance(loaded.head, nn.QuantizedLinear) + + def test_loading_without_quantization_flag(self, tmp_path): + """Loading a non-quantized model should have standard Linear layers.""" + from mlx_video.models.wan.model import WanModel + + config = _make_tiny_config() + model = WanModel(config) + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + model_path = tmp_path / "model.safetensors" + mx.save_safetensors(str(model_path), weights_dict) + + from mlx_video.models.wan.loading import load_wan_model + loaded = load_wan_model(model_path, config, quantization=None) + + assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear) + assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear) + + +# --------------------------------------------------------------------------- +# Quantized Inference Tests +# --------------------------------------------------------------------------- + +class TestQuantizedInference: + def _make_quantized_model(self, config, bits=4): + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_predicate + + model = WanModel(config) + nn.quantize( + model, + group_size=64, + bits=bits, + class_predicate=lambda path, m: _quantize_predicate(path, m), + ) + mx.eval(model.parameters()) + return model + + def test_forward_pass_4bit(self): + config = _make_tiny_config() + model = self._make_quantized_model(config, bits=4) + + C, F, H, W = config.in_dim, 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + x = [mx.random.normal((C, F, H, W))] + t = mx.array([500.0]) + context = [mx.random.normal((4, config.text_dim))] + + out = model(x, t, context, seq_len) + mx.eval(out[0]) + + assert len(out) == 1 + assert out[0].shape == (C, F, H, W) + + def test_forward_pass_8bit(self): + config = _make_tiny_config() + model = self._make_quantized_model(config, bits=8) + + C, F, H, W = config.in_dim, 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + x = [mx.random.normal((C, F, H, W))] + t = mx.array([500.0]) + context = [mx.random.normal((4, config.text_dim))] + + out = model(x, t, context, seq_len) + mx.eval(out[0]) + + assert len(out) == 1 + assert out[0].shape == (C, F, H, W) + + def test_quantized_output_differs_from_unquantized(self): + """Sanity check: quantization should change the weights.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_predicate + + config = _make_tiny_config() + mx.random.seed(42) + + # Get unquantized weights + model = WanModel(config) + mx.eval(model.parameters()) + orig_weight = np.array(model.blocks[0].self_attn.q.weight) + + # Quantize + nn.quantize( + model, + group_size=64, + bits=4, + class_predicate=lambda path, m: _quantize_predicate(path, m), + ) + mx.eval(model.parameters()) + + # QuantizedLinear stores weight differently (uint32 packed) + assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear) + assert hasattr(model.blocks[0].self_attn.q, "scales") + + +# --------------------------------------------------------------------------- +# Config Metadata Tests +# --------------------------------------------------------------------------- + +class TestQuantizationConfig: + def test_config_metadata_written(self, tmp_path): + """Verify _quantize_saved_model writes quantization metadata to config.json.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_saved_model + + config = _make_tiny_config() + model = WanModel(config) + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + + # Save unquantized model + config + model_path = tmp_path / "model.safetensors" + mx.save_safetensors(str(model_path), weights_dict) + with open(tmp_path / "config.json", "w") as f: + json.dump({"dim": config.dim}, f) + + # Run quantization + _quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64) + + # Verify metadata + with open(tmp_path / "config.json") as f: + cfg = json.load(f) + assert "quantization" in cfg + assert cfg["quantization"]["bits"] == 4 + assert cfg["quantization"]["group_size"] == 64 + + def test_config_metadata_8bit(self, tmp_path): + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_saved_model + + config = _make_tiny_config() + model = WanModel(config) + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + + model_path = tmp_path / "model.safetensors" + mx.save_safetensors(str(model_path), weights_dict) + with open(tmp_path / "config.json", "w") as f: + json.dump({}, f) + + _quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32) + + with open(tmp_path / "config.json") as f: + cfg = json.load(f) + assert cfg["quantization"]["bits"] == 8 + assert cfg["quantization"]["group_size"] == 32 + + def test_dual_model_quantization(self, tmp_path): + """Verify dual-model quantization writes both model files.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.convert_wan import _quantize_saved_model + + config = _make_tiny_config() + + for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]: + model = WanModel(config) + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + mx.save_safetensors(str(tmp_path / name), weights_dict) + + with open(tmp_path / "config.json", "w") as f: + json.dump({}, f) + + _quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64) + + # Both files should now contain quantized weights (have .scales keys) + for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]: + weights = mx.load(str(tmp_path / name)) + has_scales = any("scales" in k for k in weights) + assert has_scales, f"{name} should have quantized layers"