feat(wan): Add chunked VAE encoding and TI2V-5B support
This commit is contained in:
@@ -56,9 +56,11 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
# Causal temporal padding: always kernel_size-1 on the left.
|
||||
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
|
||||
self._causal_pad_t = kernel_size[0] - 1
|
||||
# Causal temporal padding: matches the reference CausalConv3d(nn.Conv3d)
|
||||
# which converts symmetric padding to causal: 2*padding[0] on the left.
|
||||
# For most convs (kernel=3, padding=1): 2*1 = 2 (same as kernel-1).
|
||||
# For downsample time_conv (kernel=3, padding=0): 2*0 = 0 (NO padding).
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
@@ -68,7 +70,7 @@ class CausalConv3d(nn.Module):
|
||||
))
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, cache_x=None):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
kd, kh, kw = self.kernel_size
|
||||
@@ -81,10 +83,15 @@ class CausalConv3d(nn.Module):
|
||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||
|
||||
# Causal temporal padding (left only) — zeros match the reference
|
||||
# implementation and what the model was trained with.
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||
# Causal temporal padding: prepend cached frames if available,
|
||||
# then zero-pad any remaining positions.
|
||||
pad_needed = self._causal_pad_t
|
||||
if cache_x is not None and pad_needed > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=1)
|
||||
pad_needed -= cache_x.shape[1]
|
||||
|
||||
if pad_needed > 0:
|
||||
pad_t = mx.zeros((B, pad_needed, H, W, C), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=1)
|
||||
|
||||
# Spatial padding
|
||||
@@ -144,9 +151,9 @@ class ResidualBlock(nn.Module):
|
||||
else None
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
h = self.shortcut(x) if self.shortcut is not None else x
|
||||
return self.residual(x) + h
|
||||
return self.residual(x, feat_cache, feat_idx) + h
|
||||
|
||||
|
||||
class ResidualBlockLayers(nn.Module):
|
||||
@@ -169,14 +176,34 @@ class ResidualBlockLayers(nn.Module):
|
||||
# Index 6: CausalConv3d
|
||||
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
def _conv_with_cache(self, conv, x, feat_cache, feat_idx):
|
||||
"""Apply CausalConv3d with temporal caching for chunked encoding."""
|
||||
idx = feat_idx[0]
|
||||
# Save last CACHE_T frames before conv (for next chunk's context)
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
out = conv(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return out
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
x = self.layer_0(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_2(x)
|
||||
if feat_cache is not None:
|
||||
x = self._conv_with_cache(self.layer_2, x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.layer_2(x)
|
||||
mx.eval(x) # Eval between convolutions to limit graph size
|
||||
x = self.layer_3(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_6(x)
|
||||
if feat_cache is not None:
|
||||
x = self._conv_with_cache(self.layer_6, x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.layer_6(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -344,53 +371,34 @@ class Resample(nn.Module):
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
# --- Temporal upsample (before spatial, matching reference) ---
|
||||
if self.mode == "upsample3d":
|
||||
if first_chunk and T > 1:
|
||||
# Match official chunked behavior: the first frame bypasses
|
||||
# time_conv entirely (only spatial upsample). Remaining frames
|
||||
# go through time_conv with causal zero-padding, which
|
||||
# naturally gives each frame the same limited temporal context
|
||||
# as the official frame-by-frame decode with caching.
|
||||
first_frame = x[:, 0:1] # [B, 1, H, W, C]
|
||||
rest = x[:, 1:] # [B, T-1, H, W, C]
|
||||
|
||||
# time_conv on remaining frames (causal pad gives zero context
|
||||
# before rest[0], matching the official "Rep" cache path)
|
||||
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
|
||||
first_frame = x[:, 0:1]
|
||||
rest = x[:, 1:]
|
||||
tc_out = self.time_conv(rest)
|
||||
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
interleaved = mx.stack([stream0, stream1], axis=2)
|
||||
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
|
||||
|
||||
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
|
||||
x = mx.concatenate([first_frame, interleaved], axis=1)
|
||||
elif self.mode == "upsample3d":
|
||||
# Non-first-chunk or single frame: time_conv all frames
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
else:
|
||||
tc_out = self.time_conv(x)
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
x = mx.stack([stream0, stream1], axis=2)
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
if self.mode == "downsample3d" and T > 1:
|
||||
# Temporal downsample via strided CausalConv3d
|
||||
# Skip for T=1 (single frame) — matches official chunked encoding
|
||||
# where first chunk stores cache but doesn't apply time_conv
|
||||
x = self.time_conv(x)
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
# --- Spatial operation (all modes, matching reference line 152-155) ---
|
||||
if self.mode in ("upsample2d", "upsample3d"):
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
@@ -400,18 +408,36 @@ class Resample(nn.Module):
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
elif self.mode in ("downsample2d", "downsample3d"):
|
||||
# Spatial downsample: per-frame strided Conv2d
|
||||
x_flat = x.reshape(B * T, H, W, C)
|
||||
x_flat = self._downsample_conv2d(x_flat)
|
||||
mx.eval(x_flat)
|
||||
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
||||
x = x_flat.reshape(B, T, H2, W2, C)
|
||||
|
||||
# --- Temporal downsample (after spatial, matching reference line 157-168) ---
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
# First chunk: store spatially-downsampled result, skip time_conv
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
# Subsequent chunks: prepend cached last frame, apply time_conv
|
||||
save_x = x[:, -1:]
|
||||
x = self.time_conv(
|
||||
mx.concatenate([feat_cache[idx][:, -1:], x], axis=1)
|
||||
)
|
||||
feat_cache[idx] = save_x
|
||||
feat_idx[0] += 1
|
||||
elif T > 1:
|
||||
x = self.time_conv(x)
|
||||
mx.eval(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -488,12 +514,20 @@ class Down_ResidualBlock(nn.Module):
|
||||
|
||||
self.downsamples = blocks
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
x_shortcut = self.avg_shortcut(x)
|
||||
mx.eval(x_shortcut)
|
||||
|
||||
for module in self.downsamples:
|
||||
x = module(x)
|
||||
if feat_cache is not None:
|
||||
if isinstance(module, ResidualBlock):
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
elif isinstance(module, Resample):
|
||||
x = module(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = module(x)
|
||||
else:
|
||||
x = module(x)
|
||||
mx.eval(x)
|
||||
|
||||
return x + x_shortcut
|
||||
@@ -597,18 +631,51 @@ class Encoder3d(nn.Module):
|
||||
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
||||
self.head = Head22(out_dim, out_channels=z_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
# x: [B, T, H, W, 12] (patchified)
|
||||
x = self.conv1(x)
|
||||
if feat_cache is not None:
|
||||
return self._forward_cached(x, feat_cache, feat_idx)
|
||||
|
||||
# No cache: internally chunk as 1+4+4+... (matches reference behavior)
|
||||
num_convs = _count_conv3d(self)
|
||||
internal_cache = [None] * num_convs
|
||||
T = x.shape[1]
|
||||
starts = [0] + list(range(1, T, 4))
|
||||
ends = starts[1:] + [T]
|
||||
outputs = []
|
||||
for s, e in zip(starts, ends):
|
||||
if s >= e:
|
||||
continue
|
||||
feat_idx_local = [0]
|
||||
out = self._forward_cached(x[:, s:e], internal_cache, feat_idx_local)
|
||||
outputs.append(out)
|
||||
mx.eval(internal_cache)
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
def _forward_cached(self, x, feat_cache, feat_idx):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
for layer in self.downsamples:
|
||||
x = layer(x)
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
if isinstance(layer, ResidualBlock):
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
mx.eval(x)
|
||||
|
||||
x = self.head(x)
|
||||
x = self.head(x, feat_cache, feat_idx)
|
||||
return x
|
||||
|
||||
|
||||
@@ -626,13 +693,38 @@ class Head22(nn.Module):
|
||||
# Index 2: CausalConv3d
|
||||
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
x = self.layer_0(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_2(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
x = self.layer_2(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.layer_2(x)
|
||||
return x
|
||||
|
||||
|
||||
def _count_conv3d(module):
|
||||
"""Count all CausalConv3d instances in a module tree (for cache sizing)."""
|
||||
count = 0
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
for child in module.children().values():
|
||||
if isinstance(child, list):
|
||||
for item in child:
|
||||
count += _count_conv3d(item)
|
||||
elif isinstance(child, nn.Module):
|
||||
count += _count_conv3d(child)
|
||||
return count
|
||||
|
||||
|
||||
class Wan22VAEEncoder(nn.Module):
|
||||
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
||||
|
||||
@@ -649,8 +741,11 @@ class Wan22VAEEncoder(nn.Module):
|
||||
temperal_downsample=(False, True, True),
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
"""Encode image/video to latent space.
|
||||
def encode(self, img):
|
||||
"""Encode image/video using chunked encoding (1+4+4+... pattern).
|
||||
|
||||
This matches the reference implementation's chunked encoding with
|
||||
persistent temporal cache, which is critical for correct I2V latents.
|
||||
|
||||
Args:
|
||||
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||
@@ -658,11 +753,28 @@ class Wan22VAEEncoder(nn.Module):
|
||||
Returns:
|
||||
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||
"""
|
||||
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
|
||||
x = _patchify(img, patch_size=2)
|
||||
T = x.shape[1]
|
||||
|
||||
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
|
||||
out = self.encoder(x)
|
||||
# Initialize temporal cache (one slot per CausalConv3d in encoder)
|
||||
num_convs = _count_conv3d(self.encoder)
|
||||
feat_cache = [None] * num_convs
|
||||
|
||||
# Chunked encoding: first chunk = 1 frame, rest = 4 frames each
|
||||
num_chunks = 1 + (T - 1) // 4
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
feat_idx = [0] # Reset layer index each chunk (but keep cache)
|
||||
if i == 0:
|
||||
chunk = x[:, :1]
|
||||
else:
|
||||
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
else:
|
||||
out = mx.concatenate([out, chunk_out], axis=1)
|
||||
mx.eval(out)
|
||||
|
||||
# conv1 (pointwise) + split into mu, log_var
|
||||
out = self.conv1(out)
|
||||
@@ -672,6 +784,17 @@ class Wan22VAEEncoder(nn.Module):
|
||||
mu = normalize_latents(mu)
|
||||
return mu
|
||||
|
||||
def __call__(self, img):
|
||||
"""Encode image/video to latent space (delegates to chunked encode).
|
||||
|
||||
Args:
|
||||
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||
|
||||
Returns:
|
||||
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||
"""
|
||||
return self.encode(img)
|
||||
|
||||
|
||||
class Wan22VAEDecoder(nn.Module):
|
||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||
|
||||
Reference in New Issue
Block a user