From 9bdda9f22e55acae5019d68bdeee200fb0680bb5 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 4 Mar 2026 14:32:45 +0100 Subject: [PATCH] feat(wan): Add tiled VAE decoding and fix TI2V quality --- mlx_video/generate_wan.py | 87 +++++++--- mlx_video/models/ltx/video_vae/tiling.py | 21 ++- mlx_video/models/wan/model.py | 12 +- mlx_video/models/wan/transformer.py | 9 +- mlx_video/models/wan/vae.py | 53 ++++++ mlx_video/models/wan/vae22.py | 61 +++++++ tests/test_wan_tiling.py | 198 +++++++++++++++++++++++ 7 files changed, 407 insertions(+), 34 deletions(-) create mode 100644 tests/test_wan_tiling.py diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index f1f0275..10a76d1 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -46,8 +46,10 @@ def generate_video( loras: list | None = None, loras_high: list | None = None, loras_low: list | None = None, - + tiling: str = "auto", + no_compile: bool = False, ): + """Generate video using Wan pipeline (supports T2V and I2V). Args: @@ -67,6 +69,13 @@ def generate_video( loras: Optional list of (path, strength) tuples applied to all models loras_high: Optional list of (path, strength) tuples for high-noise model only loras_low: Optional list of (path, strength) tuples for low-noise model only + tiling: Tiling mode for VAE decoding. Options: + - "auto": Automatically determine tiling based on video size (default) + - "none": Disable tiling + - "default", "aggressive", "conservative": Preset tiling configs + - "spatial": Spatial tiling only + - "temporal": Temporal tiling only + no_compile: If True, skip mx.compile on models (useful for debugging) """ import json @@ -173,12 +182,7 @@ def generate_video( # Validate frame count assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}" - # For T2V: generate 1 extra latent frame so the VAE's causal zero-padding - # artifacts land on throwaway frames. The reference Wan2.2 speech2video.py - # uses a similar "drop_first_motion" approach (drops 3 pixel frames). - # For I2V the reference image provides real first-frame content, so no extra needed. - extra_frames = config.vae_stride[0] if not is_i2v else 0 - gen_frames = num_frames + extra_frames + gen_frames = num_frames version_str = f"Wan{config.model_version}" mode_str = "dual-model" if is_dual else "single-model" @@ -241,8 +245,6 @@ def generate_video( ) print(f"{Colors.DIM} Latent shape: {target_shape}") - if extra_frames > 0: - print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts") print(f" Sequence length: {seq_len}{Colors.RESET}") # Load T5 encoder @@ -419,7 +421,7 @@ def generate_video( rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes) mx.eval(rope_cos_sin_low, rope_cos_sin_high) else: - rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes) + rope_cos_sin = single_model.prepare_rope(rope_grid_sizes) mx.eval(rope_cos_sin) # Setup scheduler @@ -448,12 +450,13 @@ def generate_video( print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") t3 = time.time() - # Compile model forward for faster denoising - models_to_compile = ( - [high_noise_model, low_noise_model] if is_dual else [single_model] - ) - for m in models_to_compile: - m._compiled = mx.compile(m) + if not no_compile: + models_to_compile = ( + [high_noise_model, low_noise_model] if is_dual else [single_model] + ) + for m in models_to_compile: + m._compiled = mx.compile(m) + @@ -585,24 +588,53 @@ def generate_video( is_wan22_vae = config.vae_z_dim == 48 + # Select tiling configuration + from mlx_video.models.ltx.video_vae.tiling import TilingConfig + + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, num_frames) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + elif tiling == "spatial": + tiling_config = TilingConfig.spatial_only() + elif tiling == "temporal": + tiling_config = TilingConfig.temporal_only() + else: + print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + tiling_config = TilingConfig.auto(height, width, num_frames) + + if tiling_config is not None: + spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" + temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" + print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + if is_wan22_vae: from mlx_video.models.wan.vae22 import denormalize_latents # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) z = latents.transpose(1, 2, 3, 0)[None] z = denormalize_latents(z) - video = vae(z) + if tiling_config is not None: + video = vae.decode_tiled(z, tiling_config) + else: + video = vae(z) mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") video = np.array(video[0]) # [T', H', W', 3] - # Trim extra frames generated for zero-padding warmup - if extra_frames > 0: - video = video[extra_frames:] video = (video + 1.0) / 2.0 video = np.clip(video * 255.0, 0, 255).astype(np.uint8) else: - video = vae.decode(latents[None]) + if tiling_config is not None: + video = vae.decode_tiled(latents[None], tiling_config) + else: + video = vae.decode(latents[None]) mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") @@ -651,6 +683,17 @@ def main(): "--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"), help="Apply a LoRA to low-noise model only (dual-model, repeatable)", ) + parser.add_argument( + "--tiling", + type=str, + default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="VAE tiling mode to reduce memory during decoding (default: auto)", + ) + parser.add_argument( + "--no-compile", action="store_true", + help="Disable mx.compile on models (for debugging)", + ) args = parser.parse_args() @@ -688,6 +731,8 @@ def main(): loras=_parse_lora_args(args.lora), loras_high=_parse_lora_args(args.lora_high), loras_low=_parse_lora_args(args.lora_low), + tiling=args.tiling, + no_compile=args.no_compile, ) diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index 72d32e4..512edff 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -283,6 +283,7 @@ def decode_with_tiling( spatial_scale: int = 32, temporal_scale: int = 8, causal: bool = False, + causal_temporal: bool = True, timestep: Optional[mx.array] = None, chunked_conv: bool = False, on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, @@ -296,6 +297,10 @@ def decode_with_tiling( spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify). temporal_scale: Temporal scale factor (8 for LTX VAE). causal: Whether to use causal convolutions. + causal_temporal: Whether the decoder uses causal temporal mapping where + T input frames produce 1+(T-1)*scale output frames. When False, uses + simple scaling where T frames produce T*scale output frames. + Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1). timestep: Optional timestep for conditioning. chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. @@ -310,7 +315,7 @@ def decode_with_tiling( b, c, f_latent, h_latent, w_latent = latents.shape # Compute output shape - out_f = 1 + (f_latent - 1) * temporal_scale + out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale) out_h = h_latent * spatial_scale out_w = w_latent * spatial_scale @@ -332,7 +337,10 @@ def decode_with_tiling( temporal_overlap = 0 # Compute intervals for each dimension - temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + if causal_temporal: + temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + else: + temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) @@ -355,7 +363,10 @@ def decode_with_tiling( t_right = temporal_intervals.right_ramps[t_idx] # Map temporal coordinates - out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + if causal_temporal: + out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + else: + out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale) for h_idx in range(num_h_tiles): h_start = height_intervals.starts[h_idx] @@ -461,8 +472,10 @@ def decode_with_tiling( # Map to output frame index (first frame of next tile's contribution) if next_tile_start_latent == 0: next_tile_start_out = 0 - else: + elif causal_temporal: next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + else: + next_tile_start_out = next_tile_start_latent * temporal_scale # We need to track how many frames we've already emitted if not hasattr(decode_with_tiling, '_emitted_frames'): diff --git a/mlx_video/models/wan/model.py b/mlx_video/models/wan/model.py index 5f2e391..4f7dfd0 100644 --- a/mlx_video/models/wan/model.py +++ b/mlx_video/models/wan/model.py @@ -48,9 +48,8 @@ class Head(nn.Module): """ if e.ndim == 2: e = e[:, None, :] # [B, 1, dim] - # Compute modulation in float32 for precision, cast to working dtype - w_dtype = _linear_dtype(self.head) - mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype) + # Compute modulation in float32 (matching reference's autocast(float32)) + mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32 e0 = mod[:, :, 0, :] # [B, L_e, dim] shift e1 = mod[:, :, 1, :] # [B, L_e, dim] scale x_norm = self.norm(x) @@ -120,10 +119,13 @@ class WanModel(nn.Module): ], axis=1) # Precompute sinusoidal inv_freq for time embedding + # Use numpy float64 for precision (matches reference torch.float64), + # then store as float32 since MLX GPU doesn't support float64. half = config.freq_dim // 2 - self._inv_freq = mx.power( - 10000.0, -mx.arange(half).astype(mx.float32) / half + inv_freq_np = np.power( + 10000.0, -np.arange(half, dtype=np.float64) / half ) + self._inv_freq = mx.array(inv_freq_np.astype(np.float32)) def _patchify(self, x: mx.array) -> tuple: diff --git a/mlx_video/models/wan/transformer.py b/mlx_video/models/wan/transformer.py index 59aa651..7186b82 100644 --- a/mlx_video/models/wan/transformer.py +++ b/mlx_video/models/wan/transformer.py @@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module): rope_cos_sin: tuple | None = None, attn_mask: mx.array | None = None, ) -> mx.array: - # Modulation: compute in float32 for precision, cast to working dtype - # to avoid promoting the full hidden state (seq_len × dim) to float32 - w_dtype = _linear_dtype(self.self_attn.q) - mod = (self.modulation + e).astype(w_dtype) + # Modulation: compute in float32 for precision, matching the reference + # which keeps residual x in float32 via torch.amp.autocast(dtype=float32). + # By keeping modulation in float32, type promotion ensures the residual + # stream stays float32 throughout all 30 layers (gate * output + x → float32). + mod = self.modulation + e # float32 e0, e1, e2, e3, e4, e5 = ( mod[:, :, 0, :], # shift for self-attn mod[:, :, 1, :], # scale for self-attn diff --git a/mlx_video/models/wan/vae.py b/mlx_video/models/wan/vae.py index fe8ccaf..f4d83e7 100644 --- a/mlx_video/models/wan/vae.py +++ b/mlx_video/models/wan/vae.py @@ -534,3 +534,56 @@ class WanVAE(nn.Module): x = self.conv2(z) out = self.decoder(x) return mx.clip(out, -1, 1) + + def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array: + """Decode latent to video using tiling to reduce memory usage. + + Splits the latent tensor into overlapping spatial/temporal tiles, + decodes each tile independently, and blends them with trapezoidal + masks. Reuses the LTX-2 tiling infrastructure. + + Args: + z: Normalized latent [B, z_dim, T, H, W] + tiling_config: Optional TilingConfig. If None, uses default. + + Returns: + Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] + """ + from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling + + if tiling_config is None: + tiling_config = TilingConfig.default() + + # Check if tiling is actually needed + _, _, f, h, w = z.shape + needs_tiling = False + if tiling_config.spatial_config is not None: + s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8 + if h > s_tile or w > s_tile: + needs_tiling = True + if tiling_config.temporal_config is not None: + t_tile = tiling_config.temporal_config.tile_size_in_frames // 4 + if f > t_tile: + needs_tiling = True + + if not needs_tiling: + return self.decode(z) + + # Denormalize once (small tensor), then tile the denormalized latents + mean = self.mean.reshape(1, -1, 1, 1, 1) + inv_std = self.inv_std.reshape(1, -1, 1, 1, 1) + z_denorm = z / inv_std + mean + + def tile_decode(tile_latents, **kwargs): + x = self.conv2(tile_latents) + out = self.decoder(x) + return mx.clip(out, -1, 1) + + return decode_with_tiling( + decoder_fn=tile_decode, + latents=z_denorm, + tiling_config=tiling_config, + spatial_scale=8, # 3× spatial 2× upsamples = 8× + temporal_scale=4, # 2× temporal upsamples × 2 = 4× + causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T) + ) diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index a0f7234..48058f6 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module): return mx.clip(out, -1.0, 1.0) + def decode_tiled(self, z, tiling_config=None): + """Decode latents using tiling to reduce memory usage. + + Splits the latent tensor into overlapping spatial/temporal tiles, + decodes each tile independently, and blends them with trapezoidal + masks. Reuses the LTX-2 tiling infrastructure with channels-first + adapter (future: refactor tiling.py to be layout-agnostic). + + Args: + z: [B, T, H, W, C=48] latent tensor (already denormalized) + tiling_config: Optional TilingConfig. If None, uses default. + + Returns: + video: [B, T', H', W', 3] decoded RGB in [-1, 1] + """ + from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling + + if tiling_config is None: + tiling_config = TilingConfig.default() + + # Check if tiling is actually needed + b, t, h_px, w_px, c = z.shape + # Latent dimensions (before conv2/decoder upsampling) + h_lat, w_lat = h_px, w_px + needs_tiling = False + if tiling_config.spatial_config is not None: + s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16 + if h_lat > s_tile or w_lat > s_tile: + needs_tiling = True + if tiling_config.temporal_config is not None: + t_tile = tiling_config.temporal_config.tile_size_in_frames // 4 + if t > t_tile: + needs_tiling = True + + if not needs_tiling: + return self(z) + + # Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W] + z_cf = z.transpose(0, 4, 1, 2, 3) + + # Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W') + def tile_decode(tile_latents, **kwargs): + tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C] + x = self.conv2(tile_cl) + out = self.decoder(x, first_chunk=True) + out = _unpatchify(out, patch_size=2) + out = mx.clip(out, -1.0, 1.0) + return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W'] + + result_cf = decode_with_tiling( + decoder_fn=tile_decode, + latents=z_cf, + tiling_config=tiling_config, + spatial_scale=16, # 8× conv upsample + 2× unpatchify + temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal) + causal_temporal=True, + ) + + # Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3] + return result_cf.transpose(0, 2, 3, 4, 1) + def denormalize_latents(z, mean=None, std=None): """Denormalize latents: z = z / (1/std) + mean.""" diff --git a/tests/test_wan_tiling.py b/tests/test_wan_tiling.py new file mode 100644 index 0000000..3353dd4 --- /dev/null +++ b/tests/test_wan_tiling.py @@ -0,0 +1,198 @@ +"""Tests for Wan VAE tiled decoding.""" + +import mlx.core as mx +import numpy as np +import pytest + +from mlx_video.models.ltx.video_vae.tiling import ( + TilingConfig, + decode_with_tiling, + split_in_spatial, + split_in_temporal, +) + + +class TestNonCausalTemporal: + """Tests for the causal_temporal=False path in decode_with_tiling.""" + + def test_split_spatial_for_temporal(self): + """Non-causal temporal should use split_in_spatial (no causal shift).""" + intervals = split_in_spatial(8, 2, 20) + # No causal adjustment: starts should be evenly spaced + assert intervals.starts[0] == 0 + for i in range(1, len(intervals.starts)): + assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2) + + def test_causal_vs_noncausal_output_size(self): + """Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S.""" + mx.random.seed(42) + b, c, t, h, w = 1, 4, 4, 4, 4 + latents = mx.random.normal((b, c, t, h, w)) + scale = 4 + + # Simple passthrough decoder: just repeat along dimensions + def dummy_decoder_causal(x, **kwargs): + b, c, t, h, w = x.shape + out_t = 1 + (t - 1) * scale + out_h = h * scale + out_w = w * scale + return mx.ones((b, 3, out_t, out_h, out_w)) + + def dummy_decoder_noncausal(x, **kwargs): + b, c, t, h, w = x.shape + out_t = t * scale + out_h = h * scale + out_w = w * scale + return mx.ones((b, 3, out_t, out_h, out_w)) + + config = TilingConfig.spatial_only(tile_size=128, overlap=64) + + # Causal: 1 + (4-1)*4 = 13 + out_causal = decode_with_tiling( + dummy_decoder_causal, latents, config, + spatial_scale=scale, temporal_scale=scale, causal_temporal=True, + ) + mx.eval(out_causal) + assert out_causal.shape[2] == 1 + (t - 1) * scale # 13 + + # Non-causal: 4*4 = 16 + out_noncausal = decode_with_tiling( + dummy_decoder_noncausal, latents, config, + spatial_scale=scale, temporal_scale=scale, causal_temporal=False, + ) + mx.eval(out_noncausal) + assert out_noncausal.shape[2] == t * scale # 16 + + +class TestWan22TiledDecoding: + """Tests for Wan2.2 VAE tiled decoding.""" + + def _make_small_wan22_decoder(self): + """Create a small Wan2.2 decoder for testing.""" + from mlx_video.models.wan.vae22 import Wan22VAEDecoder + + # Use very small dimensions for fast testing + vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16) + mx.eval(vae.parameters()) + return vae + + def test_decode_tiled_output_shape(self): + """Tiled decode should produce same shape as non-tiled.""" + mx.random.seed(42) + vae = self._make_small_wan22_decoder() + + # Small input: [B=1, T=3, H=2, W=2, C=48] + z = mx.random.normal((1, 3, 2, 2, 48)) + mx.eval(z) + + # Non-tiled + out_regular = vae(z) + mx.eval(out_regular) + + # Tiled (force tiling with very small tile sizes) + # Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4) + config = TilingConfig( + spatial_config=None, # Don't tile spatially (input is tiny) + temporal_config=None, # Don't tile temporally (input is tiny) + ) + # With no tiling config, decode_tiled should fall through to regular decode + out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) + mx.eval(out_tiled) + + # Both should produce the same shape + assert out_regular.shape == out_tiled.shape, ( + f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" + ) + + def test_decode_tiled_falls_through_when_small(self): + """When input is smaller than tile size, decode_tiled should produce same output as __call__.""" + mx.random.seed(42) + vae = self._make_small_wan22_decoder() + + # Input smaller than any tile size + z = mx.random.normal((1, 2, 2, 2, 48)) + mx.eval(z) + + out_regular = vae(z) + mx.eval(out_regular) + + out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) + mx.eval(out_tiled) + + np.testing.assert_allclose( + np.array(out_regular), np.array(out_tiled), + rtol=1e-4, atol=1e-4, + err_msg="Tiled decode should match regular decode for small inputs", + ) + + +class TestWan21TiledDecoding: + """Tests for Wan2.1 VAE tiled decoding.""" + + def _make_small_wan21_vae(self): + """Create a small Wan2.1 VAE for testing.""" + from mlx_video.models.wan.vae import WanVAE + + vae = WanVAE(z_dim=16) + mx.eval(vae.parameters()) + return vae + + def test_decode_tiled_output_shape(self): + """Tiled decode should produce correct output shape.""" + mx.random.seed(42) + vae = self._make_small_wan21_vae() + + # [B=1, C=16, T=3, H=4, W=4] + z = mx.random.normal((1, 16, 3, 4, 4)) + mx.eval(z) + + out_regular = vae.decode(z) + mx.eval(out_regular) + + out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) + mx.eval(out_tiled) + + assert out_regular.shape == out_tiled.shape, ( + f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}" + ) + + def test_decode_tiled_falls_through_when_small(self): + """When input is smaller than tile size, decode_tiled should produce same output as decode.""" + mx.random.seed(42) + vae = self._make_small_wan21_vae() + + z = mx.random.normal((1, 16, 2, 4, 4)) + mx.eval(z) + + out_regular = vae.decode(z) + mx.eval(out_regular) + + out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default()) + mx.eval(out_tiled) + + np.testing.assert_allclose( + np.array(out_regular), np.array(out_tiled), + rtol=1e-4, atol=1e-4, + err_msg="Tiled decode should match regular decode for small inputs", + ) + + +class TestWan21TemporalScale: + """Verify Wan2.1 decoder temporal output is T*4 (non-causal).""" + + def test_wan21_decoder_temporal_output(self): + """Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling).""" + from mlx_video.models.wan.vae import Decoder3d + + # Small decoder for fast test + dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1, + temporal_upsample=[True, True, False]) + mx.eval(dec.parameters()) + + x = mx.random.normal((1, 4, 3, 4, 4)) # T=3 + mx.eval(x) + out = dec(x) + mx.eval(out) + + # With two temporal 2× upsamples: T=3 → 6 → 12 + assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"