feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -53,7 +53,9 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
# Causal temporal padding: always kernel_size-1 on the left.
|
||||
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
|
||||
self._causal_pad_t = kernel_size[0] - 1
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
@@ -250,6 +252,46 @@ class DupUp3D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
"""Downsample by grouping channels across spatial/temporal factors and averaging.
|
||||
|
||||
Inverse of DupUp3D. No learnable parameters.
|
||||
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = factor_t * factor_s * factor_s
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
# Pad temporal if not divisible by factor_t
|
||||
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
|
||||
if pad_t > 0:
|
||||
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
|
||||
T = T + pad_t
|
||||
|
||||
ft, fs = self.factor_t, self.factor_s
|
||||
# Reshape to split spatial/temporal dims
|
||||
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
|
||||
# Move factors next to channels
|
||||
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
|
||||
# Expand channels
|
||||
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
|
||||
# Group and average
|
||||
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
|
||||
x = x.mean(axis=-1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
||||
|
||||
@@ -267,6 +309,15 @@ class Resample(nn.Module):
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
elif mode == "downsample2d":
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
elif mode == "downsample3d":
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
@@ -283,6 +334,12 @@ class Resample(nn.Module):
|
||||
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
||||
|
||||
def _downsample_conv2d(self, x):
|
||||
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
@@ -320,20 +377,37 @@ class Resample(nn.Module):
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
t_end = min(t_start + chunk_size, T)
|
||||
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||
x_chunk = self._upsample2x(x_chunk)
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
if self.mode == "downsample3d" and T > 1:
|
||||
# Temporal downsample via strided CausalConv3d
|
||||
# Skip for T=1 (single frame) — matches official chunked encoding
|
||||
# where first chunk stores cache but doesn't apply time_conv
|
||||
x = self.time_conv(x)
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
if self.mode in ("upsample2d", "upsample3d"):
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
t_end = min(t_start + chunk_size, T)
|
||||
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||
x_chunk = self._upsample2x(x_chunk)
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
elif self.mode in ("downsample2d", "downsample3d"):
|
||||
# Spatial downsample: per-frame strided Conv2d
|
||||
x_flat = x.reshape(B * T, H, W, C)
|
||||
x_flat = self._downsample_conv2d(x_flat)
|
||||
mx.eval(x_flat)
|
||||
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
||||
x = x_flat.reshape(B, T, H2, W2, C)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
return x
|
||||
|
||||
|
||||
@@ -383,6 +457,44 @@ class Up_ResidualBlock(nn.Module):
|
||||
return x_main
|
||||
|
||||
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||
super().__init__()
|
||||
self.down_flag = down_flag
|
||||
|
||||
# AvgDown3D shortcut (no learnable params, always present)
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim, out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path: ResidualBlocks + optional Resample
|
||||
blocks = []
|
||||
dim_in = in_dim
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(ResidualBlock(dim_in, out_dim))
|
||||
dim_in = out_dim
|
||||
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
blocks.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.downsamples = blocks
|
||||
|
||||
def __call__(self, x):
|
||||
x_shortcut = self.avg_shortcut(x)
|
||||
mx.eval(x_shortcut)
|
||||
|
||||
for module in self.downsamples:
|
||||
x = module(x)
|
||||
mx.eval(x)
|
||||
|
||||
return x + x_shortcut
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""Wan2.2 3D VAE Decoder."""
|
||||
|
||||
@@ -439,6 +551,63 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=160,
|
||||
z_dim=96,
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_downsample=(False, True, True),
|
||||
):
|
||||
super().__init__()
|
||||
# Channel dimensions: [160, 160, 320, 640, 640]
|
||||
dims = [dim * m for m in [1] + list(dim_mult)]
|
||||
|
||||
# Initial conv: patchified input (12 ch) → first dim
|
||||
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||
|
||||
# Downsample blocks
|
||||
self.downsamples = []
|
||||
for i in range(len(dim_mult)):
|
||||
in_d, out_d = dims[i], dims[i + 1]
|
||||
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
self.downsamples.append(Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
))
|
||||
|
||||
# Middle blocks (same as decoder)
|
||||
out_dim = dims[-1]
|
||||
self.middle = [
|
||||
ResidualBlock(out_dim, out_dim),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim),
|
||||
]
|
||||
|
||||
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
||||
self.head = Head22(out_dim, out_channels=z_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, 12] (patchified)
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
mx.eval(x)
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class Head22(nn.Module):
|
||||
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
||||
|
||||
@@ -460,6 +629,46 @@ class Head22(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Wan22VAEEncoder(nn.Module):
|
||||
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
||||
|
||||
def __init__(self, z_dim=48, dim=160):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.encoder = Encoder3d(
|
||||
dim=dim,
|
||||
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_downsample=(False, True, True),
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
"""Encode image/video to latent space.
|
||||
|
||||
Args:
|
||||
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||
|
||||
Returns:
|
||||
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||
"""
|
||||
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
|
||||
x = _patchify(img, patch_size=2)
|
||||
|
||||
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
|
||||
out = self.encoder(x)
|
||||
|
||||
# conv1 (pointwise) + split into mu, log_var
|
||||
out = self.conv1(out)
|
||||
mu = out[:, :, :, :, :self.z_dim]
|
||||
|
||||
# Normalize
|
||||
mu = normalize_latents(mu)
|
||||
return mu
|
||||
|
||||
|
||||
class Wan22VAEDecoder(nn.Module):
|
||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||
|
||||
@@ -507,6 +716,15 @@ def denormalize_latents(z, mean=None, std=None):
|
||||
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
||||
|
||||
|
||||
def normalize_latents(z, mean=None, std=None):
|
||||
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
|
||||
if mean is None:
|
||||
mean = VAE22_MEAN
|
||||
if std is None:
|
||||
std = VAE22_STD
|
||||
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
|
||||
|
||||
|
||||
def _unpatchify(x, patch_size=2):
|
||||
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
||||
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
||||
@@ -527,10 +745,30 @@ def _unpatchify(x, patch_size=2):
|
||||
return x
|
||||
|
||||
|
||||
def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||
def _patchify(x, patch_size=2):
|
||||
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
|
||||
Inverse of _unpatchify.
|
||||
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
|
||||
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
|
||||
"""
|
||||
if patch_size == 1:
|
||||
return x
|
||||
B, T, Hfull, Wfull, C = x.shape
|
||||
H = Hfull // patch_size
|
||||
W = Wfull // patch_size
|
||||
# [B, T, H, q, W, r, C]
|
||||
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
|
||||
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
|
||||
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
|
||||
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
|
||||
return x
|
||||
|
||||
|
||||
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
|
||||
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
||||
|
||||
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
|
||||
By default keeps decoder + conv2 weights only. Set include_encoder=True
|
||||
to also keep encoder + conv1 weights (needed for I2V encoding).
|
||||
Transposes conv weights from channels-first to channels-last.
|
||||
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
||||
Maps PyTorch nn.Sequential indices to our named layers.
|
||||
@@ -538,9 +776,10 @@ def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip encoder and conv1 (encoder-only)
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
continue
|
||||
# Skip encoder and conv1 unless requested
|
||||
if not include_encoder:
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
|
||||
|
||||
Reference in New Issue
Block a user