Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -22,7 +22,7 @@ from mlx_video.models.ltx_2.utils import (
|
||||
load_safetensors,
|
||||
save_weights,
|
||||
)
|
||||
from mlx_video.models.wan2 import WanModel, WanModelConfig
|
||||
from mlx_video.models.wan_2 import WanModel, WanModelConfig
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.wan2 import WanModel, WanModelConfig
|
||||
from mlx_video.models.wan_2 import WanModel, WanModelConfig
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
2
mlx_video/models/wan_2/__init__.py
Normal file
2
mlx_video/models/wan_2/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
@@ -247,7 +247,7 @@ def _load_lora_configs(
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
@@ -282,7 +282,7 @@ def load_and_apply_loras(
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
@@ -594,7 +594,7 @@ def _quantize_saved_model(
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
@@ -704,7 +704,7 @@ def quantize_mlx_model(
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
@@ -11,15 +11,15 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan2.utils import (
|
||||
from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan_2.utils import (
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.models.wan2.postprocess import save_video
|
||||
from mlx_video.models.wan_2.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
@@ -121,8 +121,8 @@ def generate_video(
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
@@ -767,7 +767,7 @@ def generate_video(
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
@@ -21,12 +21,12 @@ def load_wan_model(
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
@@ -42,7 +42,7 @@ def load_wan_model(
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.models.wan2.convert import _load_lora_configs
|
||||
from mlx_video.models.wan_2.convert import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
@@ -53,7 +53,7 @@ def load_wan_model(
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.models.wan2.convert import load_and_apply_loras
|
||||
from mlx_video.models.wan_2.convert import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
@@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config):
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
@@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
@@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None):
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
@@ -589,7 +589,7 @@ class WanVAE(nn.Module):
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
@@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module):
|
||||
Returns:
|
||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
Reference in New Issue
Block a user