Refactor LoRA loading for v2.3 in generate.py to prioritize distilled-lora files over full model weights, enhancing model flexibility. Update key sanitization logic to utilize a replacement list for improved readability and maintainability. Modify denoise_dev_av function to include sigma parameters for audio and video modalities, ensuring consistent handling of latent variables during processing. Adjust Vocoder weight loading to allow for non-strict loading, accommodating additional keys in model weights.
This commit is contained in:
@@ -89,7 +89,9 @@ def load_and_merge_lora(
|
|||||||
candidates = sorted(lora_file.glob("*.safetensors"))
|
candidates = sorted(lora_file.glob("*.safetensors"))
|
||||||
if not candidates:
|
if not candidates:
|
||||||
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
|
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
|
||||||
lora_file = candidates[0]
|
# Prefer distilled-lora files over full model weights
|
||||||
|
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
||||||
|
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
||||||
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
|
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
|
||||||
else:
|
else:
|
||||||
# Treat as HuggingFace repo ID
|
# Treat as HuggingFace repo ID
|
||||||
@@ -97,7 +99,9 @@ def load_and_merge_lora(
|
|||||||
candidates = sorted(lora_dir.glob("*.safetensors"))
|
candidates = sorted(lora_dir.glob("*.safetensors"))
|
||||||
if not candidates:
|
if not candidates:
|
||||||
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
|
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
|
||||||
lora_file = candidates[0]
|
# Prefer distilled-lora files over full model weights
|
||||||
|
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
||||||
|
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
||||||
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
|
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
|
||||||
|
|
||||||
# Load LoRA weights
|
# Load LoRA weights
|
||||||
@@ -123,17 +127,26 @@ def load_and_merge_lora(
|
|||||||
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
|
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
|
||||||
|
|
||||||
# Apply key sanitization only for raw PyTorch format
|
# Apply key sanitization only for raw PyTorch format
|
||||||
|
# Replacements handle both mid-string and end-of-string positions
|
||||||
|
# since LoRA base keys end at the module name without trailing dot
|
||||||
|
_LORA_KEY_REPLACEMENTS = [
|
||||||
|
(".to_out.0", ".to_out"),
|
||||||
|
(".ff.net.0.proj", ".ff.proj_in"),
|
||||||
|
(".ff.net.2", ".ff.proj_out"),
|
||||||
|
(".audio_ff.net.0.proj", ".audio_ff.proj_in"),
|
||||||
|
(".audio_ff.net.2", ".audio_ff.proj_out"),
|
||||||
|
(".linear_1", ".linear1"),
|
||||||
|
(".linear_2", ".linear2"),
|
||||||
|
]
|
||||||
if has_prefix:
|
if has_prefix:
|
||||||
sanitized_pairs = {}
|
sanitized_pairs = {}
|
||||||
for key, pair in lora_pairs.items():
|
for key, pair in lora_pairs.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
for old, new in _LORA_KEY_REPLACEMENTS:
|
||||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
if new_key.endswith(old):
|
||||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
new_key = new_key[:-len(old)] + new
|
||||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
else:
|
||||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
new_key = new_key.replace(old + ".", new + ".")
|
||||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
|
||||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
|
||||||
sanitized_pairs[new_key] = pair
|
sanitized_pairs[new_key] = pair
|
||||||
else:
|
else:
|
||||||
sanitized_pairs = lora_pairs
|
sanitized_pairs = lora_pairs
|
||||||
@@ -823,15 +836,17 @@ def denoise_dev_av(
|
|||||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||||
|
|
||||||
# Positive conditioning pass
|
# Positive conditioning pass
|
||||||
|
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||||
|
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||||
video_modality_pos = Modality(
|
video_modality_pos = Modality(
|
||||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||||
positional_embeddings=precomputed_video_rope,
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||||
)
|
)
|
||||||
audio_modality_pos = Modality(
|
audio_modality_pos = Modality(
|
||||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||||
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
||||||
positional_embeddings=precomputed_audio_rope,
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||||
)
|
)
|
||||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||||
mx.eval(video_vel_pos, audio_vel_pos)
|
mx.eval(video_vel_pos, audio_vel_pos)
|
||||||
@@ -857,12 +872,12 @@ def denoise_dev_av(
|
|||||||
video_modality_neg = Modality(
|
video_modality_neg = Modality(
|
||||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||||
context=video_embeddings_neg, context_mask=None, enabled=True,
|
context=video_embeddings_neg, context_mask=None, enabled=True,
|
||||||
positional_embeddings=precomputed_video_rope,
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||||
)
|
)
|
||||||
audio_modality_neg = Modality(
|
audio_modality_neg = Modality(
|
||||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||||
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
||||||
positional_embeddings=precomputed_audio_rope,
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||||
)
|
)
|
||||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||||
mx.eval(video_vel_neg, audio_vel_neg)
|
mx.eval(video_vel_neg, audio_vel_neg)
|
||||||
|
|||||||
@@ -120,8 +120,8 @@ class Vocoder(nn.Module):
|
|||||||
model = cls(config)
|
model = cls(config)
|
||||||
weights = mx.load(str(model_path / "model.safetensors"))
|
weights = mx.load(str(model_path / "model.safetensors"))
|
||||||
|
|
||||||
# weights = vocoder.sanitize(weights)
|
# Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3)
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user