feat(wan): Add chunked VAE encoding and TI2V-5B support
This commit is contained in:
@@ -29,13 +29,37 @@ from mlx_video.utils import Colors
|
|||||||
_build_i2v_mask = build_i2v_mask
|
_build_i2v_mask = build_i2v_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _best_output_size(w, h, dw, dh, max_area):
|
||||||
|
"""Compute the best output resolution that fits within max_area while
|
||||||
|
preserving the input aspect ratio and satisfying alignment constraints.
|
||||||
|
Matches the reference implementation's best_output_size().
|
||||||
|
"""
|
||||||
|
ratio = w / h
|
||||||
|
ow = (max_area * ratio) ** 0.5
|
||||||
|
oh = max_area / ow
|
||||||
|
|
||||||
|
# Option 1: process width first
|
||||||
|
ow1 = int(ow // dw * dw)
|
||||||
|
oh1 = int(max_area / ow1 // dh * dh)
|
||||||
|
ratio1 = ow1 / oh1
|
||||||
|
|
||||||
|
# Option 2: process height first
|
||||||
|
oh2 = int(oh // dh * dh)
|
||||||
|
ow2 = int(max_area / oh2 // dw * dw)
|
||||||
|
ratio2 = ow2 / oh2
|
||||||
|
|
||||||
|
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
|
||||||
|
return ow1, oh1
|
||||||
|
return ow2, oh2
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
model_dir: str,
|
model_dir: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str | None = None,
|
negative_prompt: str | None = None,
|
||||||
image: str | None = None,
|
image: str | None = None,
|
||||||
width: int = 1280,
|
width: int = 1280,
|
||||||
height: int = 720,
|
height: int = 704,
|
||||||
num_frames: int = 81,
|
num_frames: int = 81,
|
||||||
steps: int = None,
|
steps: int = None,
|
||||||
guide_scale: str | float | tuple = None,
|
guide_scale: str | float | tuple = None,
|
||||||
@@ -232,6 +256,15 @@ def generate_video(
|
|||||||
width = align_w
|
width = align_w
|
||||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||||
|
|
||||||
|
# Enforce max_area constraint (model-specific resolution limit)
|
||||||
|
if config.max_area > 0 and height * width > config.max_area:
|
||||||
|
old_h, old_w = height, width
|
||||||
|
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
||||||
|
print(
|
||||||
|
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||||
|
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
# Compute target latent shape
|
# Compute target latent shape
|
||||||
z_dim = config.vae_z_dim
|
z_dim = config.vae_z_dim
|
||||||
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||||
@@ -334,7 +367,7 @@ def generate_video(
|
|||||||
mx.eval(img_tensor)
|
mx.eval(img_tensor)
|
||||||
|
|
||||||
vae_enc = load_vae_encoder(vae_path, config)
|
vae_enc = load_vae_encoder(vae_path, config)
|
||||||
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||||
mx.eval(z_img)
|
mx.eval(z_img)
|
||||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||||
@@ -658,8 +691,8 @@ def main():
|
|||||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||||
help="Disable negative prompt (use empty string instead of config default)")
|
help="Disable negative prompt (use empty string instead of config default)")
|
||||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
||||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
||||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
"杂乱的背景,三条腿,背景人很多,倒着走"
|
"杂乱的背景,三条腿,背景人很多,倒着走"
|
||||||
)
|
)
|
||||||
|
|
||||||
# T5
|
# Resolution constraints
|
||||||
|
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
|
||||||
t5_vocab_size: int = 256384
|
t5_vocab_size: int = 256384
|
||||||
t5_dim: int = 4096
|
t5_dim: int = 4096
|
||||||
t5_dim_attn: int = 4096
|
t5_dim_attn: int = 4096
|
||||||
@@ -102,7 +103,8 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
boundary=0.900,
|
boundary=0.900,
|
||||||
sample_shift=5.0,
|
sample_shift=5.0,
|
||||||
sample_guide_scale=(3.5, 3.5),
|
sample_guide_scale=(3.5, 3.5),
|
||||||
)
|
max_area=704 * 1280,
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||||
@@ -120,7 +122,8 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
dual_model=False,
|
dual_model=False,
|
||||||
boundary=0.0,
|
boundary=0.0,
|
||||||
sample_shift=5.0,
|
sample_shift=5.0,
|
||||||
sample_steps=50,
|
sample_steps=40,
|
||||||
sample_guide_scale=5.0,
|
sample_guide_scale=5.0,
|
||||||
sample_fps=24,
|
sample_fps=24,
|
||||||
)
|
max_area=704 * 1280,
|
||||||
|
|
||||||
|
|||||||
@@ -56,9 +56,11 @@ class CausalConv3d(nn.Module):
|
|||||||
|
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
# Causal temporal padding: always kernel_size-1 on the left.
|
# Causal temporal padding: matches the reference CausalConv3d(nn.Conv3d)
|
||||||
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
|
# which converts symmetric padding to causal: 2*padding[0] on the left.
|
||||||
self._causal_pad_t = kernel_size[0] - 1
|
# For most convs (kernel=3, padding=1): 2*1 = 2 (same as kernel-1).
|
||||||
|
# For downsample time_conv (kernel=3, padding=0): 2*0 = 0 (NO padding).
|
||||||
|
self._causal_pad_t = 2 * padding[0]
|
||||||
self._pad_h = padding[1]
|
self._pad_h = padding[1]
|
||||||
self._pad_w = padding[2]
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
@@ -68,7 +70,7 @@ class CausalConv3d(nn.Module):
|
|||||||
))
|
))
|
||||||
self.bias = mx.zeros((out_channels,))
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, cache_x=None):
|
||||||
# x: [B, T, H, W, C]
|
# x: [B, T, H, W, C]
|
||||||
B, T, H, W, C = x.shape
|
B, T, H, W, C = x.shape
|
||||||
kd, kh, kw = self.kernel_size
|
kd, kh, kw = self.kernel_size
|
||||||
@@ -81,10 +83,15 @@ class CausalConv3d(nn.Module):
|
|||||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||||
|
|
||||||
# Causal temporal padding (left only) — zeros match the reference
|
# Causal temporal padding: prepend cached frames if available,
|
||||||
# implementation and what the model was trained with.
|
# then zero-pad any remaining positions.
|
||||||
if self._causal_pad_t > 0:
|
pad_needed = self._causal_pad_t
|
||||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
if cache_x is not None and pad_needed > 0:
|
||||||
|
x = mx.concatenate([cache_x, x], axis=1)
|
||||||
|
pad_needed -= cache_x.shape[1]
|
||||||
|
|
||||||
|
if pad_needed > 0:
|
||||||
|
pad_t = mx.zeros((B, pad_needed, H, W, C), dtype=x.dtype)
|
||||||
x = mx.concatenate([pad_t, x], axis=1)
|
x = mx.concatenate([pad_t, x], axis=1)
|
||||||
|
|
||||||
# Spatial padding
|
# Spatial padding
|
||||||
@@ -144,9 +151,9 @@ class ResidualBlock(nn.Module):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
h = self.shortcut(x) if self.shortcut is not None else x
|
h = self.shortcut(x) if self.shortcut is not None else x
|
||||||
return self.residual(x) + h
|
return self.residual(x, feat_cache, feat_idx) + h
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlockLayers(nn.Module):
|
class ResidualBlockLayers(nn.Module):
|
||||||
@@ -169,13 +176,33 @@ class ResidualBlockLayers(nn.Module):
|
|||||||
# Index 6: CausalConv3d
|
# Index 6: CausalConv3d
|
||||||
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
|
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||||
|
|
||||||
def __call__(self, x):
|
def _conv_with_cache(self, conv, x, feat_cache, feat_idx):
|
||||||
|
"""Apply CausalConv3d with temporal caching for chunked encoding."""
|
||||||
|
idx = feat_idx[0]
|
||||||
|
# Save last CACHE_T frames before conv (for next chunk's context)
|
||||||
|
cache_x = x[:, -CACHE_T:]
|
||||||
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = mx.concatenate(
|
||||||
|
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||||
|
)
|
||||||
|
out = conv(x, cache_x=feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
x = self.layer_0(x)
|
x = self.layer_0(x)
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = self._conv_with_cache(self.layer_2, x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
x = self.layer_2(x)
|
x = self.layer_2(x)
|
||||||
mx.eval(x) # Eval between convolutions to limit graph size
|
mx.eval(x) # Eval between convolutions to limit graph size
|
||||||
x = self.layer_3(x)
|
x = self.layer_3(x)
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = self._conv_with_cache(self.layer_6, x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
x = self.layer_6(x)
|
x = self.layer_6(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -344,53 +371,34 @@ class Resample(nn.Module):
|
|||||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||||
|
|
||||||
def __call__(self, x, first_chunk=False):
|
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
||||||
# x: [B, T, H, W, C]
|
# x: [B, T, H, W, C]
|
||||||
B, T, H, W, C = x.shape
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
|
# --- Temporal upsample (before spatial, matching reference) ---
|
||||||
if self.mode == "upsample3d":
|
if self.mode == "upsample3d":
|
||||||
if first_chunk and T > 1:
|
if first_chunk and T > 1:
|
||||||
# Match official chunked behavior: the first frame bypasses
|
first_frame = x[:, 0:1]
|
||||||
# time_conv entirely (only spatial upsample). Remaining frames
|
rest = x[:, 1:]
|
||||||
# go through time_conv with causal zero-padding, which
|
tc_out = self.time_conv(rest)
|
||||||
# naturally gives each frame the same limited temporal context
|
|
||||||
# as the official frame-by-frame decode with caching.
|
|
||||||
first_frame = x[:, 0:1] # [B, 1, H, W, C]
|
|
||||||
rest = x[:, 1:] # [B, T-1, H, W, C]
|
|
||||||
|
|
||||||
# time_conv on remaining frames (causal pad gives zero context
|
|
||||||
# before rest[0], matching the official "Rep" cache path)
|
|
||||||
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
|
|
||||||
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
|
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
|
||||||
stream0 = tc_out[:, :, :, :, 0, :]
|
stream0 = tc_out[:, :, :, :, 0, :]
|
||||||
stream1 = tc_out[:, :, :, :, 1, :]
|
stream1 = tc_out[:, :, :, :, 1, :]
|
||||||
interleaved = mx.stack([stream0, stream1], axis=2)
|
interleaved = mx.stack([stream0, stream1], axis=2)
|
||||||
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
|
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
|
||||||
|
|
||||||
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
|
|
||||||
x = mx.concatenate([first_frame, interleaved], axis=1)
|
x = mx.concatenate([first_frame, interleaved], axis=1)
|
||||||
elif self.mode == "upsample3d":
|
else:
|
||||||
# Non-first-chunk or single frame: time_conv all frames
|
tc_out = self.time_conv(x)
|
||||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
|
||||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||||
stream0 = tc_out[:, :, :, :, 0, :]
|
stream0 = tc_out[:, :, :, :, 0, :]
|
||||||
stream1 = tc_out[:, :, :, :, 1, :]
|
stream1 = tc_out[:, :, :, :, 1, :]
|
||||||
x = mx.stack([stream0, stream1], axis=2)
|
x = mx.stack([stream0, stream1], axis=2)
|
||||||
x = x.reshape(B, T * 2, H, W, C)
|
x = x.reshape(B, T * 2, H, W, C)
|
||||||
|
|
||||||
mx.eval(x)
|
|
||||||
T = x.shape[1]
|
|
||||||
|
|
||||||
if self.mode == "downsample3d" and T > 1:
|
|
||||||
# Temporal downsample via strided CausalConv3d
|
|
||||||
# Skip for T=1 (single frame) — matches official chunked encoding
|
|
||||||
# where first chunk stores cache but doesn't apply time_conv
|
|
||||||
x = self.time_conv(x)
|
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
T = x.shape[1]
|
T = x.shape[1]
|
||||||
|
|
||||||
|
# --- Spatial operation (all modes, matching reference line 152-155) ---
|
||||||
if self.mode in ("upsample2d", "upsample3d"):
|
if self.mode in ("upsample2d", "upsample3d"):
|
||||||
# Spatial upsample in temporal chunks to limit peak memory
|
|
||||||
chunk_size = 8
|
chunk_size = 8
|
||||||
chunks = []
|
chunks = []
|
||||||
for t_start in range(0, T, chunk_size):
|
for t_start in range(0, T, chunk_size):
|
||||||
@@ -400,18 +408,36 @@ class Resample(nn.Module):
|
|||||||
x_chunk = self._conv2d(x_chunk)
|
x_chunk = self._conv2d(x_chunk)
|
||||||
mx.eval(x_chunk)
|
mx.eval(x_chunk)
|
||||||
chunks.append(x_chunk)
|
chunks.append(x_chunk)
|
||||||
|
|
||||||
x = mx.concatenate(chunks, axis=0)
|
x = mx.concatenate(chunks, axis=0)
|
||||||
H2, W2 = x.shape[1], x.shape[2]
|
H2, W2 = x.shape[1], x.shape[2]
|
||||||
x = x.reshape(B, T, H2, W2, C)
|
x = x.reshape(B, T, H2, W2, C)
|
||||||
elif self.mode in ("downsample2d", "downsample3d"):
|
elif self.mode in ("downsample2d", "downsample3d"):
|
||||||
# Spatial downsample: per-frame strided Conv2d
|
|
||||||
x_flat = x.reshape(B * T, H, W, C)
|
x_flat = x.reshape(B * T, H, W, C)
|
||||||
x_flat = self._downsample_conv2d(x_flat)
|
x_flat = self._downsample_conv2d(x_flat)
|
||||||
mx.eval(x_flat)
|
mx.eval(x_flat)
|
||||||
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
||||||
x = x_flat.reshape(B, T, H2, W2, C)
|
x = x_flat.reshape(B, T, H2, W2, C)
|
||||||
|
|
||||||
|
# --- Temporal downsample (after spatial, matching reference line 157-168) ---
|
||||||
|
if self.mode == "downsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
# First chunk: store spatially-downsampled result, skip time_conv
|
||||||
|
feat_cache[idx] = x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
# Subsequent chunks: prepend cached last frame, apply time_conv
|
||||||
|
save_x = x[:, -1:]
|
||||||
|
x = self.time_conv(
|
||||||
|
mx.concatenate([feat_cache[idx][:, -1:], x], axis=1)
|
||||||
|
)
|
||||||
|
feat_cache[idx] = save_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
elif T > 1:
|
||||||
|
x = self.time_conv(x)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -488,11 +514,19 @@ class Down_ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
self.downsamples = blocks
|
self.downsamples = blocks
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
x_shortcut = self.avg_shortcut(x)
|
x_shortcut = self.avg_shortcut(x)
|
||||||
mx.eval(x_shortcut)
|
mx.eval(x_shortcut)
|
||||||
|
|
||||||
for module in self.downsamples:
|
for module in self.downsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
if isinstance(module, ResidualBlock):
|
||||||
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
elif isinstance(module, Resample):
|
||||||
|
x = module(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
|
else:
|
||||||
x = module(x)
|
x = module(x)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
@@ -597,18 +631,51 @@ class Encoder3d(nn.Module):
|
|||||||
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
||||||
self.head = Head22(out_dim, out_channels=z_dim)
|
self.head = Head22(out_dim, out_channels=z_dim)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
# x: [B, T, H, W, 12] (patchified)
|
# x: [B, T, H, W, 12] (patchified)
|
||||||
x = self.conv1(x)
|
if feat_cache is not None:
|
||||||
|
return self._forward_cached(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
# No cache: internally chunk as 1+4+4+... (matches reference behavior)
|
||||||
|
num_convs = _count_conv3d(self)
|
||||||
|
internal_cache = [None] * num_convs
|
||||||
|
T = x.shape[1]
|
||||||
|
starts = [0] + list(range(1, T, 4))
|
||||||
|
ends = starts[1:] + [T]
|
||||||
|
outputs = []
|
||||||
|
for s, e in zip(starts, ends):
|
||||||
|
if s >= e:
|
||||||
|
continue
|
||||||
|
feat_idx_local = [0]
|
||||||
|
out = self._forward_cached(x[:, s:e], internal_cache, feat_idx_local)
|
||||||
|
outputs.append(out)
|
||||||
|
mx.eval(internal_cache)
|
||||||
|
if len(outputs) == 1:
|
||||||
|
return outputs[0]
|
||||||
|
return mx.concatenate(outputs, axis=1)
|
||||||
|
|
||||||
|
def _forward_cached(self, x, feat_cache, feat_idx):
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, -CACHE_T:]
|
||||||
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = mx.concatenate(
|
||||||
|
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||||
|
)
|
||||||
|
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
x = layer(x)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock):
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
x = self.head(x)
|
x = self.head(x, feat_cache, feat_idx)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -626,13 +693,38 @@ class Head22(nn.Module):
|
|||||||
# Index 2: CausalConv3d
|
# Index 2: CausalConv3d
|
||||||
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
|
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
x = self.layer_0(x)
|
x = self.layer_0(x)
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, -CACHE_T:]
|
||||||
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = mx.concatenate(
|
||||||
|
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||||
|
)
|
||||||
|
x = self.layer_2(x, cache_x=feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
x = self.layer_2(x)
|
x = self.layer_2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _count_conv3d(module):
|
||||||
|
"""Count all CausalConv3d instances in a module tree (for cache sizing)."""
|
||||||
|
count = 0
|
||||||
|
if isinstance(module, CausalConv3d):
|
||||||
|
count += 1
|
||||||
|
for child in module.children().values():
|
||||||
|
if isinstance(child, list):
|
||||||
|
for item in child:
|
||||||
|
count += _count_conv3d(item)
|
||||||
|
elif isinstance(child, nn.Module):
|
||||||
|
count += _count_conv3d(child)
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
class Wan22VAEEncoder(nn.Module):
|
class Wan22VAEEncoder(nn.Module):
|
||||||
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
||||||
|
|
||||||
@@ -649,8 +741,11 @@ class Wan22VAEEncoder(nn.Module):
|
|||||||
temperal_downsample=(False, True, True),
|
temperal_downsample=(False, True, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, img):
|
def encode(self, img):
|
||||||
"""Encode image/video to latent space.
|
"""Encode image/video using chunked encoding (1+4+4+... pattern).
|
||||||
|
|
||||||
|
This matches the reference implementation's chunked encoding with
|
||||||
|
persistent temporal cache, which is critical for correct I2V latents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img: [B, T, H, W, 3] image/video in [-1, 1]
|
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||||
@@ -658,11 +753,28 @@ class Wan22VAEEncoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||||
"""
|
"""
|
||||||
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
|
|
||||||
x = _patchify(img, patch_size=2)
|
x = _patchify(img, patch_size=2)
|
||||||
|
T = x.shape[1]
|
||||||
|
|
||||||
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
|
# Initialize temporal cache (one slot per CausalConv3d in encoder)
|
||||||
out = self.encoder(x)
|
num_convs = _count_conv3d(self.encoder)
|
||||||
|
feat_cache = [None] * num_convs
|
||||||
|
|
||||||
|
# Chunked encoding: first chunk = 1 frame, rest = 4 frames each
|
||||||
|
num_chunks = 1 + (T - 1) // 4
|
||||||
|
out = None
|
||||||
|
for i in range(num_chunks):
|
||||||
|
feat_idx = [0] # Reset layer index each chunk (but keep cache)
|
||||||
|
if i == 0:
|
||||||
|
chunk = x[:, :1]
|
||||||
|
else:
|
||||||
|
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
|
||||||
|
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||||
|
if out is None:
|
||||||
|
out = chunk_out
|
||||||
|
else:
|
||||||
|
out = mx.concatenate([out, chunk_out], axis=1)
|
||||||
|
mx.eval(out)
|
||||||
|
|
||||||
# conv1 (pointwise) + split into mu, log_var
|
# conv1 (pointwise) + split into mu, log_var
|
||||||
out = self.conv1(out)
|
out = self.conv1(out)
|
||||||
@@ -672,6 +784,17 @@ class Wan22VAEEncoder(nn.Module):
|
|||||||
mu = normalize_latents(mu)
|
mu = normalize_latents(mu)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""Encode image/video to latent space (delegates to chunked encode).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||||
|
"""
|
||||||
|
return self.encode(img)
|
||||||
|
|
||||||
|
|
||||||
class Wan22VAEDecoder(nn.Module):
|
class Wan22VAEDecoder(nn.Module):
|
||||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||||
|
|||||||
Reference in New Issue
Block a user