From ffe271699a120f8c2d0f9dbea2e49d95f5f185d8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Mar 2026 15:24:50 +0100 Subject: [PATCH] Refactor LoRA loading for v2.3 in generate.py to prioritize distilled-lora files over full model weights, enhancing model flexibility. Update key sanitization logic to utilize a replacement list for improved readability and maintainability. Modify denoise_dev_av function to include sigma parameters for audio and video modalities, ensuring consistent handling of latent variables during processing. Adjust Vocoder weight loading to allow for non-strict loading, accommodating additional keys in model weights. --- mlx_video/generate.py | 41 ++++++++++++++++------- mlx_video/models/ltx/audio_vae/vocoder.py | 4 +-- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index daa7ed0..8253b57 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -89,7 +89,9 @@ def load_and_merge_lora( candidates = sorted(lora_file.glob("*.safetensors")) if not candidates: raise FileNotFoundError(f"No .safetensors files found in {lora_path}") - lora_file = candidates[0] + # Prefer distilled-lora files over full model weights + lora_candidates = [c for c in candidates if "distilled-lora" in c.name] + lora_file = lora_candidates[0] if lora_candidates else candidates[0] console.print(f"[dim]Using LoRA file: {lora_file.name}[/]") else: # Treat as HuggingFace repo ID @@ -97,7 +99,9 @@ def load_and_merge_lora( candidates = sorted(lora_dir.glob("*.safetensors")) if not candidates: raise FileNotFoundError(f"No .safetensors files found in {lora_dir}") - lora_file = candidates[0] + # Prefer distilled-lora files over full model weights + lora_candidates = [c for c in candidates if "distilled-lora" in c.name] + lora_file = lora_candidates[0] if lora_candidates else candidates[0] console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]") # Load LoRA weights @@ -123,17 +127,26 @@ def load_and_merge_lora( lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key] # Apply key sanitization only for raw PyTorch format + # Replacements handle both mid-string and end-of-string positions + # since LoRA base keys end at the module name without trailing dot + _LORA_KEY_REPLACEMENTS = [ + (".to_out.0", ".to_out"), + (".ff.net.0.proj", ".ff.proj_in"), + (".ff.net.2", ".ff.proj_out"), + (".audio_ff.net.0.proj", ".audio_ff.proj_in"), + (".audio_ff.net.2", ".audio_ff.proj_out"), + (".linear_1", ".linear1"), + (".linear_2", ".linear2"), + ] if has_prefix: sanitized_pairs = {} for key, pair in lora_pairs.items(): new_key = key - new_key = new_key.replace(".to_out.0.", ".to_out.") - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") - new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") - new_key = new_key.replace(".linear_1.", ".linear1.") - new_key = new_key.replace(".linear_2.", ".linear2.") + for old, new in _LORA_KEY_REPLACEMENTS: + if new_key.endswith(old): + new_key = new_key[:-len(old)] + new + else: + new_key = new_key.replace(old + ".", new + ".") sanitized_pairs[new_key] = pair else: sanitized_pairs = lora_pairs @@ -823,15 +836,17 @@ def denoise_dev_av( audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) # Positive conditioning pass + sigma_array = mx.full((b,), sigma, dtype=dtype) + audio_sigma_array = mx.full((ab,), sigma, dtype=dtype) video_modality_pos = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, context=video_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, + positional_embeddings=precomputed_video_rope, sigma=sigma_array, ) audio_modality_pos = Modality( latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, context=audio_embeddings_pos, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, + positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, ) video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) mx.eval(video_vel_pos, audio_vel_pos) @@ -857,12 +872,12 @@ def denoise_dev_av( video_modality_neg = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, context=video_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_video_rope, + positional_embeddings=precomputed_video_rope, sigma=sigma_array, ) audio_modality_neg = Modality( latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, context=audio_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_audio_rope, + positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, ) video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) mx.eval(video_vel_neg, audio_vel_neg) diff --git a/mlx_video/models/ltx/audio_vae/vocoder.py b/mlx_video/models/ltx/audio_vae/vocoder.py index f996d2f..ea06f63 100644 --- a/mlx_video/models/ltx/audio_vae/vocoder.py +++ b/mlx_video/models/ltx/audio_vae/vocoder.py @@ -120,8 +120,8 @@ class Vocoder(nn.Module): model = cls(config) weights = mx.load(str(model_path / "model.safetensors")) - # weights = vocoder.sanitize(weights) - model.load_weights(list(weights.items()), strict=strict) + # Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3) + model.load_weights(list(weights.items()), strict=False) return model