Refactor LTX-2 model structure
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
|
||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
@@ -1,8 +0,0 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
8
mlx_video/models/ltx_2/__init__.py
Normal file
8
mlx_video/models/ltx_2/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
@@ -6,8 +6,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
@@ -168,7 +168,7 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
@@ -380,7 +380,7 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Latent-based conditioning for I2V (Image-to-Video) generation.
|
||||
|
||||
This module provides conditioning that injects encoded image latents into
|
||||
the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConditionByLatentIndex:
|
||||
"""Condition video generation by injecting latents at a specific frame index.
|
||||
|
||||
This replaces the latent at the specified frame index with the conditioned
|
||||
latent and controls how much denoising is applied via the strength parameter.
|
||||
|
||||
Args:
|
||||
latent: Encoded image latent of shape (B, C, 1, H, W)
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
def get_num_latent_frames(self) -> int:
|
||||
"""Get number of latent frames in the conditioning."""
|
||||
return self.latent.shape[2]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatentState:
|
||||
"""State for latent diffusion with conditioning support.
|
||||
|
||||
Attributes:
|
||||
latent: Current noisy latent (B, C, F, H, W)
|
||||
clean_latent: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
"""Create a copy of the state."""
|
||||
return LatentState(
|
||||
latent=self.latent,
|
||||
clean_latent=self.clean_latent,
|
||||
denoise_mask=self.denoise_mask,
|
||||
)
|
||||
|
||||
|
||||
def create_initial_state(
|
||||
shape: Tuple[int, ...],
|
||||
seed: Optional[int] = None,
|
||||
noise_scale: float = 1.0,
|
||||
) -> LatentState:
|
||||
"""Create initial noisy latent state.
|
||||
|
||||
Args:
|
||||
shape: Shape of latent (B, C, F, H, W)
|
||||
seed: Optional random seed
|
||||
noise_scale: Scale for initial noise (sigma)
|
||||
|
||||
Returns:
|
||||
Initial LatentState with random noise
|
||||
"""
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
return LatentState(
|
||||
latent=noise * noise_scale,
|
||||
clean_latent=mx.zeros(shape),
|
||||
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
|
||||
)
|
||||
|
||||
|
||||
def apply_conditioning(
|
||||
state: LatentState,
|
||||
conditionings: List[VideoConditionByLatentIndex],
|
||||
) -> LatentState:
|
||||
"""Apply conditioning items to a latent state.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
conditionings: List of conditioning items to apply
|
||||
|
||||
Returns:
|
||||
Updated LatentState with conditioning applied
|
||||
"""
|
||||
state = state.clone()
|
||||
dtype = state.latent.dtype
|
||||
b, c, f, h, w = state.latent.shape
|
||||
|
||||
for cond in conditionings:
|
||||
cond_latent = cond.latent
|
||||
frame_idx = cond.frame_idx
|
||||
strength = cond.strength
|
||||
|
||||
# Validate shapes
|
||||
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
|
||||
if (cond_c, cond_h, cond_w) != (c, h, w):
|
||||
raise ValueError(
|
||||
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
|
||||
f"does not match target shape ({c}, {h}, {w})"
|
||||
)
|
||||
|
||||
if frame_idx >= f:
|
||||
raise ValueError(
|
||||
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
|
||||
)
|
||||
|
||||
# Get the conditioning frames count
|
||||
num_cond_frames = cond_f
|
||||
end_idx = min(frame_idx + num_cond_frames, f)
|
||||
|
||||
# Replace latent at conditioning position
|
||||
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
|
||||
latent_list = []
|
||||
clean_list = []
|
||||
mask_list = []
|
||||
|
||||
for i in range(f):
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
state.denoise_mask = mx.concatenate(mask_list, axis=2)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def apply_denoise_mask(
|
||||
denoised: mx.array,
|
||||
clean: mx.array,
|
||||
denoise_mask: mx.array,
|
||||
) -> mx.array:
|
||||
"""Blend denoised output with clean state based on mask.
|
||||
|
||||
Args:
|
||||
denoised: Denoised latent (B, C, F, H, W)
|
||||
clean: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
|
||||
|
||||
Returns:
|
||||
Blended latent
|
||||
"""
|
||||
one = mx.array(1.0, dtype=denoised.dtype)
|
||||
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||
|
||||
|
||||
def add_noise_with_state(
|
||||
state: LatentState,
|
||||
noise_scale: float,
|
||||
) -> LatentState:
|
||||
"""Add noise to state while respecting conditioning.
|
||||
|
||||
For conditioned frames (mask < 1.0), adds noise proportionally
|
||||
to allow some refinement while preserving the conditioning.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
noise_scale: Scale for noise (sigma)
|
||||
|
||||
Returns:
|
||||
Updated state with noise added
|
||||
"""
|
||||
state = state.clone()
|
||||
|
||||
# Generate noise
|
||||
noise = mx.random.normal(state.latent.shape)
|
||||
|
||||
# For fully conditioned frames (mask=0), we want to add minimal noise
|
||||
# For unconditioned frames (mask=1), we want full noise
|
||||
# noisy = noise * sigma + latent * (1 - sigma)
|
||||
# But we scale sigma by the mask for conditioned regions
|
||||
|
||||
effective_scale = noise_scale * state.denoise_mask
|
||||
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||
|
||||
return state
|
||||
@@ -355,9 +355,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
])
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
@@ -26,14 +26,14 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
|
||||
|
||||
Usage:
|
||||
# From HF repo ID
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
|
||||
# From local folder containing the monolithic safetensors
|
||||
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
|
||||
# From a direct safetensors file path
|
||||
python -m mlx_video.models.ltx.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||
"""
|
||||
|
||||
import argparse
|
||||
2566
mlx_video/models/ltx_2/generate.py
Normal file
2566
mlx_video/models/ltx_2/generate.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,16 +3,16 @@ from typing import List, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
from mlx_video.models.ltx.config import (
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx.transformer import (
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
BasicAVTransformerBlock,
|
||||
Modality,
|
||||
TransformerArgs,
|
||||
165
mlx_video/models/ltx_2/postprocess.py
Normal file
165
mlx_video/models/ltx_2/postprocess.py
Normal file
@@ -0,0 +1,165 @@
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array (H, W, C)
|
||||
d: Diameter of each pixel neighborhood
|
||||
sigma_color: Filter sigma in the color space
|
||||
sigma_space: Filter sigma in the coordinate space
|
||||
|
||||
Returns:
|
||||
Filtered image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||
except ImportError:
|
||||
# Fallback to simple Gaussian blur if cv2 not available
|
||||
return gaussian_blur(image, kernel_size=3)
|
||||
|
||||
|
||||
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
||||
"""Apply Gaussian blur.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array (H, W, C)
|
||||
kernel_size: Size of the Gaussian kernel (must be odd)
|
||||
|
||||
Returns:
|
||||
Blurred image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||
except ImportError:
|
||||
# Simple box blur fallback
|
||||
from scipy.ndimage import uniform_filter
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
||||
"""Apply unsharp masking to enhance edges after blur.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array
|
||||
kernel_size: Size of the Gaussian kernel
|
||||
sigma: Gaussian sigma
|
||||
amount: Strength of sharpening
|
||||
|
||||
Returns:
|
||||
Sharpened image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||
except ImportError:
|
||||
return image
|
||||
|
||||
|
||||
def reduce_grid_artifacts(
|
||||
video: np.ndarray,
|
||||
method: str = "bilateral",
|
||||
strength: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""Reduce grid artifacts in video frames.
|
||||
|
||||
Args:
|
||||
video: Video as numpy array (F, H, W, C) uint8
|
||||
method: "bilateral", "gaussian", or "frequency"
|
||||
strength: How strong to apply the filter (0-1)
|
||||
|
||||
Returns:
|
||||
Processed video
|
||||
"""
|
||||
if method == "bilateral":
|
||||
d = max(3, int(5 * strength))
|
||||
sigma = 50 + 50 * strength
|
||||
processed = np.stack([
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
])
|
||||
elif method == "gaussian":
|
||||
kernel_size = max(3, int(3 + 4 * strength))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
processed = np.stack([
|
||||
gaussian_blur(frame, kernel_size=kernel_size)
|
||||
for frame in video
|
||||
])
|
||||
elif method == "frequency":
|
||||
processed = np.stack([
|
||||
remove_grid_frequency(frame, grid_size=8)
|
||||
for frame in video
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
# Optionally sharpen to recover some detail
|
||||
if strength < 1.0:
|
||||
# Blend with original based on strength
|
||||
alpha = strength
|
||||
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
||||
"""Remove grid-frequency components using FFT.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C) uint8
|
||||
grid_size: Expected grid periodicity in pixels
|
||||
|
||||
Returns:
|
||||
Filtered frame
|
||||
"""
|
||||
result = np.zeros_like(frame)
|
||||
|
||||
for c in range(frame.shape[2]):
|
||||
channel = frame[:, :, c].astype(np.float32)
|
||||
h, w = channel.shape
|
||||
|
||||
# FFT
|
||||
fft = np.fft.fft2(channel)
|
||||
fft_shifted = np.fft.fftshift(fft)
|
||||
|
||||
# Create notch filter at grid frequencies
|
||||
cy, cx = h // 2, w // 2
|
||||
mask = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
# Attenuate frequencies at grid periodicity
|
||||
freq_y = h // grid_size
|
||||
freq_x = w // grid_size
|
||||
|
||||
for fy in range(-2, 3):
|
||||
for fx in range(-2, 3):
|
||||
if fy == 0 and fx == 0:
|
||||
continue
|
||||
y_pos = cy + fy * freq_y
|
||||
x_pos = cx + fx * freq_x
|
||||
if 0 <= y_pos < h and 0 <= x_pos < w:
|
||||
# Gaussian attenuation around the frequency
|
||||
for dy in range(-2, 3):
|
||||
for dx in range(-2, 3):
|
||||
yy, xx = y_pos + dy, x_pos + dx
|
||||
if 0 <= yy < h and 0 <= xx < w:
|
||||
dist = np.sqrt(dy**2 + dx**2)
|
||||
mask[yy, xx] *= min(1.0, dist / 3.0)
|
||||
|
||||
# Apply mask and inverse FFT
|
||||
fft_filtered = fft_shifted * mask
|
||||
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
|
||||
|
||||
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
181
mlx_video/models/ltx_2/samplers.py
Normal file
181
mlx_video/models/ltx_2/samplers.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Second-order res_2s sampler for diffusion models.
|
||||
|
||||
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
|
||||
noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
phi_1(z) = (e^z - 1) / z
|
||||
phi_2(z) = (e^z - 1 - z) / z^2
|
||||
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
|
||||
"""
|
||||
if abs(neg_h) < 1e-10:
|
||||
return 1.0 / math.factorial(j)
|
||||
|
||||
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
|
||||
return (math.exp(neg_h) - remainder) / (neg_h**j)
|
||||
|
||||
|
||||
def get_res2s_coefficients(
|
||||
h: float,
|
||||
phi_cache: dict,
|
||||
c2: float = 0.5,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute res_2s Runge-Kutta coefficients for a given step size.
|
||||
|
||||
Args:
|
||||
h: Step size in log-space = log(sigma / sigma_next)
|
||||
phi_cache: Dictionary to cache phi function results.
|
||||
c2: Substep position (default 0.5 = midpoint)
|
||||
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
return phi_cache[cache_key]
|
||||
result = phi(j, neg_h)
|
||||
phi_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
neg_h_c2 = -h * c2
|
||||
phi_1_c2 = get_phi(1, neg_h_c2)
|
||||
a21 = c2 * phi_1_c2
|
||||
|
||||
neg_h_full = -h
|
||||
phi_2_full = get_phi(2, neg_h_full)
|
||||
b2 = phi_2_full / c2
|
||||
|
||||
phi_1_full = get_phi(1, neg_h_full)
|
||||
b1 = phi_1_full - b2
|
||||
|
||||
return a21, b1, b2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute SDE coefficients for variance-preserving noise injection.
|
||||
|
||||
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
|
||||
|
||||
Returns:
|
||||
(alpha_ratio, sigma_down, sigma_up)
|
||||
"""
|
||||
sigma_up = sigma_next * 0.5
|
||||
# Clamp sigma_up to avoid sqrt(negative)
|
||||
sigma_up = min(sigma_up, sigma_next * 0.9999)
|
||||
|
||||
sigma_signal = 1.0 - sigma_next # sigma_max=1
|
||||
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
|
||||
alpha_ratio = sigma_signal + sigma_residual
|
||||
|
||||
if alpha_ratio == 0:
|
||||
sigma_down = sigma_next
|
||||
else:
|
||||
sigma_down = sigma_residual / alpha_ratio
|
||||
|
||||
# Handle NaN edge cases
|
||||
if math.isnan(sigma_up):
|
||||
sigma_up = 0.0
|
||||
if math.isnan(sigma_down):
|
||||
sigma_down = sigma_next
|
||||
if math.isnan(alpha_ratio):
|
||||
alpha_ratio = 1.0
|
||||
|
||||
return alpha_ratio, sigma_down, sigma_up
|
||||
|
||||
|
||||
def sde_noise_step(
|
||||
sample: mx.array,
|
||||
denoised_sample: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
noise: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply SDE noise injection step.
|
||||
|
||||
Advances sample from sigma to sigma_next with stochastic noise injection.
|
||||
|
||||
Args:
|
||||
sample: Current sample (anchor point)
|
||||
denoised_sample: Denoised prediction at this step
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
noise: Pre-generated noise tensor (channel-wise normalized)
|
||||
|
||||
Returns:
|
||||
Noised sample at sigma_next
|
||||
"""
|
||||
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
|
||||
|
||||
if sigma_up == 0 or sigma_next == 0:
|
||||
return denoised_sample
|
||||
|
||||
# Float32 arithmetic
|
||||
sample_f32 = sample.astype(mx.float32)
|
||||
denoised_f32 = denoised_sample.astype(mx.float32)
|
||||
noise_f32 = noise.astype(mx.float32)
|
||||
|
||||
# Extract epsilon prediction
|
||||
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
|
||||
return x_noised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
Operates on the last 2 dimensions (spatial H, W or time, freq).
|
||||
"""
|
||||
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
|
||||
x = x - mean
|
||||
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
|
||||
x = x / std
|
||||
return x
|
||||
|
||||
|
||||
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
|
||||
"""Generate channel-wise normalized Gaussian noise.
|
||||
|
||||
PyTorch uses float64; we use float32 (MLX doesn't support float64).
|
||||
The channel-wise normalization is the key quality-affecting step.
|
||||
|
||||
Args:
|
||||
shape: Shape of the noise tensor
|
||||
key: MLX random key for deterministic generation
|
||||
|
||||
Returns:
|
||||
Channel-wise normalized noise in float32
|
||||
"""
|
||||
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
|
||||
# Global normalization
|
||||
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
|
||||
# Channel-wise normalization
|
||||
noise = channelwise_normalize(noise)
|
||||
return noise
|
||||
@@ -15,7 +15,7 @@ from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
@@ -4,9 +4,9 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx.attention import Attention
|
||||
from mlx_video.models.ltx.feed_forward import FeedForward
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
|
||||
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
@@ -21,10 +21,10 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -6,7 +6,7 @@ to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
@@ -7,15 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx.video_vae.resnet import (
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx.video_vae.sampling import (
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
@@ -229,7 +229,7 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -409,7 +409,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
Reference in New Issue
Block a user