format
This commit is contained in:
@@ -8,12 +8,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.ops import (
|
||||
PerChannelStatistics,
|
||||
patchify,
|
||||
unpatchify,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
self.per_channel_statistics = PerChannelStatistics(
|
||||
latent_channels=config.out_channels
|
||||
)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
in_channels = config.in_channels * config.patch_size**2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
elif config.latent_log_var in {
|
||||
LogVarianceType.UNIFORM,
|
||||
LogVarianceType.CONSTANT,
|
||||
}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
sample = mx.concatenate(
|
||||
[
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
means = sample[:, : self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
|
||||
Reference in New Issue
Block a user