Refactor video generation and model loading processes to utilize from_pretrained methods for VideoEncoder and VideoDecoder. Update denoising functions to include a cfg_rescale parameter for improved artifact reduction. Ensure consistent dtype handling across audio and video processing, enhancing precision and aligning with PyTorch behavior.
This commit is contained in:
@@ -10,24 +10,16 @@ from mlx_video.convert import (
|
||||
|
||||
# Audio VAE components
|
||||
from mlx_video.models.ltx.audio_vae import (
|
||||
AudioEncoder,
|
||||
AudioDecoder,
|
||||
Vocoder,
|
||||
AudioProcessor,
|
||||
decode_audio,
|
||||
)
|
||||
|
||||
# Patchifiers
|
||||
from mlx_video.components.patchifiers import (
|
||||
VideoLatentPatchifier,
|
||||
AudioPatchifier,
|
||||
VideoLatentShape,
|
||||
AudioLatentShape,
|
||||
PerChannelStatistics,
|
||||
)
|
||||
|
||||
# Conditioning
|
||||
from mlx_video.conditioning import (
|
||||
VideoConditionByKeyframeIndex,
|
||||
VideoConditionByLatentIndex,
|
||||
)
|
||||
|
||||
@@ -43,17 +35,12 @@ __all__ = [
|
||||
"sanitize_audio_vae_weights",
|
||||
"sanitize_vocoder_weights",
|
||||
# Audio VAE
|
||||
"AudioEncoder",
|
||||
"AudioDecoder",
|
||||
"Vocoder",
|
||||
"AudioProcessor",
|
||||
"decode_audio",
|
||||
# Patchifiers
|
||||
"VideoLatentPatchifier",
|
||||
"AudioPatchifier",
|
||||
"VideoLatentShape",
|
||||
"AudioLatentShape",
|
||||
"PerChannelStatistics",
|
||||
# Conditioning
|
||||
"VideoConditionByKeyframeIndex",
|
||||
"VideoConditionByLatentIndex",
|
||||
]
|
||||
@@ -21,13 +21,12 @@ from rich.panel import Panel
|
||||
console = Console()
|
||||
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
|
||||
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
from mlx_video.models.ltx.video_vae.decoder import VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||
@@ -58,19 +57,8 @@ AUDIO_MEL_BINS = 16
|
||||
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
||||
|
||||
# Default negative prompt for CFG (dev pipeline)
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
# Matches PyTorch LTX-2 reference InferenceConfig default
|
||||
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
|
||||
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
||||
@@ -123,6 +111,7 @@ def ltx2_scheduler(
|
||||
|
||||
# Apply shift transformation
|
||||
power = 1
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
sigmas = np.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
@@ -194,7 +183,13 @@ def create_position_grid(
|
||||
a_max=None
|
||||
)
|
||||
|
||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||
# Compute temporal division in bfloat16 to match PyTorch's precision behavior
|
||||
# This ensures RoPE frequencies are computed identically to the reference implementation
|
||||
temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16)
|
||||
fps_bf16 = mx.array(fps, dtype=mx.bfloat16)
|
||||
temporal_coords = temporal_coords / fps_bf16
|
||||
mx.eval(temporal_coords)
|
||||
pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32))
|
||||
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
@@ -484,16 +479,29 @@ def denoise_dev_av(
|
||||
transformer: LTXModel,
|
||||
sigmas: mx.array,
|
||||
cfg_scale: float = 4.0,
|
||||
cfg_rescale: float = 0.0,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for dev pipeline with CFG and audio."""
|
||||
"""Run denoising loop for dev pipeline with CFG and audio.
|
||||
|
||||
Args:
|
||||
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
|
||||
towards the positive-only prediction, helping reduce artifacts.
|
||||
Default 0.0 means no rescaling (standard CFG).
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
|
||||
dtype = video_latents.dtype
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
|
||||
# Keep latents in float32 throughout the denoising loop to avoid
|
||||
# bfloat16 quantization noise accumulation over many steps.
|
||||
# PyTorch keeps latents in float32; model input is cast to model dtype.
|
||||
video_latents = video_latents.astype(mx.float32)
|
||||
audio_latents = audio_latents.astype(mx.float32)
|
||||
|
||||
sigmas_list = sigmas.tolist()
|
||||
use_cfg = cfg_scale != 1.0
|
||||
num_steps = len(sigmas_list) - 1
|
||||
@@ -538,15 +546,15 @@ def denoise_dev_av(
|
||||
sigma = sigmas_list[i]
|
||||
sigma_next = sigmas_list[i + 1]
|
||||
|
||||
# Flatten video latents
|
||||
# Flatten video latents (cast to model dtype for transformer input)
|
||||
b, c, f, h, w = video_latents.shape
|
||||
num_video_tokens = f * h * w
|
||||
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
||||
|
||||
# Flatten audio latents
|
||||
# Flatten audio latents (cast to model dtype for transformer input)
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||
|
||||
# Compute timesteps
|
||||
if video_state is not None:
|
||||
@@ -571,8 +579,26 @@ def denoise_dev_av(
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
)
|
||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||
mx.eval(video_vel_pos, audio_vel_pos)
|
||||
|
||||
if use_cfg:
|
||||
# Convert velocity to denoised (x0) using per-token timesteps
|
||||
# This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity
|
||||
# For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant)
|
||||
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
||||
# Use the float32 latents (not the bfloat16 model input) for precision
|
||||
video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||
audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af))
|
||||
video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
|
||||
audio_timesteps_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
|
||||
|
||||
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
||||
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
||||
|
||||
# Dynamic CFG: compute per-step effective scale
|
||||
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
|
||||
apply_cfg_this_step = step_cfg_scale > 1.0
|
||||
|
||||
if apply_cfg_this_step:
|
||||
# Negative conditioning pass
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
@@ -585,39 +611,53 @@ def denoise_dev_av(
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
)
|
||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||
mx.eval(video_vel_neg, audio_vel_neg)
|
||||
|
||||
# Apply CFG
|
||||
video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg)
|
||||
audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg)
|
||||
# Convert negative velocity to x0 using per-token timesteps
|
||||
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
|
||||
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
|
||||
|
||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
|
||||
# Apply CFG rescale if enabled
|
||||
if cfg_rescale > 0.0:
|
||||
video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32
|
||||
audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32
|
||||
else:
|
||||
video_velocity_flat = video_vel_pos
|
||||
audio_velocity_flat = audio_vel_pos
|
||||
video_x0_guided_f32 = video_x0_pos_f32
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32
|
||||
|
||||
# Reshape velocities
|
||||
video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||
audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
||||
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
||||
video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||
audio_denoised_f32 = mx.reshape(audio_x0_guided_f32, (ab, at, ac, af))
|
||||
audio_denoised_f32 = mx.transpose(audio_denoised_f32, (0, 2, 1, 3))
|
||||
|
||||
# Compute denoised
|
||||
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||
|
||||
if video_state is not None:
|
||||
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
|
||||
|
||||
# Euler step
|
||||
if sigma_next > 0:
|
||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
# Post-process: blend denoised with clean latent using mask
|
||||
# Matches PyTorch's post_process_latent: denoised * mask + clean * (1 - mask)
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
|
||||
video_latents_f32 = video_latents.astype(mx.float32)
|
||||
video_denoised_f32 = video_denoised.astype(mx.float32)
|
||||
video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype)
|
||||
if video_state is not None:
|
||||
clean_f32 = video_state.clean_latent.astype(mx.float32)
|
||||
mask_f32 = video_state.denoise_mask.astype(mx.float32)
|
||||
video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * (1.0 - mask_f32)
|
||||
|
||||
audio_latents_f32 = audio_latents.astype(mx.float32)
|
||||
audio_denoised_f32 = audio_denoised.astype(mx.float32)
|
||||
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
|
||||
mx.eval(video_denoised_f32, audio_denoised_f32)
|
||||
|
||||
# Euler step matching PyTorch: sample + velocity * dt
|
||||
# Latents stay in float32 throughout (matching PyTorch behavior)
|
||||
if sigma_next > 0:
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
dt_f32 = sigma_next_f32 - sigma_f32
|
||||
|
||||
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
|
||||
video_latents = video_latents + video_velocity_f32 * dt_f32
|
||||
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
@@ -634,33 +674,12 @@ def denoise_dev_av(
|
||||
|
||||
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
||||
"""Load audio VAE decoder."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
||||
from mlx_video.convert import sanitize_audio_vae_weights
|
||||
|
||||
decoder = AudioDecoder(
|
||||
ch=128,
|
||||
out_ch=2,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=set(),
|
||||
resolution=256,
|
||||
z_channels=AUDIO_LATENT_CHANNELS,
|
||||
norm_type=NormType.PIXEL,
|
||||
causality_axis=CausalityAxis.HEIGHT,
|
||||
mel_bins=64,
|
||||
mid_block_add_attention=False, # Config says no attention in mid block
|
||||
)
|
||||
|
||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
||||
if weight_file.exists():
|
||||
raw_weights = mx.load(str(weight_file))
|
||||
sanitized = sanitize_audio_vae_weights(raw_weights)
|
||||
if sanitized:
|
||||
decoder.load_weights(list(sanitized.items()), strict=False)
|
||||
if "per_channel_statistics._mean_of_means" in sanitized:
|
||||
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
|
||||
if "per_channel_statistics._std_of_means" in sanitized:
|
||||
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
|
||||
|
||||
decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae"))
|
||||
|
||||
return decoder
|
||||
|
||||
@@ -668,24 +687,9 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
||||
def load_vocoder(model_path: Path, pipeline: PipelineType):
|
||||
"""Load vocoder for mel to waveform conversion."""
|
||||
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||
from mlx_video.convert import sanitize_vocoder_weights
|
||||
|
||||
vocoder = Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes=[16, 15, 8, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=1024,
|
||||
stereo=True,
|
||||
output_sample_rate=AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
||||
if weight_file.exists():
|
||||
raw_weights = mx.load(str(weight_file))
|
||||
sanitized = sanitize_vocoder_weights(raw_weights)
|
||||
if sanitized:
|
||||
vocoder.load_weights(list(sanitized.items()), strict=False)
|
||||
vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder"))
|
||||
|
||||
return vocoder
|
||||
|
||||
@@ -747,6 +751,7 @@ def generate_video(
|
||||
num_frames: int = 33,
|
||||
num_inference_steps: int = 40,
|
||||
cfg_scale: float = 4.0,
|
||||
cfg_rescale: float = 0.0,
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
@@ -891,40 +896,7 @@ def generate_video(
|
||||
# Load transformer
|
||||
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
||||
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
||||
|
||||
|
||||
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
|
||||
|
||||
config_kwargs = dict(
|
||||
model_type=model_type,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
if audio:
|
||||
config_kwargs.update(
|
||||
audio_num_attention_heads=32,
|
||||
audio_attention_head_dim=64,
|
||||
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||
audio_cross_attention_dim=2048,
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
)
|
||||
|
||||
config = LTXModelConfig(**config_kwargs)
|
||||
|
||||
transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True)
|
||||
transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True)
|
||||
|
||||
console.print("[green]✓[/] Transformer loaded")
|
||||
|
||||
@@ -942,8 +914,7 @@ def generate_video(
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = load_vae_encoder(str(model_path / weight_file))
|
||||
mx.eval(vae_encoder.parameters())
|
||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder"))
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
@@ -1010,9 +981,9 @@ def generate_video(
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None)
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
||||
mx.eval(latents)
|
||||
|
||||
del upsampler
|
||||
@@ -1077,8 +1048,7 @@ def generate_video(
|
||||
image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = load_vae_encoder(str(model_path / weight_file))
|
||||
mx.eval(vae_encoder.parameters())
|
||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder"))
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
@@ -1090,8 +1060,9 @@ def generate_video(
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Generate sigma schedule
|
||||
num_tokens = latent_frames * latent_h * latent_w
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
|
||||
# PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses
|
||||
# the default MAX_SHIFT_ANCHOR (4096) for the shift calculation
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||
|
||||
@@ -1141,16 +1112,20 @@ def generate_video(
|
||||
video_positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state
|
||||
)
|
||||
else:
|
||||
# Use original denoise_dev with computed sigmas
|
||||
latents = denoise_dev(
|
||||
latents, video_positions, video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state
|
||||
latents, video_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, state=video_state
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None)
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
@@ -1356,6 +1331,7 @@ Examples:
|
||||
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
|
||||
parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)")
|
||||
parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)")
|
||||
parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)")
|
||||
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
||||
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")
|
||||
@@ -1391,6 +1367,7 @@ Examples:
|
||||
num_frames=args.num_frames,
|
||||
num_inference_steps=args.steps,
|
||||
cfg_scale=args.cfg_scale,
|
||||
cfg_rescale=args.cfg_rescale,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output_path,
|
||||
|
||||
@@ -560,3 +560,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
VideoDecoder = LTX2VideoDecoder
|
||||
|
||||
Reference in New Issue
Block a user