Refactor LTX-2 model structure
This commit is contained in:
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Latent-based conditioning for I2V (Image-to-Video) generation.
|
||||
|
||||
This module provides conditioning that injects encoded image latents into
|
||||
the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConditionByLatentIndex:
|
||||
"""Condition video generation by injecting latents at a specific frame index.
|
||||
|
||||
This replaces the latent at the specified frame index with the conditioned
|
||||
latent and controls how much denoising is applied via the strength parameter.
|
||||
|
||||
Args:
|
||||
latent: Encoded image latent of shape (B, C, 1, H, W)
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
def get_num_latent_frames(self) -> int:
|
||||
"""Get number of latent frames in the conditioning."""
|
||||
return self.latent.shape[2]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatentState:
|
||||
"""State for latent diffusion with conditioning support.
|
||||
|
||||
Attributes:
|
||||
latent: Current noisy latent (B, C, F, H, W)
|
||||
clean_latent: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
"""Create a copy of the state."""
|
||||
return LatentState(
|
||||
latent=self.latent,
|
||||
clean_latent=self.clean_latent,
|
||||
denoise_mask=self.denoise_mask,
|
||||
)
|
||||
|
||||
|
||||
def create_initial_state(
|
||||
shape: Tuple[int, ...],
|
||||
seed: Optional[int] = None,
|
||||
noise_scale: float = 1.0,
|
||||
) -> LatentState:
|
||||
"""Create initial noisy latent state.
|
||||
|
||||
Args:
|
||||
shape: Shape of latent (B, C, F, H, W)
|
||||
seed: Optional random seed
|
||||
noise_scale: Scale for initial noise (sigma)
|
||||
|
||||
Returns:
|
||||
Initial LatentState with random noise
|
||||
"""
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
return LatentState(
|
||||
latent=noise * noise_scale,
|
||||
clean_latent=mx.zeros(shape),
|
||||
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
|
||||
)
|
||||
|
||||
|
||||
def apply_conditioning(
|
||||
state: LatentState,
|
||||
conditionings: List[VideoConditionByLatentIndex],
|
||||
) -> LatentState:
|
||||
"""Apply conditioning items to a latent state.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
conditionings: List of conditioning items to apply
|
||||
|
||||
Returns:
|
||||
Updated LatentState with conditioning applied
|
||||
"""
|
||||
state = state.clone()
|
||||
dtype = state.latent.dtype
|
||||
b, c, f, h, w = state.latent.shape
|
||||
|
||||
for cond in conditionings:
|
||||
cond_latent = cond.latent
|
||||
frame_idx = cond.frame_idx
|
||||
strength = cond.strength
|
||||
|
||||
# Validate shapes
|
||||
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
|
||||
if (cond_c, cond_h, cond_w) != (c, h, w):
|
||||
raise ValueError(
|
||||
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
|
||||
f"does not match target shape ({c}, {h}, {w})"
|
||||
)
|
||||
|
||||
if frame_idx >= f:
|
||||
raise ValueError(
|
||||
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
|
||||
)
|
||||
|
||||
# Get the conditioning frames count
|
||||
num_cond_frames = cond_f
|
||||
end_idx = min(frame_idx + num_cond_frames, f)
|
||||
|
||||
# Replace latent at conditioning position
|
||||
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
|
||||
latent_list = []
|
||||
clean_list = []
|
||||
mask_list = []
|
||||
|
||||
for i in range(f):
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
state.denoise_mask = mx.concatenate(mask_list, axis=2)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def apply_denoise_mask(
|
||||
denoised: mx.array,
|
||||
clean: mx.array,
|
||||
denoise_mask: mx.array,
|
||||
) -> mx.array:
|
||||
"""Blend denoised output with clean state based on mask.
|
||||
|
||||
Args:
|
||||
denoised: Denoised latent (B, C, F, H, W)
|
||||
clean: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
|
||||
|
||||
Returns:
|
||||
Blended latent
|
||||
"""
|
||||
one = mx.array(1.0, dtype=denoised.dtype)
|
||||
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||
|
||||
|
||||
def add_noise_with_state(
|
||||
state: LatentState,
|
||||
noise_scale: float,
|
||||
) -> LatentState:
|
||||
"""Add noise to state while respecting conditioning.
|
||||
|
||||
For conditioned frames (mask < 1.0), adds noise proportionally
|
||||
to allow some refinement while preserving the conditioning.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
noise_scale: Scale for noise (sigma)
|
||||
|
||||
Returns:
|
||||
Updated state with noise added
|
||||
"""
|
||||
state = state.clone()
|
||||
|
||||
# Generate noise
|
||||
noise = mx.random.normal(state.latent.shape)
|
||||
|
||||
# For fully conditioned frames (mask=0), we want to add minimal noise
|
||||
# For unconditioned frames (mask=1), we want full noise
|
||||
# noisy = noise * sigma + latent * (1 - sigma)
|
||||
# But we scale sigma by the mask for conditioned regions
|
||||
|
||||
effective_scale = noise_scale * state.denoise_mask
|
||||
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||
|
||||
return state
|
||||
Reference in New Issue
Block a user