127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
"""Operations for Video VAE."""
|
|
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
|
"""Convert video to patches.
|
|
|
|
Moves spatial pixels from H, W dimensions to channel dimension.
|
|
|
|
Args:
|
|
x: Input tensor of shape (B, C, F, H, W)
|
|
patch_size_hw: Spatial patch size
|
|
patch_size_t: Temporal patch size
|
|
|
|
Returns:
|
|
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
|
"""
|
|
b, c, f, h, w = x.shape
|
|
|
|
# Check dimensions are divisible
|
|
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
|
assert f % patch_size_t == 0
|
|
|
|
# New dimensions
|
|
new_h = h // patch_size_hw
|
|
new_w = w // patch_size_hw
|
|
new_f = f // patch_size_t
|
|
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
|
|
|
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
|
x = mx.reshape(
|
|
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
|
|
)
|
|
|
|
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
|
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
|
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
|
|
|
|
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
|
|
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
|
|
|
return x
|
|
|
|
|
|
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
|
"""Convert patches back to video.
|
|
|
|
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
|
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
|
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
|
|
|
Args:
|
|
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
|
patch_size_hw: Spatial patch size
|
|
patch_size_t: Temporal patch size
|
|
|
|
Returns:
|
|
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
|
"""
|
|
b, c_packed, f, h, w = x.shape
|
|
|
|
# Calculate original channel count
|
|
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
|
|
|
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
|
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
|
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
|
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
|
|
|
# Permute to interleave patches with spatial dims:
|
|
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
|
|
|
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
|
|
|
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
|
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
|
|
|
return x
|
|
|
|
|
|
class PerChannelStatistics(nn.Module):
|
|
|
|
def __init__(self, latent_channels: int = 128):
|
|
|
|
super().__init__()
|
|
self.latent_channels = latent_channels
|
|
|
|
# Learnable per-channel mean and std
|
|
self.mean = mx.zeros((latent_channels,))
|
|
self.std = mx.ones((latent_channels,))
|
|
|
|
def normalize(self, x: mx.array) -> mx.array:
|
|
"""Normalize latents using per-channel statistics.
|
|
|
|
Args:
|
|
x: Input tensor of shape (B, C, ...)
|
|
|
|
Returns:
|
|
Normalized tensor
|
|
"""
|
|
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
|
dtype = x.dtype
|
|
# Cast to float32 for precision
|
|
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
|
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
|
|
|
return ((x - mean) / std).astype(dtype)
|
|
|
|
def un_normalize(self, x: mx.array) -> mx.array:
|
|
"""Denormalize latents using per-channel statistics.
|
|
|
|
Args:
|
|
x: Normalized tensor of shape (B, C, ...)
|
|
|
|
Returns:
|
|
Denormalized tensor
|
|
"""
|
|
dtype = x.dtype
|
|
# Cast to float32 for precision
|
|
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
|
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
|
|
|
return (x * std + mean).astype(dtype)
|