Refactor LTX-2 model structure

This commit is contained in:
Prince Canuma
2026-03-16 14:50:01 +01:00
parent decb3eb9e5
commit 3a0da19adb
50 changed files with 3882 additions and 3365 deletions

View File

@@ -1,4 +1,4 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.convert import ( from mlx_video.convert import (
load_transformer_weights, load_transformer_weights,
load_vae_weights, load_vae_weights,
@@ -9,7 +9,7 @@ from mlx_video.convert import (
) )
# Audio VAE components # Audio VAE components
from mlx_video.models.ltx.audio_vae import ( from mlx_video.models.ltx_2.audio_vae import (
AudioDecoder, AudioDecoder,
Vocoder, Vocoder,
decode_audio, decode_audio,
@@ -19,7 +19,7 @@ from mlx_video.models.ltx.audio_vae import (
) )
# Conditioning # Conditioning
from mlx_video.conditioning import ( from mlx_video.models.ltx_2.conditioning import (
VideoConditionByLatentIndex, VideoConditionByLatentIndex,
) )

View File

@@ -1,3 +0,0 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -7,8 +7,8 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx_2.ltx import LTXModel
def get_model_path( def get_model_path(
@@ -639,8 +639,8 @@ def convert_audio_encoder(
raw_weights = mx.load(vae_path) raw_weights = mx.load(vae_path)
# Extract encoder weights and per-channel statistics # Extract encoder weights and per-channel statistics
from mlx_video.models.ltx.audio_vae import AudioEncoder from mlx_video.models.ltx_2.audio_vae import AudioEncoder
from mlx_video.models.ltx.config import AudioEncoderModelConfig from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture) # Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "config.json" decoder_config_path = model_path / "audio_vae" / "config.json"

File diff suppressed because it is too large Load Diff

View File

@@ -1,2 +1,2 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -6,8 +6,8 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType from mlx_video.models.ltx_2.config import LTXRopeType
from mlx_video.models.ltx.rope import apply_rotary_emb from mlx_video.models.ltx_2.rope import apply_rotary_emb
def scaled_dot_product_attention( def scaled_dot_product_attention(

View File

@@ -168,7 +168,7 @@ class AudioEncoder(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder": def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights.""" """Load audio encoder from pretrained weights."""
from mlx_video.models.ltx.config import AudioEncoderModelConfig from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
import json import json
model_path = Path(model_path) model_path = Path(model_path)
@@ -380,7 +380,7 @@ class AudioDecoder(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder": def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model.""" """Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
import json import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))

View File

@@ -0,0 +1,3 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -355,9 +355,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
]) ])
def __post_init__(self): def __post_init__(self):
from mlx_video.models.ltx.video_vae.resnet import NormLayerType from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
if self.norm_layer is None: if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM self.norm_layer = NormLayerType.PIXEL_NORM

View File

@@ -26,14 +26,14 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
Usage: Usage:
# From HF repo ID # From HF repo ID
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
# From local folder containing the monolithic safetensors # From local folder containing the monolithic safetensors
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
# From a direct safetensors file path # From a direct safetensors file path
python -m mlx_video.models.ltx.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
""" """
import argparse import argparse

File diff suppressed because it is too large Load Diff

View File

@@ -3,16 +3,16 @@ from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from pathlib import Path from pathlib import Path
from mlx_video.models.ltx.config import ( from mlx_video.models.ltx_2.config import (
LTXModelConfig, LTXModelConfig,
LTXModelType, LTXModelType,
LTXRopeType, LTXRopeType,
TransformerConfig, TransformerConfig,
) )
from mlx_video.models.ltx.adaln import AdaLayerNormSingle from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx.rope import precompute_freqs_cis from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx.transformer import ( from mlx_video.models.ltx_2.transformer import (
BasicAVTransformerBlock, BasicAVTransformerBlock,
Modality, Modality,
TransformerArgs, TransformerArgs,

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
from mlx_video.models.ltx.config import LTXRopeType from mlx_video.models.ltx_2.config import LTXRopeType
def apply_rotary_emb( def apply_rotary_emb(

View File

@@ -15,7 +15,7 @@ from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig from mlx_vlm.models.gemma3.config import TextConfig

View File

@@ -4,9 +4,9 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -21,10 +21,10 @@ from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding( def get_timestep_embedding(

View File

@@ -6,7 +6,7 @@ to latent space, which can then be used to condition video generation.
""" """
import mlx.core as mx import mlx.core as mx
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -6,7 +6,7 @@ from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm from mlx_video.utils import PixelNorm

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module): class SpaceToDepthDownsample(nn.Module):

View File

@@ -7,15 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx.video_vae.resnet import ( from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType, NormLayerType,
ResnetBlock3D, ResnetBlock3D,
UNetMidBlock3D, UNetMidBlock3D,
get_norm_layer, get_norm_layer,
) )
from mlx_video.models.ltx.video_vae.sampling import ( from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample, DepthToSpaceUpsample,
SpaceToDepthDownsample, SpaceToDepthDownsample,
) )
@@ -229,7 +229,7 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters config: VideoEncoderModelConfig with encoder parameters
""" """
super().__init__() super().__init__()
from mlx_video.models.ltx.config import VideoEncoderModelConfig from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.norm_layer = config.norm_layer self.norm_layer = config.norm_layer
@@ -409,7 +409,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance Loaded VideoEncoder instance
""" """
import json import json
from mlx_video.models.ltx.config import VideoEncoderModelConfig from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config # Load config
config_path = model_path / "config.json" config_path = model_path / "config.json"

View File

@@ -1,32 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
act_fn: str = "gelu_tanh",
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
if act_fn == "gelu_tanh":
self.act = nn.GELU(approx="tanh")
elif act_fn == "silu":
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

View File

@@ -2,10 +2,10 @@ import pytest
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from mlx_video.models.ltx.rope import ( from mlx_video.models.ltx_2.rope import (
precompute_freqs_cis, precompute_freqs_cis,
) )
from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
def create_video_position_grid( def create_video_position_grid(

View File

@@ -4,8 +4,8 @@ import pytest
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import ( from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig, TilingConfig,
compute_trapezoidal_mask_1d, compute_trapezoidal_mask_1d,
decode_with_tiling, decode_with_tiling,

1954
uv.lock generated

File diff suppressed because it is too large Load Diff