Merge pull request #12 from Blaizzy/pc/add-vae-tiling

Add VAE Tiling + BFloat16 Support for Memory-Efficient Video Generation
This commit is contained in:
Prince Canuma
2026-01-21 15:41:46 +01:00
committed by GitHub
12 changed files with 1062 additions and 113 deletions

View File

@@ -95,6 +95,7 @@ def apply_conditioning(
Updated LatentState with conditioning applied Updated LatentState with conditioning applied
""" """
state = state.clone() state = state.clone()
dtype = state.latent.dtype
b, c, f, h, w = state.latent.shape b, c, f, h, w = state.latent.shape
for cond in conditionings: for cond in conditionings:
@@ -132,7 +133,7 @@ def apply_conditioning(
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
# Set mask: 1.0 - strength means less denoising for conditioned frames # Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength)) mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else: else:
# Keep original # Keep original
latent_list.append(state.latent[:, :, i:i+1]) latent_list.append(state.latent[:, :, i:i+1])
@@ -161,7 +162,8 @@ def apply_denoise_mask(
Returns: Returns:
Blended latent Blended latent
""" """
return denoised * denoise_mask + clean * (1.0 - denoise_mask) one = mx.array(1.0, dtype=denoised.dtype)
return denoised * denoise_mask + clean * (one - denoise_mask)
def add_noise_with_state( def add_noise_with_state(
@@ -191,6 +193,7 @@ def add_noise_with_state(
# But we scale sigma by the mask for conditioned regions # But we scale sigma by the mask for conditioned regions
effective_scale = noise_scale * state.denoise_mask effective_scale = noise_scale * state.denoise_mask
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale) one = mx.array(1.0, dtype=state.latent.dtype)
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
return state return state

View File

@@ -1,9 +1,10 @@
import argparse import argparse
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@@ -27,6 +28,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
@@ -109,6 +111,7 @@ def create_position_grid(
# Convert temporal to time in seconds by dividing by fps # Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
@@ -136,6 +139,7 @@ def denoise(
Denoised latent tensor Denoised latent tensor
""" """
# If state is provided, use its latent (which may have conditioning applied) # If state is provided, use its latent (which may have conditioning applied)
dtype = latents.dtype
if state is not None: if state is not None:
latents = state.latent latents = state.latent
@@ -153,11 +157,11 @@ def denoise(
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask # Per-token timesteps: sigma * mask (preserve dtype)
timesteps = sigma * denoise_mask_flat timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else: else:
# All tokens get the same timestep # All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma) timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
video_modality = Modality( video_modality = Modality(
latent=latents_flat, latent=latents_flat,
@@ -180,8 +184,11 @@ def denoise(
mx.eval(denoised) mx.eval(denoised)
# Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0: if sigma_next > 0:
latents = denoised + sigma_next * (latents - denoised) / sigma sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
else: else:
latents = denoised latents = denoised
mx.eval(latents) mx.eval(latents)
@@ -207,6 +214,7 @@ def generate_video(
image: Optional[str] = None, image: Optional[str] = None,
image_strength: float = 1.0, image_strength: float = 1.0,
image_frame_idx: int = 0, image_frame_idx: int = 0,
tiling: str = "auto",
): ):
"""Generate video from text prompt, optionally conditioned on an image. """Generate video from text prompt, optionally conditioned on an image.
@@ -228,6 +236,14 @@ def generate_video(
image: Path to conditioning image for I2V (Image-to-Video) image: Path to conditioning image for I2V (Image-to-Video)
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame) image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine based on video size (default)
- "none": Disable tiling
- "default": 512px spatial, 64 frame temporal
- "aggressive": 256px spatial, 32 frame temporal (lowest memory)
- "conservative": 768px spatial, 96 frame temporal (faster)
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
""" """
start_time = time.time() start_time = time.time()
@@ -273,6 +289,7 @@ def generate_video(
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
mx.eval(text_embeddings) mx.eval(text_embeddings)
del text_encoder del text_encoder
@@ -282,6 +299,8 @@ def generate_video(
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig( config = LTXModelConfig(
model_type=LTXModelType.VideoOnly, model_type=LTXModelType.VideoOnly,
@@ -300,7 +319,7 @@ def generate_video(
timestep_scale_multiplier=1000, timestep_scale_multiplier=1000,
) )
transformer = LTXModel(config) transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False) transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters()) mx.eval(transformer.parameters())
@@ -313,15 +332,15 @@ def generate_video(
mx.eval(vae_encoder.parameters()) mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution) # Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor) stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent) mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}") print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution) # Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width) input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor) stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent) mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}") print(f" Stage 2 image latent: {stage2_image_latent.shape}")
@@ -333,6 +352,7 @@ def generate_video(
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
@@ -343,24 +363,26 @@ def generate_video(
# Create initial state with zeros # Create initial state with zeros
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
state1 = LatentState( state1 = LatentState(
latent=mx.zeros(latent_shape), latent=mx.zeros(latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(latent_shape), clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent, latent=stage1_image_latent,
frame_idx=image_frame_idx, frame_idx=image_frame_idx,
strength=image_strength, strength=image_strength,
) )
state1 = apply_conditioning(state1, [conditioning]) state1 = apply_conditioning(state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 1, noise_scale = 1.0 (first sigma) # For Stage 1, noise_scale = 1.0 (first sigma)
noise = mx.random.normal(latent_shape) noise = mx.random.normal(latent_shape, dtype=model_dtype)
noise_scale = STAGE_1_SIGMAS[0] # 1.0 noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = state1.denoise_mask * noise_scale scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState( state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent, clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask, denoise_mask=state1.denoise_mask,
) )
@@ -368,7 +390,7 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
else: else:
# T2V: just use random noise # T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
@@ -391,6 +413,7 @@ def generate_video(
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
@@ -401,7 +424,7 @@ def generate_video(
state2 = LatentState( state2 = LatentState(
latent=latents, # Start with upscaled latent latent=latents, # Start with upscaled latent
clean_latent=mx.zeros_like(latents), clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent, latent=stage2_image_latent,
@@ -413,11 +436,11 @@ def generate_video(
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 2, noise_scale = stage_2_sigmas[0] # For Stage 2, noise_scale = stage_2_sigmas[0]
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
noise = mx.random.normal(latents.shape) noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState( state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent, clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask, denoise_mask=state2.denoise_mask,
) )
@@ -425,9 +448,10 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
else: else:
# T2V: add noise to all frames for refinement # T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
latents = noise * noise_scale + latents * (1 - noise_scale) noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
@@ -435,9 +459,36 @@ def generate_video(
del transformer del transformer
mx.clear_cache() mx.clear_cache()
# Decode to video # Decode to video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(latents)
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents)
mx.eval(video) mx.eval(video)
mx.clear_cache() mx.clear_cache()
@@ -594,6 +645,15 @@ Examples:
default=0, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)" help="Frame index to condition for I2V (0 = first frame, default: 0)"
) )
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on video size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(

View File

@@ -3,7 +3,7 @@
import argparse import argparse
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Optional
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
@@ -163,6 +164,7 @@ def denoise_av(
Returns: Returns:
Tuple of (video_latents, audio_latents) Tuple of (video_latents, audio_latents)
""" """
dtype = video_latents.dtype
# If video state is provided, use its latent # If video state is provided, use its latent
if video_state is not None: if video_state is not None:
video_latents = video_state.latent video_latents = video_state.latent
@@ -188,10 +190,10 @@ def denoise_av(
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
# Per-token timesteps: sigma * mask # Per-token timesteps: sigma * mask
video_timesteps = sigma * denoise_mask_flat video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else: else:
# All tokens get the same timestep # All tokens get the same timestep
video_timesteps = mx.full((b, num_video_tokens), sigma) video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
video_modality = Modality( video_modality = Modality(
latent=video_flat, latent=video_flat,
@@ -204,7 +206,7 @@ def denoise_av(
audio_modality = Modality( audio_modality = Modality(
latent=audio_flat, latent=audio_flat,
timesteps=mx.full((ab, at), sigma), timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions, positions=audio_positions,
context=audio_embeddings, context=audio_embeddings,
context_mask=None, context_mask=None,
@@ -229,10 +231,12 @@ def denoise_av(
mx.eval(video_denoised, audio_denoised) mx.eval(video_denoised, audio_denoised)
# Euler step # Euler step - use dtype-preserving arrays to avoid float32 promotion
if sigma_next > 0: if sigma_next > 0:
video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma sigma_next_arr = mx.array(sigma_next, dtype=dtype)
audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma sigma_arr = mx.array(sigma, dtype=dtype)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else: else:
video_latents = video_denoised video_latents = video_denoised
audio_latents = audio_denoised audio_latents = audio_denoised
@@ -363,6 +367,7 @@ def generate_video_with_audio(
image: Optional[str] = None, image: Optional[str] = None,
image_strength: float = 1.0, image_strength: float = 1.0,
image_frame_idx: int = 0, image_frame_idx: int = 0,
tiling: str = "auto",
): ):
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image. """Generate video with synchronized audio from text prompt, optionally conditioned on an image.
@@ -384,6 +389,7 @@ def generate_video_with_audio(
image: Path to conditioning image for I2V image: Path to conditioning image for I2V
image_strength: Conditioning strength (1.0 = full denoise) image_strength: Conditioning strength (1.0 = full denoise)
image_frame_idx: Frame index to condition (0 = first frame) image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
""" """
start_time = time.time() start_time = time.time()
@@ -432,6 +438,7 @@ def generate_video_with_audio(
# Get both video and audio embeddings # Get both video and audio embeddings
video_embeddings, audio_embeddings = text_encoder(prompt) video_embeddings, audio_embeddings = text_encoder(prompt)
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
mx.eval(video_embeddings, audio_embeddings) mx.eval(video_embeddings, audio_embeddings)
del text_encoder del text_encoder
@@ -442,6 +449,9 @@ def generate_video_with_audio(
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig( config = LTXModelConfig(
model_type=LTXModelType.AudioVideo, model_type=LTXModelType.AudioVideo,
num_attention_heads=32, num_attention_heads=32,
@@ -479,18 +489,16 @@ def generate_video_with_audio(
mx.eval(vae_encoder.parameters()) mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution) # Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor) stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent) mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution) # Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width) input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor) stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent) mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder del vae_encoder
mx.clear_cache() mx.clear_cache()
@@ -499,9 +507,10 @@ def generate_video_with_audio(
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
# Create position grids # Create position grids - MUST stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
audio_positions = create_audio_position_grid(1, audio_frames) video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
audio_positions = create_audio_position_grid(1, audio_frames) # float32
mx.eval(video_positions, audio_positions) mx.eval(video_positions, audio_positions)
# Apply I2V conditioning for stage 1 if provided # Apply I2V conditioning for stage 1 if provided
@@ -510,9 +519,9 @@ def generate_video_with_audio(
if is_i2v and stage1_image_latent is not None: if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser # PyTorch flow: create zeros -> apply conditioning -> apply noiser
video_state1 = LatentState( video_state1 = LatentState(
latent=mx.zeros(video_latent_shape), latent=mx.zeros(video_latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(video_latent_shape), clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent, latent=stage1_image_latent,
@@ -522,11 +531,11 @@ def generate_video_with_audio(
video_state1 = apply_conditioning(video_state1, [conditioning]) video_state1 = apply_conditioning(video_state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
noise = mx.random.normal(video_latent_shape) noise = mx.random.normal(video_latent_shape).astype(model_dtype)
noise_scale = STAGE_1_SIGMAS[0] # 1.0 noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = video_state1.denoise_mask * noise_scale scaled_mask = video_state1.denoise_mask * noise_scale
video_state1 = LatentState( video_state1 = LatentState(
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state1.clean_latent, clean_latent=video_state1.clean_latent,
denoise_mask=video_state1.denoise_mask, denoise_mask=video_state1.denoise_mask,
) )
@@ -534,11 +543,11 @@ def generate_video_with_audio(
mx.eval(video_latents) mx.eval(video_latents)
else: else:
# T2V: just use random noise # T2V: just use random noise
video_latents = mx.random.normal(video_latent_shape) video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
mx.eval(video_latents) mx.eval(video_latents)
# Audio always uses pure noise (no I2V for audio) # Audio always uses pure noise (no I2V for audio)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_latents) mx.eval(audio_latents)
# Stage 1 denoising # Stage 1 denoising
@@ -568,7 +577,8 @@ def generate_video_with_audio(
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # Position grids stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
mx.eval(video_positions) mx.eval(video_positions)
# Apply I2V conditioning for stage 2 if provided # Apply I2V conditioning for stage 2 if provided
@@ -578,7 +588,7 @@ def generate_video_with_audio(
video_state2 = LatentState( video_state2 = LatentState(
latent=video_latents, # Start with upscaled latent latent=video_latents, # Start with upscaled latent
clean_latent=mx.zeros_like(video_latents), clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent, latent=stage2_image_latent,
@@ -588,11 +598,11 @@ def generate_video_with_audio(
video_state2 = apply_conditioning(video_state2, [conditioning]) video_state2 = apply_conditioning(video_state2, [conditioning])
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
video_noise = mx.random.normal(video_latents.shape) video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = video_state2.denoise_mask * noise_scale scaled_mask = video_state2.denoise_mask * noise_scale
video_state2 = LatentState( video_state2 = LatentState(
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask), latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state2.clean_latent, clean_latent=video_state2.clean_latent,
denoise_mask=video_state2.denoise_mask, denoise_mask=video_state2.denoise_mask,
) )
@@ -600,16 +610,18 @@ def generate_video_with_audio(
mx.eval(video_latents) mx.eval(video_latents)
# Audio still gets noise (no I2V for audio) # Audio still gets noise (no I2V for audio)
audio_noise = mx.random.normal(audio_latents.shape) audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents) mx.eval(audio_latents)
else: else:
# T2V: add noise to all frames for refinement # T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
video_noise = mx.random.normal(video_latents.shape) one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_noise = mx.random.normal(audio_latents.shape) video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) video_latents = video_noise * noise_scale + video_latents * one_minus_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(video_latents, audio_latents) mx.eval(video_latents, audio_latents)
video_latents, audio_latents = denoise_av( video_latents, audio_latents = denoise_av(
@@ -623,9 +635,36 @@ def generate_video_with_audio(
del transformer del transformer
mx.clear_cache() mx.clear_cache()
# Decode video # Decode video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(video_latents)
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(video_latents)
mx.eval(video) mx.eval(video)
# Convert video to uint8 frames # Convert video to uint8 frames
@@ -641,27 +680,13 @@ def generate_video_with_audio(
vocoder = load_vocoder(model_path) vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters()) mx.eval(audio_decoder.parameters(), vocoder.parameters())
# Debug: check per-channel statistics are loaded
pcs = audio_decoder.per_channel_statistics
print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]")
# Debug: check audio latent statistics
print(f"Audio latents shape: {audio_latents.shape}")
print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}")
mel_spectrogram = audio_decoder(audio_latents) mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram) mx.eval(mel_spectrogram)
print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}")
# Audio decoder output is already in vocoder format (B, C, T, F) # Audio decoder output is already in vocoder format (B, C, T, F)
audio_waveform = vocoder(mel_spectrogram) audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform) mx.eval(audio_waveform)
print(f"Audio waveform shape: {audio_waveform.shape}")
print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}")
audio_np = np.array(audio_waveform) audio_np = np.array(audio_waveform)
if audio_np.ndim == 3: if audio_np.ndim == 3:
audio_np = audio_np[0] # Remove batch dim audio_np = audio_np[0] # Remove batch dim
@@ -762,6 +787,11 @@ Examples:
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)") help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
parser.add_argument("--image-frame-idx", type=int, default=0, parser.add_argument("--image-frame-idx", type=int, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)") help="Frame index to condition for I2V (0 = first frame, default: 0)")
parser.add_argument("--tiling", type=str, default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
args = parser.parse_args() args = parser.parse_args()
@@ -783,6 +813,7 @@ Examples:
image=args.image, image=args.image,
image_strength=args.image_strength, image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx, image_frame_idx=args.image_frame_idx,
tiling=args.tiling,
) )

View File

@@ -52,10 +52,11 @@ class TransformerArgsPreprocessor:
self, self,
timestep: mx.array, timestep: mx.array,
batch_size: int, batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1)) timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
# Reshape to (batch, tokens, dim) # Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
@@ -117,7 +118,7 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs: def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent) x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0]) timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings( pe = self._prepare_positional_embeddings(
@@ -201,6 +202,7 @@ class MultiModalTransformerArgsPreprocessor:
timestep=modality.timesteps, timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0], batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
) )
return replace( return replace(
@@ -215,15 +217,16 @@ class MultiModalTransformerArgsPreprocessor:
timestep: mx.array, timestep: mx.array,
timestep_scale_multiplier: int, timestep_scale_multiplier: int,
batch_size: int, batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1)) scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor) gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep return scale_shift_timestep, gate_timestep

View File

@@ -128,6 +128,7 @@ def apply_split_rotary_emb(
Returns: Returns:
Tensor with split rotary embeddings applied Tensor with split rotary embeddings applied
""" """
input_dtype = input_tensor.dtype
needs_reshape = False needs_reshape = False
original_shape = input_tensor.shape original_shape = input_tensor.shape
@@ -139,6 +140,11 @@ def apply_split_rotary_emb(
input_tensor = mx.swapaxes(input_tensor, 1, 2) input_tensor = mx.swapaxes(input_tensor, 1, 2)
needs_reshape = True needs_reshape = True
# Cast to float32 for computation precision
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Split into two halves: (..., dim) -> (..., 2, dim//2) # Split into two halves: (..., dim) -> (..., 2, dim//2)
dim = input_tensor.shape[-1] dim = input_tensor.shape[-1]
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2)) split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
@@ -167,7 +173,7 @@ def apply_split_rotary_emb(
output = mx.swapaxes(output, 1, 2) output = mx.swapaxes(output, 1, 2)
output = mx.reshape(output, (b, t, h * d)) output = mx.reshape(output, (b, t, h * d))
return output return output.astype(input_dtype)
def generate_freq_grid( def generate_freq_grid(
@@ -424,8 +430,20 @@ def _precompute_freqs_cis_double_precision(
rope_type: LTXRopeType, rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
# Convert to numpy float64 # Warn if positions are bfloat16 - this causes quality degradation
indices_grid_np = np.array(indices_grid).astype(np.float64) if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
"Use float32 for position grids to avoid quality degradation. "
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
UserWarning,
stacklevel=2
)
# Convert to numpy float64 (first to float32 for numpy compatibility)
# Note: If input is bfloat16, precision is already lost at this step
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Generate frequency indices in float64 # Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1] n_pos_dims = indices_grid_np.shape[1]

View File

@@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module):
Returns: Returns:
Tensor with SPLIT rotary embeddings applied Tensor with SPLIT rotary embeddings applied
""" """
input_dtype = x.dtype
# Cast to float32 for precision, then cast back
x = x.astype(mx.float32)
cos_freq = cos_freq.astype(mx.float32)
sin_freq = sin_freq.astype(mx.float32)
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2) # Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
half_dim = x.shape[-1] // 2 half_dim = x.shape[-1] // 2
x1 = x[..., :half_dim] x1 = x[..., :half_dim]
@@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module):
out1 = x1 * cos_freq - x2 * sin_freq out1 = x1 * cos_freq - x2 * sin_freq
out2 = x2 * cos_freq + x1 * sin_freq out2 = x2 * cos_freq + x1 * sin_freq
return mx.concatenate([out1, out2], axis=-1) return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
class GEGLU(nn.Module): class GEGLU(nn.Module):
@@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module):
attention_mask: mx.array, attention_mask: mx.array,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
batch_size, seq_len, dim = hidden_states.shape batch_size, seq_len, dim = hidden_states.shape
dtype = hidden_states.dtype
# Binary mask: 1 for valid tokens, 0 for padded # Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded # attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq) mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
# Tile registers to match sequence length # Tile registers to match sequence length, cast to hidden_states dtype
num_tiles = seq_len // self.num_learnable_registers num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim) registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing) # Process each batch item (PyTorch uses advanced indexing)
result_list = [] result_list = []
@@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module):
# Pad with zeros on the right to get back to seq_len # Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid pad_length = seq_len - num_valid
if pad_length > 0: if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype) padding = mx.zeros((pad_length, dim), dtype=dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim) adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
else: else:
adjusted = valid_tokens adjusted = valid_tokens
@@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module):
], axis=0) # (seq,) ], axis=0) # (seq,)
# Combine: valid tokens at front, registers at back # Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1) flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
result_list.append(combined) result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
@@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module):
hidden_states: mx.array, hidden_states: mx.array,
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
# Replace padded tokens with learnable registers # Replace padded tokens with learnable registers
if self.num_learnable_registers > 0 and attention_mask is not None: if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_registers( hidden_states, attention_mask = self._replace_padded_with_registers(
@@ -521,6 +527,7 @@ def norm_and_concat_hidden_states(
# Stack hidden states: (batch, seq, dim, num_layers) # Stack hidden states: (batch, seq, dim, num_layers)
stacked = mx.stack(hidden_states, axis=-1) stacked = mx.stack(hidden_states, axis=-1)
dtype = stacked.dtype
b, t, d, num_layers = stacked.shape b, t, d, num_layers = stacked.shape
# Compute sequence lengths from attention mask # Compute sequence lengths from attention mask
@@ -536,16 +543,16 @@ def norm_and_concat_hidden_states(
mask = token_indices >= start_indices # (B, T) mask = token_indices >= start_indices # (B, T)
mask = mask[:, :, None, None] # (B, T, 1, 1) mask = mask[:, :, None, None] # (B, T, 1, 1)
eps = 1e-6 eps = mx.array(1e-6, dtype=dtype)
# Compute masked mean per layer # Compute masked mean per layer - ensure dtype consistency
masked = mx.where(mask, stacked, mx.zeros_like(stacked)) masked = mx.where(mask, stacked, mx.zeros_like(stacked))
denom = (sequence_lengths * d).reshape(b, 1, 1, 1) denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype)
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer # Compute masked min/max per layer
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype)) x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype)) x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min range_val = x_max - x_min
@@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module):
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"
) )
features = self.feature_extractor(concat_hidden) features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
@@ -918,7 +922,7 @@ class LTX2TextEncoder(nn.Module):
if response.token == 1 or response.token == 107: # EOS tokens if response.token == 1 or response.token == 107: # EOS tokens
break break
mx.clear_cache()
# Decode only the new tokens # Decode only the new tokens

View File

@@ -1,3 +1,8 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -23,6 +23,7 @@ import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding( def get_timestep_embedding(
@@ -347,10 +348,11 @@ class LTX2VideoDecoder(nn.Module):
def denormalize(self, x: mx.array) -> mx.array: def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.""" """Denormalize latents using per-channel statistics."""
dtype = x.dtype
# Cast to float32 for precision (statistics may be in bfloat16) # Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return x * std + mean return (x * std + mean).astype(dtype)
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization.""" """Apply pixel normalization."""
@@ -444,6 +446,75 @@ class LTX2VideoDecoder(nn.Module):
return x return x
def decode_tiled(
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
This method is useful for decoding large videos that would otherwise
cause out-of-memory errors. It divides the latents into tiles,
decodes each tile separately, and blends them together.
Args:
sample: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video of shape (B, 3, F*8, H*8, W*8).
"""
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = sample.shape
needs_spatial_tiling = False
needs_temporal_tiling = False
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
# Temporal scale is 8
spatial_scale = 32
temporal_scale = 8
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
if h > tile_size_latent or w > tile_size_latent:
needs_spatial_tiling = True
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
if f > tile_size_latent:
needs_temporal_tiling = True
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
if debug:
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
return decode_with_tiling(
decoder_fn=self,
latents=sample,
tiling_config=tiling_config,
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
temporal_scale=8, # VAE temporal upsampling factor
causal=causal,
timestep=timestep,
debug=debug,
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder: def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path from pathlib import Path

View File

@@ -0,0 +1,470 @@
"""VAE Tiling Configuration for decoding large videos.
Implements spatial and temporal tiling with trapezoidal blending masks
to decode large videos without running out of memory.
Default configuration (from PyTorch):
- Spatial: 512px tiles with 64px overlap
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass, replace
from typing import List, Optional, Tuple
import mlx.core as mx
def compute_trapezoidal_mask_1d(
length: int,
ramp_left: int,
ramp_right: int,
left_starts_from_0: bool = False,
) -> mx.array:
"""Generate a 1D trapezoidal blending mask with linear ramps.
Args:
length: Output length of the mask.
ramp_left: Fade-in length on the left.
ramp_right: Fade-out length on the right.
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
Useful for temporal tiles where the first tile is causal.
Returns:
A 1D array of shape (length,) with values in [0, 1].
"""
if length <= 0:
raise ValueError("Mask length must be positive.")
ramp_left = max(0, min(ramp_left, length))
ramp_right = max(0, min(ramp_right, length))
# Start with ones
mask = [1.0] * length
# Apply left ramp (fade in)
if ramp_left > 0:
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
# Create fade_in values using linspace logic
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
fade_in = fade_in_full[:-1] # Remove last element
if not left_starts_from_0:
fade_in = fade_in[1:] # Remove first element too
for i in range(min(ramp_left, len(fade_in))):
mask[i] *= fade_in[i]
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
return mx.clip(mx.array(mask), 0, 1)
@dataclass(frozen=True)
class SpatialTilingConfig:
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
tile_size_in_pixels: int
tile_overlap_in_pixels: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
)
@dataclass(frozen=True)
class TemporalTilingConfig:
"""Configuration for dividing a video into temporal tiles."""
tile_size_in_frames: int
tile_overlap_in_frames: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
)
@dataclass(frozen=True)
class TilingConfig:
"""Configuration for splitting video into tiles with optional overlap."""
spatial_config: Optional[SpatialTilingConfig] = None
temporal_config: Optional[TemporalTilingConfig] = None
@classmethod
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
temporal_config=None,
)
@classmethod
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
)
@classmethod
def auto(
cls,
height: int,
width: int,
num_frames: int,
spatial_threshold: int = 512,
temporal_threshold: int = 65,
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Args:
height: Video height in pixels
width: Video width in pixels
num_frames: Number of video frames
spatial_threshold: Enable spatial tiling if either dimension exceeds this
temporal_threshold: Enable temporal tiling if frames exceed this
Returns:
TilingConfig if tiling is needed, None otherwise
"""
needs_spatial = height > spatial_threshold or width > spatial_threshold
needs_temporal = num_frames > temporal_threshold
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
ends = [start + size for start in starts]
ends[-1] = dimension_size
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
starts = intervals.starts.copy()
left_ramps = intervals.left_ramps.copy()
for i in range(1, len(starts)):
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
return slice(start, stop), mask
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
if debug:
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
if debug:
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Use direct slice assignment (MLX supports this)
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Clean up weights
del weights
gc.collect()
if debug:
print(f"[Tiling] Done. Final shape: {output.shape}")
# Convert back to original dtype if needed
return output.astype(latents.dtype)

View File

@@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
class_predicate=get_class_predicate, class_predicate=get_class_predicate,
) )
@partial(mx.compile, shapeless=True)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps) return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
@@ -71,9 +70,12 @@ def to_denoised(
Denoised tensor x_0 Denoised tensor x_0
""" """
if isinstance(sigma, (int, float)): if isinstance(sigma, (int, float)):
return noisy - sigma * velocity # Convert to array with matching dtype to avoid float32 promotion
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
return noisy - sigma_arr * velocity
else: else:
# sigma is per-sample # sigma is per-sample - ensure dtype matches
sigma = sigma.astype(velocity.dtype)
while sigma.ndim < velocity.ndim: while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1) sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity return noisy - sigma * velocity
@@ -169,6 +171,7 @@ def load_image(
image_path: Union[str, Path], image_path: Union[str, Path],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
dtype: mx.Dtype = mx.float32,
) -> mx.array: ) -> mx.array:
"""Load and preprocess an image for I2V conditioning. """Load and preprocess an image for I2V conditioning.
@@ -208,7 +211,7 @@ def load_image(
# Convert to numpy then MLX # Convert to numpy then MLX
image_np = np.array(image).astype(np.float32) / 255.0 image_np = np.array(image).astype(np.float32) / 255.0
return mx.array(image_np) return mx.array(image_np, dtype=dtype)
def resize_image_aspect_ratio( def resize_image_aspect_ratio(
@@ -251,6 +254,7 @@ def prepare_image_for_encoding(
image: mx.array, image: mx.array,
target_height: int, target_height: int,
target_width: int, target_width: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array: ) -> mx.array:
"""Prepare image for VAE encoding by resizing and normalizing. """Prepare image for VAE encoding by resizing and normalizing.
@@ -281,4 +285,4 @@ def prepare_image_for_encoding(
image = mx.expand_dims(image, axis=0) # (1, 3, H, W) image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W) image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
return image return image.astype(dtype)

0
tests/__init__.py Normal file
View File

280
tests/test_rope.py Normal file
View File

@@ -0,0 +1,280 @@
import pytest
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx.config import LTXRopeType
def create_video_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Create a simple video position grid for testing."""
t_coords = np.arange(0, num_frames)
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Scale to pixel space
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
return mx.array(pixel_coords, dtype=dtype)
class TestRoPEPositionPrecision:
"""Test suite for RoPE position precision requirements."""
def test_float32_positions_produce_consistent_output(self):
"""Float32 position grids should produce stable RoPE frequencies."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Verify output dtype is float32
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
# Verify no NaN or Inf values
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
This test documents the known issue: bfloat16 has only 7 bits of mantissa
vs 23 bits for float32, causing quantization errors that get amplified
by sin/cos calculations in RoPE.
"""
# Create identical position grids in different dtypes
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# Compute RoPE frequencies
cos_f32, sin_f32 = precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_bf16, sin_bf16 = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Calculate the difference
cos_diff = mx.abs(cos_f32 - cos_bf16)
sin_diff = mx.abs(sin_f32 - sin_bf16)
max_cos_diff = mx.max(cos_diff).item()
max_sin_diff = mx.max(sin_diff).item()
# bfloat16 positions WILL cause measurable differences
# This test documents this known behavior
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
# Document the precision loss (this is expected behavior)
if has_precision_loss:
print(f"\nPrecision loss detected (expected):")
print(f" Max cos difference: {max_cos_diff:.6e}")
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# The double precision path in rope.py line 434:
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# This means bfloat16 -> float32 -> float64
# The precision is already lost at the bfloat16 -> float32 step
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Output should still be float32
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
def test_position_grid_should_be_float32_recommendation(self):
"""Test that validates the recommended practice: positions should be float32.
This test serves as documentation that position grids MUST be float32
to avoid quality degradation in generated videos/audio.
"""
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
class TestRoPEInterleaved:
"""Tests for interleaved RoPE mode."""
def test_interleaved_rope_with_float32_positions(self):
"""Interleaved RoPE should work correctly with float32 positions."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.INTERLEAVED,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
assert not mx.any(mx.isnan(sin_freq)).item()
class TestRoPEWarnings:
"""Tests for RoPE warnings."""
def test_bfloat16_positions_trigger_warning(self):
"""Verify that bfloat16 positions trigger a UserWarning."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
def test_float32_positions_no_warning(self):
"""Verify that float32 positions do NOT trigger a warning."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
# This should not raise any warnings
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
class TestRoPESplit:
"""Tests for split RoPE mode (used by LTX-2)."""
def test_split_rope_output_shape(self):
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
batch_size = 1
num_frames = 4
height = 4
width = 4
num_heads = 32
dim = 128
positions = create_video_position_grid(batch_size, num_frames, height, width)
num_tokens = num_frames * height * width
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Shape should be (B, H, T, dim_per_head//2)
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])