Refactor LoRA loading for v2.3 in generate.py to prioritize distilled-lora files over full model weights, enhancing model flexibility. Update key sanitization logic to utilize a replacement list for improved readability and maintainability. Modify denoise_dev_av function to include sigma parameters for audio and video modalities, ensuring consistent handling of latent variables during processing. Adjust Vocoder weight loading to allow for non-strict loading, accommodating additional keys in model weights.
This commit is contained in:
@@ -89,7 +89,9 @@ def load_and_merge_lora(
|
||||
candidates = sorted(lora_file.glob("*.safetensors"))
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
|
||||
lora_file = candidates[0]
|
||||
# Prefer distilled-lora files over full model weights
|
||||
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
||||
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
||||
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
|
||||
else:
|
||||
# Treat as HuggingFace repo ID
|
||||
@@ -97,7 +99,9 @@ def load_and_merge_lora(
|
||||
candidates = sorted(lora_dir.glob("*.safetensors"))
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
|
||||
lora_file = candidates[0]
|
||||
# Prefer distilled-lora files over full model weights
|
||||
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
||||
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
||||
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
|
||||
|
||||
# Load LoRA weights
|
||||
@@ -123,17 +127,26 @@ def load_and_merge_lora(
|
||||
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
|
||||
|
||||
# Apply key sanitization only for raw PyTorch format
|
||||
# Replacements handle both mid-string and end-of-string positions
|
||||
# since LoRA base keys end at the module name without trailing dot
|
||||
_LORA_KEY_REPLACEMENTS = [
|
||||
(".to_out.0", ".to_out"),
|
||||
(".ff.net.0.proj", ".ff.proj_in"),
|
||||
(".ff.net.2", ".ff.proj_out"),
|
||||
(".audio_ff.net.0.proj", ".audio_ff.proj_in"),
|
||||
(".audio_ff.net.2", ".audio_ff.proj_out"),
|
||||
(".linear_1", ".linear1"),
|
||||
(".linear_2", ".linear2"),
|
||||
]
|
||||
if has_prefix:
|
||||
sanitized_pairs = {}
|
||||
for key, pair in lora_pairs.items():
|
||||
new_key = key
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
for old, new in _LORA_KEY_REPLACEMENTS:
|
||||
if new_key.endswith(old):
|
||||
new_key = new_key[:-len(old)] + new
|
||||
else:
|
||||
new_key = new_key.replace(old + ".", new + ".")
|
||||
sanitized_pairs[new_key] = pair
|
||||
else:
|
||||
sanitized_pairs = lora_pairs
|
||||
@@ -823,15 +836,17 @@ def denoise_dev_av(
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
# Positive conditioning pass
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_pos = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||
mx.eval(video_vel_pos, audio_vel_pos)
|
||||
@@ -857,12 +872,12 @@ def denoise_dev_av(
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_neg = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||
mx.eval(video_vel_neg, audio_vel_neg)
|
||||
|
||||
@@ -120,8 +120,8 @@ class Vocoder(nn.Module):
|
||||
model = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
|
||||
# weights = vocoder.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
# Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3)
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user