Refactor and remove Wan2.1/2.2 model files; update README.md to include new model features and usage instructions for LTX-2 and Wan2 models.
This commit is contained in:
58
mlx_video/models/wan2/i2v_utils.py
Normal file
58
mlx_video/models/wan2/i2v_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Image-to-Video utility functions for Wan2.2."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||
"""Load, resize, center-crop, and normalize an image for I2V.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
width: Target width
|
||||
height: Target height
|
||||
|
||||
Returns:
|
||||
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Resize so that the image covers the target size (LANCZOS)
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
|
||||
# Center crop
|
||||
x1 = (img.width - width) // 2
|
||||
y1 = (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
|
||||
# To tensor: [H, W, 3] float32 in [-1, 1]
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
|
||||
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
|
||||
|
||||
|
||||
def build_i2v_mask(z_shape, patch_size):
|
||||
"""Build temporal mask for I2V: first frame = 0, rest = 1.
|
||||
|
||||
Args:
|
||||
z_shape: Latent shape (C, T, H, W) in channels-first
|
||||
patch_size: (pt, ph, pw) patch size
|
||||
|
||||
Returns:
|
||||
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
|
||||
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
|
||||
"""
|
||||
C, T, H, W = z_shape
|
||||
mask = mx.ones(z_shape)
|
||||
# Zero out the first temporal position
|
||||
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
|
||||
|
||||
# Token-level mask for per-token timesteps: subsample to patch grid
|
||||
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
|
||||
pt, ph, pw = patch_size
|
||||
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
|
||||
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
|
||||
return mask, mask_tokens
|
||||
Reference in New Issue
Block a user