Refactor LoRA loading for v2.3 in generate.py to prioritize distilled-lora files over full model weights, enhancing model flexibility. Update key sanitization logic to utilize a replacement list for improved readability and maintainability. Modify denoise_dev_av function to include sigma parameters for audio and video modalities, ensuring consistent handling of latent variables during processing. Adjust Vocoder weight loading to allow for non-strict loading, accommodating additional keys in model weights.

This commit is contained in:
Prince Canuma
2026-03-14 15:24:50 +01:00
parent 9cba2ea7cd
commit ffe271699a
2 changed files with 30 additions and 15 deletions

View File

@@ -89,7 +89,9 @@ def load_and_merge_lora(
candidates = sorted(lora_file.glob("*.safetensors")) candidates = sorted(lora_file.glob("*.safetensors"))
if not candidates: if not candidates:
raise FileNotFoundError(f"No .safetensors files found in {lora_path}") raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
lora_file = candidates[0] # Prefer distilled-lora files over full model weights
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]") console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
else: else:
# Treat as HuggingFace repo ID # Treat as HuggingFace repo ID
@@ -97,7 +99,9 @@ def load_and_merge_lora(
candidates = sorted(lora_dir.glob("*.safetensors")) candidates = sorted(lora_dir.glob("*.safetensors"))
if not candidates: if not candidates:
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}") raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
lora_file = candidates[0] # Prefer distilled-lora files over full model weights
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]") console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
# Load LoRA weights # Load LoRA weights
@@ -123,17 +127,26 @@ def load_and_merge_lora(
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key] lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
# Apply key sanitization only for raw PyTorch format # Apply key sanitization only for raw PyTorch format
# Replacements handle both mid-string and end-of-string positions
# since LoRA base keys end at the module name without trailing dot
_LORA_KEY_REPLACEMENTS = [
(".to_out.0", ".to_out"),
(".ff.net.0.proj", ".ff.proj_in"),
(".ff.net.2", ".ff.proj_out"),
(".audio_ff.net.0.proj", ".audio_ff.proj_in"),
(".audio_ff.net.2", ".audio_ff.proj_out"),
(".linear_1", ".linear1"),
(".linear_2", ".linear2"),
]
if has_prefix: if has_prefix:
sanitized_pairs = {} sanitized_pairs = {}
for key, pair in lora_pairs.items(): for key, pair in lora_pairs.items():
new_key = key new_key = key
new_key = new_key.replace(".to_out.0.", ".to_out.") for old, new in _LORA_KEY_REPLACEMENTS:
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") if new_key.endswith(old):
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") new_key = new_key[:-len(old)] + new
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") else:
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") new_key = new_key.replace(old + ".", new + ".")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized_pairs[new_key] = pair sanitized_pairs[new_key] = pair
else: else:
sanitized_pairs = lora_pairs sanitized_pairs = lora_pairs
@@ -823,15 +836,17 @@ def denoise_dev_av(
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
# Positive conditioning pass # Positive conditioning pass
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
video_modality_pos = Modality( video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions, latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True, context=video_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, positional_embeddings=precomputed_video_rope, sigma=sigma_array,
) )
audio_modality_pos = Modality( audio_modality_pos = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_pos, context_mask=None, enabled=True, context=audio_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
) )
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
mx.eval(video_vel_pos, audio_vel_pos) mx.eval(video_vel_pos, audio_vel_pos)
@@ -857,12 +872,12 @@ def denoise_dev_av(
video_modality_neg = Modality( video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions, latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_neg, context_mask=None, enabled=True, context=video_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, positional_embeddings=precomputed_video_rope, sigma=sigma_array,
) )
audio_modality_neg = Modality( audio_modality_neg = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_neg, context_mask=None, enabled=True, context=audio_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
) )
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg) mx.eval(video_vel_neg, audio_vel_neg)

View File

@@ -120,8 +120,8 @@ class Vocoder(nn.Module):
model = cls(config) model = cls(config)
weights = mx.load(str(model_path / "model.safetensors")) weights = mx.load(str(model_path / "model.safetensors"))
# weights = vocoder.sanitize(weights) # Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3)
model.load_weights(list(weights.items()), strict=strict) model.load_weights(list(weights.items()), strict=False)
return model return model