Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -247,7 +247,7 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.generate_wan import Colors
from mlx_video.models.wan2.generate import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -282,7 +282,7 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.generate_wan import Colors
from mlx_video.models.wan2.generate import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs:
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(
@@ -594,7 +594,7 @@ def _quantize_saved_model(
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
if source_dir is None:
source_dir = output_dir
@@ -704,7 +704,7 @@ def quantize_mlx_model(
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__