Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.
This commit is contained in:
@@ -780,12 +780,9 @@ def denoise_dev_av(
|
||||
|
||||
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
||||
"""Load audio VAE decoder."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder
|
||||
|
||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
||||
|
||||
decoder = AudioDecoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/audio_vae"))
|
||||
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae")
|
||||
|
||||
return decoder
|
||||
|
||||
@@ -794,8 +791,7 @@ def load_vocoder(model_path: Path, pipeline: PipelineType):
|
||||
"""Load vocoder for mel to waveform conversion."""
|
||||
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||
|
||||
weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors")
|
||||
vocoder = Vocoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vocoder"))
|
||||
vocoder = Vocoder.from_pretrained(model_path / "vocoder")
|
||||
|
||||
return vocoder
|
||||
|
||||
@@ -951,8 +947,6 @@ def generate_video(
|
||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||
|
||||
# Model weight file
|
||||
weight_file = "ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors"
|
||||
|
||||
# Calculate latent dimensions
|
||||
if pipeline == PipelineType.DISTILLED:
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
@@ -1008,7 +1002,7 @@ def generate_video(
|
||||
# Load transformer
|
||||
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
||||
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
||||
transformer = LTXModel.from_pretrained(model_path=Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/transformer"), strict=True)
|
||||
transformer = LTXModel.from_pretrained(model_path=model_path / "transformer", strict=True)
|
||||
|
||||
console.print("[green]✓[/] Transformer loaded")
|
||||
|
||||
@@ -1026,7 +1020,7 @@ def generate_video(
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-distilled/vae/encoder"))
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
@@ -1093,7 +1087,7 @@ def generate_video(
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
||||
mx.eval(latents)
|
||||
@@ -1160,7 +1154,7 @@ def generate_video(
|
||||
image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(Path("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/encoder"))
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
@@ -1173,7 +1167,7 @@ def generate_video(
|
||||
|
||||
# Generate sigma schedule with token-count-dependent shifting
|
||||
num_tokens = latent_frames * latent_h * latent_w
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||
|
||||
@@ -1238,7 +1232,7 @@ def generate_video(
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
Reference in New Issue
Block a user