feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -43,7 +43,9 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
|
||||
# With dilation=1: k-stride (pads left only, no future context)
|
||||
self._causal_pad_t = kernel_size[0] - stride[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
@@ -51,12 +53,17 @@ class CausalConv3d(nn.Module):
|
||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
"""x: [B, C, T, H, W] (channel-first)"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
|
||||
causal_pad = self._causal_pad_t
|
||||
if cache_x is not None and causal_pad > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=2)
|
||||
causal_pad = max(0, causal_pad - cache_x.shape[2])
|
||||
|
||||
if causal_pad > 0:
|
||||
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
@@ -136,12 +143,35 @@ class ResidualBlock(nn.Module):
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
h = x if self.shortcut is None else self.shortcut(x)
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# First conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
# Second conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[3](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[6](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
@@ -180,23 +210,31 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Upsample block matching original Wan VAE structure.
|
||||
"""Resample block matching original Wan VAE structure.
|
||||
|
||||
Uses `resample` list with [None, Conv2d] to match original
|
||||
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
|
||||
Supports both upsampling (decoder) and downsampling (encoder).
|
||||
Uses list-based param storage to match original nn.Sequential key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str):
|
||||
super().__init__()
|
||||
assert mode in ("upsample2d", "upsample3d")
|
||||
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
||||
self.mode = mode
|
||||
self.dim = dim
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if mode.startswith("upsample"):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
@@ -204,17 +242,43 @@ class Resample(nn.Module):
|
||||
# Temporal upsample via learned conv
|
||||
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||
# Interleave along time: [B, C, 2T, H, W]
|
||||
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||
t = t * 2
|
||||
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
if self.mode.startswith("upsample"):
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
else:
|
||||
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
|
||||
x = self.resample[1](x) # Conv2d stride=2
|
||||
c_out = x.shape[-1]
|
||||
h_out, w_out = x.shape[1], x.shape[2]
|
||||
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
# First chunk: save x, skip time_conv
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(
|
||||
x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.time_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -284,10 +348,108 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization."""
|
||||
class Encoder3d(nn.Module):
|
||||
"""3D VAE Encoder matching Wan2.1 architecture.
|
||||
|
||||
def __init__(self, z_dim: int = 16):
|
||||
Mirror of Decoder3d with downsampling instead of upsampling.
|
||||
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_downsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_downsample is None:
|
||||
temporal_downsample = [False, True, True]
|
||||
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# Flat downsample list matching original nn.Sequential indexing
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
self.downsamples = downsamples
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
]
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False),
|
||||
None, # SiLU
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
|
||||
if feat_cache is not None:
|
||||
# conv1 with caching
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None and isinstance(layer, ResidualBlock):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# Head: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.head[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization.
|
||||
|
||||
Supports both encode (for I2V) and decode (for all models).
|
||||
"""
|
||||
|
||||
def __init__(self, z_dim: int = 16, encoder: bool = False):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.mean = mx.array(VAE_MEAN)
|
||||
@@ -297,6 +459,65 @@ class WanVAE(nn.Module):
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||
|
||||
if encoder:
|
||||
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
|
||||
def encode(self, x: mx.array) -> mx.array:
|
||||
"""Encode video to normalized latent using chunked encoding.
|
||||
|
||||
Uses chunked encoding with temporal caching to match reference behavior.
|
||||
First frame encoded alone, then 4-frame chunks with cached context.
|
||||
|
||||
Args:
|
||||
x: Video [B, 3, T, H, W] in [-1, 1]
|
||||
|
||||
Returns:
|
||||
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
|
||||
"""
|
||||
# Count cacheable CausalConv3d slots in encoder
|
||||
num_slots = self._count_encoder_cache_slots()
|
||||
feat_cache = [None] * num_slots
|
||||
|
||||
t = x.shape[2]
|
||||
num_chunks = 1 + (t - 1) // 4
|
||||
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
feat_idx = [0]
|
||||
if i == 0:
|
||||
chunk = x[:, :, :1]
|
||||
else:
|
||||
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
else:
|
||||
out = mx.concatenate([out, chunk_out], axis=2)
|
||||
|
||||
mu, _ = mx.split(self.conv1(out), 2, axis=1)
|
||||
|
||||
# Normalize: (mu - mean) * inv_std
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
return (mu - mean) * inv_std
|
||||
|
||||
def _count_encoder_cache_slots(self) -> int:
|
||||
"""Count CausalConv3d that participate in chunked encoding cache."""
|
||||
count = 1 # encoder.conv1
|
||||
for layer in self.encoder.downsamples:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2 # two convs in residual path
|
||||
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
|
||||
count += 1 # time_conv
|
||||
for layer in self.encoder.middle:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2
|
||||
count += 1 # encoder.head CausalConv3d
|
||||
return count
|
||||
|
||||
def decode(self, z: mx.array) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user