diff --git a/mlx_video/generate.py b/mlx_video/generate.py index f542abb..5945486 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1192,9 +1192,11 @@ def generate_video( if is_i2v: console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") - audio_frames = None + # Always compute audio frames - PyTorch distilled pipeline unconditionally + # generates audio alongside video (model was trained with joint audio-video). + # The --audio flag only controls whether audio is decoded and saved to output. + audio_frames = compute_audio_frames(num_frames, fps) if audio: - audio_frames = compute_audio_frames(num_frames, fps) console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]") # Get model path @@ -1233,32 +1235,21 @@ def generate_video( prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") - # Encode prompts + # Encode prompts - always get audio embeddings since the model was trained + # with joint audio-video processing (PyTorch unconditionally generates audio) if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG - if audio: - video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) - video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) - model_dtype = video_embeddings_pos.dtype - mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) - else: - video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) - video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) - audio_embeddings_pos = audio_embeddings_neg = None - model_dtype = video_embeddings_pos.dtype - mx.eval(video_embeddings_pos, video_embeddings_neg) + video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) + video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) + model_dtype = video_embeddings_pos.dtype + mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) # For dev-two-stage, stage 2 uses single positive embedding (no CFG) if pipeline == PipelineType.DEV_TWO_STAGE: text_embeddings = video_embeddings_pos else: # Distilled pipeline - single embedding - if audio: - text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) - mx.eval(text_embeddings, audio_embeddings) - else: - text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) - audio_embeddings = None - mx.eval(text_embeddings) + text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) + mx.eval(text_embeddings, audio_embeddings) model_dtype = text_embeddings.dtype del text_encoder @@ -1317,12 +1308,10 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) - audio_positions = None - audio_latents = None - if audio: - audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) - mx.eval(audio_positions, audio_latents) + # Always init audio latents/positions - PyTorch unconditionally generates audio + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) + mx.eval(audio_positions, audio_latents) # Apply I2V conditioning state1 = None @@ -1406,7 +1395,7 @@ def generate_video( mx.eval(latents) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) - if audio and audio_latents is not None: + if audio_latents is not None: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) @@ -1417,7 +1406,7 @@ def generate_video( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, audio_latents=audio_latents, audio_positions=audio_positions, - audio_embeddings=audio_embeddings if audio else None, + audio_embeddings=audio_embeddings, ) elif pipeline == PipelineType.DEV: @@ -1451,12 +1440,10 @@ def generate_video( video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) mx.eval(video_positions) - audio_positions = None - audio_latents = None - if audio: - audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) - mx.eval(audio_positions, audio_latents) + # Always init audio latents/positions - PyTorch unconditionally generates audio + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) # Initialize latents with optional I2V conditioning video_state = None @@ -1484,31 +1471,19 @@ def generate_video( latents = mx.random.normal(video_latent_shape, dtype=model_dtype) mx.eval(latents) - # Denoise with CFG/APG/STG/modality - if audio: - latents, audio_latents = denoise_dev_av( - latents, audio_latents, - video_positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, - audio_cfg_scale=audio_cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_video_blocks=stg_blocks, - stg_audio_blocks=stg_blocks, modality_scale=modality_scale, - ) - else: - # Use original denoise_dev with computed sigmas - latents = denoise_dev( - latents, video_positions, - video_embeddings_pos, video_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, - cfg_rescale=cfg_rescale, - verbose=verbose, state=video_state, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_blocks=stg_blocks, - ) + # Always use A/V denoising - PyTorch always processes audio+video jointly + latents, audio_latents = denoise_dev_av( + latents, audio_latents, + video_positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + audio_cfg_scale=audio_cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) @@ -1553,12 +1528,10 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) - audio_positions = None - audio_latents = None - if audio: - audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) - mx.eval(audio_positions, audio_latents) + # Always init audio latents/positions - PyTorch unconditionally generates audio + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) # Apply I2V conditioning for stage 1 state1 = None @@ -1586,33 +1559,21 @@ def generate_video( latents = mx.random.normal(stage1_shape, dtype=model_dtype) mx.eval(latents) - # Stage 1: Joint AV denoising at half resolution (matches PyTorch) - if audio: - latents, audio_latents = denoise_dev_av( - latents, audio_latents, - positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, - audio_cfg_scale=audio_cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_video_blocks=stg_blocks, - stg_audio_blocks=stg_blocks, modality_scale=modality_scale, - ) - else: - latents = denoise_dev( - latents, positions, - video_embeddings_pos, video_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, - cfg_rescale=cfg_rescale, - verbose=verbose, state=state1, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, - stg_scale=stg_scale, stg_blocks=stg_blocks, - ) + # Stage 1: Always use joint AV denoising (matches PyTorch) + latents, audio_latents = denoise_dev_av( + latents, audio_latents, + positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + audio_cfg_scale=audio_cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + ) - if audio and audio_latents is not None: - mx.eval(audio_latents) + mx.eval(audio_latents) # Upsample latents 2x with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): @@ -1680,7 +1641,7 @@ def generate_video( mx.eval(latents) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) - if audio and audio_latents is not None: + if audio_latents is not None: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) @@ -1691,7 +1652,7 @@ def generate_video( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, audio_latents=audio_latents, audio_positions=audio_positions, - audio_embeddings=audio_embeddings_pos if audio else None, + audio_embeddings=audio_embeddings_pos, ) del transformer diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 009bf62..1cfb0a6 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -147,11 +147,13 @@ class LTXModelConfig(BaseModelConfig): if self.audio_positional_embedding_max_pos is None: self.audio_positional_embedding_max_pos = [20] - # PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision" - # instead of "rope_double_precision" from the config, so double_precision_rope - # is always False in PyTorch regardless of what the config file says. Since the - # model was trained with this behavior, we must match it. - self.double_precision_rope = False + # PyTorch LTX-2 configurator reads "frequencies_precision" (not + # "double_precision_rope") from the config. For LTX-2 (no prompt adaln) + # the key is absent, so double_precision_rope = False. For LTX-2.3 + # (has_prompt_adaln=True) the safetensors config has + # frequencies_precision="float64", so double_precision_rope = True. + if not self.has_prompt_adaln: + self.double_precision_rope = False # Convert string enum values if loading from dict if isinstance(self.model_type, str): diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index d9ae359..cd2bda4 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -399,13 +399,13 @@ def precompute_freqs_cis( num_attention_heads, rope_type ) - # Cast positions to bfloat16 to match PyTorch's behavior. - # In PyTorch, positions are in bfloat16 (model dtype) during the entire - # generate_freqs computation — fractional positions, scaling, etc. are all - # computed in bfloat16. The multiplication with float32 freq_indices then - # upcasts to float32. This precision behavior is what the model was trained - # with, so we must replicate it. - indices_grid = indices_grid.astype(mx.bfloat16) + # Keep positions in float32 for RoPE computation. + # Even though PyTorch nominally casts positions to model dtype (bfloat16), + # empirical comparison shows float32 positions produce RoPE values matching + # PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional + # position computation that gets amplified by high-frequency indices + # (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88. + indices_grid = indices_grid.astype(mx.float32) # Generate frequency indices indices = generate_freq_grid(theta, indices_grid.shape[1], dim) @@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision( ) -> Tuple[mx.array, mx.array]: """Compute RoPE frequencies with higher precision using float64 for frequency grid. - Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid - computation (log-spaced values), then converts to float32 for the final tensor. - This provides better numerical precision in the frequency generation phase. + Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical + frequency grid computation (log-spaced values), then converts to float32. + Position grid stays in bfloat16 to match PyTorch behavior (positions are in + model dtype throughout generate_freqs). """ import numpy as np - # Warn if positions are bfloat16 - this causes quality degradation - if indices_grid.dtype == mx.bfloat16: - import warnings - warnings.warn( - "Position grid has dtype bfloat16, which causes precision loss in RoPE. " - "Use float32 for position grids to avoid quality degradation.", - UserWarning, - stacklevel=2 - ) - - # Cast to float32 for position computation + # Keep positions in float32 — same reasoning as the non-double-precision path. indices_grid_f32 = indices_grid.astype(mx.float32) n_pos_dims = indices_grid_f32.shape[1] diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 90c061b..de95504 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -725,17 +725,17 @@ class LTX2TextEncoder(nn.Module): ) # Deeper connectors with matching dims and gate_logits - # NOTE: positional_embedding_max_pos=[1] matches PyTorch default - # (connector_positional_embedding_max_pos not in LTX-2.3 config) + # connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors + # config (nested under config.transformer.connector_positional_embedding_max_pos) self.video_embeddings_connector = Embeddings1DConnector( dim=video_output_dim, num_heads=32, head_dim=128, num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[1], has_gate_logits=True, + positional_embedding_max_pos=[4096], has_gate_logits=True, ) self.audio_embeddings_connector = Embeddings1DConnector( dim=audio_output_dim, num_heads=32, head_dim=64, num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[1], has_gate_logits=True, + positional_embedding_max_pos=[4096], has_gate_logits=True, ) else: # LTX-2: shared feature extractor, 3840-dim connectors diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index 72d32e4..ad4c442 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -160,6 +160,9 @@ class TilingConfig: ) -> Optional["TilingConfig"]: """Automatically determine tiling config based on video dimensions. + Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides + enough context for CausalConv3d and sufficient overlap for clean blending. + Args: height: Video height in pixels width: Video width in pixels @@ -176,37 +179,17 @@ class TilingConfig: if not needs_spatial and not needs_temporal: return None - # Estimate memory requirement (rough heuristic) - # Output size in bytes (float32): B * 3 * F * H * W * 4 - estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3) - - # For very large videos, use aggressive tiling - if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100): - return cls.aggressive() - + # Use the same defaults as PyTorch (512px spatial, 64f temporal). + # Smaller tiles cause quality degradation because CausalConv3d needs + # sufficient temporal context and overlap for clean blending. spatial_config = None temporal_config = None if needs_spatial: - # Choose tile size based on resolution - max_dim = max(height, width) - if max_dim > 1024: - tile_size = 384 # Smaller tiles for very large resolutions - elif max_dim > 768: - tile_size = 512 - else: - tile_size = 384 - spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64) + spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64) if needs_temporal: - # Choose tile size based on frame count - if num_frames > 200: - tile_size, overlap = 32, 8 # Aggressive for very long videos - elif num_frames > 100: - tile_size, overlap = 48, 16 - else: - tile_size, overlap = 64, 24 - temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap) + temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24) return cls(spatial_config=spatial_config, temporal_config=temporal_config)