format
This commit is contained in:
@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization statistics for z_dim=16
|
||||
VAE_MEAN = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
VAE_STD = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
|
||||
|
||||
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
|
||||
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
],
|
||||
)
|
||||
|
||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||
out = self._conv3d(x)
|
||||
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
norm_dim = 1 if self.channel_first else -1
|
||||
# L2 normalize along channel dim (matches F.normalize)
|
||||
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
|
||||
norm = mx.sqrt(
|
||||
mx.clip(
|
||||
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
|
||||
)
|
||||
)
|
||||
return (x / norm) * self.scale * self.gamma
|
||||
|
||||
|
||||
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.residual = [
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
@@ -226,13 +266,16 @@ class Resample(nn.Module):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
|
||||
)
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
@@ -272,8 +315,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(
|
||||
x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||
]
|
||||
|
||||
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user