format
This commit is contained in:
@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization for z_dim=48 latent space
|
||||
VAE22_MEAN = mx.array([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
])
|
||||
VAE22_MEAN = mx.array(
|
||||
[
|
||||
-0.2289,
|
||||
-0.0052,
|
||||
-0.1323,
|
||||
-0.2339,
|
||||
-0.2799,
|
||||
0.0174,
|
||||
0.1838,
|
||||
0.1557,
|
||||
-0.1382,
|
||||
0.0542,
|
||||
0.2813,
|
||||
0.0891,
|
||||
0.1570,
|
||||
-0.0098,
|
||||
0.0375,
|
||||
-0.1825,
|
||||
-0.2246,
|
||||
-0.1207,
|
||||
-0.0698,
|
||||
0.5109,
|
||||
0.2665,
|
||||
-0.2108,
|
||||
-0.2158,
|
||||
0.2502,
|
||||
-0.2055,
|
||||
-0.0322,
|
||||
0.1109,
|
||||
0.1567,
|
||||
-0.0729,
|
||||
0.0899,
|
||||
-0.2799,
|
||||
-0.1230,
|
||||
-0.0313,
|
||||
-0.1649,
|
||||
0.0117,
|
||||
0.0723,
|
||||
-0.2839,
|
||||
-0.2083,
|
||||
-0.0520,
|
||||
0.3748,
|
||||
0.0152,
|
||||
0.1957,
|
||||
0.1433,
|
||||
-0.2944,
|
||||
0.3573,
|
||||
-0.0548,
|
||||
-0.1681,
|
||||
-0.0667,
|
||||
]
|
||||
)
|
||||
|
||||
VAE22_STD = mx.array([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
|
||||
])
|
||||
VAE22_STD = mx.array(
|
||||
[
|
||||
0.4765,
|
||||
1.0364,
|
||||
0.4514,
|
||||
1.1677,
|
||||
0.5313,
|
||||
0.4990,
|
||||
0.4818,
|
||||
0.5013,
|
||||
0.8158,
|
||||
1.0344,
|
||||
0.5894,
|
||||
1.0901,
|
||||
0.6885,
|
||||
0.6165,
|
||||
0.8454,
|
||||
0.4978,
|
||||
0.5759,
|
||||
0.3523,
|
||||
0.7135,
|
||||
0.6804,
|
||||
0.5833,
|
||||
1.4146,
|
||||
0.8986,
|
||||
0.5659,
|
||||
0.7069,
|
||||
0.5338,
|
||||
0.4889,
|
||||
0.4917,
|
||||
0.4069,
|
||||
0.4999,
|
||||
0.6866,
|
||||
0.4093,
|
||||
0.5709,
|
||||
0.6065,
|
||||
0.6415,
|
||||
0.4944,
|
||||
0.5726,
|
||||
1.2042,
|
||||
0.5458,
|
||||
1.6887,
|
||||
0.3971,
|
||||
1.0600,
|
||||
0.3943,
|
||||
0.5537,
|
||||
0.5444,
|
||||
0.4089,
|
||||
0.7468,
|
||||
0.7744,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# Weight: [O, D, H, W, I] for MLX
|
||||
self.weight = mx.zeros((
|
||||
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
|
||||
))
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
# Spatial padding
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w), (0, 0)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
(0, 0),
|
||||
],
|
||||
)
|
||||
|
||||
T_padded = x.shape[1]
|
||||
H_padded, W_padded = x.shape[2], x.shape[3]
|
||||
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
|
||||
for d in range(kd):
|
||||
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
||||
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
||||
conv_out = mx.conv_general(frame, w2d,
|
||||
stride=(self.stride[1], self.stride[2]))
|
||||
conv_out = mx.conv_general(
|
||||
frame, w2d, stride=(self.stride[1], self.stride[2])
|
||||
)
|
||||
accum = conv_out if accum is None else accum + conv_out
|
||||
outputs.append(accum + self.bias)
|
||||
|
||||
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.scale = dim**0.5
|
||||
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
|
||||
# x: [..., C] (channels-last)
|
||||
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
||||
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
||||
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||
return (
|
||||
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
|
||||
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
||||
# We store as named layers matching PyTorch's indices
|
||||
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
||||
self.shortcut = (
|
||||
CausalConv3d(in_dim, out_dim, 1)
|
||||
if in_dim != out_dim
|
||||
else None
|
||||
)
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
h = self.shortcut(x) if self.shortcut is not None else x
|
||||
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
|
||||
# Save last CACHE_T frames before conv (for next chunk's context)
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
out = conv(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
|
||||
x = self.norm(x)
|
||||
|
||||
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
||||
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
|
||||
qkv = (
|
||||
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
|
||||
) # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
||||
|
||||
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
|
||||
scale = C ** -0.5
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
|
||||
scale = C**-0.5
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale
|
||||
) # [BT, 1, HW, C]
|
||||
out = out.squeeze(1).reshape(B * T, H, W, C)
|
||||
|
||||
# Project output
|
||||
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
|
||||
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
||||
|
||||
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
||||
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
|
||||
x = x.reshape(
|
||||
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
|
||||
)
|
||||
|
||||
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
||||
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
|
||||
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
||||
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
|
||||
x = x.reshape(
|
||||
B,
|
||||
T * self.factor_t,
|
||||
H * self.factor_s,
|
||||
W * self.factor_s,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
if first_chunk:
|
||||
x = x[:, self.factor_t - 1:, :, :, :]
|
||||
x = x[:, self.factor_t - 1 :, :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
@@ -348,7 +452,9 @@ class Resample(nn.Module):
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
@@ -369,7 +475,9 @@ class Resample(nn.Module):
|
||||
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
return (
|
||||
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
)
|
||||
|
||||
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
||||
# x: [B, T, H, W, C]
|
||||
@@ -444,14 +552,17 @@ class Resample(nn.Module):
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
"""Upsampling residual block with optional DupUp3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
self.up_flag = up_flag
|
||||
|
||||
# DupUp3D shortcut (no learnable params)
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim, out_dim,
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
num_res_blocks,
|
||||
temperal_downsample=False,
|
||||
down_flag=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.down_flag = down_flag
|
||||
|
||||
# AvgDown3D shortcut (no learnable params, always present)
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim, out_dim,
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
|
||||
self.upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||
self.upsamples.append(Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks + 1,
|
||||
temperal_upsample=t_up,
|
||||
up_flag=(i != len(dim_mult) - 1),
|
||||
))
|
||||
self.upsamples.append(
|
||||
Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks + 1,
|
||||
temperal_upsample=t_up,
|
||||
up_flag=(i != len(dim_mult) - 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
||||
self.head = Head22(dims[-1])
|
||||
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
|
||||
for i in range(len(dim_mult)):
|
||||
in_d, out_d = dims[i], dims[i + 1]
|
||||
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
self.downsamples.append(Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
))
|
||||
self.downsamples.append(
|
||||
Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Middle blocks (same as decoder)
|
||||
out_dim = dims[-1]
|
||||
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -700,9 +821,7 @@ class Head22(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
x = self.layer_2(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
|
||||
if i == 0:
|
||||
chunk = x[:, :1]
|
||||
else:
|
||||
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
|
||||
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
|
||||
|
||||
# conv1 (pointwise) + split into mu, log_var
|
||||
out = self.conv1(out)
|
||||
mu = out[:, :, :, :, :self.z_dim]
|
||||
mu = out[:, :, :, :, : self.z_dim]
|
||||
|
||||
# Normalize
|
||||
mu = normalize_latents(mu)
|
||||
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_cf,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
causal_temporal=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user