format
This commit is contained in:
@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
in_channels=256, time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
def __call__(
|
||||
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
ada_values = self.scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
num_layers = (
|
||||
block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
)
|
||||
self.up_blocks[idx] = ResBlockGroup(
|
||||
ch, num_layers, spatial_padding_mode, timestep_conditioning
|
||||
)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
if (
|
||||
".conv.conv.weight" not in new_key
|
||||
and ".conv.conv.bias" not in new_key
|
||||
):
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
def from_pretrained(
|
||||
cls, model_path: Path, strict: bool = True
|
||||
) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
scaled_timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
use_chunked_conv = tiling_mode in (
|
||||
"conservative",
|
||||
"none",
|
||||
"auto",
|
||||
"default",
|
||||
"spatial",
|
||||
)
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
return self(
|
||||
sample,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
|
||||
Reference in New Issue
Block a user