From 061ae4407cdbb627bc67653d167c4bc67485259c Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 9 Mar 2026 20:47:37 +0100 Subject: [PATCH] feat(wan): Add chunked VAE encoding and TI2V-5B support --- mlx_video/generate_wan.py | 41 +++++- mlx_video/models/wan/config.py | 11 +- mlx_video/models/wan/vae22.py | 235 +++++++++++++++++++++++++-------- 3 files changed, 223 insertions(+), 64 deletions(-) diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 10a76d1..14358b7 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -29,13 +29,37 @@ from mlx_video.utils import Colors _build_i2v_mask = build_i2v_mask +def _best_output_size(w, h, dw, dh, max_area): + """Compute the best output resolution that fits within max_area while + preserving the input aspect ratio and satisfying alignment constraints. + Matches the reference implementation's best_output_size(). + """ + ratio = w / h + ow = (max_area * ratio) ** 0.5 + oh = max_area / ow + + # Option 1: process width first + ow1 = int(ow // dw * dw) + oh1 = int(max_area / ow1 // dh * dh) + ratio1 = ow1 / oh1 + + # Option 2: process height first + oh2 = int(oh // dh * dh) + ow2 = int(max_area / oh2 // dw * dw) + ratio2 = ow2 / oh2 + + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio): + return ow1, oh1 + return ow2, oh2 + + def generate_video( model_dir: str, prompt: str, negative_prompt: str | None = None, image: str | None = None, width: int = 1280, - height: int = 720, + height: int = 704, num_frames: int = 81, steps: int = None, guide_scale: str | float | tuple = None, @@ -232,6 +256,15 @@ def generate_video( width = align_w print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}") + # Enforce max_area constraint (model-specific resolution limit) + if config.max_area > 0 and height * width > config.max_area: + old_h, old_w = height, width + width, height = _best_output_size(width, height, align_w, align_h, config.max_area) + print( + f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area " + f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}" + ) + # Compute target latent shape z_dim = config.vae_z_dim t_latent = (gen_frames - 1) // vae_stride[0] + 1 @@ -334,7 +367,7 @@ def generate_video( mx.eval(img_tensor) vae_enc = load_vae_encoder(vae_path, config) - z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim] + z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim] mx.eval(z_img) z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat] i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size) @@ -658,8 +691,8 @@ def main(): help="Negative prompt for CFG (default: official Chinese prompt from config)") parser.add_argument("--no-negative-prompt", action="store_true", help="Disable negative prompt (use empty string instead of config default)") - parser.add_argument("--width", type=int, default=1280, help="Video width") - parser.add_argument("--height", type=int, default=720, help="Video height") + parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)") + parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)") parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)") parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)") parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair") diff --git a/mlx_video/models/wan/config.py b/mlx_video/models/wan/config.py index 08370d4..5e51f3b 100644 --- a/mlx_video/models/wan/config.py +++ b/mlx_video/models/wan/config.py @@ -45,7 +45,8 @@ class WanModelConfig(BaseModelConfig): "杂乱的背景,三条腿,背景人很多,倒着走" ) - # T5 + # Resolution constraints + max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B t5_vocab_size: int = 256384 t5_dim: int = 4096 t5_dim_attn: int = 4096 @@ -102,7 +103,8 @@ class WanModelConfig(BaseModelConfig): boundary=0.900, sample_shift=5.0, sample_guide_scale=(3.5, 3.5), - ) + max_area=704 * 1280, + @classmethod def wan22_ti2v_5b(cls) -> "WanModelConfig": @@ -120,7 +122,8 @@ class WanModelConfig(BaseModelConfig): dual_model=False, boundary=0.0, sample_shift=5.0, - sample_steps=50, + sample_steps=40, sample_guide_scale=5.0, sample_fps=24, - ) + max_area=704 * 1280, + diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index 48058f6..72a4f04 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -56,9 +56,11 @@ class CausalConv3d(nn.Module): self.kernel_size = kernel_size self.stride = stride - # Causal temporal padding: always kernel_size-1 on the left. - # This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...). - self._causal_pad_t = kernel_size[0] - 1 + # Causal temporal padding: matches the reference CausalConv3d(nn.Conv3d) + # which converts symmetric padding to causal: 2*padding[0] on the left. + # For most convs (kernel=3, padding=1): 2*1 = 2 (same as kernel-1). + # For downsample time_conv (kernel=3, padding=0): 2*0 = 0 (NO padding). + self._causal_pad_t = 2 * padding[0] self._pad_h = padding[1] self._pad_w = padding[2] @@ -68,7 +70,7 @@ class CausalConv3d(nn.Module): )) self.bias = mx.zeros((out_channels,)) - def __call__(self, x): + def __call__(self, x, cache_x=None): # x: [B, T, H, W, C] B, T, H, W, C = x.shape kd, kh, kw = self.kernel_size @@ -81,10 +83,15 @@ class CausalConv3d(nn.Module): y = mx.conv_general(x_flat, w2d) + self.bias return y.reshape(B, T, y.shape[1], y.shape[2], -1) - # Causal temporal padding (left only) — zeros match the reference - # implementation and what the model was trained with. - if self._causal_pad_t > 0: - pad_t = mx.zeros((B, self._causal_pad_t, H, W, C)) + # Causal temporal padding: prepend cached frames if available, + # then zero-pad any remaining positions. + pad_needed = self._causal_pad_t + if cache_x is not None and pad_needed > 0: + x = mx.concatenate([cache_x, x], axis=1) + pad_needed -= cache_x.shape[1] + + if pad_needed > 0: + pad_t = mx.zeros((B, pad_needed, H, W, C), dtype=x.dtype) x = mx.concatenate([pad_t, x], axis=1) # Spatial padding @@ -144,9 +151,9 @@ class ResidualBlock(nn.Module): else None ) - def __call__(self, x): + def __call__(self, x, feat_cache=None, feat_idx=None): h = self.shortcut(x) if self.shortcut is not None else x - return self.residual(x) + h + return self.residual(x, feat_cache, feat_idx) + h class ResidualBlockLayers(nn.Module): @@ -169,14 +176,34 @@ class ResidualBlockLayers(nn.Module): # Index 6: CausalConv3d self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1) - def __call__(self, x): + def _conv_with_cache(self, conv, x, feat_cache, feat_idx): + """Apply CausalConv3d with temporal caching for chunked encoding.""" + idx = feat_idx[0] + # Save last CACHE_T frames before conv (for next chunk's context) + cache_x = x[:, -CACHE_T:] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate( + [feat_cache[idx][:, -1:], cache_x], axis=1 + ) + out = conv(x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return out + + def __call__(self, x, feat_cache=None, feat_idx=None): x = self.layer_0(x) x = nn.silu(x) - x = self.layer_2(x) + if feat_cache is not None: + x = self._conv_with_cache(self.layer_2, x, feat_cache, feat_idx) + else: + x = self.layer_2(x) mx.eval(x) # Eval between convolutions to limit graph size x = self.layer_3(x) x = nn.silu(x) - x = self.layer_6(x) + if feat_cache is not None: + x = self._conv_with_cache(self.layer_6, x, feat_cache, feat_idx) + else: + x = self.layer_6(x) return x @@ -344,53 +371,34 @@ class Resample(nn.Module): x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias - def __call__(self, x, first_chunk=False): + def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None): # x: [B, T, H, W, C] B, T, H, W, C = x.shape + # --- Temporal upsample (before spatial, matching reference) --- if self.mode == "upsample3d": if first_chunk and T > 1: - # Match official chunked behavior: the first frame bypasses - # time_conv entirely (only spatial upsample). Remaining frames - # go through time_conv with causal zero-padding, which - # naturally gives each frame the same limited temporal context - # as the official frame-by-frame decode with caching. - first_frame = x[:, 0:1] # [B, 1, H, W, C] - rest = x[:, 1:] # [B, T-1, H, W, C] - - # time_conv on remaining frames (causal pad gives zero context - # before rest[0], matching the official "Rep" cache path) - tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C] + first_frame = x[:, 0:1] + rest = x[:, 1:] + tc_out = self.time_conv(rest) tc_out = tc_out.reshape(B, T - 1, H, W, 2, C) stream0 = tc_out[:, :, :, :, 0, :] stream1 = tc_out[:, :, :, :, 1, :] interleaved = mx.stack([stream0, stream1], axis=2) interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C) - - # first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames x = mx.concatenate([first_frame, interleaved], axis=1) - elif self.mode == "upsample3d": - # Non-first-chunk or single frame: time_conv all frames - tc_out = self.time_conv(x) # [B, T, H, W, 2C] + else: + tc_out = self.time_conv(x) tc_out = tc_out.reshape(B, T, H, W, 2, C) stream0 = tc_out[:, :, :, :, 0, :] stream1 = tc_out[:, :, :, :, 1, :] x = mx.stack([stream0, stream1], axis=2) x = x.reshape(B, T * 2, H, W, C) - - mx.eval(x) - T = x.shape[1] - - if self.mode == "downsample3d" and T > 1: - # Temporal downsample via strided CausalConv3d - # Skip for T=1 (single frame) — matches official chunked encoding - # where first chunk stores cache but doesn't apply time_conv - x = self.time_conv(x) mx.eval(x) T = x.shape[1] + # --- Spatial operation (all modes, matching reference line 152-155) --- if self.mode in ("upsample2d", "upsample3d"): - # Spatial upsample in temporal chunks to limit peak memory chunk_size = 8 chunks = [] for t_start in range(0, T, chunk_size): @@ -400,18 +408,36 @@ class Resample(nn.Module): x_chunk = self._conv2d(x_chunk) mx.eval(x_chunk) chunks.append(x_chunk) - x = mx.concatenate(chunks, axis=0) H2, W2 = x.shape[1], x.shape[2] x = x.reshape(B, T, H2, W2, C) elif self.mode in ("downsample2d", "downsample3d"): - # Spatial downsample: per-frame strided Conv2d x_flat = x.reshape(B * T, H, W, C) x_flat = self._downsample_conv2d(x_flat) mx.eval(x_flat) H2, W2 = x_flat.shape[1], x_flat.shape[2] x = x_flat.reshape(B, T, H2, W2, C) + # --- Temporal downsample (after spatial, matching reference line 157-168) --- + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + # First chunk: store spatially-downsampled result, skip time_conv + feat_cache[idx] = x + feat_idx[0] += 1 + else: + # Subsequent chunks: prepend cached last frame, apply time_conv + save_x = x[:, -1:] + x = self.time_conv( + mx.concatenate([feat_cache[idx][:, -1:], x], axis=1) + ) + feat_cache[idx] = save_x + feat_idx[0] += 1 + elif T > 1: + x = self.time_conv(x) + mx.eval(x) + return x @@ -488,12 +514,20 @@ class Down_ResidualBlock(nn.Module): self.downsamples = blocks - def __call__(self, x): + def __call__(self, x, feat_cache=None, feat_idx=None): x_shortcut = self.avg_shortcut(x) mx.eval(x_shortcut) for module in self.downsamples: - x = module(x) + if feat_cache is not None: + if isinstance(module, ResidualBlock): + x = module(x, feat_cache, feat_idx) + elif isinstance(module, Resample): + x = module(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = module(x) + else: + x = module(x) mx.eval(x) return x + x_shortcut @@ -597,18 +631,51 @@ class Encoder3d(nn.Module): # Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels self.head = Head22(out_dim, out_channels=z_dim) - def __call__(self, x): + def __call__(self, x, feat_cache=None, feat_idx=None): # x: [B, T, H, W, 12] (patchified) - x = self.conv1(x) + if feat_cache is not None: + return self._forward_cached(x, feat_cache, feat_idx) + + # No cache: internally chunk as 1+4+4+... (matches reference behavior) + num_convs = _count_conv3d(self) + internal_cache = [None] * num_convs + T = x.shape[1] + starts = [0] + list(range(1, T, 4)) + ends = starts[1:] + [T] + outputs = [] + for s, e in zip(starts, ends): + if s >= e: + continue + feat_idx_local = [0] + out = self._forward_cached(x[:, s:e], internal_cache, feat_idx_local) + outputs.append(out) + mx.eval(internal_cache) + if len(outputs) == 1: + return outputs[0] + return mx.concatenate(outputs, axis=1) + + def _forward_cached(self, x, feat_cache, feat_idx): + idx = feat_idx[0] + cache_x = x[:, -CACHE_T:] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate( + [feat_cache[idx][:, -1:], cache_x], axis=1 + ) + x = self.conv1(x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 for layer in self.downsamples: - x = layer(x) + x = layer(x, feat_cache, feat_idx) for layer in self.middle: - x = layer(x) + if isinstance(layer, ResidualBlock): + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) mx.eval(x) - x = self.head(x) + x = self.head(x, feat_cache, feat_idx) return x @@ -626,13 +693,38 @@ class Head22(nn.Module): # Index 2: CausalConv3d self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1) - def __call__(self, x): + def __call__(self, x, feat_cache=None, feat_idx=None): x = self.layer_0(x) x = nn.silu(x) - x = self.layer_2(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, -CACHE_T:] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate( + [feat_cache[idx][:, -1:], cache_x], axis=1 + ) + x = self.layer_2(x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.layer_2(x) return x +def _count_conv3d(module): + """Count all CausalConv3d instances in a module tree (for cache sizing).""" + count = 0 + if isinstance(module, CausalConv3d): + count += 1 + for child in module.children().values(): + if isinstance(child, list): + for item in child: + count += _count_conv3d(item) + elif isinstance(child, nn.Module): + count += _count_conv3d(child) + return count + + class Wan22VAEEncoder(nn.Module): """Full Wan2.2 VAE encoder with patchify and normalization.""" @@ -649,8 +741,11 @@ class Wan22VAEEncoder(nn.Module): temperal_downsample=(False, True, True), ) - def __call__(self, img): - """Encode image/video to latent space. + def encode(self, img): + """Encode image/video using chunked encoding (1+4+4+... pattern). + + This matches the reference implementation's chunked encoding with + persistent temporal cache, which is critical for correct I2V latents. Args: img: [B, T, H, W, 3] image/video in [-1, 1] @@ -658,11 +753,28 @@ class Wan22VAEEncoder(nn.Module): Returns: mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent """ - # Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12] x = _patchify(img, patch_size=2) + T = x.shape[1] - # Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2] - out = self.encoder(x) + # Initialize temporal cache (one slot per CausalConv3d in encoder) + num_convs = _count_conv3d(self.encoder) + feat_cache = [None] * num_convs + + # Chunked encoding: first chunk = 1 frame, rest = 4 frames each + num_chunks = 1 + (T - 1) // 4 + out = None + for i in range(num_chunks): + feat_idx = [0] # Reset layer index each chunk (but keep cache) + if i == 0: + chunk = x[:, :1] + else: + chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i] + chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx) + if out is None: + out = chunk_out + else: + out = mx.concatenate([out, chunk_out], axis=1) + mx.eval(out) # conv1 (pointwise) + split into mu, log_var out = self.conv1(out) @@ -672,6 +784,17 @@ class Wan22VAEEncoder(nn.Module): mu = normalize_latents(mu) return mu + def __call__(self, img): + """Encode image/video to latent space (delegates to chunked encode). + + Args: + img: [B, T, H, W, 3] image/video in [-1, 1] + + Returns: + mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent + """ + return self.encode(img) + class Wan22VAEDecoder(nn.Module): """Full Wan2.2 VAE decoder with normalization and unpatchify."""