This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional, Dict
from pathlib import Path
from typing import Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
in_channels=256, time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
def __call__(
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
ada_values = self.scale_shift_table[
None, :, :, None, None, None
] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
timestep_conditioning=timestep_conditioning,
)
for i in range(num_layers)
}
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
timestep.flatten(), hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Build up blocks from config
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
num_layers = (
block_def[2] if len(block_def) > 2 else num_layers_per_block
)
self.up_blocks[idx] = ResBlockGroup(
ch, num_layers, spatial_padding_mode, timestep_conditioning
)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
if (
".conv.conv.weight" not in new_key
and ".conv.conv.bias" not in new_key
):
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
def from_pretrained(
cls, model_path: Path, strict: bool = True
) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
scaled_timestep.flatten(), hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ada_values = self.last_scale_shift_table[
None, :, :, None, None, None
] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
x = self.act(x)
x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
use_chunked_conv = tiling_mode in (
"conservative",
"none",
"auto",
"default",
"spatial",
)
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return self(
sample,
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
)
return decode_with_tiling(
decoder_fn=self,