fix loading
This commit is contained in:
@@ -594,11 +594,7 @@ def denoise_dev_av(
|
||||
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
||||
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
||||
|
||||
# Dynamic CFG: compute per-step effective scale
|
||||
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
|
||||
apply_cfg_this_step = step_cfg_scale > 1.0
|
||||
|
||||
if apply_cfg_this_step:
|
||||
if use_cfg:
|
||||
# Negative conditioning pass
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
@@ -620,8 +616,8 @@ def denoise_dev_av(
|
||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
|
||||
# Apply CFG rescale if enabled
|
||||
if cfg_rescale > 0.0:
|
||||
@@ -659,8 +655,8 @@ def denoise_dev_av(
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
video_latents = video_denoised_f32
|
||||
audio_latents = audio_denoised_f32
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
@@ -1125,7 +1121,7 @@ def generate_video(
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
||||
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
Reference in New Issue
Block a user