feat(wan): Add chunked VAE encoding and TI2V-5B support

This commit is contained in:
Daniel
2026-03-09 20:47:37 +01:00
parent 967218b7c1
commit 061ae4407c
3 changed files with 223 additions and 64 deletions

View File

@@ -45,7 +45,8 @@ class WanModelConfig(BaseModelConfig):
"杂乱的背景,三条腿,背景人很多,倒着走"
)
# T5
# Resolution constraints
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
t5_vocab_size: int = 256384
t5_dim: int = 4096
t5_dim_attn: int = 4096
@@ -102,7 +103,8 @@ class WanModelConfig(BaseModelConfig):
boundary=0.900,
sample_shift=5.0,
sample_guide_scale=(3.5, 3.5),
)
max_area=704 * 1280,
@classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig":
@@ -120,7 +122,8 @@ class WanModelConfig(BaseModelConfig):
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_steps=40,
sample_guide_scale=5.0,
sample_fps=24,
)
max_area=704 * 1280,

View File

@@ -56,9 +56,11 @@ class CausalConv3d(nn.Module):
self.kernel_size = kernel_size
self.stride = stride
# Causal temporal padding: always kernel_size-1 on the left.
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
self._causal_pad_t = kernel_size[0] - 1
# Causal temporal padding: matches the reference CausalConv3d(nn.Conv3d)
# which converts symmetric padding to causal: 2*padding[0] on the left.
# For most convs (kernel=3, padding=1): 2*1 = 2 (same as kernel-1).
# For downsample time_conv (kernel=3, padding=0): 2*0 = 0 (NO padding).
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
@@ -68,7 +70,7 @@ class CausalConv3d(nn.Module):
))
self.bias = mx.zeros((out_channels,))
def __call__(self, x):
def __call__(self, x, cache_x=None):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
kd, kh, kw = self.kernel_size
@@ -81,10 +83,15 @@ class CausalConv3d(nn.Module):
y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only) — zeros match the reference
# implementation and what the model was trained with.
if self._causal_pad_t > 0:
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
# Causal temporal padding: prepend cached frames if available,
# then zero-pad any remaining positions.
pad_needed = self._causal_pad_t
if cache_x is not None and pad_needed > 0:
x = mx.concatenate([cache_x, x], axis=1)
pad_needed -= cache_x.shape[1]
if pad_needed > 0:
pad_t = mx.zeros((B, pad_needed, H, W, C), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=1)
# Spatial padding
@@ -144,9 +151,9 @@ class ResidualBlock(nn.Module):
else None
)
def __call__(self, x):
def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x
return self.residual(x) + h
return self.residual(x, feat_cache, feat_idx) + h
class ResidualBlockLayers(nn.Module):
@@ -169,14 +176,34 @@ class ResidualBlockLayers(nn.Module):
# Index 6: CausalConv3d
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
def __call__(self, x):
def _conv_with_cache(self, conv, x, feat_cache, feat_idx):
"""Apply CausalConv3d with temporal caching for chunked encoding."""
idx = feat_idx[0]
# Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
return out
def __call__(self, x, feat_cache=None, feat_idx=None):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
if feat_cache is not None:
x = self._conv_with_cache(self.layer_2, x, feat_cache, feat_idx)
else:
x = self.layer_2(x)
mx.eval(x) # Eval between convolutions to limit graph size
x = self.layer_3(x)
x = nn.silu(x)
x = self.layer_6(x)
if feat_cache is not None:
x = self._conv_with_cache(self.layer_6, x, feat_cache, feat_idx)
else:
x = self.layer_6(x)
return x
@@ -344,53 +371,34 @@ class Resample(nn.Module):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
def __call__(self, x, first_chunk=False):
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# --- Temporal upsample (before spatial, matching reference) ---
if self.mode == "upsample3d":
if first_chunk and T > 1:
# Match official chunked behavior: the first frame bypasses
# time_conv entirely (only spatial upsample). Remaining frames
# go through time_conv with causal zero-padding, which
# naturally gives each frame the same limited temporal context
# as the official frame-by-frame decode with caching.
first_frame = x[:, 0:1] # [B, 1, H, W, C]
rest = x[:, 1:] # [B, T-1, H, W, C]
# time_conv on remaining frames (causal pad gives zero context
# before rest[0], matching the official "Rep" cache path)
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
first_frame = x[:, 0:1]
rest = x[:, 1:]
tc_out = self.time_conv(rest)
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
interleaved = mx.stack([stream0, stream1], axis=2)
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
x = mx.concatenate([first_frame, interleaved], axis=1)
elif self.mode == "upsample3d":
# Non-first-chunk or single frame: time_conv all frames
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
else:
tc_out = self.time_conv(x)
tc_out = tc_out.reshape(B, T, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
x = mx.stack([stream0, stream1], axis=2)
x = x.reshape(B, T * 2, H, W, C)
mx.eval(x)
T = x.shape[1]
if self.mode == "downsample3d" and T > 1:
# Temporal downsample via strided CausalConv3d
# Skip for T=1 (single frame) — matches official chunked encoding
# where first chunk stores cache but doesn't apply time_conv
x = self.time_conv(x)
mx.eval(x)
T = x.shape[1]
# --- Spatial operation (all modes, matching reference line 152-155) ---
if self.mode in ("upsample2d", "upsample3d"):
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
@@ -400,18 +408,36 @@ class Resample(nn.Module):
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
elif self.mode in ("downsample2d", "downsample3d"):
# Spatial downsample: per-frame strided Conv2d
x_flat = x.reshape(B * T, H, W, C)
x_flat = self._downsample_conv2d(x_flat)
mx.eval(x_flat)
H2, W2 = x_flat.shape[1], x_flat.shape[2]
x = x_flat.reshape(B, T, H2, W2, C)
# --- Temporal downsample (after spatial, matching reference line 157-168) ---
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
# First chunk: store spatially-downsampled result, skip time_conv
feat_cache[idx] = x
feat_idx[0] += 1
else:
# Subsequent chunks: prepend cached last frame, apply time_conv
save_x = x[:, -1:]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:], x], axis=1)
)
feat_cache[idx] = save_x
feat_idx[0] += 1
elif T > 1:
x = self.time_conv(x)
mx.eval(x)
return x
@@ -488,12 +514,20 @@ class Down_ResidualBlock(nn.Module):
self.downsamples = blocks
def __call__(self, x):
def __call__(self, x, feat_cache=None, feat_idx=None):
x_shortcut = self.avg_shortcut(x)
mx.eval(x_shortcut)
for module in self.downsamples:
x = module(x)
if feat_cache is not None:
if isinstance(module, ResidualBlock):
x = module(x, feat_cache, feat_idx)
elif isinstance(module, Resample):
x = module(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = module(x)
else:
x = module(x)
mx.eval(x)
return x + x_shortcut
@@ -597,18 +631,51 @@ class Encoder3d(nn.Module):
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
self.head = Head22(out_dim, out_channels=z_dim)
def __call__(self, x):
def __call__(self, x, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, 12] (patchified)
x = self.conv1(x)
if feat_cache is not None:
return self._forward_cached(x, feat_cache, feat_idx)
# No cache: internally chunk as 1+4+4+... (matches reference behavior)
num_convs = _count_conv3d(self)
internal_cache = [None] * num_convs
T = x.shape[1]
starts = [0] + list(range(1, T, 4))
ends = starts[1:] + [T]
outputs = []
for s, e in zip(starts, ends):
if s >= e:
continue
feat_idx_local = [0]
out = self._forward_cached(x[:, s:e], internal_cache, feat_idx_local)
outputs.append(out)
mx.eval(internal_cache)
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
def _forward_cached(self, x, feat_cache, feat_idx):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
for layer in self.downsamples:
x = layer(x)
x = layer(x, feat_cache, feat_idx)
for layer in self.middle:
x = layer(x)
if isinstance(layer, ResidualBlock):
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
mx.eval(x)
x = self.head(x)
x = self.head(x, feat_cache, feat_idx)
return x
@@ -626,13 +693,38 @@ class Head22(nn.Module):
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
def __call__(self, x):
def __call__(self, x, feat_cache=None, feat_idx=None):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.layer_2(x)
return x
def _count_conv3d(module):
"""Count all CausalConv3d instances in a module tree (for cache sizing)."""
count = 0
if isinstance(module, CausalConv3d):
count += 1
for child in module.children().values():
if isinstance(child, list):
for item in child:
count += _count_conv3d(item)
elif isinstance(child, nn.Module):
count += _count_conv3d(child)
return count
class Wan22VAEEncoder(nn.Module):
"""Full Wan2.2 VAE encoder with patchify and normalization."""
@@ -649,8 +741,11 @@ class Wan22VAEEncoder(nn.Module):
temperal_downsample=(False, True, True),
)
def __call__(self, img):
"""Encode image/video to latent space.
def encode(self, img):
"""Encode image/video using chunked encoding (1+4+4+... pattern).
This matches the reference implementation's chunked encoding with
persistent temporal cache, which is critical for correct I2V latents.
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
@@ -658,11 +753,28 @@ class Wan22VAEEncoder(nn.Module):
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
x = _patchify(img, patch_size=2)
T = x.shape[1]
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
out = self.encoder(x)
# Initialize temporal cache (one slot per CausalConv3d in encoder)
num_convs = _count_conv3d(self.encoder)
feat_cache = [None] * num_convs
# Chunked encoding: first chunk = 1 frame, rest = 4 frames each
num_chunks = 1 + (T - 1) // 4
out = None
for i in range(num_chunks):
feat_idx = [0] # Reset layer index each chunk (but keep cache)
if i == 0:
chunk = x[:, :1]
else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
else:
out = mx.concatenate([out, chunk_out], axis=1)
mx.eval(out)
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
@@ -672,6 +784,17 @@ class Wan22VAEEncoder(nn.Module):
mu = normalize_latents(mu)
return mu
def __call__(self, img):
"""Encode image/video to latent space (delegates to chunked encode).
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
return self.encode(img)
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""