Refactor video generation and model loading processes to utilize from_pretrained methods for VideoEncoder and VideoDecoder. Update denoising functions to include a cfg_rescale parameter for improved artifact reduction. Ensure consistent dtype handling across audio and video processing, enhancing precision and aligning with PyTorch behavior.

This commit is contained in:
Prince Canuma
2026-01-23 17:39:02 +01:00
parent 02bfa228d9
commit df753312c7
3 changed files with 119 additions and 151 deletions

View File

@@ -10,24 +10,16 @@ from mlx_video.convert import (
# Audio VAE components
from mlx_video.models.ltx.audio_vae import (
AudioEncoder,
AudioDecoder,
Vocoder,
AudioProcessor,
decode_audio,
)
# Patchifiers
from mlx_video.components.patchifiers import (
VideoLatentPatchifier,
AudioPatchifier,
VideoLatentShape,
AudioLatentShape,
PerChannelStatistics,
)
# Conditioning
from mlx_video.conditioning import (
VideoConditionByKeyframeIndex,
VideoConditionByLatentIndex,
)
@@ -43,17 +35,12 @@ __all__ = [
"sanitize_audio_vae_weights",
"sanitize_vocoder_weights",
# Audio VAE
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"AudioProcessor",
"decode_audio",
# Patchifiers
"VideoLatentPatchifier",
"AudioPatchifier",
"VideoLatentShape",
"AudioLatentShape",
"PerChannelStatistics",
# Conditioning
"VideoConditionByKeyframeIndex",
"VideoConditionByLatentIndex",
]

View File

@@ -21,13 +21,12 @@ from rich.panel import Panel
console = Console()
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.decoder import VideoDecoder
from mlx_video.models.ltx.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
@@ -58,19 +57,8 @@ AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
# Default negative prompt for CFG (dev pipeline)
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
# Matches PyTorch LTX-2 reference InferenceConfig default
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
@@ -123,6 +111,7 @@ def ltx2_scheduler(
# Apply shift transformation
power = 1
with np.errstate(divide='ignore', invalid='ignore'):
sigmas = np.where(
sigmas != 0,
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
@@ -194,7 +183,13 @@ def create_position_grid(
a_max=None
)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
# Compute temporal division in bfloat16 to match PyTorch's precision behavior
# This ensures RoPE frequencies are computed identically to the reference implementation
temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16)
fps_bf16 = mx.array(fps, dtype=mx.bfloat16)
temporal_coords = temporal_coords / fps_bf16
mx.eval(temporal_coords)
pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32))
return mx.array(pixel_coords, dtype=mx.float32)
@@ -484,16 +479,29 @@ def denoise_dev_av(
transformer: LTXModel,
sigmas: mx.array,
cfg_scale: float = 4.0,
cfg_rescale: float = 0.0,
verbose: bool = True,
video_state: Optional[LatentState] = None,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for dev pipeline with CFG and audio."""
"""Run denoising loop for dev pipeline with CFG and audio.
Args:
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
towards the positive-only prediction, helping reduce artifacts.
Default 0.0 means no rescaling (standard CFG).
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis
dtype = video_latents.dtype
if video_state is not None:
video_latents = video_state.latent
# Keep latents in float32 throughout the denoising loop to avoid
# bfloat16 quantization noise accumulation over many steps.
# PyTorch keeps latents in float32; model input is cast to model dtype.
video_latents = video_latents.astype(mx.float32)
audio_latents = audio_latents.astype(mx.float32)
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
num_steps = len(sigmas_list) - 1
@@ -538,15 +546,15 @@ def denoise_dev_av(
sigma = sigmas_list[i]
sigma_next = sigmas_list[i + 1]
# Flatten video latents
# Flatten video latents (cast to model dtype for transformer input)
b, c, f, h, w = video_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
# Flatten audio latents
# Flatten audio latents (cast to model dtype for transformer input)
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
# Compute timesteps
if video_state is not None:
@@ -571,8 +579,26 @@ def denoise_dev_av(
positional_embeddings=precomputed_audio_rope,
)
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
mx.eval(video_vel_pos, audio_vel_pos)
if use_cfg:
# Convert velocity to denoised (x0) using per-token timesteps
# This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant)
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
# Use the float32 latents (not the bfloat16 model input) for precision
video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af))
video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
audio_timesteps_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
# Dynamic CFG: compute per-step effective scale
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
apply_cfg_this_step = step_cfg_scale > 1.0
if apply_cfg_this_step:
# Negative conditioning pass
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
@@ -585,39 +611,53 @@ def denoise_dev_av(
positional_embeddings=precomputed_audio_rope,
)
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg)
# Apply CFG
video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg)
audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg)
# Convert negative velocity to x0 using per-token timesteps
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
# delta = (scale - 1) * (x0_pos - x0_neg)
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled
if cfg_rescale > 0.0:
video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32
audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32
else:
video_velocity_flat = video_vel_pos
audio_velocity_flat = audio_vel_pos
video_x0_guided_f32 = video_x0_pos_f32
audio_x0_guided_f32 = audio_x0_pos_f32
# Reshape velocities
video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w))
audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
audio_denoised_f32 = mx.reshape(audio_x0_guided_f32, (ab, at, ac, af))
audio_denoised_f32 = mx.transpose(audio_denoised_f32, (0, 2, 1, 3))
# Compute denoised
video_denoised = to_denoised(video_latents, video_velocity, sigma)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
if video_state is not None:
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
# Euler step
if sigma_next > 0:
# Compute Euler step in float32 for precision (matching PyTorch behavior)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
# Post-process: blend denoised with clean latent using mask
# Matches PyTorch's post_process_latent: denoised * mask + clean * (1 - mask)
sigma_f32 = mx.array(sigma, dtype=mx.float32)
video_latents_f32 = video_latents.astype(mx.float32)
video_denoised_f32 = video_denoised.astype(mx.float32)
video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype)
if video_state is not None:
clean_f32 = video_state.clean_latent.astype(mx.float32)
mask_f32 = video_state.denoise_mask.astype(mx.float32)
video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * (1.0 - mask_f32)
audio_latents_f32 = audio_latents.astype(mx.float32)
audio_denoised_f32 = audio_denoised.astype(mx.float32)
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
mx.eval(video_denoised_f32, audio_denoised_f32)
# Euler step matching PyTorch: sample + velocity * dt
# Latents stay in float32 throughout (matching PyTorch behavior)
if sigma_next > 0:
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
dt_f32 = sigma_next_f32 - sigma_f32
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
video_latents = video_latents + video_velocity_f32 * dt_f32
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
else:
video_latents = video_denoised
audio_latents = audio_denoised
@@ -634,33 +674,12 @@ def denoise_dev_av(
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
"""Load audio VAE decoder."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
from mlx_video.convert import sanitize_audio_vae_weights
decoder = AudioDecoder(
ch=128,
out_ch=2,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_resolutions=set(),
resolution=256,
z_channels=AUDIO_LATENT_CHANNELS,
norm_type=NormType.PIXEL,
causality_axis=CausalityAxis.HEIGHT,
mel_bins=64,
mid_block_add_attention=False, # Config says no attention in mid block
)
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_audio_vae_weights(raw_weights)
if sanitized:
decoder.load_weights(list(sanitized.items()), strict=False)
if "per_channel_statistics._mean_of_means" in sanitized:
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
if "per_channel_statistics._std_of_means" in sanitized:
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae"))
return decoder
@@ -668,24 +687,9 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
def load_vocoder(model_path: Path, pipeline: PipelineType):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
from mlx_video.convert import sanitize_vocoder_weights
vocoder = Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[16, 15, 8, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1024,
stereo=True,
output_sample_rate=AUDIO_SAMPLE_RATE,
)
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_vocoder_weights(raw_weights)
if sanitized:
vocoder.load_weights(list(sanitized.items()), strict=False)
vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder"))
return vocoder
@@ -747,6 +751,7 @@ def generate_video(
num_frames: int = 33,
num_inference_steps: int = 40,
cfg_scale: float = 4.0,
cfg_rescale: float = 0.0,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
@@ -891,40 +896,7 @@ def generate_video(
# Load transformer
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
config_kwargs = dict(
model_type=model_type,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
if audio:
config_kwargs.update(
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
audio_positional_embedding_max_pos=[20],
)
config = LTXModelConfig(**config_kwargs)
transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True)
transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True)
console.print("[green]✓[/] Transformer loaded")
@@ -942,8 +914,7 @@ def generate_video(
stage2_image_latent = None
if is_i2v:
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = load_vae_encoder(str(model_path / weight_file))
mx.eval(vae_encoder.parameters())
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder"))
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
@@ -1010,9 +981,9 @@ def generate_video(
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
mx.eval(latents)
del upsampler
@@ -1077,8 +1048,7 @@ def generate_video(
image_latent = None
if is_i2v:
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = load_vae_encoder(str(model_path / weight_file))
mx.eval(vae_encoder.parameters())
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder"))
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
@@ -1090,8 +1060,9 @@ def generate_video(
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Generate sigma schedule
num_tokens = latent_frames * latent_h * latent_w
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
# PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses
# the default MAX_SHIFT_ANCHOR (4096) for the shift calculation
sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
@@ -1141,16 +1112,20 @@ def generate_video(
video_positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state
)
else:
# Use original denoise_dev with computed sigmas
latents = denoise_dev(
latents, video_positions, video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state
latents, video_positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, state=video_state
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
del transformer
mx.clear_cache()
@@ -1356,6 +1331,7 @@ Examples:
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)")
parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)")
parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)")
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")
@@ -1391,6 +1367,7 @@ Examples:
num_frames=args.num_frames,
num_inference_steps=args.steps,
cfg_scale=args.cfg_scale,
cfg_rescale=args.cfg_rescale,
seed=args.seed,
fps=args.fps,
output_path=args.output_path,

View File

@@ -560,3 +560,7 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
# Backward-compatible alias
VideoDecoder = LTX2VideoDecoder