Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -0,0 +1,60 @@
"""Image-to-Video utility functions for Wan2.2."""
import mlx.core as mx
import numpy as np
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
"""Load, resize, center-crop, and normalize an image for I2V.
Args:
image_path: Path to input image
width: Target width
height: Target height
Returns:
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
"""
from PIL import Image
img = Image.open(image_path).convert("RGB")
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
# Center crop
x1 = (img.width - width) // 2
y1 = (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
# To tensor: [H, W, 3] float32 in [-1, 1]
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
def build_i2v_mask(z_shape, patch_size):
"""Build temporal mask for I2V: first frame = 0, rest = 1.
Args:
z_shape: Latent shape (C, T, H, W) in channels-first
patch_size: (pt, ph, pw) patch size
Returns:
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
"""
C, T, H, W = z_shape
mask = mx.ones(z_shape)
# Zero out the first temporal position
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
# Token-level mask for per-token timesteps: subsample to patch grid
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
pt, ph, pw = patch_size
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
return mask, mask_tokens