Refactor LTX-2 model structure
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||||
from mlx_video.convert import (
|
from mlx_video.convert import (
|
||||||
load_transformer_weights,
|
load_transformer_weights,
|
||||||
load_vae_weights,
|
load_vae_weights,
|
||||||
@@ -9,7 +9,7 @@ from mlx_video.convert import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio VAE components
|
# Audio VAE components
|
||||||
from mlx_video.models.ltx.audio_vae import (
|
from mlx_video.models.ltx_2.audio_vae import (
|
||||||
AudioDecoder,
|
AudioDecoder,
|
||||||
Vocoder,
|
Vocoder,
|
||||||
decode_audio,
|
decode_audio,
|
||||||
@@ -19,7 +19,7 @@ from mlx_video.models.ltx.audio_vae import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Conditioning
|
# Conditioning
|
||||||
from mlx_video.conditioning import (
|
from mlx_video.models.ltx_2.conditioning import (
|
||||||
VideoConditionByLatentIndex,
|
VideoConditionByLatentIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
"""Conditioning modules for LTX-2 video generation."""
|
|
||||||
|
|
||||||
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
|
||||||
@@ -7,8 +7,8 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType
|
||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx_2.ltx import LTXModel
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(
|
def get_model_path(
|
||||||
@@ -639,8 +639,8 @@ def convert_audio_encoder(
|
|||||||
raw_weights = mx.load(vae_path)
|
raw_weights = mx.load(vae_path)
|
||||||
|
|
||||||
# Extract encoder weights and per-channel statistics
|
# Extract encoder weights and per-channel statistics
|
||||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
||||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||||
|
|
||||||
# Build config from the decoder config (same audio VAE architecture)
|
# Build config from the decoder config (same audio VAE architecture)
|
||||||
decoder_config_path = model_path / "audio_vae" / "config.json"
|
decoder_config_path = model_path / "audio_vae" / "config.json"
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,2 +1,2 @@
|
|||||||
|
|
||||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
|
|
||||||
from mlx_video.models.ltx.config import (
|
|
||||||
LTXModelConfig,
|
|
||||||
TransformerConfig,
|
|
||||||
LTXModelType,
|
|
||||||
)
|
|
||||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
|
||||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
|
||||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
|
||||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
|
||||||
from mlx_video.models.ltx.video_vae.tiling import (
|
|
||||||
TilingConfig,
|
|
||||||
SpatialTilingConfig,
|
|
||||||
TemporalTilingConfig,
|
|
||||||
)
|
|
||||||
8
mlx_video/models/ltx_2/__init__.py
Normal file
8
mlx_video/models/ltx_2/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
from mlx_video.models.ltx_2.config import (
|
||||||
|
LTXModelConfig,
|
||||||
|
TransformerConfig,
|
||||||
|
LTXModelType,
|
||||||
|
)
|
||||||
|
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||||
|
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||||
@@ -6,8 +6,8 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXRopeType
|
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||||
from mlx_video.models.ltx.rope import apply_rotary_emb
|
from mlx_video.models.ltx_2.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(
|
def scaled_dot_product_attention(
|
||||||
@@ -168,7 +168,7 @@ class AudioEncoder(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||||
"""Load audio encoder from pretrained weights."""
|
"""Load audio encoder from pretrained weights."""
|
||||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||||
import json
|
import json
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
@@ -380,7 +380,7 @@ class AudioDecoder(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||||
"""Load audio VAE decoder from pretrained model."""
|
"""Load audio VAE decoder from pretrained model."""
|
||||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||||
import json
|
import json
|
||||||
|
|
||||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||||
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Conditioning modules for LTX-2 video generation."""
|
||||||
|
|
||||||
|
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||||
@@ -355,9 +355,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
|||||||
])
|
])
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||||
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||||
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||||
|
|
||||||
if self.norm_layer is None:
|
if self.norm_layer is None:
|
||||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||||
@@ -26,14 +26,14 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# From HF repo ID
|
# From HF repo ID
|
||||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||||
|
|
||||||
# From local folder containing the monolithic safetensors
|
# From local folder containing the monolithic safetensors
|
||||||
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||||
|
|
||||||
# From a direct safetensors file path
|
# From a direct safetensors file path
|
||||||
python -m mlx_video.models.ltx.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
2566
mlx_video/models/ltx_2/generate.py
Normal file
2566
mlx_video/models/ltx_2/generate.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,16 +3,16 @@ from typing import List, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from mlx_video.models.ltx.config import (
|
from mlx_video.models.ltx_2.config import (
|
||||||
LTXModelConfig,
|
LTXModelConfig,
|
||||||
LTXModelType,
|
LTXModelType,
|
||||||
LTXRopeType,
|
LTXRopeType,
|
||||||
TransformerConfig,
|
TransformerConfig,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
|
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||||
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
|
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||||
from mlx_video.models.ltx.transformer import (
|
from mlx_video.models.ltx_2.transformer import (
|
||||||
BasicAVTransformerBlock,
|
BasicAVTransformerBlock,
|
||||||
Modality,
|
Modality,
|
||||||
TransformerArgs,
|
TransformerArgs,
|
||||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXRopeType
|
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
def apply_rotary_emb(
|
||||||
@@ -15,7 +15,7 @@ from rich.console import Console
|
|||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm, apply_quantization
|
from mlx_video.utils import rms_norm, apply_quantization
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||||
|
|
||||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||||
from mlx_vlm.models.gemma3.config import TextConfig
|
from mlx_vlm.models.gemma3.config import TextConfig
|
||||||
@@ -4,9 +4,9 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||||
from mlx_video.models.ltx.attention import Attention
|
from mlx_video.models.ltx_2.attention import Attention
|
||||||
from mlx_video.models.ltx.feed_forward import FeedForward
|
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm
|
||||||
|
|
||||||
|
|
||||||
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||||
|
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||||
|
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||||
|
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||||
|
TilingConfig,
|
||||||
|
SpatialTilingConfig,
|
||||||
|
TemporalTilingConfig,
|
||||||
|
)
|
||||||
@@ -21,10 +21,10 @@ from pathlib import Path
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
@@ -6,7 +6,7 @@ to latent space, which can then be used to condition video generation.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.utils import PixelNorm
|
from mlx_video.utils import PixelNorm
|
||||||
|
|
||||||
|
|
||||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
|
|
||||||
|
|
||||||
class SpaceToDepthDownsample(nn.Module):
|
class SpaceToDepthDownsample(nn.Module):
|
||||||
@@ -7,15 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||||
from mlx_video.models.ltx.video_vae.resnet import (
|
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||||
NormLayerType,
|
NormLayerType,
|
||||||
ResnetBlock3D,
|
ResnetBlock3D,
|
||||||
UNetMidBlock3D,
|
UNetMidBlock3D,
|
||||||
get_norm_layer,
|
get_norm_layer,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx.video_vae.sampling import (
|
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||||
DepthToSpaceUpsample,
|
DepthToSpaceUpsample,
|
||||||
SpaceToDepthDownsample,
|
SpaceToDepthDownsample,
|
||||||
)
|
)
|
||||||
@@ -229,7 +229,7 @@ class VideoEncoder(nn.Module):
|
|||||||
config: VideoEncoderModelConfig with encoder parameters
|
config: VideoEncoderModelConfig with encoder parameters
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||||
|
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.norm_layer = config.norm_layer
|
self.norm_layer = config.norm_layer
|
||||||
@@ -409,7 +409,7 @@ class VideoEncoder(nn.Module):
|
|||||||
Loaded VideoEncoder instance
|
Loaded VideoEncoder instance
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class PixArtAlphaTextProjection(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
hidden_size: int,
|
|
||||||
out_features: int | None = None,
|
|
||||||
bias: bool = True,
|
|
||||||
act_fn: str = "gelu_tanh",
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
out_features = out_features or hidden_size
|
|
||||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
|
||||||
if act_fn == "gelu_tanh":
|
|
||||||
self.act = nn.GELU(approx="tanh")
|
|
||||||
elif act_fn == "silu":
|
|
||||||
self.act = nn.SiLU()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
|
||||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
x = self.linear1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.linear2(x)
|
|
||||||
return x
|
|
||||||
@@ -2,10 +2,10 @@ import pytest
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mlx_video.models.ltx.rope import (
|
from mlx_video.models.ltx_2.rope import (
|
||||||
precompute_freqs_cis,
|
precompute_freqs_cis,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType
|
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||||
|
|
||||||
|
|
||||||
def create_video_position_grid(
|
def create_video_position_grid(
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import pytest
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||||
from mlx_video.models.ltx.video_vae.tiling import (
|
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||||
TilingConfig,
|
TilingConfig,
|
||||||
compute_trapezoidal_mask_1d,
|
compute_trapezoidal_mask_1d,
|
||||||
decode_with_tiling,
|
decode_with_tiling,
|
||||||
|
|||||||
Reference in New Issue
Block a user