format
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
|
||||
:, :, ::-1, :, :
|
||||
] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
|
||||
:, :, :, ::-1, :
|
||||
] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
@@ -50,7 +54,7 @@ def make_conv_nd(
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
@@ -118,15 +122,17 @@ class CausalConv3d(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
first_frame_pad = mx.repeat(
|
||||
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
|
||||
)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
@@ -191,11 +196,10 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
|
||||
@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
in_channels=256, time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
def __call__(
|
||||
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
ada_values = self.scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
num_layers = (
|
||||
block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
)
|
||||
self.up_blocks[idx] = ResBlockGroup(
|
||||
ch, num_layers, spatial_padding_mode, timestep_conditioning
|
||||
)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
if (
|
||||
".conv.conv.weight" not in new_key
|
||||
and ".conv.conv.bias" not in new_key
|
||||
):
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
def from_pretrained(
|
||||
cls, model_path: Path, strict: bool = True
|
||||
) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
scaled_timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
use_chunked_conv = tiling_mode in (
|
||||
"conservative",
|
||||
"none",
|
||||
"auto",
|
||||
"default",
|
||||
"spatial",
|
||||
)
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
return self(
|
||||
sample,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
|
||||
@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
x = mx.reshape(
|
||||
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
|
||||
)
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
@@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module):
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
@@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module):
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
dtype = x.dtype
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
def __call__(
|
||||
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
|
||||
) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
def _chunked_conv_depth_to_space(
|
||||
self, x: mx.array, causal: bool = True
|
||||
) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
|
||||
@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
fade_out = [
|
||||
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
|
||||
]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
@@ -113,15 +127,21 @@ class TilingConfig:
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
|
||||
),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@@ -130,23 +150,33 @@ class TilingConfig:
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=256, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=32, tile_overlap_in_frames=8
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=768, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=96, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -186,10 +216,14 @@ class TilingConfig:
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
temporal_config = TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
@@ -197,16 +231,21 @@ class TilingConfig:
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_spatial(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
|
||||
)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_temporal(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts,
|
||||
ends=intervals.ends,
|
||||
left_ramps=left_ramps,
|
||||
right_ramps=intervals.right_ramps,
|
||||
)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_temporal_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, True
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_spatial_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, False
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
@@ -315,7 +373,9 @@ def decode_with_tiling(
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -338,7 +398,9 @@ def decode_with_tiling(
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -347,7 +409,9 @@ def decode_with_tiling(
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
@@ -356,13 +420,23 @@ def decode_with_tiling(
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -385,13 +459,15 @@ def decode_with_tiling(
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
@@ -409,11 +485,37 @@ def decode_with_tiling(
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
@@ -445,10 +547,12 @@ def decode_with_tiling(
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
@@ -456,7 +560,10 @@ def decode_with_tiling(
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
@@ -473,7 +580,7 @@ def decode_with_tiling(
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
@@ -481,7 +588,7 @@ def decode_with_tiling(
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
|
||||
@@ -8,12 +8,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.ops import (
|
||||
PerChannelStatistics,
|
||||
patchify,
|
||||
unpatchify,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
self.per_channel_statistics = PerChannelStatistics(
|
||||
latent_channels=config.out_channels
|
||||
)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
in_channels = config.in_channels * config.patch_size**2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
elif config.latent_log_var in {
|
||||
LogVarianceType.UNIFORM,
|
||||
LogVarianceType.CONSTANT,
|
||||
}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
sample = mx.concatenate(
|
||||
[
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
means = sample[:, : self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
|
||||
Reference in New Issue
Block a user