This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -8,12 +8,15 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.ops import (
PerChannelStatistics,
patchify,
unpatchify,
)
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
self.per_channel_statistics = PerChannelStatistics(
latent_channels=config.out_channels
)
# After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2
in_channels = config.in_channels * config.patch_size**2
feature_channels = config.out_channels
# Initial convolution
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_encoder_block(
block_name=block_name,
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
elif config.latent_log_var in {
LogVarianceType.UNIFORM,
LogVarianceType.CONSTANT,
}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
sample = mx.concatenate(
[
sample,
mx.full_like(sample, approx_ln_0),
],
axis=1,
)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
means = sample[:, : self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
out_channels = out_channels * patch_size**2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_decoder_block(
block_name=block_name,