fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -1192,9 +1192,11 @@ def generate_video(
if is_i2v:
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
audio_frames = None
if audio:
# Always compute audio frames - PyTorch distilled pipeline unconditionally
# generates audio alongside video (model was trained with joint audio-video).
# The --audio flag only controls whether audio is decoded and saved to output.
audio_frames = compute_audio_frames(num_frames, fps)
if audio:
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
# Get model path
@@ -1233,32 +1235,21 @@ def generate_video(
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Encode prompts
# Encode prompts - always get audio embeddings since the model was trained
# with joint audio-video processing (PyTorch unconditionally generates audio)
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
if audio:
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
else:
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
audio_embeddings_pos = audio_embeddings_neg = None
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE:
text_embeddings = video_embeddings_pos
else:
# Distilled pipeline - single embedding
if audio:
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
mx.eval(text_embeddings, audio_embeddings)
else:
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
audio_embeddings = None
mx.eval(text_embeddings)
model_dtype = text_embeddings.dtype
del text_encoder
@@ -1317,9 +1308,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = None
audio_latents = None
if audio:
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
@@ -1406,7 +1395,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None:
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1417,7 +1406,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings if audio else None,
audio_embeddings=audio_embeddings,
)
elif pipeline == PipelineType.DEV:
@@ -1451,9 +1440,7 @@ def generate_video(
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
mx.eval(video_positions)
audio_positions = None
audio_latents = None
if audio:
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
@@ -1484,8 +1471,7 @@ def generate_video(
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents)
# Denoise with CFG/APG/STG/modality
if audio:
# Always use A/V denoising - PyTorch always processes audio+video jointly
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
video_positions, audio_positions,
@@ -1498,17 +1484,6 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
else:
# Use original denoise_dev with computed sigmas
latents = denoise_dev(
latents, video_positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -1553,9 +1528,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = None
audio_latents = None
if audio:
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
@@ -1586,8 +1559,7 @@ def generate_video(
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents)
# Stage 1: Joint AV denoising at half resolution (matches PyTorch)
if audio:
# Stage 1: Always use joint AV denoising (matches PyTorch)
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
positions, audio_positions,
@@ -1600,18 +1572,7 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
else:
latents = denoise_dev(
latents, positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
if audio and audio_latents is not None:
mx.eval(audio_latents)
# Upsample latents 2x
@@ -1680,7 +1641,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None:
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1691,7 +1652,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None,
audio_embeddings=audio_embeddings_pos,
)
del transformer

View File

@@ -147,10 +147,12 @@ class LTXModelConfig(BaseModelConfig):
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
# instead of "rope_double_precision" from the config, so double_precision_rope
# is always False in PyTorch regardless of what the config file says. Since the
# model was trained with this behavior, we must match it.
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
# the key is absent, so double_precision_rope = False. For LTX-2.3
# (has_prompt_adaln=True) the safetensors config has
# frequencies_precision="float64", so double_precision_rope = True.
if not self.has_prompt_adaln:
self.double_precision_rope = False
# Convert string enum values if loading from dict

View File

@@ -399,13 +399,13 @@ def precompute_freqs_cis(
num_attention_heads, rope_type
)
# Cast positions to bfloat16 to match PyTorch's behavior.
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
# generate_freqs computation — fractional positions, scaling, etc. are all
# computed in bfloat16. The multiplication with float32 freq_indices then
# upcasts to float32. This precision behavior is what the model was trained
# with, so we must replicate it.
indices_grid = indices_grid.astype(mx.bfloat16)
# Keep positions in float32 for RoPE computation.
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
# empirical comparison shows float32 positions produce RoPE values matching
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# position computation that gets amplified by high-frequency indices
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
frequency grid computation (log-spaced values), then converts to float32.
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
"""
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation.",
UserWarning,
stacklevel=2
)
# Cast to float32 for position computation
# Keep positions in float32 — same reasoning as the non-double-precision path.
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]

View File

@@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module):
)
# Deeper connectors with matching dims and gate_logits
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True,
positional_embedding_max_pos=[4096], has_gate_logits=True,
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True,
positional_embedding_max_pos=[4096], has_gate_logits=True,
)
else:
# LTX-2: shared feature extractor, 3840-dim connectors

View File

@@ -160,6 +160,9 @@ class TilingConfig:
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
@@ -176,37 +179,17 @@ class TilingConfig:
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)