Refactor LTX-2 model structure

This commit is contained in:
Prince Canuma
2026-03-16 14:50:01 +01:00
parent decb3eb9e5
commit 3a0da19adb
50 changed files with 3882 additions and 3365 deletions

View File

@@ -1,2 +1,2 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -6,8 +6,8 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx.rope import apply_rotary_emb
from mlx_video.models.ltx_2.config import LTXRopeType
from mlx_video.models.ltx_2.rope import apply_rotary_emb
def scaled_dot_product_attention(

View File

@@ -168,7 +168,7 @@ class AudioEncoder(nn.Module):
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
from mlx_video.models.ltx.config import AudioEncoderModelConfig
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
import json
model_path = Path(model_path)
@@ -380,7 +380,7 @@ class AudioDecoder(nn.Module):
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))

View File

@@ -0,0 +1,3 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -0,0 +1,199 @@
"""Latent-based conditioning for I2V (Image-to-Video) generation.
This module provides conditioning that injects encoded image latents into
the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
import mlx.core as mx
@dataclass
class VideoConditionByLatentIndex:
"""Condition video generation by injecting latents at a specific frame index.
This replaces the latent at the specified frame index with the conditioned
latent and controls how much denoising is applied via the strength parameter.
Args:
latent: Encoded image latent of shape (B, C, 1, H, W)
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
def get_num_latent_frames(self) -> int:
"""Get number of latent frames in the conditioning."""
return self.latent.shape[2]
@dataclass
class LatentState:
"""State for latent diffusion with conditioning support.
Attributes:
latent: Current noisy latent (B, C, F, H, W)
clean_latent: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
def clone(self) -> "LatentState":
"""Create a copy of the state."""
return LatentState(
latent=self.latent,
clean_latent=self.clean_latent,
denoise_mask=self.denoise_mask,
)
def create_initial_state(
shape: Tuple[int, ...],
seed: Optional[int] = None,
noise_scale: float = 1.0,
) -> LatentState:
"""Create initial noisy latent state.
Args:
shape: Shape of latent (B, C, F, H, W)
seed: Optional random seed
noise_scale: Scale for initial noise (sigma)
Returns:
Initial LatentState with random noise
"""
if seed is not None:
mx.random.seed(seed)
noise = mx.random.normal(shape)
return LatentState(
latent=noise * noise_scale,
clean_latent=mx.zeros(shape),
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
)
def apply_conditioning(
state: LatentState,
conditionings: List[VideoConditionByLatentIndex],
) -> LatentState:
"""Apply conditioning items to a latent state.
Args:
state: Current latent state
conditionings: List of conditioning items to apply
Returns:
Updated LatentState with conditioning applied
"""
state = state.clone()
dtype = state.latent.dtype
b, c, f, h, w = state.latent.shape
for cond in conditionings:
cond_latent = cond.latent
frame_idx = cond.frame_idx
strength = cond.strength
# Validate shapes
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
if (cond_c, cond_h, cond_w) != (c, h, w):
raise ValueError(
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
f"does not match target shape ({c}, {h}, {w})"
)
if frame_idx >= f:
raise ValueError(
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
)
# Get the conditioning frames count
num_cond_frames = cond_f
end_idx = min(frame_idx + num_cond_frames, f)
# Replace latent at conditioning position
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
latent_list = []
clean_list = []
mask_list = []
for i in range(f):
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)
state.denoise_mask = mx.concatenate(mask_list, axis=2)
return state
def apply_denoise_mask(
denoised: mx.array,
clean: mx.array,
denoise_mask: mx.array,
) -> mx.array:
"""Blend denoised output with clean state based on mask.
Args:
denoised: Denoised latent (B, C, F, H, W)
clean: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
Returns:
Blended latent
"""
one = mx.array(1.0, dtype=denoised.dtype)
return denoised * denoise_mask + clean * (one - denoise_mask)
def add_noise_with_state(
state: LatentState,
noise_scale: float,
) -> LatentState:
"""Add noise to state while respecting conditioning.
For conditioned frames (mask < 1.0), adds noise proportionally
to allow some refinement while preserving the conditioning.
Args:
state: Current latent state
noise_scale: Scale for noise (sigma)
Returns:
Updated state with noise added
"""
state = state.clone()
# Generate noise
noise = mx.random.normal(state.latent.shape)
# For fully conditioned frames (mask=0), we want to add minimal noise
# For unconditioned frames (mask=1), we want full noise
# noisy = noise * sigma + latent * (1 - sigma)
# But we scale sigma by the mask for conditioned regions
effective_scale = noise_scale * state.denoise_mask
one = mx.array(1.0, dtype=state.latent.dtype)
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
return state

View File

@@ -355,9 +355,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
])
def __post_init__(self):
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM

View File

@@ -26,14 +26,14 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
Usage:
# From HF repo ID
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
# From local folder containing the monolithic safetensors
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
# From a direct safetensors file path
python -m mlx_video.models.ltx.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
"""
import argparse

File diff suppressed because it is too large Load Diff

View File

@@ -3,16 +3,16 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx.config import (
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx.transformer import (
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import (
BasicAVTransformerBlock,
Modality,
TransformerArgs,

View File

@@ -0,0 +1,165 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
image: Input image as uint8 numpy array (H, W, C)
d: Diameter of each pixel neighborhood
sigma_color: Filter sigma in the color space
sigma_space: Filter sigma in the coordinate space
Returns:
Filtered image
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
return gaussian_blur(image, kernel_size=3)
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""Apply Gaussian blur.
Args:
image: Input image as uint8 numpy array (H, W, C)
kernel_size: Size of the Gaussian kernel (must be odd)
Returns:
Blurred image
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
image: Input image as uint8 numpy array
kernel_size: Size of the Gaussian kernel
sigma: Gaussian sigma
amount: Strength of sharpening
Returns:
Sharpened image
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
except ImportError:
return image
def reduce_grid_artifacts(
video: np.ndarray,
method: str = "bilateral",
strength: float = 1.0,
) -> np.ndarray:
"""Reduce grid artifacts in video frames.
Args:
video: Video as numpy array (F, H, W, C) uint8
method: "bilateral", "gaussian", or "frequency"
strength: How strong to apply the filter (0-1)
Returns:
Processed video
"""
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
else:
raise ValueError(f"Unknown method: {method}")
# Optionally sharpen to recover some detail
if strength < 1.0:
# Blend with original based on strength
alpha = strength
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
return processed
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
"""Remove grid-frequency components using FFT.
Args:
frame: Input frame (H, W, C) uint8
grid_size: Expected grid periodicity in pixels
Returns:
Filtered frame
"""
result = np.zeros_like(frame)
for c in range(frame.shape[2]):
channel = frame[:, :, c].astype(np.float32)
h, w = channel.shape
# FFT
fft = np.fft.fft2(channel)
fft_shifted = np.fft.fftshift(fft)
# Create notch filter at grid frequencies
cy, cx = h // 2, w // 2
mask = np.ones((h, w), dtype=np.float32)
# Attenuate frequencies at grid periodicity
freq_y = h // grid_size
freq_x = w // grid_size
for fy in range(-2, 3):
for fx in range(-2, 3):
if fy == 0 and fx == 0:
continue
y_pos = cy + fy * freq_y
x_pos = cx + fx * freq_x
if 0 <= y_pos < h and 0 <= x_pos < w:
# Gaussian attenuation around the frequency
for dy in range(-2, 3):
for dx in range(-2, 3):
yy, xx = y_pos + dy, x_pos + dx
if 0 <= yy < h and 0 <= xx < w:
dist = np.sqrt(dy**2 + dx**2)
mask[yy, xx] *= min(1.0, dist / 3.0)
# Apply mask and inverse FFT
fft_filtered = fft_shifted * mask
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx_2.config import LTXRopeType
def apply_rotary_emb(

View File

@@ -0,0 +1,181 @@
"""Second-order res_2s sampler for diffusion models.
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
phi_1(z) = (e^z - 1) / z
phi_2(z) = (e^z - 1 - z) / z^2
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
"""
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
def get_res2s_coefficients(
h: float,
phi_cache: dict,
c2: float = 0.5,
) -> tuple[float, float, float]:
"""Compute res_2s Runge-Kutta coefficients for a given step size.
Args:
h: Step size in log-space = log(sigma / sigma_next)
phi_cache: Dictionary to cache phi function results.
c2: Substep position (default 0.5 = midpoint)
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
return phi_cache[cache_key]
result = phi(j, neg_h)
phi_cache[cache_key] = result
return result
neg_h_c2 = -h * c2
phi_1_c2 = get_phi(1, neg_h_c2)
a21 = c2 * phi_1_c2
neg_h_full = -h
phi_2_full = get_phi(2, neg_h_full)
b2 = phi_2_full / c2
phi_1_full = get_phi(1, neg_h_full)
b1 = phi_1_full - b2
return a21, b1, b2
# ---------------------------------------------------------------------------
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
"""Compute SDE coefficients for variance-preserving noise injection.
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
Returns:
(alpha_ratio, sigma_down, sigma_up)
"""
sigma_up = sigma_next * 0.5
# Clamp sigma_up to avoid sqrt(negative)
sigma_up = min(sigma_up, sigma_next * 0.9999)
sigma_signal = 1.0 - sigma_next # sigma_max=1
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
alpha_ratio = sigma_signal + sigma_residual
if alpha_ratio == 0:
sigma_down = sigma_next
else:
sigma_down = sigma_residual / alpha_ratio
# Handle NaN edge cases
if math.isnan(sigma_up):
sigma_up = 0.0
if math.isnan(sigma_down):
sigma_down = sigma_next
if math.isnan(alpha_ratio):
alpha_ratio = 1.0
return alpha_ratio, sigma_down, sigma_up
def sde_noise_step(
sample: mx.array,
denoised_sample: mx.array,
sigma: float,
sigma_next: float,
noise: mx.array,
) -> mx.array:
"""Apply SDE noise injection step.
Advances sample from sigma to sigma_next with stochastic noise injection.
Args:
sample: Current sample (anchor point)
denoised_sample: Denoised prediction at this step
sigma: Current noise level
sigma_next: Next noise level
noise: Pre-generated noise tensor (channel-wise normalized)
Returns:
Noised sample at sigma_next
"""
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
if sigma_up == 0 or sigma_next == 0:
return denoised_sample
# Float32 arithmetic
sample_f32 = sample.astype(mx.float32)
denoised_f32 = denoised_sample.astype(mx.float32)
noise_f32 = noise.astype(mx.float32)
# Extract epsilon prediction
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
return x_noised
# ---------------------------------------------------------------------------
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.
Operates on the last 2 dimensions (spatial H, W or time, freq).
"""
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
x = x - mean
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
x = x / std
return x
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
"""Generate channel-wise normalized Gaussian noise.
PyTorch uses float64; we use float32 (MLX doesn't support float64).
The channel-wise normalization is the key quality-affecting step.
Args:
shape: Shape of the noise tensor
key: MLX random key for deterministic generation
Returns:
Channel-wise normalized noise in float32
"""
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
# Global normalization
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
# Channel-wise normalization
noise = channelwise_normalize(noise)
return noise

View File

@@ -15,7 +15,7 @@ from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig

View File

@@ -4,9 +4,9 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -21,10 +21,10 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding(

View File

@@ -6,7 +6,7 @@ to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -6,7 +6,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):

View File

@@ -7,15 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx.video_vae.resnet import (
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx.video_vae.sampling import (
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
@@ -229,7 +229,7 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx.config import VideoEncoderModelConfig
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
@@ -409,7 +409,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx.config import VideoEncoderModelConfig
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"