fix tiling, rope precision and weights
This commit is contained in:
@@ -1192,9 +1192,11 @@ def generate_video(
|
||||
if is_i2v:
|
||||
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
|
||||
|
||||
audio_frames = None
|
||||
# Always compute audio frames - PyTorch distilled pipeline unconditionally
|
||||
# generates audio alongside video (model was trained with joint audio-video).
|
||||
# The --audio flag only controls whether audio is decoded and saved to output.
|
||||
audio_frames = compute_audio_frames(num_frames, fps)
|
||||
if audio:
|
||||
audio_frames = compute_audio_frames(num_frames, fps)
|
||||
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
|
||||
|
||||
# Get model path
|
||||
@@ -1233,32 +1235,21 @@ def generate_video(
|
||||
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
|
||||
|
||||
# Encode prompts
|
||||
# Encode prompts - always get audio embeddings since the model was trained
|
||||
# with joint audio-video processing (PyTorch unconditionally generates audio)
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
|
||||
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
|
||||
if audio:
|
||||
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||
else:
|
||||
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
|
||||
audio_embeddings_pos = audio_embeddings_neg = None
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg)
|
||||
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
|
||||
if pipeline == PipelineType.DEV_TWO_STAGE:
|
||||
text_embeddings = video_embeddings_pos
|
||||
else:
|
||||
# Distilled pipeline - single embedding
|
||||
if audio:
|
||||
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
|
||||
mx.eval(text_embeddings, audio_embeddings)
|
||||
else:
|
||||
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||
audio_embeddings = None
|
||||
mx.eval(text_embeddings)
|
||||
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
|
||||
mx.eval(text_embeddings, audio_embeddings)
|
||||
model_dtype = text_embeddings.dtype
|
||||
|
||||
del text_encoder
|
||||
@@ -1317,12 +1308,10 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = None
|
||||
audio_latents = None
|
||||
if audio:
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning
|
||||
state1 = None
|
||||
@@ -1406,7 +1395,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio and audio_latents is not None:
|
||||
if audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1417,7 +1406,7 @@ def generate_video(
|
||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings if audio else None,
|
||||
audio_embeddings=audio_embeddings,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV:
|
||||
@@ -1451,12 +1440,10 @@ def generate_video(
|
||||
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
||||
mx.eval(video_positions)
|
||||
|
||||
audio_positions = None
|
||||
audio_latents = None
|
||||
if audio:
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Initialize latents with optional I2V conditioning
|
||||
video_state = None
|
||||
@@ -1484,31 +1471,19 @@ def generate_video(
|
||||
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
# Denoise with CFG/APG/STG/modality
|
||||
if audio:
|
||||
latents, audio_latents = denoise_dev_av(
|
||||
latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
)
|
||||
else:
|
||||
# Use original denoise_dev with computed sigmas
|
||||
latents = denoise_dev(
|
||||
latents, video_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
cfg_rescale=cfg_rescale,
|
||||
verbose=verbose, state=video_state,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_blocks=stg_blocks,
|
||||
)
|
||||
# Always use A/V denoising - PyTorch always processes audio+video jointly
|
||||
latents, audio_latents = denoise_dev_av(
|
||||
latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
@@ -1553,12 +1528,10 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = None
|
||||
audio_latents = None
|
||||
if audio:
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
state1 = None
|
||||
@@ -1586,33 +1559,21 @@ def generate_video(
|
||||
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
# Stage 1: Joint AV denoising at half resolution (matches PyTorch)
|
||||
if audio:
|
||||
latents, audio_latents = denoise_dev_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
)
|
||||
else:
|
||||
latents = denoise_dev(
|
||||
latents, positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
cfg_rescale=cfg_rescale,
|
||||
verbose=verbose, state=state1,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_blocks=stg_blocks,
|
||||
)
|
||||
# Stage 1: Always use joint AV denoising (matches PyTorch)
|
||||
latents, audio_latents = denoise_dev_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
)
|
||||
|
||||
if audio and audio_latents is not None:
|
||||
mx.eval(audio_latents)
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Upsample latents 2x
|
||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
||||
@@ -1680,7 +1641,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio and audio_latents is not None:
|
||||
if audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1691,7 +1652,7 @@ def generate_video(
|
||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings_pos if audio else None,
|
||||
audio_embeddings=audio_embeddings_pos,
|
||||
)
|
||||
|
||||
del transformer
|
||||
|
||||
@@ -147,11 +147,13 @@ class LTXModelConfig(BaseModelConfig):
|
||||
if self.audio_positional_embedding_max_pos is None:
|
||||
self.audio_positional_embedding_max_pos = [20]
|
||||
|
||||
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
|
||||
# instead of "rope_double_precision" from the config, so double_precision_rope
|
||||
# is always False in PyTorch regardless of what the config file says. Since the
|
||||
# model was trained with this behavior, we must match it.
|
||||
self.double_precision_rope = False
|
||||
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
|
||||
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
|
||||
# the key is absent, so double_precision_rope = False. For LTX-2.3
|
||||
# (has_prompt_adaln=True) the safetensors config has
|
||||
# frequencies_precision="float64", so double_precision_rope = True.
|
||||
if not self.has_prompt_adaln:
|
||||
self.double_precision_rope = False
|
||||
|
||||
# Convert string enum values if loading from dict
|
||||
if isinstance(self.model_type, str):
|
||||
|
||||
@@ -399,13 +399,13 @@ def precompute_freqs_cis(
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Cast positions to bfloat16 to match PyTorch's behavior.
|
||||
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
|
||||
# generate_freqs computation — fractional positions, scaling, etc. are all
|
||||
# computed in bfloat16. The multiplication with float32 freq_indices then
|
||||
# upcasts to float32. This precision behavior is what the model was trained
|
||||
# with, so we must replicate it.
|
||||
indices_grid = indices_grid.astype(mx.bfloat16)
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
|
||||
# empirical comparison shows float32 positions produce RoPE values matching
|
||||
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
|
||||
# position computation that gets amplified by high-frequency indices
|
||||
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
|
||||
indices_grid = indices_grid.astype(mx.float32)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||
This provides better numerical precision in the frequency generation phase.
|
||||
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
|
||||
frequency grid computation (log-spaced values), then converts to float32.
|
||||
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
|
||||
model dtype throughout generate_freqs).
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
|
||||
"Use float32 for position grids to avoid quality degradation.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Cast to float32 for position computation
|
||||
# Keep positions in float32 — same reasoning as the non-double-precision path.
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
|
||||
@@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module):
|
||||
)
|
||||
|
||||
# Deeper connectors with matching dims and gate_logits
|
||||
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
|
||||
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
|
||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
|
||||
@@ -160,6 +160,9 @@ class TilingConfig:
|
||||
) -> Optional["TilingConfig"]:
|
||||
"""Automatically determine tiling config based on video dimensions.
|
||||
|
||||
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
|
||||
enough context for CausalConv3d and sufficient overlap for clean blending.
|
||||
|
||||
Args:
|
||||
height: Video height in pixels
|
||||
width: Video width in pixels
|
||||
@@ -176,37 +179,17 @@ class TilingConfig:
|
||||
if not needs_spatial and not needs_temporal:
|
||||
return None
|
||||
|
||||
# Estimate memory requirement (rough heuristic)
|
||||
# Output size in bytes (float32): B * 3 * F * H * W * 4
|
||||
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
|
||||
|
||||
# For very large videos, use aggressive tiling
|
||||
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
|
||||
return cls.aggressive()
|
||||
|
||||
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
|
||||
# Smaller tiles cause quality degradation because CausalConv3d needs
|
||||
# sufficient temporal context and overlap for clean blending.
|
||||
spatial_config = None
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
# Choose tile size based on resolution
|
||||
max_dim = max(height, width)
|
||||
if max_dim > 1024:
|
||||
tile_size = 384 # Smaller tiles for very large resolutions
|
||||
elif max_dim > 768:
|
||||
tile_size = 512
|
||||
else:
|
||||
tile_size = 384
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
|
||||
if needs_temporal:
|
||||
# Choose tile size based on frame count
|
||||
if num_frames > 200:
|
||||
tile_size, overlap = 32, 8 # Aggressive for very long videos
|
||||
elif num_frames > 100:
|
||||
tile_size, overlap = 48, 16
|
||||
else:
|
||||
tile_size, overlap = 64, 24
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user