fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -1192,9 +1192,11 @@ def generate_video(
if is_i2v: if is_i2v:
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
audio_frames = None # Always compute audio frames - PyTorch distilled pipeline unconditionally
if audio: # generates audio alongside video (model was trained with joint audio-video).
# The --audio flag only controls whether audio is decoded and saved to output.
audio_frames = compute_audio_frames(num_frames, fps) audio_frames = compute_audio_frames(num_frames, fps)
if audio:
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]") console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
# Get model path # Get model path
@@ -1233,32 +1235,21 @@ def generate_video(
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Encode prompts # Encode prompts - always get audio embeddings since the model was trained
# with joint audio-video processing (PyTorch unconditionally generates audio)
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
if audio:
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
else:
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
audio_embeddings_pos = audio_embeddings_neg = None
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG) # For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE: if pipeline == PipelineType.DEV_TWO_STAGE:
text_embeddings = video_embeddings_pos text_embeddings = video_embeddings_pos
else: else:
# Distilled pipeline - single embedding # Distilled pipeline - single embedding
if audio:
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
mx.eval(text_embeddings, audio_embeddings) mx.eval(text_embeddings, audio_embeddings)
else:
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
audio_embeddings = None
mx.eval(text_embeddings)
model_dtype = text_embeddings.dtype model_dtype = text_embeddings.dtype
del text_encoder del text_encoder
@@ -1317,9 +1308,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
audio_positions = None # Always init audio latents/positions - PyTorch unconditionally generates audio
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames) audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents) mx.eval(audio_positions, audio_latents)
@@ -1406,7 +1395,7 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None: if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1417,7 +1406,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions, audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings if audio else None, audio_embeddings=audio_embeddings,
) )
elif pipeline == PipelineType.DEV: elif pipeline == PipelineType.DEV:
@@ -1451,9 +1440,7 @@ def generate_video(
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
mx.eval(video_positions) mx.eval(video_positions)
audio_positions = None # Always init audio latents/positions - PyTorch unconditionally generates audio
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames) audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents) mx.eval(audio_positions, audio_latents)
@@ -1484,8 +1471,7 @@ def generate_video(
latents = mx.random.normal(video_latent_shape, dtype=model_dtype) latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
# Denoise with CFG/APG/STG/modality # Always use A/V denoising - PyTorch always processes audio+video jointly
if audio:
latents, audio_latents = denoise_dev_av( latents, audio_latents = denoise_dev_av(
latents, audio_latents, latents, audio_latents,
video_positions, audio_positions, video_positions, audio_positions,
@@ -1498,17 +1484,6 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks, stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale, stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
) )
else:
# Use original denoise_dev with computed sigmas
latents = denoise_dev(
latents, video_positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -1553,9 +1528,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
audio_positions = None # Always init audio latents/positions - PyTorch unconditionally generates audio
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames) audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents) mx.eval(audio_positions, audio_latents)
@@ -1586,8 +1559,7 @@ def generate_video(
latents = mx.random.normal(stage1_shape, dtype=model_dtype) latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
# Stage 1: Joint AV denoising at half resolution (matches PyTorch) # Stage 1: Always use joint AV denoising (matches PyTorch)
if audio:
latents, audio_latents = denoise_dev_av( latents, audio_latents = denoise_dev_av(
latents, audio_latents, latents, audio_latents,
positions, audio_positions, positions, audio_positions,
@@ -1600,18 +1572,7 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks, stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale, stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
) )
else:
latents = denoise_dev(
latents, positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
if audio and audio_latents is not None:
mx.eval(audio_latents) mx.eval(audio_latents)
# Upsample latents 2x # Upsample latents 2x
@@ -1680,7 +1641,7 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None: if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1691,7 +1652,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions, audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None, audio_embeddings=audio_embeddings_pos,
) )
del transformer del transformer

View File

@@ -147,10 +147,12 @@ class LTXModelConfig(BaseModelConfig):
if self.audio_positional_embedding_max_pos is None: if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20] self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision" # PyTorch LTX-2 configurator reads "frequencies_precision" (not
# instead of "rope_double_precision" from the config, so double_precision_rope # "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
# is always False in PyTorch regardless of what the config file says. Since the # the key is absent, so double_precision_rope = False. For LTX-2.3
# model was trained with this behavior, we must match it. # (has_prompt_adaln=True) the safetensors config has
# frequencies_precision="float64", so double_precision_rope = True.
if not self.has_prompt_adaln:
self.double_precision_rope = False self.double_precision_rope = False
# Convert string enum values if loading from dict # Convert string enum values if loading from dict

View File

@@ -399,13 +399,13 @@ def precompute_freqs_cis(
num_attention_heads, rope_type num_attention_heads, rope_type
) )
# Cast positions to bfloat16 to match PyTorch's behavior. # Keep positions in float32 for RoPE computation.
# In PyTorch, positions are in bfloat16 (model dtype) during the entire # Even though PyTorch nominally casts positions to model dtype (bfloat16),
# generate_freqs computation — fractional positions, scaling, etc. are all # empirical comparison shows float32 positions produce RoPE values matching
# computed in bfloat16. The multiplication with float32 freq_indices then # PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# upcasts to float32. This precision behavior is what the model was trained # position computation that gets amplified by high-frequency indices
# with, so we must replicate it. # (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.bfloat16) indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices # Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim) indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid. """Compute RoPE frequencies with higher precision using float64 for frequency grid.
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
computation (log-spaced values), then converts to float32 for the final tensor. frequency grid computation (log-spaced values), then converts to float32.
This provides better numerical precision in the frequency generation phase. Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
""" """
import numpy as np import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation # Keep positions in float32 — same reasoning as the non-double-precision path.
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation.",
UserWarning,
stacklevel=2
)
# Cast to float32 for position computation
indices_grid_f32 = indices_grid.astype(mx.float32) indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1] n_pos_dims = indices_grid_f32.shape[1]

View File

@@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module):
) )
# Deeper connectors with matching dims and gate_logits # Deeper connectors with matching dims and gate_logits
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default # connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# (connector_positional_embedding_max_pos not in LTX-2.3 config) # config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector( self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128, dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128, num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True, positional_embedding_max_pos=[4096], has_gate_logits=True,
) )
self.audio_embeddings_connector = Embeddings1DConnector( self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64, dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128, num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True, positional_embedding_max_pos=[4096], has_gate_logits=True,
) )
else: else:
# LTX-2: shared feature extractor, 3840-dim connectors # LTX-2: shared feature extractor, 3840-dim connectors

View File

@@ -160,6 +160,9 @@ class TilingConfig:
) -> Optional["TilingConfig"]: ) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions. """Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args: Args:
height: Video height in pixels height: Video height in pixels
width: Video width in pixels width: Video width in pixels
@@ -176,37 +179,17 @@ class TilingConfig:
if not needs_spatial and not needs_temporal: if not needs_spatial and not needs_temporal:
return None return None
# Estimate memory requirement (rough heuristic) # Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Output size in bytes (float32): B * 3 * F * H * W * 4 # Smaller tiles cause quality degradation because CausalConv3d needs
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3) # sufficient temporal context and overlap for clean blending.
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
spatial_config = None spatial_config = None
temporal_config = None temporal_config = None
if needs_spatial: if needs_spatial:
# Choose tile size based on resolution spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
if needs_temporal: if needs_temporal:
# Choose tile size based on frame count temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
return cls(spatial_config=spatial_config, temporal_config=temporal_config) return cls(spatial_config=spatial_config, temporal_config=temporal_config)