Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -22,7 +22,7 @@ from mlx_video.models.ltx_2.utils import (
load_safetensors,
save_weights,
)
from mlx_video.models.wan2 import WanModel, WanModelConfig
from mlx_video.models.wan_2 import WanModel, WanModelConfig
__all__ = [
# Models

View File

@@ -1,2 +1,2 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan2 import WanModel, WanModelConfig
from mlx_video.models.wan_2 import WanModel, WanModelConfig

View File

@@ -1,2 +0,0 @@
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.wan2 import WanModel

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.wan_2 import WanModel

View File

@@ -247,7 +247,7 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.models.wan2.generate import Colors
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -282,7 +282,7 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.models.wan2.generate import Colors
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs:
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(
@@ -594,7 +594,7 @@ def _quantize_saved_model(
import mlx.nn as nn
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
if source_dir is None:
source_dir = output_dir
@@ -704,7 +704,7 @@ def quantize_mlx_model(
).exists()
# Build model config
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__

View File

@@ -11,15 +11,15 @@ import mlx.core as mx
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan2.utils import (
from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan_2.utils import (
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan2.postprocess import save_video
from mlx_video.models.wan_2.postprocess import save_video
class Colors:
@@ -121,8 +121,8 @@ def generate_video(
"""
import json
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -767,7 +767,7 @@ def generate_video(
)
if is_wan22_vae:
from mlx_video.models.wan2.vae22 import denormalize_latents
from mlx_video.models.wan_2.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]

View File

@@ -21,12 +21,12 @@ def load_wan_model(
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
if quantization:
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
nn.quantize(
model,
@@ -42,7 +42,7 @@ def load_wan_model(
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.models.wan2.convert import _load_lora_configs
from mlx_video.models.wan_2.convert import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
@@ -53,7 +53,7 @@ def load_wan_model(
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.models.wan2.convert import load_and_apply_loras
from mlx_video.models.wan_2.convert import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
@@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config):
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
@@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None):
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16)
@@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None):
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)

View File

@@ -589,7 +589,7 @@ class WanVAE(nn.Module):
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()

View File

@@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module):
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()