feat(wan): Add chunked VAE encoding and TI2V-5B support

This commit is contained in:
Daniel
2026-03-09 20:47:37 +01:00
parent 967218b7c1
commit 061ae4407c
3 changed files with 223 additions and 64 deletions

View File

@@ -29,13 +29,37 @@ from mlx_video.utils import Colors
_build_i2v_mask = build_i2v_mask _build_i2v_mask = build_i2v_mask
def _best_output_size(w, h, dw, dh, max_area):
"""Compute the best output resolution that fits within max_area while
preserving the input aspect ratio and satisfying alignment constraints.
Matches the reference implementation's best_output_size().
"""
ratio = w / h
ow = (max_area * ratio) ** 0.5
oh = max_area / ow
# Option 1: process width first
ow1 = int(ow // dw * dw)
oh1 = int(max_area / ow1 // dh * dh)
ratio1 = ow1 / oh1
# Option 2: process height first
oh2 = int(oh // dh * dh)
ow2 = int(max_area / oh2 // dw * dw)
ratio2 = ow2 / oh2
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
return ow1, oh1
return ow2, oh2
def generate_video( def generate_video(
model_dir: str, model_dir: str,
prompt: str, prompt: str,
negative_prompt: str | None = None, negative_prompt: str | None = None,
image: str | None = None, image: str | None = None,
width: int = 1280, width: int = 1280,
height: int = 720, height: int = 704,
num_frames: int = 81, num_frames: int = 81,
steps: int = None, steps: int = None,
guide_scale: str | float | tuple = None, guide_scale: str | float | tuple = None,
@@ -232,6 +256,15 @@ def generate_video(
width = align_w width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}") print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
)
# Compute target latent shape # Compute target latent shape
z_dim = config.vae_z_dim z_dim = config.vae_z_dim
t_latent = (gen_frames - 1) // vae_stride[0] + 1 t_latent = (gen_frames - 1) // vae_stride[0] + 1
@@ -334,7 +367,7 @@ def generate_video(
mx.eval(img_tensor) mx.eval(img_tensor)
vae_enc = load_vae_encoder(vae_path, config) vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim] z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img) mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat] z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size) i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
@@ -658,8 +691,8 @@ def main():
help="Negative prompt for CFG (default: official Chinese prompt from config)") help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true", parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)") help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width") parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=720, help="Video height") parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)") parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)") parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair") parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")

View File

@@ -45,7 +45,8 @@ class WanModelConfig(BaseModelConfig):
"杂乱的背景,三条腿,背景人很多,倒着走" "杂乱的背景,三条腿,背景人很多,倒着走"
) )
# T5 # Resolution constraints
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
t5_vocab_size: int = 256384 t5_vocab_size: int = 256384
t5_dim: int = 4096 t5_dim: int = 4096
t5_dim_attn: int = 4096 t5_dim_attn: int = 4096
@@ -102,7 +103,8 @@ class WanModelConfig(BaseModelConfig):
boundary=0.900, boundary=0.900,
sample_shift=5.0, sample_shift=5.0,
sample_guide_scale=(3.5, 3.5), sample_guide_scale=(3.5, 3.5),
) max_area=704 * 1280,
@classmethod @classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig": def wan22_ti2v_5b(cls) -> "WanModelConfig":
@@ -120,7 +122,8 @@ class WanModelConfig(BaseModelConfig):
dual_model=False, dual_model=False,
boundary=0.0, boundary=0.0,
sample_shift=5.0, sample_shift=5.0,
sample_steps=50, sample_steps=40,
sample_guide_scale=5.0, sample_guide_scale=5.0,
sample_fps=24, sample_fps=24,
) max_area=704 * 1280,

View File

@@ -56,9 +56,11 @@ class CausalConv3d(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
# Causal temporal padding: always kernel_size-1 on the left. # Causal temporal padding: matches the reference CausalConv3d(nn.Conv3d)
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...). # which converts symmetric padding to causal: 2*padding[0] on the left.
self._causal_pad_t = kernel_size[0] - 1 # For most convs (kernel=3, padding=1): 2*1 = 2 (same as kernel-1).
# For downsample time_conv (kernel=3, padding=0): 2*0 = 0 (NO padding).
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1] self._pad_h = padding[1]
self._pad_w = padding[2] self._pad_w = padding[2]
@@ -68,7 +70,7 @@ class CausalConv3d(nn.Module):
)) ))
self.bias = mx.zeros((out_channels,)) self.bias = mx.zeros((out_channels,))
def __call__(self, x): def __call__(self, x, cache_x=None):
# x: [B, T, H, W, C] # x: [B, T, H, W, C]
B, T, H, W, C = x.shape B, T, H, W, C = x.shape
kd, kh, kw = self.kernel_size kd, kh, kw = self.kernel_size
@@ -81,10 +83,15 @@ class CausalConv3d(nn.Module):
y = mx.conv_general(x_flat, w2d) + self.bias y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1) return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only) — zeros match the reference # Causal temporal padding: prepend cached frames if available,
# implementation and what the model was trained with. # then zero-pad any remaining positions.
if self._causal_pad_t > 0: pad_needed = self._causal_pad_t
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C)) if cache_x is not None and pad_needed > 0:
x = mx.concatenate([cache_x, x], axis=1)
pad_needed -= cache_x.shape[1]
if pad_needed > 0:
pad_t = mx.zeros((B, pad_needed, H, W, C), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=1) x = mx.concatenate([pad_t, x], axis=1)
# Spatial padding # Spatial padding
@@ -144,9 +151,9 @@ class ResidualBlock(nn.Module):
else None else None
) )
def __call__(self, x): def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x h = self.shortcut(x) if self.shortcut is not None else x
return self.residual(x) + h return self.residual(x, feat_cache, feat_idx) + h
class ResidualBlockLayers(nn.Module): class ResidualBlockLayers(nn.Module):
@@ -169,14 +176,34 @@ class ResidualBlockLayers(nn.Module):
# Index 6: CausalConv3d # Index 6: CausalConv3d
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1) self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
def __call__(self, x): def _conv_with_cache(self, conv, x, feat_cache, feat_idx):
"""Apply CausalConv3d with temporal caching for chunked encoding."""
idx = feat_idx[0]
# Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
return out
def __call__(self, x, feat_cache=None, feat_idx=None):
x = self.layer_0(x) x = self.layer_0(x)
x = nn.silu(x) x = nn.silu(x)
x = self.layer_2(x) if feat_cache is not None:
x = self._conv_with_cache(self.layer_2, x, feat_cache, feat_idx)
else:
x = self.layer_2(x)
mx.eval(x) # Eval between convolutions to limit graph size mx.eval(x) # Eval between convolutions to limit graph size
x = self.layer_3(x) x = self.layer_3(x)
x = nn.silu(x) x = nn.silu(x)
x = self.layer_6(x) if feat_cache is not None:
x = self._conv_with_cache(self.layer_6, x, feat_cache, feat_idx)
else:
x = self.layer_6(x)
return x return x
@@ -344,53 +371,34 @@ class Resample(nn.Module):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
def __call__(self, x, first_chunk=False): def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C] # x: [B, T, H, W, C]
B, T, H, W, C = x.shape B, T, H, W, C = x.shape
# --- Temporal upsample (before spatial, matching reference) ---
if self.mode == "upsample3d": if self.mode == "upsample3d":
if first_chunk and T > 1: if first_chunk and T > 1:
# Match official chunked behavior: the first frame bypasses first_frame = x[:, 0:1]
# time_conv entirely (only spatial upsample). Remaining frames rest = x[:, 1:]
# go through time_conv with causal zero-padding, which tc_out = self.time_conv(rest)
# naturally gives each frame the same limited temporal context
# as the official frame-by-frame decode with caching.
first_frame = x[:, 0:1] # [B, 1, H, W, C]
rest = x[:, 1:] # [B, T-1, H, W, C]
# time_conv on remaining frames (causal pad gives zero context
# before rest[0], matching the official "Rep" cache path)
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C) tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :] stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :] stream1 = tc_out[:, :, :, :, 1, :]
interleaved = mx.stack([stream0, stream1], axis=2) interleaved = mx.stack([stream0, stream1], axis=2)
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C) interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
x = mx.concatenate([first_frame, interleaved], axis=1) x = mx.concatenate([first_frame, interleaved], axis=1)
elif self.mode == "upsample3d": else:
# Non-first-chunk or single frame: time_conv all frames tc_out = self.time_conv(x)
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
tc_out = tc_out.reshape(B, T, H, W, 2, C) tc_out = tc_out.reshape(B, T, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :] stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :] stream1 = tc_out[:, :, :, :, 1, :]
x = mx.stack([stream0, stream1], axis=2) x = mx.stack([stream0, stream1], axis=2)
x = x.reshape(B, T * 2, H, W, C) x = x.reshape(B, T * 2, H, W, C)
mx.eval(x)
T = x.shape[1]
if self.mode == "downsample3d" and T > 1:
# Temporal downsample via strided CausalConv3d
# Skip for T=1 (single frame) — matches official chunked encoding
# where first chunk stores cache but doesn't apply time_conv
x = self.time_conv(x)
mx.eval(x) mx.eval(x)
T = x.shape[1] T = x.shape[1]
# --- Spatial operation (all modes, matching reference line 152-155) ---
if self.mode in ("upsample2d", "upsample3d"): if self.mode in ("upsample2d", "upsample3d"):
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8 chunk_size = 8
chunks = [] chunks = []
for t_start in range(0, T, chunk_size): for t_start in range(0, T, chunk_size):
@@ -400,18 +408,36 @@ class Resample(nn.Module):
x_chunk = self._conv2d(x_chunk) x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk) mx.eval(x_chunk)
chunks.append(x_chunk) chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0) x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2] H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C) x = x.reshape(B, T, H2, W2, C)
elif self.mode in ("downsample2d", "downsample3d"): elif self.mode in ("downsample2d", "downsample3d"):
# Spatial downsample: per-frame strided Conv2d
x_flat = x.reshape(B * T, H, W, C) x_flat = x.reshape(B * T, H, W, C)
x_flat = self._downsample_conv2d(x_flat) x_flat = self._downsample_conv2d(x_flat)
mx.eval(x_flat) mx.eval(x_flat)
H2, W2 = x_flat.shape[1], x_flat.shape[2] H2, W2 = x_flat.shape[1], x_flat.shape[2]
x = x_flat.reshape(B, T, H2, W2, C) x = x_flat.reshape(B, T, H2, W2, C)
# --- Temporal downsample (after spatial, matching reference line 157-168) ---
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
# First chunk: store spatially-downsampled result, skip time_conv
feat_cache[idx] = x
feat_idx[0] += 1
else:
# Subsequent chunks: prepend cached last frame, apply time_conv
save_x = x[:, -1:]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:], x], axis=1)
)
feat_cache[idx] = save_x
feat_idx[0] += 1
elif T > 1:
x = self.time_conv(x)
mx.eval(x)
return x return x
@@ -488,12 +514,20 @@ class Down_ResidualBlock(nn.Module):
self.downsamples = blocks self.downsamples = blocks
def __call__(self, x): def __call__(self, x, feat_cache=None, feat_idx=None):
x_shortcut = self.avg_shortcut(x) x_shortcut = self.avg_shortcut(x)
mx.eval(x_shortcut) mx.eval(x_shortcut)
for module in self.downsamples: for module in self.downsamples:
x = module(x) if feat_cache is not None:
if isinstance(module, ResidualBlock):
x = module(x, feat_cache, feat_idx)
elif isinstance(module, Resample):
x = module(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = module(x)
else:
x = module(x)
mx.eval(x) mx.eval(x)
return x + x_shortcut return x + x_shortcut
@@ -597,18 +631,51 @@ class Encoder3d(nn.Module):
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels # Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
self.head = Head22(out_dim, out_channels=z_dim) self.head = Head22(out_dim, out_channels=z_dim)
def __call__(self, x): def __call__(self, x, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, 12] (patchified) # x: [B, T, H, W, 12] (patchified)
x = self.conv1(x) if feat_cache is not None:
return self._forward_cached(x, feat_cache, feat_idx)
# No cache: internally chunk as 1+4+4+... (matches reference behavior)
num_convs = _count_conv3d(self)
internal_cache = [None] * num_convs
T = x.shape[1]
starts = [0] + list(range(1, T, 4))
ends = starts[1:] + [T]
outputs = []
for s, e in zip(starts, ends):
if s >= e:
continue
feat_idx_local = [0]
out = self._forward_cached(x[:, s:e], internal_cache, feat_idx_local)
outputs.append(out)
mx.eval(internal_cache)
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
def _forward_cached(self, x, feat_cache, feat_idx):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
for layer in self.downsamples: for layer in self.downsamples:
x = layer(x) x = layer(x, feat_cache, feat_idx)
for layer in self.middle: for layer in self.middle:
x = layer(x) if isinstance(layer, ResidualBlock):
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
mx.eval(x) mx.eval(x)
x = self.head(x) x = self.head(x, feat_cache, feat_idx)
return x return x
@@ -626,13 +693,38 @@ class Head22(nn.Module):
# Index 2: CausalConv3d # Index 2: CausalConv3d
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1) self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
def __call__(self, x): def __call__(self, x, feat_cache=None, feat_idx=None):
x = self.layer_0(x) x = self.layer_0(x)
x = nn.silu(x) x = nn.silu(x)
x = self.layer_2(x) if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.layer_2(x)
return x return x
def _count_conv3d(module):
"""Count all CausalConv3d instances in a module tree (for cache sizing)."""
count = 0
if isinstance(module, CausalConv3d):
count += 1
for child in module.children().values():
if isinstance(child, list):
for item in child:
count += _count_conv3d(item)
elif isinstance(child, nn.Module):
count += _count_conv3d(child)
return count
class Wan22VAEEncoder(nn.Module): class Wan22VAEEncoder(nn.Module):
"""Full Wan2.2 VAE encoder with patchify and normalization.""" """Full Wan2.2 VAE encoder with patchify and normalization."""
@@ -649,8 +741,11 @@ class Wan22VAEEncoder(nn.Module):
temperal_downsample=(False, True, True), temperal_downsample=(False, True, True),
) )
def __call__(self, img): def encode(self, img):
"""Encode image/video to latent space. """Encode image/video using chunked encoding (1+4+4+... pattern).
This matches the reference implementation's chunked encoding with
persistent temporal cache, which is critical for correct I2V latents.
Args: Args:
img: [B, T, H, W, 3] image/video in [-1, 1] img: [B, T, H, W, 3] image/video in [-1, 1]
@@ -658,11 +753,28 @@ class Wan22VAEEncoder(nn.Module):
Returns: Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
""" """
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
x = _patchify(img, patch_size=2) x = _patchify(img, patch_size=2)
T = x.shape[1]
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2] # Initialize temporal cache (one slot per CausalConv3d in encoder)
out = self.encoder(x) num_convs = _count_conv3d(self.encoder)
feat_cache = [None] * num_convs
# Chunked encoding: first chunk = 1 frame, rest = 4 frames each
num_chunks = 1 + (T - 1) // 4
out = None
for i in range(num_chunks):
feat_idx = [0] # Reset layer index each chunk (but keep cache)
if i == 0:
chunk = x[:, :1]
else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
else:
out = mx.concatenate([out, chunk_out], axis=1)
mx.eval(out)
# conv1 (pointwise) + split into mu, log_var # conv1 (pointwise) + split into mu, log_var
out = self.conv1(out) out = self.conv1(out)
@@ -672,6 +784,17 @@ class Wan22VAEEncoder(nn.Module):
mu = normalize_latents(mu) mu = normalize_latents(mu)
return mu return mu
def __call__(self, img):
"""Encode image/video to latent space (delegates to chunked encode).
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
return self.encode(img)
class Wan22VAEDecoder(nn.Module): class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify.""" """Full Wan2.2 VAE decoder with normalization and unpatchify."""