From cb2d19c84d8ff02f8c77a3522f7122f76cd6eff2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 24 Jan 2026 01:37:38 +0100 Subject: [PATCH] fix loading --- mlx_video/generate.py | 16 +++++------- mlx_video/models/ltx/video_vae/decoder.py | 30 ++++++++++++----------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 2811368..2486ef0 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -594,11 +594,7 @@ def denoise_dev_av( video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) - # Dynamic CFG: compute per-step effective scale - step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0 - apply_cfg_this_step = step_cfg_scale > 1.0 - - if apply_cfg_this_step: + if use_cfg: # Negative conditioning pass video_modality_neg = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, @@ -620,8 +616,8 @@ def denoise_dev_av( # Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider # delta = (scale - 1) * (x0_pos - x0_neg) # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect) - video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) - audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) + video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) + audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) # Apply CFG rescale if enabled if cfg_rescale > 0.0: @@ -659,8 +655,8 @@ def denoise_dev_av( audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32 audio_latents = audio_latents + audio_velocity_f32 * dt_f32 else: - video_latents = video_denoised - audio_latents = audio_denoised + video_latents = video_denoised_f32 + audio_latents = audio_denoised_f32 mx.eval(video_latents, audio_latents) progress.advance(task) @@ -1125,7 +1121,7 @@ def generate_video( ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) - vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file)) + vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder") del transformer mx.clear_cache() diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 4896c22..5f45d8a 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module): def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: # Build decoder weights dict with key remapping sanitized = {} + if "per_channel_statistics.mean" in weights: + return weights for key, value in weights.items(): new_key = key @@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module): import json model_path = Path(model_path) + config_dict = {} + + # Load config from directory + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path) as f: + config_dict = json.load(f) - if model_path.is_dir(): - # Load config from directory - config_path = model_path / "config.json" - if config_path.exists(): - with open(config_path) as f: - config_dict = json.load(f) - - # Load weights from directory - weight_files = sorted(model_path.glob("*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors files found in {model_path}") - weights = {} - for wf in weight_files: - weights.update(mx.load(str(wf))) + # Load weights from directory + weight_files = sorted(model_path.glob("*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + weights = {} + for wf in weight_files: + weights.update(mx.load(str(wf))) model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))