Refactor LoRA loading for v2.3 in generate.py to prioritize distilled-lora files over full model weights, enhancing model flexibility. Update key sanitization logic to utilize a replacement list for improved readability and maintainability. Modify denoise_dev_av function to include sigma parameters for audio and video modalities, ensuring consistent handling of latent variables during processing. Adjust Vocoder weight loading to allow for non-strict loading, accommodating additional keys in model weights.

This commit is contained in:
Prince Canuma
2026-03-14 15:24:50 +01:00
parent 9cba2ea7cd
commit ffe271699a
2 changed files with 30 additions and 15 deletions

View File

@@ -89,7 +89,9 @@ def load_and_merge_lora(
candidates = sorted(lora_file.glob("*.safetensors"))
if not candidates:
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
lora_file = candidates[0]
# Prefer distilled-lora files over full model weights
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
else:
# Treat as HuggingFace repo ID
@@ -97,7 +99,9 @@ def load_and_merge_lora(
candidates = sorted(lora_dir.glob("*.safetensors"))
if not candidates:
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
lora_file = candidates[0]
# Prefer distilled-lora files over full model weights
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
# Load LoRA weights
@@ -123,17 +127,26 @@ def load_and_merge_lora(
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
# Apply key sanitization only for raw PyTorch format
# Replacements handle both mid-string and end-of-string positions
# since LoRA base keys end at the module name without trailing dot
_LORA_KEY_REPLACEMENTS = [
(".to_out.0", ".to_out"),
(".ff.net.0.proj", ".ff.proj_in"),
(".ff.net.2", ".ff.proj_out"),
(".audio_ff.net.0.proj", ".audio_ff.proj_in"),
(".audio_ff.net.2", ".audio_ff.proj_out"),
(".linear_1", ".linear1"),
(".linear_2", ".linear2"),
]
if has_prefix:
sanitized_pairs = {}
for key, pair in lora_pairs.items():
new_key = key
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
for old, new in _LORA_KEY_REPLACEMENTS:
if new_key.endswith(old):
new_key = new_key[:-len(old)] + new
else:
new_key = new_key.replace(old + ".", new + ".")
sanitized_pairs[new_key] = pair
else:
sanitized_pairs = lora_pairs
@@ -823,15 +836,17 @@ def denoise_dev_av(
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
# Positive conditioning pass
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_pos = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
mx.eval(video_vel_pos, audio_vel_pos)
@@ -857,12 +872,12 @@ def denoise_dev_av(
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_neg = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg)