Merge pull request #12 from Blaizzy/pc/add-vae-tiling
Add VAE Tiling + BFloat16 Support for Memory-Efficient Video Generation
This commit is contained in:
@@ -95,6 +95,7 @@ def apply_conditioning(
|
||||
Updated LatentState with conditioning applied
|
||||
"""
|
||||
state = state.clone()
|
||||
dtype = state.latent.dtype
|
||||
b, c, f, h, w = state.latent.shape
|
||||
|
||||
for cond in conditionings:
|
||||
@@ -132,7 +133,7 @@ def apply_conditioning(
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
@@ -161,7 +162,8 @@ def apply_denoise_mask(
|
||||
Returns:
|
||||
Blended latent
|
||||
"""
|
||||
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
|
||||
one = mx.array(1.0, dtype=denoised.dtype)
|
||||
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||
|
||||
|
||||
def add_noise_with_state(
|
||||
@@ -191,6 +193,7 @@ def add_noise_with_state(
|
||||
# But we scale sigma by the mask for conditioned regions
|
||||
|
||||
effective_scale = noise_scale * state.denoise_mask
|
||||
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
|
||||
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||
|
||||
return state
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
@@ -27,6 +28,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder
|
||||
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
|
||||
@@ -109,6 +111,7 @@ def create_position_grid(
|
||||
# Convert temporal to time in seconds by dividing by fps
|
||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||
|
||||
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
|
||||
@@ -136,6 +139,7 @@ def denoise(
|
||||
Denoised latent tensor
|
||||
"""
|
||||
# If state is provided, use its latent (which may have conditioning applied)
|
||||
dtype = latents.dtype
|
||||
if state is not None:
|
||||
latents = state.latent
|
||||
|
||||
@@ -153,11 +157,11 @@ def denoise(
|
||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
||||
# Per-token timesteps: sigma * mask
|
||||
timesteps = sigma * denoise_mask_flat
|
||||
# Per-token timesteps: sigma * mask (preserve dtype)
|
||||
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
# All tokens get the same timestep
|
||||
timesteps = mx.full((b, num_tokens), sigma)
|
||||
# All tokens get the same timestep (use latent dtype)
|
||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
@@ -180,8 +184,11 @@ def denoise(
|
||||
|
||||
mx.eval(denoised)
|
||||
|
||||
# Euler step (preserve dtype by converting Python floats to arrays)
|
||||
if sigma_next > 0:
|
||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||
else:
|
||||
latents = denoised
|
||||
mx.eval(latents)
|
||||
@@ -207,6 +214,7 @@ def generate_video(
|
||||
image: Optional[str] = None,
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "auto",
|
||||
):
|
||||
"""Generate video from text prompt, optionally conditioned on an image.
|
||||
|
||||
@@ -228,6 +236,14 @@ def generate_video(
|
||||
image: Path to conditioning image for I2V (Image-to-Video)
|
||||
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
|
||||
image_frame_idx: Frame index to condition (0 = first frame)
|
||||
tiling: Tiling mode for VAE decoding. Options:
|
||||
- "auto": Automatically determine based on video size (default)
|
||||
- "none": Disable tiling
|
||||
- "default": 512px spatial, 64 frame temporal
|
||||
- "aggressive": 256px spatial, 32 frame temporal (lowest memory)
|
||||
- "conservative": 768px spatial, 96 frame temporal (faster)
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -273,6 +289,7 @@ def generate_video(
|
||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||
|
||||
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
|
||||
mx.eval(text_embeddings)
|
||||
|
||||
del text_encoder
|
||||
@@ -282,6 +299,8 @@ def generate_video(
|
||||
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
# Convert transformer weights to bfloat16 for memory efficiency
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
@@ -300,7 +319,7 @@ def generate_video(
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
@@ -313,15 +332,15 @@ def generate_video(
|
||||
mx.eval(vae_encoder.parameters())
|
||||
|
||||
# Load and prepare image for stage 1 (half resolution)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||
|
||||
# Load and prepare image for stage 2 (full resolution)
|
||||
input_image = load_image(image, height=height, width=width)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||
@@ -333,6 +352,7 @@ def generate_video(
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Position grids stay float32 for RoPE precision
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
@@ -343,24 +363,26 @@ def generate_video(
|
||||
# Create initial state with zeros
|
||||
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||
state1 = LatentState(
|
||||
latent=mx.zeros(latent_shape),
|
||||
clean_latent=mx.zeros(latent_shape),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage1_image_latent,
|
||||
frame_idx=image_frame_idx,
|
||||
strength=image_strength,
|
||||
)
|
||||
|
||||
state1 = apply_conditioning(state1, [conditioning])
|
||||
|
||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||
# For Stage 1, noise_scale = 1.0 (first sigma)
|
||||
noise = mx.random.normal(latent_shape)
|
||||
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
||||
noise = mx.random.normal(latent_shape, dtype=model_dtype)
|
||||
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
|
||||
scaled_mask = state1.denoise_mask * noise_scale
|
||||
|
||||
state1 = LatentState(
|
||||
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask),
|
||||
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state1.clean_latent,
|
||||
denoise_mask=state1.denoise_mask,
|
||||
)
|
||||
@@ -368,7 +390,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
else:
|
||||
# T2V: just use random noise
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
||||
@@ -391,6 +413,7 @@ def generate_video(
|
||||
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||
# Position grids stay float32 for RoPE precision
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
@@ -401,7 +424,7 @@ def generate_video(
|
||||
state2 = LatentState(
|
||||
latent=latents, # Start with upscaled latent
|
||||
clean_latent=mx.zeros_like(latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage2_image_latent,
|
||||
@@ -413,11 +436,11 @@ def generate_video(
|
||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||
# For Stage 2, noise_scale = stage_2_sigmas[0]
|
||||
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||
noise = mx.random.normal(latents.shape)
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
scaled_mask = state2.denoise_mask * noise_scale
|
||||
state2 = LatentState(
|
||||
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask),
|
||||
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state2.clean_latent,
|
||||
denoise_mask=state2.denoise_mask,
|
||||
)
|
||||
@@ -425,9 +448,10 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
else:
|
||||
# T2V: add noise to all frames for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape)
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
latents = noise * noise_scale + latents * one_minus_scale
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
||||
@@ -435,9 +459,36 @@ def generate_video(
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode to video
|
||||
# Decode to video with tiling
|
||||
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
|
||||
# Select tiling configuration
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
|
||||
@@ -594,6 +645,15 @@ Examples:
|
||||
default=0,
|
||||
help="Frame index to condition for I2V (0 = first frame, default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="Tiling mode for VAE decoding (default: auto). "
|
||||
"auto=based on video size, none=disabled, default=512px/64f, "
|
||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w
|
||||
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
|
||||
@@ -163,6 +164,7 @@ def denoise_av(
|
||||
Returns:
|
||||
Tuple of (video_latents, audio_latents)
|
||||
"""
|
||||
dtype = video_latents.dtype
|
||||
# If video state is provided, use its latent
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
@@ -188,10 +190,10 @@ def denoise_av(
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||
# Per-token timesteps: sigma * mask
|
||||
video_timesteps = sigma * denoise_mask_flat
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
# All tokens get the same timestep
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma)
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
|
||||
video_modality = Modality(
|
||||
latent=video_flat,
|
||||
@@ -204,7 +206,7 @@ def denoise_av(
|
||||
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((ab, at), sigma),
|
||||
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
@@ -229,10 +231,12 @@ def denoise_av(
|
||||
|
||||
mx.eval(video_denoised, audio_denoised)
|
||||
|
||||
# Euler step
|
||||
# Euler step - use dtype-preserving arrays to avoid float32 promotion
|
||||
if sigma_next > 0:
|
||||
video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma
|
||||
audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
@@ -363,6 +367,7 @@ def generate_video_with_audio(
|
||||
image: Optional[str] = None,
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "auto",
|
||||
):
|
||||
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
|
||||
|
||||
@@ -384,6 +389,7 @@ def generate_video_with_audio(
|
||||
image: Path to conditioning image for I2V
|
||||
image_strength: Conditioning strength (1.0 = full denoise)
|
||||
image_frame_idx: Frame index to condition (0 = first frame)
|
||||
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -432,6 +438,7 @@ def generate_video_with_audio(
|
||||
|
||||
# Get both video and audio embeddings
|
||||
video_embeddings, audio_embeddings = text_encoder(prompt)
|
||||
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
|
||||
mx.eval(video_embeddings, audio_embeddings)
|
||||
|
||||
del text_encoder
|
||||
@@ -442,6 +449,9 @@ def generate_video_with_audio(
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
# Convert transformer weights to bfloat16 for memory efficiency
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.AudioVideo,
|
||||
num_attention_heads=32,
|
||||
@@ -479,18 +489,16 @@ def generate_video_with_audio(
|
||||
mx.eval(vae_encoder.parameters())
|
||||
|
||||
# Load and prepare image for stage 1 (half resolution)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||
|
||||
# Load and prepare image for stage 2 (full resolution)
|
||||
input_image = load_image(image, height=height, width=width)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
@@ -499,9 +507,10 @@ def generate_video_with_audio(
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create position grids
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
# Create position grids - MUST stay float32 for RoPE precision
|
||||
# bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
|
||||
audio_positions = create_audio_position_grid(1, audio_frames) # float32
|
||||
mx.eval(video_positions, audio_positions)
|
||||
|
||||
# Apply I2V conditioning for stage 1 if provided
|
||||
@@ -510,9 +519,9 @@ def generate_video_with_audio(
|
||||
if is_i2v and stage1_image_latent is not None:
|
||||
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
||||
video_state1 = LatentState(
|
||||
latent=mx.zeros(video_latent_shape),
|
||||
clean_latent=mx.zeros(video_latent_shape),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage1_image_latent,
|
||||
@@ -522,11 +531,11 @@ def generate_video_with_audio(
|
||||
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||
|
||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||
noise = mx.random.normal(video_latent_shape)
|
||||
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
||||
noise = mx.random.normal(video_latent_shape).astype(model_dtype)
|
||||
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
|
||||
scaled_mask = video_state1.denoise_mask * noise_scale
|
||||
video_state1 = LatentState(
|
||||
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask),
|
||||
latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=video_state1.clean_latent,
|
||||
denoise_mask=video_state1.denoise_mask,
|
||||
)
|
||||
@@ -534,11 +543,11 @@ def generate_video_with_audio(
|
||||
mx.eval(video_latents)
|
||||
else:
|
||||
# T2V: just use random noise
|
||||
video_latents = mx.random.normal(video_latent_shape)
|
||||
video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Audio always uses pure noise (no I2V for audio)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Stage 1 denoising
|
||||
@@ -568,7 +577,8 @@ def generate_video_with_audio(
|
||||
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
# Position grids stay float32 for RoPE precision
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
|
||||
mx.eval(video_positions)
|
||||
|
||||
# Apply I2V conditioning for stage 2 if provided
|
||||
@@ -578,7 +588,7 @@ def generate_video_with_audio(
|
||||
video_state2 = LatentState(
|
||||
latent=video_latents, # Start with upscaled latent
|
||||
clean_latent=mx.zeros_like(video_latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage2_image_latent,
|
||||
@@ -588,11 +598,11 @@ def generate_video_with_audio(
|
||||
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||
|
||||
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||
video_noise = mx.random.normal(video_latents.shape)
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
scaled_mask = video_state2.denoise_mask * noise_scale
|
||||
video_state2 = LatentState(
|
||||
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask),
|
||||
latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=video_state2.clean_latent,
|
||||
denoise_mask=video_state2.denoise_mask,
|
||||
)
|
||||
@@ -600,16 +610,18 @@ def generate_video_with_audio(
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Audio still gets noise (no I2V for audio)
|
||||
audio_noise = mx.random.normal(audio_latents.shape)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(audio_latents)
|
||||
else:
|
||||
# T2V: add noise to all frames for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
video_noise = mx.random.normal(video_latents.shape)
|
||||
audio_noise = mx.random.normal(audio_latents.shape)
|
||||
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
video_latents = video_noise * noise_scale + video_latents * one_minus_scale
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
video_latents, audio_latents = denoise_av(
|
||||
@@ -623,9 +635,36 @@ def generate_video_with_audio(
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode video
|
||||
# Decode video with tiling
|
||||
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||
video = vae_decoder(video_latents)
|
||||
|
||||
# Select tiling configuration
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(video_latents)
|
||||
mx.eval(video)
|
||||
|
||||
# Convert video to uint8 frames
|
||||
@@ -641,27 +680,13 @@ def generate_video_with_audio(
|
||||
vocoder = load_vocoder(model_path)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
|
||||
# Debug: check per-channel statistics are loaded
|
||||
pcs = audio_decoder.per_channel_statistics
|
||||
print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]")
|
||||
|
||||
# Debug: check audio latent statistics
|
||||
print(f"Audio latents shape: {audio_latents.shape}")
|
||||
print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}")
|
||||
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
|
||||
print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
|
||||
print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}")
|
||||
|
||||
# Audio decoder output is already in vocoder format (B, C, T, F)
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
|
||||
print(f"Audio waveform shape: {audio_waveform.shape}")
|
||||
print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}")
|
||||
|
||||
audio_np = np.array(audio_waveform)
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0] # Remove batch dim
|
||||
@@ -762,6 +787,11 @@ Examples:
|
||||
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
|
||||
parser.add_argument("--image-frame-idx", type=int, default=0,
|
||||
help="Frame index to condition for I2V (0 = first frame, default: 0)")
|
||||
parser.add_argument("--tiling", type=str, default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="Tiling mode for VAE decoding (default: auto). "
|
||||
"auto=based on size, none=disabled, default=512px/64f, "
|
||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -783,6 +813,7 @@ Examples:
|
||||
image=args.image,
|
||||
image_strength=args.image_strength,
|
||||
image_frame_idx=args.image_frame_idx,
|
||||
tiling=args.tiling,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -52,10 +52,11 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
@@ -117,7 +118,7 @@ class TransformerArgsPreprocessor:
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
pe = self._prepare_positional_embeddings(
|
||||
@@ -201,6 +202,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
|
||||
return replace(
|
||||
@@ -215,15 +217,16 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
timestep: mx.array,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
@@ -128,6 +128,7 @@ def apply_split_rotary_emb(
|
||||
Returns:
|
||||
Tensor with split rotary embeddings applied
|
||||
"""
|
||||
input_dtype = input_tensor.dtype
|
||||
needs_reshape = False
|
||||
original_shape = input_tensor.shape
|
||||
|
||||
@@ -139,6 +140,11 @@ def apply_split_rotary_emb(
|
||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Cast to float32 for computation precision
|
||||
input_tensor = input_tensor.astype(mx.float32)
|
||||
cos_freqs = cos_freqs.astype(mx.float32)
|
||||
sin_freqs = sin_freqs.astype(mx.float32)
|
||||
|
||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||
dim = input_tensor.shape[-1]
|
||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||
@@ -167,7 +173,7 @@ def apply_split_rotary_emb(
|
||||
output = mx.swapaxes(output, 1, 2)
|
||||
output = mx.reshape(output, (b, t, h * d))
|
||||
|
||||
return output
|
||||
return output.astype(input_dtype)
|
||||
|
||||
|
||||
def generate_freq_grid(
|
||||
@@ -424,8 +430,20 @@ def _precompute_freqs_cis_double_precision(
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
|
||||
"Use float32 for position grids to avoid quality degradation. "
|
||||
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||
# Note: If input is bfloat16, precision is already lost at this step
|
||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
n_pos_dims = indices_grid_np.shape[1]
|
||||
|
||||
@@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module):
|
||||
Returns:
|
||||
Tensor with SPLIT rotary embeddings applied
|
||||
"""
|
||||
input_dtype = x.dtype
|
||||
|
||||
# Cast to float32 for precision, then cast back
|
||||
x = x.astype(mx.float32)
|
||||
cos_freq = cos_freq.astype(mx.float32)
|
||||
sin_freq = sin_freq.astype(mx.float32)
|
||||
|
||||
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
|
||||
half_dim = x.shape[-1] // 2
|
||||
x1 = x[..., :half_dim]
|
||||
@@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module):
|
||||
out1 = x1 * cos_freq - x2 * sin_freq
|
||||
out2 = x2 * cos_freq + x1 * sin_freq
|
||||
|
||||
return mx.concatenate([out1, out2], axis=-1)
|
||||
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
@@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module):
|
||||
attention_mask: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
batch_size, seq_len, dim = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# Binary mask: 1 for valid tokens, 0 for padded
|
||||
# attention_mask is additive: 0 for valid, large negative for padded
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||
|
||||
# Tile registers to match sequence length
|
||||
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||
num_tiles = seq_len // self.num_learnable_registers
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
||||
|
||||
# Process each batch item (PyTorch uses advanced indexing)
|
||||
result_list = []
|
||||
@@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
# Pad with zeros on the right to get back to seq_len
|
||||
pad_length = seq_len - num_valid
|
||||
if pad_length > 0:
|
||||
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
|
||||
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||
else:
|
||||
adjusted = valid_tokens
|
||||
@@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module):
|
||||
], axis=0) # (seq,)
|
||||
|
||||
# Combine: valid tokens at front, registers at back
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||
|
||||
result_list.append(combined)
|
||||
|
||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||
@@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
# Replace padded tokens with learnable registers
|
||||
if self.num_learnable_registers > 0 and attention_mask is not None:
|
||||
hidden_states, attention_mask = self._replace_padded_with_registers(
|
||||
@@ -521,6 +527,7 @@ def norm_and_concat_hidden_states(
|
||||
|
||||
# Stack hidden states: (batch, seq, dim, num_layers)
|
||||
stacked = mx.stack(hidden_states, axis=-1)
|
||||
dtype = stacked.dtype
|
||||
b, t, d, num_layers = stacked.shape
|
||||
|
||||
# Compute sequence lengths from attention mask
|
||||
@@ -536,16 +543,16 @@ def norm_and_concat_hidden_states(
|
||||
mask = token_indices >= start_indices # (B, T)
|
||||
|
||||
mask = mask[:, :, None, None] # (B, T, 1, 1)
|
||||
eps = 1e-6
|
||||
eps = mx.array(1e-6, dtype=dtype)
|
||||
|
||||
# Compute masked mean per layer
|
||||
# Compute masked mean per layer - ensure dtype consistency
|
||||
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
|
||||
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
|
||||
denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype)
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
@@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module):
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
)
|
||||
|
||||
features = self.feature_extractor(concat_hidden)
|
||||
|
||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
@@ -918,7 +922,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
if response.token == 1 or response.token == 107: # EOS tokens
|
||||
break
|
||||
|
||||
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode only the new tokens
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ import mlx.nn as nn
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -347,10 +348,11 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return x * std + mean
|
||||
return (x * std + mean).astype(dtype)
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
@@ -444,6 +446,75 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
This method is useful for decoding large videos that would otherwise
|
||||
cause out-of-memory errors. It divides the latents into tiles,
|
||||
decodes each tile separately, and blends them together.
|
||||
|
||||
Args:
|
||||
sample: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F*8, H*8, W*8).
|
||||
"""
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = sample.shape
|
||||
needs_spatial_tiling = False
|
||||
needs_temporal_tiling = False
|
||||
|
||||
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
|
||||
# Temporal scale is 8
|
||||
spatial_scale = 32
|
||||
temporal_scale = 8
|
||||
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
if h > tile_size_latent or w > tile_size_latent:
|
||||
needs_spatial_tiling = True
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
if debug:
|
||||
print("[Tiling] Input fits within tile size, using regular decode")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
latents=sample,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
|
||||
470
mlx_video/models/ltx/video_vae/tiling.py
Normal file
470
mlx_video/models/ltx/video_vae/tiling.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""VAE Tiling Configuration for decoding large videos.
|
||||
|
||||
Implements spatial and temporal tiling with trapezoidal blending masks
|
||||
to decode large videos without running out of memory.
|
||||
|
||||
Default configuration (from PyTorch):
|
||||
- Spatial: 512px tiles with 64px overlap
|
||||
- Temporal: 64 frames with 24 frame overlap
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def compute_trapezoidal_mask_1d(
|
||||
length: int,
|
||||
ramp_left: int,
|
||||
ramp_right: int,
|
||||
left_starts_from_0: bool = False,
|
||||
) -> mx.array:
|
||||
"""Generate a 1D trapezoidal blending mask with linear ramps.
|
||||
|
||||
Args:
|
||||
length: Output length of the mask.
|
||||
ramp_left: Fade-in length on the left.
|
||||
ramp_right: Fade-out length on the right.
|
||||
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
|
||||
Useful for temporal tiles where the first tile is causal.
|
||||
|
||||
Returns:
|
||||
A 1D array of shape (length,) with values in [0, 1].
|
||||
"""
|
||||
if length <= 0:
|
||||
raise ValueError("Mask length must be positive.")
|
||||
|
||||
ramp_left = max(0, min(ramp_left, length))
|
||||
ramp_right = max(0, min(ramp_right, length))
|
||||
|
||||
# Start with ones
|
||||
mask = [1.0] * length
|
||||
|
||||
# Apply left ramp (fade in)
|
||||
if ramp_left > 0:
|
||||
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
|
||||
# Create fade_in values using linspace logic
|
||||
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
|
||||
fade_in = fade_in_full[:-1] # Remove last element
|
||||
if not left_starts_from_0:
|
||||
fade_in = fade_in[1:] # Remove first element too
|
||||
for i in range(min(ramp_left, len(fade_in))):
|
||||
mask[i] *= fade_in[i]
|
||||
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
return mx.clip(mx.array(mask), 0, 1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpatialTilingConfig:
|
||||
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
|
||||
|
||||
tile_size_in_pixels: int
|
||||
tile_overlap_in_pixels: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemporalTilingConfig:
|
||||
"""Configuration for dividing a video into temporal tiles."""
|
||||
|
||||
tile_size_in_frames: int
|
||||
tile_overlap_in_frames: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TilingConfig:
|
||||
"""Configuration for splitting video into tiles with optional overlap."""
|
||||
|
||||
spatial_config: Optional[SpatialTilingConfig] = None
|
||||
temporal_config: Optional[TemporalTilingConfig] = None
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def auto(
|
||||
cls,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
spatial_threshold: int = 512,
|
||||
temporal_threshold: int = 65,
|
||||
) -> Optional["TilingConfig"]:
|
||||
"""Automatically determine tiling config based on video dimensions.
|
||||
|
||||
Args:
|
||||
height: Video height in pixels
|
||||
width: Video width in pixels
|
||||
num_frames: Number of video frames
|
||||
spatial_threshold: Enable spatial tiling if either dimension exceeds this
|
||||
temporal_threshold: Enable temporal tiling if frames exceed this
|
||||
|
||||
Returns:
|
||||
TilingConfig if tiling is needed, None otherwise
|
||||
"""
|
||||
needs_spatial = height > spatial_threshold or width > spatial_threshold
|
||||
needs_temporal = num_frames > temporal_threshold
|
||||
|
||||
if not needs_spatial and not needs_temporal:
|
||||
return None
|
||||
|
||||
# Estimate memory requirement (rough heuristic)
|
||||
# Output size in bytes (float32): B * 3 * F * H * W * 4
|
||||
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
|
||||
|
||||
# For very large videos, use aggressive tiling
|
||||
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
|
||||
return cls.aggressive()
|
||||
|
||||
spatial_config = None
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
# Choose tile size based on resolution
|
||||
max_dim = max(height, width)
|
||||
if max_dim > 1024:
|
||||
tile_size = 384 # Smaller tiles for very large resolutions
|
||||
elif max_dim > 768:
|
||||
tile_size = 512
|
||||
else:
|
||||
tile_size = 384
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
|
||||
|
||||
if needs_temporal:
|
||||
# Choose tile size based on frame count
|
||||
if num_frames > 200:
|
||||
tile_size, overlap = 32, 8 # Aggressive for very long videos
|
||||
elif num_frames > 100:
|
||||
tile_size, overlap = 48, 16
|
||||
else:
|
||||
tile_size, overlap = 64, 24
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
ends = [start + size for start in starts]
|
||||
ends[-1] = dimension_size
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
|
||||
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
|
||||
starts = intervals.starts.copy()
|
||||
left_ramps = intervals.left_ramps.copy()
|
||||
|
||||
for i in range(1, len(starts)):
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = 1 + (f_latent - 1) * temporal_scale
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
|
||||
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
|
||||
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Use direct slice assignment (MLX supports this)
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Done. Final shape: {output.shape}")
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
@@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||
|
||||
|
||||
|
||||
@@ -71,9 +70,12 @@ def to_denoised(
|
||||
Denoised tensor x_0
|
||||
"""
|
||||
if isinstance(sigma, (int, float)):
|
||||
return noisy - sigma * velocity
|
||||
# Convert to array with matching dtype to avoid float32 promotion
|
||||
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
|
||||
return noisy - sigma_arr * velocity
|
||||
else:
|
||||
# sigma is per-sample
|
||||
# sigma is per-sample - ensure dtype matches
|
||||
sigma = sigma.astype(velocity.dtype)
|
||||
while sigma.ndim < velocity.ndim:
|
||||
sigma = mx.expand_dims(sigma, axis=-1)
|
||||
return noisy - sigma * velocity
|
||||
@@ -169,6 +171,7 @@ def load_image(
|
||||
image_path: Union[str, Path],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Load and preprocess an image for I2V conditioning.
|
||||
|
||||
@@ -208,7 +211,7 @@ def load_image(
|
||||
|
||||
# Convert to numpy then MLX
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
return mx.array(image_np)
|
||||
return mx.array(image_np, dtype=dtype)
|
||||
|
||||
|
||||
def resize_image_aspect_ratio(
|
||||
@@ -251,6 +254,7 @@ def prepare_image_for_encoding(
|
||||
image: mx.array,
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Prepare image for VAE encoding by resizing and normalizing.
|
||||
|
||||
@@ -281,4 +285,4 @@ def prepare_image_for_encoding(
|
||||
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
||||
|
||||
return image
|
||||
return image.astype(dtype)
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
280
tests/test_rope.py
Normal file
280
tests/test_rope.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Create a simple video position grid for testing."""
|
||||
t_coords = np.arange(0, num_frames)
|
||||
h_coords = np.arange(0, height)
|
||||
w_coords = np.arange(0, width)
|
||||
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
patch_ends = patch_starts + 1
|
||||
|
||||
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||
num_patches = num_frames * height * width
|
||||
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||
|
||||
# Scale to pixel space
|
||||
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
|
||||
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
|
||||
|
||||
return mx.array(pixel_coords, dtype=dtype)
|
||||
|
||||
class TestRoPEPositionPrecision:
|
||||
"""Test suite for RoPE position precision requirements."""
|
||||
|
||||
def test_float32_positions_produce_consistent_output(self):
|
||||
"""Float32 position grids should produce stable RoPE frequencies."""
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Verify output dtype is float32
|
||||
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
|
||||
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
|
||||
|
||||
# Verify no NaN or Inf values
|
||||
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
|
||||
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
|
||||
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
|
||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||
|
||||
# Verify cos/sin are in valid range [-1, 1]
|
||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||
"cos_freq values out of [-1, 1] range"
|
||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||
"sin_freq values out of [-1, 1] range"
|
||||
|
||||
def test_bfloat16_positions_cause_precision_loss(self):
|
||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||
|
||||
This test documents the known issue: bfloat16 has only 7 bits of mantissa
|
||||
vs 23 bits for float32, causing quantization errors that get amplified
|
||||
by sin/cos calculations in RoPE.
|
||||
"""
|
||||
# Create identical position grids in different dtypes
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
# Compute RoPE frequencies
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(
|
||||
indices_grid=positions_f32,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
cos_bf16, sin_bf16 = precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Calculate the difference
|
||||
cos_diff = mx.abs(cos_f32 - cos_bf16)
|
||||
sin_diff = mx.abs(sin_f32 - sin_bf16)
|
||||
|
||||
max_cos_diff = mx.max(cos_diff).item()
|
||||
max_sin_diff = mx.max(sin_diff).item()
|
||||
|
||||
# bfloat16 positions WILL cause measurable differences
|
||||
# This test documents this known behavior
|
||||
# The threshold here is intentionally low to catch the issue
|
||||
precision_threshold = 1e-6
|
||||
|
||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
|
||||
# Document the precision loss (this is expected behavior)
|
||||
if has_precision_loss:
|
||||
print(f"\nPrecision loss detected (expected):")
|
||||
print(f" Max cos difference: {max_cos_diff:.6e}")
|
||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||
|
||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||
assert has_precision_loss, \
|
||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
|
||||
def test_double_precision_converts_to_float32_internally(self):
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
# The double precision path in rope.py line 434:
|
||||
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
# This means bfloat16 -> float32 -> float64
|
||||
# The precision is already lost at the bfloat16 -> float32 step
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Output should still be float32
|
||||
assert cos_freq.dtype == mx.float32
|
||||
assert sin_freq.dtype == mx.float32
|
||||
|
||||
def test_position_grid_should_be_float32_recommendation(self):
|
||||
"""Test that validates the recommended practice: positions should be float32.
|
||||
|
||||
This test serves as documentation that position grids MUST be float32
|
||||
to avoid quality degradation in generated videos/audio.
|
||||
"""
|
||||
# Recommended: create positions in float32
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
assert positions.dtype == mx.float32, \
|
||||
"Position grids should be created in float32 for RoPE precision"
|
||||
|
||||
# Verify the position values are reasonable
|
||||
# Temporal positions should be small (seconds)
|
||||
temporal_positions = positions[:, 0, :, :]
|
||||
assert mx.max(temporal_positions).item() < 100, \
|
||||
"Temporal positions should be in seconds (small values)"
|
||||
|
||||
# Spatial positions should be larger (pixels)
|
||||
spatial_h = positions[:, 1, :, :]
|
||||
spatial_w = positions[:, 2, :, :]
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
"""Tests for interleaved RoPE mode."""
|
||||
|
||||
def test_interleaved_rope_with_float32_positions(self):
|
||||
"""Interleaved RoPE should work correctly with float32 positions."""
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.INTERLEAVED,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
assert sin_freq.dtype == mx.float32
|
||||
assert not mx.any(mx.isnan(cos_freq)).item()
|
||||
assert not mx.any(mx.isnan(sin_freq)).item()
|
||||
|
||||
|
||||
class TestRoPEWarnings:
|
||||
"""Tests for RoPE warnings."""
|
||||
|
||||
def test_bfloat16_positions_trigger_warning(self):
|
||||
"""Verify that bfloat16 positions trigger a UserWarning."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
def test_float32_positions_no_warning(self):
|
||||
"""Verify that float32 positions do NOT trigger a warning."""
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
# This should not raise any warnings
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_f32,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRoPESplit:
|
||||
"""Tests for split RoPE mode (used by LTX-2)."""
|
||||
|
||||
def test_split_rope_output_shape(self):
|
||||
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
|
||||
batch_size = 1
|
||||
num_frames = 4
|
||||
height = 4
|
||||
width = 4
|
||||
num_heads = 32
|
||||
dim = 128
|
||||
|
||||
positions = create_video_position_grid(batch_size, num_frames, height, width)
|
||||
num_tokens = num_frames * height * width
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=dim,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=num_heads,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Shape should be (B, H, T, dim_per_head//2)
|
||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||
dim_per_head = dim // num_heads
|
||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||
assert cos_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert sin_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user