fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -147,11 +147,13 @@ class LTXModelConfig(BaseModelConfig):
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
# instead of "rope_double_precision" from the config, so double_precision_rope
# is always False in PyTorch regardless of what the config file says. Since the
# model was trained with this behavior, we must match it.
self.double_precision_rope = False
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
# the key is absent, so double_precision_rope = False. For LTX-2.3
# (has_prompt_adaln=True) the safetensors config has
# frequencies_precision="float64", so double_precision_rope = True.
if not self.has_prompt_adaln:
self.double_precision_rope = False
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):

View File

@@ -399,13 +399,13 @@ def precompute_freqs_cis(
num_attention_heads, rope_type
)
# Cast positions to bfloat16 to match PyTorch's behavior.
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
# generate_freqs computation — fractional positions, scaling, etc. are all
# computed in bfloat16. The multiplication with float32 freq_indices then
# upcasts to float32. This precision behavior is what the model was trained
# with, so we must replicate it.
indices_grid = indices_grid.astype(mx.bfloat16)
# Keep positions in float32 for RoPE computation.
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
# empirical comparison shows float32 positions produce RoPE values matching
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# position computation that gets amplified by high-frequency indices
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
frequency grid computation (log-spaced values), then converts to float32.
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
"""
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation.",
UserWarning,
stacklevel=2
)
# Cast to float32 for position computation
# Keep positions in float32 — same reasoning as the non-double-precision path.
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]

View File

@@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module):
)
# Deeper connectors with matching dims and gate_logits
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True,
positional_embedding_max_pos=[4096], has_gate_logits=True,
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[1], has_gate_logits=True,
positional_embedding_max_pos=[4096], has_gate_logits=True,
)
else:
# LTX-2: shared feature extractor, 3840-dim connectors

View File

@@ -160,6 +160,9 @@ class TilingConfig:
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
@@ -176,37 +179,17 @@ class TilingConfig:
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)