Refactor LTX-2 model structure
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.convert import (
|
||||
load_transformer_weights,
|
||||
load_vae_weights,
|
||||
@@ -9,7 +9,7 @@ from mlx_video.convert import (
|
||||
)
|
||||
|
||||
# Audio VAE components
|
||||
from mlx_video.models.ltx.audio_vae import (
|
||||
from mlx_video.models.ltx_2.audio_vae import (
|
||||
AudioDecoder,
|
||||
Vocoder,
|
||||
decode_audio,
|
||||
@@ -19,7 +19,7 @@ from mlx_video.models.ltx.audio_vae import (
|
||||
)
|
||||
|
||||
# Conditioning
|
||||
from mlx_video.conditioning import (
|
||||
from mlx_video.models.ltx_2.conditioning import (
|
||||
VideoConditionByLatentIndex,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
@@ -7,8 +7,8 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel
|
||||
|
||||
|
||||
def get_model_path(
|
||||
@@ -639,8 +639,8 @@ def convert_audio_encoder(
|
||||
raw_weights = mx.load(vae_path)
|
||||
|
||||
# Extract encoder weights and per-channel statistics
|
||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
# Build config from the decoder config (same audio VAE architecture)
|
||||
decoder_config_path = model_path / "audio_vae" / "config.json"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,2 +1,2 @@
|
||||
|
||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
@@ -1,8 +0,0 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
8
mlx_video/models/ltx_2/__init__.py
Normal file
8
mlx_video/models/ltx_2/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
@@ -6,8 +6,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
@@ -168,7 +168,7 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
@@ -380,7 +380,7 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
@@ -355,9 +355,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
])
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
@@ -26,14 +26,14 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
|
||||
|
||||
Usage:
|
||||
# From HF repo ID
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
|
||||
# From local folder containing the monolithic safetensors
|
||||
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
|
||||
# From a direct safetensors file path
|
||||
python -m mlx_video.models.ltx.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||
"""
|
||||
|
||||
import argparse
|
||||
2566
mlx_video/models/ltx_2/generate.py
Normal file
2566
mlx_video/models/ltx_2/generate.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,16 +3,16 @@ from typing import List, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
from mlx_video.models.ltx.config import (
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx.transformer import (
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
BasicAVTransformerBlock,
|
||||
Modality,
|
||||
TransformerArgs,
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
@@ -15,7 +15,7 @@ from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
@@ -4,9 +4,9 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx.attention import Attention
|
||||
from mlx_video.models.ltx.feed_forward import FeedForward
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
|
||||
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
@@ -21,10 +21,10 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -6,7 +6,7 @@ to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
@@ -7,15 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx.video_vae.resnet import (
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx.video_vae.sampling import (
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
@@ -229,7 +229,7 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -409,7 +409,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
@@ -1,32 +0,0 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
act_fn: str = "gelu_tanh",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
@@ -2,10 +2,10 @@ import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.rope import (
|
||||
from mlx_video.models.ltx_2.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
compute_trapezoidal_mask_1d,
|
||||
decode_with_tiling,
|
||||
|
||||
Reference in New Issue
Block a user