From df753312c7bfc6e4c721b9c637e6023fafd96694 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 23 Jan 2026 17:39:02 +0100 Subject: [PATCH] Refactor video generation and model loading processes to utilize from_pretrained methods for VideoEncoder and VideoDecoder. Update denoising functions to include a cfg_rescale parameter for improved artifact reduction. Ensure consistent dtype handling across audio and video processing, enhancing precision and aligning with PyTorch behavior. --- mlx_video/__init__.py | 17 +- mlx_video/generate.py | 249 ++++++++++------------ mlx_video/models/ltx/video_vae/decoder.py | 4 + 3 files changed, 119 insertions(+), 151 deletions(-) diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index 07fd7c1..0256f7b 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -10,24 +10,16 @@ from mlx_video.convert import ( # Audio VAE components from mlx_video.models.ltx.audio_vae import ( - AudioEncoder, AudioDecoder, Vocoder, - AudioProcessor, decode_audio, -) - -# Patchifiers -from mlx_video.components.patchifiers import ( - VideoLatentPatchifier, AudioPatchifier, - VideoLatentShape, AudioLatentShape, + PerChannelStatistics, ) # Conditioning from mlx_video.conditioning import ( - VideoConditionByKeyframeIndex, VideoConditionByLatentIndex, ) @@ -43,17 +35,12 @@ __all__ = [ "sanitize_audio_vae_weights", "sanitize_vocoder_weights", # Audio VAE - "AudioEncoder", "AudioDecoder", "Vocoder", - "AudioProcessor", "decode_audio", - # Patchifiers - "VideoLatentPatchifier", "AudioPatchifier", - "VideoLatentShape", "AudioLatentShape", + "PerChannelStatistics", # Conditioning - "VideoConditionByKeyframeIndex", "VideoConditionByLatentIndex", ] \ No newline at end of file diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 8c99153..2811368 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -21,13 +21,12 @@ from rich.panel import Panel console = Console() -from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path -from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder -from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder +from mlx_video.models.ltx.video_vae.decoder import VideoDecoder +from mlx_video.models.ltx.video_vae import VideoEncoder from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning @@ -58,19 +57,8 @@ AUDIO_MEL_BINS = 16 AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 # Default negative prompt for CFG (dev pipeline) -DEFAULT_NEGATIVE_PROMPT = ( - "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " - "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " - "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " - "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " - "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " - "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " - "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " - "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " - "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " - "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " - "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -) +# Matches PyTorch LTX-2 reference InferenceConfig default +DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: @@ -123,11 +111,12 @@ def ltx2_scheduler( # Apply shift transformation power = 1 - sigmas = np.where( - sigmas != 0, - math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), - 0, - ) + with np.errstate(divide='ignore', invalid='ignore'): + sigmas = np.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) # Stretch sigmas to terminal value if stretch: @@ -194,7 +183,13 @@ def create_position_grid( a_max=None ) - pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + # Compute temporal division in bfloat16 to match PyTorch's precision behavior + # This ensures RoPE frequencies are computed identically to the reference implementation + temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16) + fps_bf16 = mx.array(fps, dtype=mx.bfloat16) + temporal_coords = temporal_coords / fps_bf16 + mx.eval(temporal_coords) + pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32)) return mx.array(pixel_coords, dtype=mx.float32) @@ -484,16 +479,29 @@ def denoise_dev_av( transformer: LTXModel, sigmas: mx.array, cfg_scale: float = 4.0, + cfg_rescale: float = 0.0, verbose: bool = True, video_state: Optional[LatentState] = None, ) -> tuple[mx.array, mx.array]: - """Run denoising loop for dev pipeline with CFG and audio.""" + """Run denoising loop for dev pipeline with CFG and audio. + + Args: + cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result + towards the positive-only prediction, helping reduce artifacts. + Default 0.0 means no rescaling (standard CFG). + """ from mlx_video.models.ltx.rope import precompute_freqs_cis dtype = video_latents.dtype if video_state is not None: video_latents = video_state.latent + # Keep latents in float32 throughout the denoising loop to avoid + # bfloat16 quantization noise accumulation over many steps. + # PyTorch keeps latents in float32; model input is cast to model dtype. + video_latents = video_latents.astype(mx.float32) + audio_latents = audio_latents.astype(mx.float32) + sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 num_steps = len(sigmas_list) - 1 @@ -538,15 +546,15 @@ def denoise_dev_av( sigma = sigmas_list[i] sigma_next = sigmas_list[i + 1] - # Flatten video latents + # Flatten video latents (cast to model dtype for transformer input) b, c, f, h, w = video_latents.shape num_video_tokens = f * h * w - video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) + video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)).astype(dtype) - # Flatten audio latents + # Flatten audio latents (cast to model dtype for transformer input) ab, ac, at, af = audio_latents.shape audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) # Compute timesteps if video_state is not None: @@ -571,8 +579,26 @@ def denoise_dev_av( positional_embeddings=precomputed_audio_rope, ) video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) + mx.eval(video_vel_pos, audio_vel_pos) - if use_cfg: + # Convert velocity to denoised (x0) using per-token timesteps + # This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity + # For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant) + # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity + # Use the float32 latents (not the bfloat16 model input) for precision + video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) + audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af)) + video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) + audio_timesteps_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1) + + video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) + audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) + + # Dynamic CFG: compute per-step effective scale + step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0 + apply_cfg_this_step = step_cfg_scale > 1.0 + + if apply_cfg_this_step: # Negative conditioning pass video_modality_neg = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, @@ -585,39 +611,53 @@ def denoise_dev_av( positional_embeddings=precomputed_audio_rope, ) video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) + mx.eval(video_vel_neg, audio_vel_neg) - # Apply CFG - video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg) - audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg) + # Convert negative velocity to x0 using per-token timesteps + video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) + audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) + + # Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider + # delta = (scale - 1) * (x0_pos - x0_neg) + # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect) + video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) + audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) + + # Apply CFG rescale if enabled + if cfg_rescale > 0.0: + video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32 + audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32 else: - video_velocity_flat = video_vel_pos - audio_velocity_flat = audio_vel_pos + video_x0_guided_f32 = video_x0_pos_f32 + audio_x0_guided_f32 = audio_x0_pos_f32 - # Reshape velocities - video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w)) - audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af)) - audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) + # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) + video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) + audio_denoised_f32 = mx.reshape(audio_x0_guided_f32, (ab, at, ac, af)) + audio_denoised_f32 = mx.transpose(audio_denoised_f32, (0, 2, 1, 3)) - # Compute denoised - video_denoised = to_denoised(video_latents, video_velocity, sigma) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + # Post-process: blend denoised with clean latent using mask + # Matches PyTorch's post_process_latent: denoised * mask + clean * (1 - mask) + sigma_f32 = mx.array(sigma, dtype=mx.float32) if video_state is not None: - video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) + clean_f32 = video_state.clean_latent.astype(mx.float32) + mask_f32 = video_state.denoise_mask.astype(mx.float32) + video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * (1.0 - mask_f32) - # Euler step + mx.eval(video_denoised_f32, audio_denoised_f32) + + # Euler step matching PyTorch: sample + velocity * dt + # Latents stay in float32 throughout (matching PyTorch behavior) if sigma_next > 0: - # Compute Euler step in float32 for precision (matching PyTorch behavior) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) - sigma_f32 = mx.array(sigma, dtype=mx.float32) + dt_f32 = sigma_next_f32 - sigma_f32 - video_latents_f32 = video_latents.astype(mx.float32) - video_denoised_f32 = video_denoised.astype(mx.float32) - video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype) + video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32 + video_latents = video_latents + video_velocity_f32 * dt_f32 - audio_latents_f32 = audio_latents.astype(mx.float32) - audio_denoised_f32 = audio_denoised.astype(mx.float32) - audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) + audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32 + audio_latents = audio_latents + audio_velocity_f32 * dt_f32 else: video_latents = video_denoised audio_latents = audio_denoised @@ -634,33 +674,12 @@ def denoise_dev_av( def load_audio_decoder(model_path: Path, pipeline: PipelineType): """Load audio VAE decoder.""" + from mlx_video.models.ltx.config import AudioDecoderModelConfig from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType - from mlx_video.convert import sanitize_audio_vae_weights - - decoder = AudioDecoder( - ch=128, - out_ch=2, - ch_mult=(1, 2, 4), - num_res_blocks=2, - attn_resolutions=set(), - resolution=256, - z_channels=AUDIO_LATENT_CHANNELS, - norm_type=NormType.PIXEL, - causality_axis=CausalityAxis.HEIGHT, - mel_bins=64, - mid_block_add_attention=False, # Config says no attention in mid block - ) weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") - if weight_file.exists(): - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_audio_vae_weights(raw_weights) - if sanitized: - decoder.load_weights(list(sanitized.items()), strict=False) - if "per_channel_statistics._mean_of_means" in sanitized: - decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] - if "per_channel_statistics._std_of_means" in sanitized: - decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] + + decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae")) return decoder @@ -668,24 +687,9 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType): def load_vocoder(model_path: Path, pipeline: PipelineType): """Load vocoder for mel to waveform conversion.""" from mlx_video.models.ltx.audio_vae import Vocoder - from mlx_video.convert import sanitize_vocoder_weights - - vocoder = Vocoder( - resblock_kernel_sizes=[3, 7, 11], - upsample_rates=[6, 5, 2, 2, 2], - upsample_kernel_sizes=[16, 15, 8, 4, 4], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_initial_channel=1024, - stereo=True, - output_sample_rate=AUDIO_SAMPLE_RATE, - ) weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") - if weight_file.exists(): - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_vocoder_weights(raw_weights) - if sanitized: - vocoder.load_weights(list(sanitized.items()), strict=False) + vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder")) return vocoder @@ -747,6 +751,7 @@ def generate_video( num_frames: int = 33, num_inference_steps: int = 40, cfg_scale: float = 4.0, + cfg_rescale: float = 0.0, seed: int = 42, fps: int = 24, output_path: str = "output.mp4", @@ -891,40 +896,7 @@ def generate_video( # Load transformer transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..." with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): - - - model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly - - config_kwargs = dict( - model_type=model_type, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - - if audio: - config_kwargs.update( - audio_num_attention_heads=32, - audio_attention_head_dim=64, - audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, - audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, - audio_cross_attention_dim=2048, - audio_positional_embedding_max_pos=[20], - ) - - config = LTXModelConfig(**config_kwargs) - - transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True) + transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True) console.print("[green]✓[/] Transformer loaded") @@ -942,8 +914,7 @@ def generate_video( stage2_image_latent = None if is_i2v: with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = load_vae_encoder(str(model_path / weight_file)) - mx.eval(vae_encoder.parameters()) + vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder")) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) @@ -1010,9 +981,9 @@ def generate_video( upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) mx.eval(upsampler.parameters()) - vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None) + vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file)) - latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) + latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) mx.eval(latents) del upsampler @@ -1077,8 +1048,7 @@ def generate_video( image_latent = None if is_i2v: with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = load_vae_encoder(str(model_path / weight_file)) - mx.eval(vae_encoder.parameters()) + vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder")) input_image = load_image(image, height=height, width=width, dtype=model_dtype) image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) @@ -1090,8 +1060,9 @@ def generate_video( console.print("[green]✓[/] VAE encoder loaded and image encoded") # Generate sigma schedule - num_tokens = latent_frames * latent_h * latent_w - sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens) + # PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses + # the default MAX_SHIFT_ANCHOR (4096) for the shift calculation + sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1141,16 +1112,20 @@ def generate_video( video_positions, audio_positions, video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state + transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state ) else: + # Use original denoise_dev with computed sigmas latents = denoise_dev( - latents, video_positions, video_embeddings_pos, video_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state + latents, video_positions, + video_embeddings_pos, video_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, state=video_state ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) - vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None) + vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file)) del transformer mx.clear_cache() @@ -1356,6 +1331,7 @@ Examples: parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)") parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)") + parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)") parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") parser.add_argument("--fps", type=int, default=24, help="Frames per second") parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path") @@ -1391,6 +1367,7 @@ Examples: num_frames=args.num_frames, num_inference_steps=args.steps, cfg_scale=args.cfg_scale, + cfg_rescale=args.cfg_rescale, seed=args.seed, fps=args.fps, output_path=args.output_path, diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 7499238..f14cca0 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -560,3 +560,7 @@ class LTX2VideoDecoder(nn.Module): chunked_conv=use_chunked_conv, on_frames_ready=on_frames_ready, ) + + +# Backward-compatible alias +VideoDecoder = LTX2VideoDecoder