fix tiling, rope precision and weights
This commit is contained in:
@@ -147,11 +147,13 @@ class LTXModelConfig(BaseModelConfig):
|
||||
if self.audio_positional_embedding_max_pos is None:
|
||||
self.audio_positional_embedding_max_pos = [20]
|
||||
|
||||
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
|
||||
# instead of "rope_double_precision" from the config, so double_precision_rope
|
||||
# is always False in PyTorch regardless of what the config file says. Since the
|
||||
# model was trained with this behavior, we must match it.
|
||||
self.double_precision_rope = False
|
||||
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
|
||||
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
|
||||
# the key is absent, so double_precision_rope = False. For LTX-2.3
|
||||
# (has_prompt_adaln=True) the safetensors config has
|
||||
# frequencies_precision="float64", so double_precision_rope = True.
|
||||
if not self.has_prompt_adaln:
|
||||
self.double_precision_rope = False
|
||||
|
||||
# Convert string enum values if loading from dict
|
||||
if isinstance(self.model_type, str):
|
||||
|
||||
@@ -399,13 +399,13 @@ def precompute_freqs_cis(
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Cast positions to bfloat16 to match PyTorch's behavior.
|
||||
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
|
||||
# generate_freqs computation — fractional positions, scaling, etc. are all
|
||||
# computed in bfloat16. The multiplication with float32 freq_indices then
|
||||
# upcasts to float32. This precision behavior is what the model was trained
|
||||
# with, so we must replicate it.
|
||||
indices_grid = indices_grid.astype(mx.bfloat16)
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
|
||||
# empirical comparison shows float32 positions produce RoPE values matching
|
||||
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
|
||||
# position computation that gets amplified by high-frequency indices
|
||||
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
|
||||
indices_grid = indices_grid.astype(mx.float32)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||
This provides better numerical precision in the frequency generation phase.
|
||||
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
|
||||
frequency grid computation (log-spaced values), then converts to float32.
|
||||
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
|
||||
model dtype throughout generate_freqs).
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
|
||||
"Use float32 for position grids to avoid quality degradation.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Cast to float32 for position computation
|
||||
# Keep positions in float32 — same reasoning as the non-double-precision path.
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
|
||||
@@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module):
|
||||
)
|
||||
|
||||
# Deeper connectors with matching dims and gate_logits
|
||||
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
|
||||
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
|
||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
|
||||
@@ -160,6 +160,9 @@ class TilingConfig:
|
||||
) -> Optional["TilingConfig"]:
|
||||
"""Automatically determine tiling config based on video dimensions.
|
||||
|
||||
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
|
||||
enough context for CausalConv3d and sufficient overlap for clean blending.
|
||||
|
||||
Args:
|
||||
height: Video height in pixels
|
||||
width: Video width in pixels
|
||||
@@ -176,37 +179,17 @@ class TilingConfig:
|
||||
if not needs_spatial and not needs_temporal:
|
||||
return None
|
||||
|
||||
# Estimate memory requirement (rough heuristic)
|
||||
# Output size in bytes (float32): B * 3 * F * H * W * 4
|
||||
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
|
||||
|
||||
# For very large videos, use aggressive tiling
|
||||
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
|
||||
return cls.aggressive()
|
||||
|
||||
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
|
||||
# Smaller tiles cause quality degradation because CausalConv3d needs
|
||||
# sufficient temporal context and overlap for clean blending.
|
||||
spatial_config = None
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
# Choose tile size based on resolution
|
||||
max_dim = max(height, width)
|
||||
if max_dim > 1024:
|
||||
tile_size = 384 # Smaller tiles for very large resolutions
|
||||
elif max_dim > 768:
|
||||
tile_size = 512
|
||||
else:
|
||||
tile_size = 384
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
|
||||
if needs_temporal:
|
||||
# Choose tile size based on frame count
|
||||
if num_frames > 200:
|
||||
tile_size, overlap = 32, 8 # Aggressive for very long videos
|
||||
elif num_frames > 100:
|
||||
tile_size, overlap = 48, 16
|
||||
else:
|
||||
tile_size, overlap = 64, 24
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user