diff --git a/mlx_video/components/__init__.py b/mlx_video/components/__init__.py new file mode 100644 index 0000000..f70fdce --- /dev/null +++ b/mlx_video/components/__init__.py @@ -0,0 +1,3 @@ +from .smart_turn import Model, ModelConfig + +__all__ = ["Model", "ModelConfig"] diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 43bdb70..4121738 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -780,12 +780,9 @@ def denoise_dev_av( def load_audio_decoder(model_path: Path, pipeline: PipelineType): """Load audio VAE decoder.""" - from mlx_video.models.ltx.config import AudioDecoderModelConfig - from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType + from mlx_video.models.ltx.audio_vae import AudioDecoder - weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") - - decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae")) + decoder = AudioDecoder.from_pretrained(model_path / "audio_vae") return decoder @@ -794,8 +791,7 @@ def load_vocoder(model_path: Path, pipeline: PipelineType): """Load vocoder for mel to waveform conversion.""" from mlx_video.models.ltx.audio_vae import Vocoder - weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") - vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder")) + vocoder = Vocoder.from_pretrained(model_path / "vocoder") return vocoder @@ -951,8 +947,6 @@ def generate_video( text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) # Model weight file - weight_file = "ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors" - # Calculate latent dimensions if pipeline == PipelineType.DISTILLED: stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 @@ -1008,7 +1002,7 @@ def generate_video( # Load transformer transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..." with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): - transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True) + transformer = LTXModel.from_pretrained(model_path=model_path / "transformer", strict=True) console.print("[green]✓[/] Transformer loaded") @@ -1026,7 +1020,7 @@ def generate_video( stage2_image_latent = None if is_i2v: with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder")) + vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) @@ -1093,7 +1087,7 @@ def generate_video( upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) mx.eval(upsampler.parameters()) - vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file)) + vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) mx.eval(latents) @@ -1160,7 +1154,7 @@ def generate_video( image_latent = None if is_i2v: with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder")) + vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") input_image = load_image(image, height=height, width=width, dtype=model_dtype) image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) @@ -1173,7 +1167,7 @@ def generate_video( # Generate sigma schedule with token-count-dependent shifting num_tokens = latent_frames * latent_h * latent_w - sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens) + sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1238,7 +1232,7 @@ def generate_video( ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) - vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder") + vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) del transformer mx.clear_cache() diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index b130665..5551b0a 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -497,14 +497,17 @@ class LTXModel(nn.Module): def sanitize(self, weights: dict) -> dict: sanitized = {} - - if "model.diffusion_model." not in weights: + + has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights) + if not has_raw_prefix: return weights for key, value in weights.items(): new_key = key - # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) - if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key: + + if not key.startswith("model.diffusion_model."): + continue + if "audio_embeddings_connector" in key or "video_embeddings_connector" in key: continue # Remove 'model.diffusion_model.' prefix @@ -520,7 +523,6 @@ class LTXModel(nn.Module): new_key = new_key.replace(".linear_1.", ".linear1.") new_key = new_key.replace(".linear_2.", ".linear2.") - sanitized[new_key] = value return sanitized diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index a38bb6d..3fa22bb 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -646,36 +646,63 @@ class LTX2TextEncoder(nn.Module): self.language_model = LanguageModel.from_pretrained(text_encoder_path) - # Load transformer weights for feature extractor and connector - transformer_files = list(model_path.glob("ltx-2-19*.safetensors")) - if transformer_files: - transformer_weights = mx.load(str(transformer_files[0])) + # Load transformer weights for feature extractor and connector. + # These weights are stored differently depending on the repo format: + # 1. Monolithic (Lightricks/LTX-2): single ltx-2-19b-*.safetensors at root + # with raw PyTorch key names (model.diffusion_model.* prefix) + # 2. Reformatted (prince-canuma/LTX-2-distilled): separate text_projections/ + # directory with pre-sanitized keys (no prefix, already renamed) + transformer_weights = {} + is_reformatted = False + # Try reformatted layout first: text_projections/ subdirectory + text_proj_dir = model_path / "text_projections" + if text_proj_dir.is_dir(): + is_reformatted = True + for sf in text_proj_dir.glob("*.safetensors"): + transformer_weights.update(mx.load(str(sf))) + + # Fall back to monolithic layout: ltx-2-19*.safetensors at root + if not transformer_weights: + transformer_files = list(model_path.glob("ltx-2-19*.safetensors")) + if transformer_files: + transformer_weights = mx.load(str(transformer_files[0])) + + if transformer_weights: # Load feature extractor (aggregate_embed) - if "text_embedding_projection.aggregate_embed.weight" in transformer_weights: - self.feature_extractor.aggregate_embed.weight = transformer_weights[ - "text_embedding_projection.aggregate_embed.weight" - ] - + # Reformatted key: "aggregate_embed.weight" + # Monolithic key: "text_embedding_projection.aggregate_embed.weight" + agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" + if agg_key in transformer_weights: + self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key] # Load video_embeddings_connector weights connector_weights = {} - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.video_embeddings_connector."): - new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "") - connector_weights[new_key] = value + if is_reformatted: + # Reformatted: keys are already sanitized with "video_embeddings_connector." prefix + for key, value in transformer_weights.items(): + if key.startswith("video_embeddings_connector."): + new_key = key.replace("video_embeddings_connector.", "") + connector_weights[new_key] = value + else: + # Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix + for key, value in transformer_weights.items(): + if key.startswith("model.diffusion_model.video_embeddings_connector."): + new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "") + connector_weights[new_key] = value if connector_weights: - # Map weight names to our structure + # Map weight names to our structure (only needed for monolithic/raw PyTorch keys) mapped_weights = {} for key, value in connector_weights.items(): new_key = key - # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - # Map ff.net.2 -> ff.proj_out (output Linear) - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - # Map to_out.0 -> to_out (Sequential -> direct) - new_key = new_key.replace(".to_out.0.", ".to_out.") + if not is_reformatted: + # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + # Map ff.net.2 -> ff.proj_out (output Linear) + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + # Map to_out.0 -> to_out (Sequential -> direct) + new_key = new_key.replace(".to_out.0.", ".to_out.") mapped_weights[new_key] = value self.video_embeddings_connector.load_weights( @@ -688,22 +715,26 @@ class LTX2TextEncoder(nn.Module): # Load audio_embeddings_connector weights (same structure as video connector) audio_connector_weights = {} - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.audio_embeddings_connector."): - new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "") - audio_connector_weights[new_key] = value + if is_reformatted: + for key, value in transformer_weights.items(): + if key.startswith("audio_embeddings_connector."): + new_key = key.replace("audio_embeddings_connector.", "") + audio_connector_weights[new_key] = value + else: + for key, value in transformer_weights.items(): + if key.startswith("model.diffusion_model.audio_embeddings_connector."): + new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "") + audio_connector_weights[new_key] = value if audio_connector_weights: # Map weight names to our structure (same as video connector) mapped_audio_weights = {} for key, value in audio_connector_weights.items(): new_key = key - # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - # Map ff.net.2 -> ff.proj_out (output Linear) - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - # Map to_out.0 -> to_out (Sequential -> direct) - new_key = new_key.replace(".to_out.0.", ".to_out.") + if not is_reformatted: + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".to_out.0.", ".to_out.") mapped_audio_weights[new_key] = value self.audio_embeddings_connector.load_weights( @@ -713,6 +744,9 @@ class LTX2TextEncoder(nn.Module): # Manually load learnable_registers (it's a plain mx.array, not a parameter) if "learnable_registers" in audio_connector_weights: self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"] + else: + print("WARNING: No transformer weights found for text projection connectors. " + "Text conditioning will use uninitialized weights!") # Load tokenizer from transformers import AutoTokenizer diff --git a/mlx_video/utils.py b/mlx_video/utils.py index 2a6eefe..2cd8647 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -12,6 +12,8 @@ from PIL import Image def get_model_path(model_repo: str): """Get or download LTX-2 model path.""" try: + if Path(model_repo).exists(): + return Path(model_repo) return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) except Exception: print("Downloading LTX-2 model weights...")