Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -11,15 +11,15 @@ import mlx.core as mx
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan2.utils import (
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan.postprocess import save_video
from mlx_video.models.wan2.postprocess import save_video
class Colors:
@@ -121,8 +121,8 @@ def generate_video(
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -729,7 +729,7 @@ def generate_video(
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
# This gives the first real frame a full temporal receptive field of real data.
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
@@ -767,7 +767,7 @@ def generate_video(
)
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]