fix loading

This commit is contained in:
Prince Canuma
2026-01-24 01:37:38 +01:00
parent ef76ec0921
commit cb2d19c84d
2 changed files with 22 additions and 24 deletions

View File

@@ -594,11 +594,7 @@ def denoise_dev_av(
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
# Dynamic CFG: compute per-step effective scale
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
apply_cfg_this_step = step_cfg_scale > 1.0
if apply_cfg_this_step:
if use_cfg:
# Negative conditioning pass
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
@@ -620,8 +616,8 @@ def denoise_dev_av(
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
# delta = (scale - 1) * (x0_pos - x0_neg)
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled
if cfg_rescale > 0.0:
@@ -659,8 +655,8 @@ def denoise_dev_av(
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
else:
video_latents = video_denoised
audio_latents = audio_denoised
video_latents = video_denoised_f32
audio_latents = audio_denoised_f32
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1125,7 +1121,7 @@ def generate_video(
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
del transformer
mx.clear_cache()