Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.
This commit is contained in:
3
mlx_video/components/__init__.py
Normal file
3
mlx_video/components/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .smart_turn import Model, ModelConfig
|
||||||
|
|
||||||
|
__all__ = ["Model", "ModelConfig"]
|
||||||
@@ -780,12 +780,9 @@ def denoise_dev_av(
|
|||||||
|
|
||||||
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
||||||
"""Load audio VAE decoder."""
|
"""Load audio VAE decoder."""
|
||||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
from mlx_video.models.ltx.audio_vae import AudioDecoder
|
||||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
|
||||||
|
|
||||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae")
|
||||||
|
|
||||||
decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae"))
|
|
||||||
|
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
@@ -794,8 +791,7 @@ def load_vocoder(model_path: Path, pipeline: PipelineType):
|
|||||||
"""Load vocoder for mel to waveform conversion."""
|
"""Load vocoder for mel to waveform conversion."""
|
||||||
from mlx_video.models.ltx.audio_vae import Vocoder
|
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||||
|
|
||||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
vocoder = Vocoder.from_pretrained(model_path / "vocoder")
|
||||||
vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder"))
|
|
||||||
|
|
||||||
return vocoder
|
return vocoder
|
||||||
|
|
||||||
@@ -951,8 +947,6 @@ def generate_video(
|
|||||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
|
|
||||||
# Model weight file
|
# Model weight file
|
||||||
weight_file = "ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors"
|
|
||||||
|
|
||||||
# Calculate latent dimensions
|
# Calculate latent dimensions
|
||||||
if pipeline == PipelineType.DISTILLED:
|
if pipeline == PipelineType.DISTILLED:
|
||||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
@@ -1008,7 +1002,7 @@ def generate_video(
|
|||||||
# Load transformer
|
# Load transformer
|
||||||
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
||||||
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
||||||
transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True)
|
transformer = LTXModel.from_pretrained(model_path=model_path / "transformer", strict=True)
|
||||||
|
|
||||||
console.print("[green]✓[/] Transformer loaded")
|
console.print("[green]✓[/] Transformer loaded")
|
||||||
|
|
||||||
@@ -1026,7 +1020,7 @@ def generate_video(
|
|||||||
stage2_image_latent = None
|
stage2_image_latent = None
|
||||||
if is_i2v:
|
if is_i2v:
|
||||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder"))
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||||
|
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||||
@@ -1093,7 +1087,7 @@ def generate_video(
|
|||||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||||
mx.eval(upsampler.parameters())
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||||
|
|
||||||
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
@@ -1160,7 +1154,7 @@ def generate_video(
|
|||||||
image_latent = None
|
image_latent = None
|
||||||
if is_i2v:
|
if is_i2v:
|
||||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder"))
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||||
|
|
||||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||||
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||||
@@ -1173,7 +1167,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Generate sigma schedule with token-count-dependent shifting
|
# Generate sigma schedule with token-count-dependent shifting
|
||||||
num_tokens = latent_frames * latent_h * latent_w
|
num_tokens = latent_frames * latent_h * latent_w
|
||||||
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
|
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||||
|
|
||||||
@@ -1238,7 +1232,7 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||||
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|||||||
@@ -498,13 +498,16 @@ class LTXModel(nn.Module):
|
|||||||
def sanitize(self, weights: dict) -> dict:
|
def sanitize(self, weights: dict) -> dict:
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
if "model.diffusion_model." not in weights:
|
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
|
||||||
|
if not has_raw_prefix:
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
|
||||||
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
if not key.startswith("model.diffusion_model."):
|
||||||
|
continue
|
||||||
|
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove 'model.diffusion_model.' prefix
|
# Remove 'model.diffusion_model.' prefix
|
||||||
@@ -520,7 +523,6 @@ class LTXModel(nn.Module):
|
|||||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||||
|
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|||||||
@@ -646,36 +646,63 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||||
|
|
||||||
# Load transformer weights for feature extractor and connector
|
# Load transformer weights for feature extractor and connector.
|
||||||
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
# These weights are stored differently depending on the repo format:
|
||||||
if transformer_files:
|
# 1. Monolithic (Lightricks/LTX-2): single ltx-2-19b-*.safetensors at root
|
||||||
transformer_weights = mx.load(str(transformer_files[0]))
|
# with raw PyTorch key names (model.diffusion_model.* prefix)
|
||||||
|
# 2. Reformatted (prince-canuma/LTX-2-distilled): separate text_projections/
|
||||||
|
# directory with pre-sanitized keys (no prefix, already renamed)
|
||||||
|
transformer_weights = {}
|
||||||
|
is_reformatted = False
|
||||||
|
|
||||||
|
# Try reformatted layout first: text_projections/ subdirectory
|
||||||
|
text_proj_dir = model_path / "text_projections"
|
||||||
|
if text_proj_dir.is_dir():
|
||||||
|
is_reformatted = True
|
||||||
|
for sf in text_proj_dir.glob("*.safetensors"):
|
||||||
|
transformer_weights.update(mx.load(str(sf)))
|
||||||
|
|
||||||
|
# Fall back to monolithic layout: ltx-2-19*.safetensors at root
|
||||||
|
if not transformer_weights:
|
||||||
|
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||||
|
if transformer_files:
|
||||||
|
transformer_weights = mx.load(str(transformer_files[0]))
|
||||||
|
|
||||||
|
if transformer_weights:
|
||||||
# Load feature extractor (aggregate_embed)
|
# Load feature extractor (aggregate_embed)
|
||||||
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
|
# Reformatted key: "aggregate_embed.weight"
|
||||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
# Monolithic key: "text_embedding_projection.aggregate_embed.weight"
|
||||||
"text_embedding_projection.aggregate_embed.weight"
|
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||||
]
|
if agg_key in transformer_weights:
|
||||||
|
self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key]
|
||||||
|
|
||||||
# Load video_embeddings_connector weights
|
# Load video_embeddings_connector weights
|
||||||
connector_weights = {}
|
connector_weights = {}
|
||||||
for key, value in transformer_weights.items():
|
if is_reformatted:
|
||||||
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
# Reformatted: keys are already sanitized with "video_embeddings_connector." prefix
|
||||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
for key, value in transformer_weights.items():
|
||||||
connector_weights[new_key] = value
|
if key.startswith("video_embeddings_connector."):
|
||||||
|
new_key = key.replace("video_embeddings_connector.", "")
|
||||||
|
connector_weights[new_key] = value
|
||||||
|
else:
|
||||||
|
# Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix
|
||||||
|
for key, value in transformer_weights.items():
|
||||||
|
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||||
|
connector_weights[new_key] = value
|
||||||
|
|
||||||
if connector_weights:
|
if connector_weights:
|
||||||
# Map weight names to our structure
|
# Map weight names to our structure (only needed for monolithic/raw PyTorch keys)
|
||||||
mapped_weights = {}
|
mapped_weights = {}
|
||||||
for key, value in connector_weights.items():
|
for key, value in connector_weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
if not is_reformatted:
|
||||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
mapped_weights[new_key] = value
|
mapped_weights[new_key] = value
|
||||||
|
|
||||||
self.video_embeddings_connector.load_weights(
|
self.video_embeddings_connector.load_weights(
|
||||||
@@ -688,22 +715,26 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
# Load audio_embeddings_connector weights (same structure as video connector)
|
# Load audio_embeddings_connector weights (same structure as video connector)
|
||||||
audio_connector_weights = {}
|
audio_connector_weights = {}
|
||||||
for key, value in transformer_weights.items():
|
if is_reformatted:
|
||||||
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
for key, value in transformer_weights.items():
|
||||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
if key.startswith("audio_embeddings_connector."):
|
||||||
audio_connector_weights[new_key] = value
|
new_key = key.replace("audio_embeddings_connector.", "")
|
||||||
|
audio_connector_weights[new_key] = value
|
||||||
|
else:
|
||||||
|
for key, value in transformer_weights.items():
|
||||||
|
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||||
|
audio_connector_weights[new_key] = value
|
||||||
|
|
||||||
if audio_connector_weights:
|
if audio_connector_weights:
|
||||||
# Map weight names to our structure (same as video connector)
|
# Map weight names to our structure (same as video connector)
|
||||||
mapped_audio_weights = {}
|
mapped_audio_weights = {}
|
||||||
for key, value in audio_connector_weights.items():
|
for key, value in audio_connector_weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
if not is_reformatted:
|
||||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
|
||||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
|
||||||
mapped_audio_weights[new_key] = value
|
mapped_audio_weights[new_key] = value
|
||||||
|
|
||||||
self.audio_embeddings_connector.load_weights(
|
self.audio_embeddings_connector.load_weights(
|
||||||
@@ -713,6 +744,9 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||||
if "learnable_registers" in audio_connector_weights:
|
if "learnable_registers" in audio_connector_weights:
|
||||||
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
||||||
|
else:
|
||||||
|
print("WARNING: No transformer weights found for text projection connectors. "
|
||||||
|
"Text conditioning will use uninitialized weights!")
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from PIL import Image
|
|||||||
def get_model_path(model_repo: str):
|
def get_model_path(model_repo: str):
|
||||||
"""Get or download LTX-2 model path."""
|
"""Get or download LTX-2 model path."""
|
||||||
try:
|
try:
|
||||||
|
if Path(model_repo).exists():
|
||||||
|
return Path(model_repo)
|
||||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Downloading LTX-2 model weights...")
|
print("Downloading LTX-2 model weights...")
|
||||||
|
|||||||
Reference in New Issue
Block a user