More poodles

This commit is contained in:
Daniel
2026-03-11 08:14:12 +01:00
parent d207275fea
commit 1cf878f5e0
9 changed files with 347 additions and 23 deletions

View File

@@ -150,7 +150,29 @@ The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and use
| `--seed` | -1 (random) | Random seed for reproducibility | | `--seed` | -1 (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video path | | `--output-path` | `output.mp4` | Output video path |
## LoRA Support
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
Which results in
![Poodles](examples/poodles-wan.gif)
## Requirements ## Requirements

BIN
examples/poodles-wan.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.4 MiB

View File

@@ -75,7 +75,6 @@ def generate_video(
trim_first_frames: int = 0, trim_first_frames: int = 0,
debug_latents: bool = False, debug_latents: bool = False,
): ):
"""Generate video using Wan pipeline (supports T2V and I2V). """Generate video using Wan pipeline (supports T2V and I2V).
Args: Args:
@@ -108,7 +107,6 @@ def generate_video(
discards first 4). Use 2 for more aggressive trimming. Default: 0. discards first 4). Use 2 for more aggressive trimming. Default: 0.
debug_latents: If True, print per-temporal-position latent statistics debug_latents: If True, print per-temporal-position latent statistics
after denoising for diagnosing first-frame artifacts. after denoising for diagnosing first-frame artifacts.
""" """
import json import json
@@ -494,6 +492,7 @@ def generate_video(
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time() t3 = time.time()
# Compile model forward for faster denoising
if not no_compile: if not no_compile:
models_to_compile = ( models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model] [high_noise_model, low_noise_model] if is_dual else [single_model]
@@ -501,9 +500,6 @@ def generate_video(
for m in models_to_compile: for m in models_to_compile:
m._compiled = mx.compile(m) m._compiled = mx.compile(m)
# Pre-convert timesteps to Python list to avoid .item() sync each step # Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist() timestep_list = sched.timesteps.tolist()
@@ -773,7 +769,6 @@ def main():
"--debug-latents", action="store_true", "--debug-latents", action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)", help="Print per-temporal-position latent statistics after denoising (diagnostic)",
) )
args = parser.parse_args() args = parser.parse_args()
# Parse guide scale # Parse guide scale
@@ -814,7 +809,6 @@ def main():
no_compile=args.no_compile, no_compile=args.no_compile,
trim_first_frames=args.trim_first_frames, trim_first_frames=args.trim_first_frames,
debug_latents=args.debug_latents, debug_latents=args.debug_latents,
) )

View File

@@ -146,12 +146,16 @@ For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.
python -m mlx_video.generate_wan \ python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \ --model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \ --width 480 \
--height 480 \ --height 704 \
--num-frames 121 \ --num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, close up, cinematic, sunset" \ --prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \ --steps 4 \
--guide-scale 1 \ --guide-scale 1 \
--trim-first-frames 1 \ --trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \ --lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1 --lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
``` ```
Which results in
![Poodles](../../../examples/poodles-wan.gif)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig from mlx_video.models.ltx.config import BaseModelConfig
@@ -104,7 +104,7 @@ class WanModelConfig(BaseModelConfig):
sample_shift=5.0, sample_shift=5.0,
sample_guide_scale=(3.5, 3.5), sample_guide_scale=(3.5, 3.5),
max_area=704 * 1280, max_area=704 * 1280,
)
@classmethod @classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig": def wan22_ti2v_5b(cls) -> "WanModelConfig":
@@ -126,4 +126,4 @@ class WanModelConfig(BaseModelConfig):
sample_guide_scale=5.0, sample_guide_scale=5.0,
sample_fps=24, sample_fps=24,
max_area=704 * 1280, max_area=704 * 1280,
)

View File

@@ -315,11 +315,6 @@ Applied alongside bug fixes to improve inference speed:
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks) - **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync - **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
### TeaCache Integration
- Polynomial rescaling stays in MLX lazy graph (Horner's method)
- Single `.item()` call on the accumulated distance for the skip/compute decision
- Configurable threshold, retention steps, and cutoff steps
--- ---
## Resolved: CFG Effectiveness (was Open Investigation) ## Resolved: CFG Effectiveness (was Open Investigation)

View File

@@ -1,5 +1,4 @@
import math import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
@@ -354,7 +353,6 @@ class WanModel(nn.Module):
for i, sl in enumerate(seq_lens_list): for i, sl in enumerate(seq_lens_list):
attn_mask[i, :, :, sl:] = -1e9 attn_mask[i, :, :, sl:] = -1e9
kwargs = dict( kwargs = dict(
e=e0, e=e0,
seq_lens=seq_lens_list, seq_lens=seq_lens_list,

View File

@@ -19,7 +19,6 @@ def _make_tiny_i2v_config():
config.boundary = 0.900 config.boundary = 0.900
config.sample_shift = 5.0 config.sample_shift = 5.0
config.sample_guide_scale = (3.5, 3.5) config.sample_guide_scale = (3.5, 3.5)
config.teacache_coefficients = None
return config return config
@@ -41,7 +40,6 @@ class TestI2VConfig:
assert config.sample_guide_scale == (3.5, 3.5) assert config.sample_guide_scale == (3.5, 3.5)
assert config.vae_stride == (4, 8, 8) assert config.vae_stride == (4, 8, 8)
assert config.vae_z_dim == 16 assert config.vae_z_dim == 16
assert config.teacache_coefficients is None
def test_i2v_vs_t2v_differences(self): def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig

View File

@@ -0,0 +1,313 @@
"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
# ---------------------------------------------------------------------------
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
# Write config.json
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
with open(tmp_path / "config.json", "w") as f:
json.dump(cfg, f)
return model_path, weights_dict
def test_4bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Verify quantized layers have scales
has_scales = any("scales" in k for k in saved_weights)
assert has_scales, "Quantized model should have .scales tensors"
# Verify a self-attention layer is QuantizedLinear
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
def test_8bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 8, "group_size": 64},
)
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
def test_non_quantized_layers_remain_linear(self, tmp_path):
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Head should NOT be quantized (it's not in the predicate patterns)
assert not isinstance(loaded.head, nn.QuantizedLinear)
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
# ---------------------------------------------------------------------------
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=64,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
return model
def test_forward_pass_4bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=4)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_forward_pass_8bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=8)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
config = _make_tiny_config()
mx.random.seed(42)
# Get unquantized weights
model = WanModel(config)
mx.eval(model.parameters())
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
# Quantize
nn.quantize(
model,
group_size=64,
bits=4,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
# QuantizedLinear stores weight differently (uint32 packed)
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
assert hasattr(model.blocks[0].self_attn.q, "scales")
# ---------------------------------------------------------------------------
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Save unquantized model + config
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({"dim": config.dim}, f)
# Run quantization
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
# Verify metadata
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert "quantization" in cfg
assert cfg["quantization"]["bits"] == 4
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert cfg["quantization"]["bits"] == 8
assert cfg["quantization"]["group_size"] == 32
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(tmp_path / name), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
# Both files should now contain quantized weights (have .scales keys)
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
weights = mx.load(str(tmp_path / name))
has_scales = any("scales" in k for k in weights)
assert has_scales, f"{name} should have quantized layers"