initial commit (LTX-2)

This commit is contained in:
Prince Canuma
2026-01-11 23:48:33 +01:00
parent 9f01d22750
commit d1ca36a315
29 changed files with 7124 additions and 0 deletions

View File

@@ -0,0 +1,120 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert video to patches.
Moves spatial pixels from H, W dimensions to channel dimension.
Args:
x: Input tensor of shape (B, C, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
"""
b, c, f, h, w = x.shape
# Check dimensions are divisible
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
assert f % patch_size_t == 0
# New dimensions
new_h = h // patch_size_hw
new_w = w // patch_size_hw
new_f = f // patch_size_t
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
return x
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert patches back to video.
Inverse of patchify - moves pixels from channel dimension back to spatial.
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
Args:
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
"""
b, c_packed, f, h, w = x.shape
# Calculate original channel count
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
# Permute to interleave patches with spatial dims:
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
return x
class PerChannelStatistics(nn.Module):
def __init__(self, latent_channels: int = 128):
super().__init__()
self.latent_channels = latent_channels
# Learnable per-channel mean and std
self.mean = mx.zeros((latent_channels,))
self.std = mx.ones((latent_channels,))
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latents using per-channel statistics.
Args:
x: Input tensor of shape (B, C, ...)
Returns:
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return (x - mean) / std
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
Args:
x: Normalized tensor of shape (B, C, ...)
Returns:
Denormalized tensor
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return x * std + mean