- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
451
main.py
451
main.py
@@ -1,6 +1,9 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
|
|||||||
from mlx_video.generate import create_position_grid
|
from mlx_video.generate import create_position_grid
|
||||||
from mlx_video.utils import to_denoised
|
from mlx_video.utils import to_denoised
|
||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
|
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||||
|
|
||||||
# Paths
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from pathlib import Path
|
|
||||||
import os
|
|
||||||
|
|
||||||
LTX2_REPO = "Lightricks/LTX-2"
|
|
||||||
|
|
||||||
def get_ltx2_cache_dir():
|
# Distilled sigma schedules
|
||||||
# Try to get local cache (local_only), will not download files
|
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model_repo: str):
|
||||||
|
"""Get or download LTX-2 model path."""
|
||||||
try:
|
try:
|
||||||
ref_path = snapshot_download(
|
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||||
repo_id=LTX2_REPO,
|
|
||||||
local_files_only=True,
|
|
||||||
allow_patterns=["*"],
|
|
||||||
ignore_patterns=[],
|
|
||||||
# leave as default revision and cache_dir, only local
|
|
||||||
)
|
|
||||||
return ref_path
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If not present locally, download from hub
|
print("Downloading LTX-2 model weights...")
|
||||||
return snapshot_download(
|
return Path(snapshot_download(
|
||||||
repo_id=LTX2_REPO,
|
repo_id=model_repo,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
allow_patterns=["*.safetensors", "*.json"],
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
ignore_patterns=[]
|
))
|
||||||
)
|
|
||||||
|
|
||||||
LTX2_PATH = Path(get_ltx2_cache_dir())
|
|
||||||
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
|
|
||||||
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
|
|
||||||
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
|
|
||||||
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
|
|
||||||
|
|
||||||
# Distilled sigma schedules (from PyTorch)
|
|
||||||
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
|
||||||
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
|
|
||||||
|
|
||||||
|
|
||||||
def denoise_loop(
|
def denoise(
|
||||||
latents: mx.array,
|
latents: mx.array,
|
||||||
positions: mx.array,
|
positions: mx.array,
|
||||||
text_embeddings: mx.array,
|
text_embeddings: mx.array,
|
||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigma_schedule: list,
|
sigmas: list,
|
||||||
stage_name: str = "Stage",
|
|
||||||
negative_embeddings: mx.array = None,
|
|
||||||
cfg_scale: float = 1.0,
|
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop for given sigma schedule.
|
"""Run denoising loop."""
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
Args:
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
latents: Noisy latent tensor
|
|
||||||
positions: Position embeddings
|
|
||||||
text_embeddings: Positive prompt embeddings
|
|
||||||
transformer: The transformer model
|
|
||||||
sigma_schedule: List of sigma values for each step
|
|
||||||
stage_name: Name for logging
|
|
||||||
negative_embeddings: Negative prompt embeddings for CFG (optional)
|
|
||||||
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
|
|
||||||
"""
|
|
||||||
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
|
|
||||||
|
|
||||||
for i in range(len(sigma_schedule) - 1):
|
|
||||||
sigma = sigma_schedule[i]
|
|
||||||
sigma_next = sigma_schedule[i + 1]
|
|
||||||
|
|
||||||
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
|
|
||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
latents_flat = mx.reshape(latents, (b, c, -1))
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||||
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
|
||||||
|
|
||||||
timesteps = mx.full((1,), sigma)
|
|
||||||
|
|
||||||
# Positive (conditioned) prediction
|
|
||||||
video_modality = Modality(
|
video_modality = Modality(
|
||||||
latent=latents_flat,
|
latent=latents_flat,
|
||||||
timesteps=timesteps,
|
timesteps=mx.full((1,), sigma),
|
||||||
positions=positions,
|
positions=positions,
|
||||||
context=text_embeddings,
|
context=text_embeddings,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
vx_cond, _ = transformer(video=video_modality, audio=None)
|
velocity, _ = transformer(video=video_modality, audio=None)
|
||||||
mx.eval(vx_cond)
|
mx.eval(velocity)
|
||||||
|
|
||||||
if use_cfg:
|
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||||
# Negative (unconditioned) prediction
|
denoised = to_denoised(latents, velocity, sigma)
|
||||||
video_modality_neg = Modality(
|
|
||||||
latent=latents_flat,
|
|
||||||
timesteps=timesteps,
|
|
||||||
positions=positions,
|
|
||||||
context=negative_embeddings,
|
|
||||||
context_mask=None,
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
|
|
||||||
mx.eval(vx_uncond)
|
|
||||||
|
|
||||||
# CFG: output = uncond + cfg_scale * (cond - uncond)
|
|
||||||
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
|
|
||||||
else:
|
|
||||||
vx = vx_cond
|
|
||||||
|
|
||||||
vx_reshaped = mx.transpose(vx, (0, 2, 1))
|
|
||||||
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
|
|
||||||
|
|
||||||
# Debug: Print velocity stats
|
|
||||||
vx_np = np.array(vx_reshaped)
|
|
||||||
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
|
|
||||||
|
|
||||||
# Get denoised prediction: x_0 = x_t - sigma * velocity
|
|
||||||
denoised = to_denoised(latents, vx_reshaped, sigma)
|
|
||||||
mx.eval(denoised)
|
mx.eval(denoised)
|
||||||
|
|
||||||
# Debug: Print denoised stats
|
|
||||||
denoised_np = np.array(denoised)
|
|
||||||
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
|
|
||||||
|
|
||||||
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
|
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
velocity = (latents - denoised) / sigma
|
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||||
latents = denoised + sigma_next * velocity
|
|
||||||
else:
|
else:
|
||||||
latents = denoised
|
latents = denoised
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
# Debug: Print latents after step
|
|
||||||
latents_np = np.array(latents)
|
|
||||||
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def generate_video(
|
||||||
print("="*60)
|
model_repo: str,
|
||||||
print("MLX LTX-2 Video Generation (Two-Stage)")
|
prompt: str,
|
||||||
print("="*60)
|
height: int = 512,
|
||||||
|
width: int = 512,
|
||||||
|
num_frames: int = 33,
|
||||||
|
seed: int = 42,
|
||||||
|
fps: int = 24,
|
||||||
|
output_path: str = "output.mp4",
|
||||||
|
save_frames: bool = False,
|
||||||
|
):
|
||||||
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
# Config - same as PyTorch reference
|
Args:
|
||||||
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
|
prompt: Text description of the video to generate
|
||||||
negative_prompt = "" # PyTorch script doesn't use negative prompt
|
height: Output video height (must be divisible by 64)
|
||||||
cfg_scale = 1.0 # No CFG in the distilled pipeline
|
width: Output video width (must be divisible by 64)
|
||||||
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
|
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
|
||||||
seed = 123
|
seed: Random seed for reproducibility
|
||||||
|
fps: Frames per second for output video
|
||||||
|
output_path: Path to save the output video
|
||||||
|
save_frames: Whether to save individual frames as images
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# Stage 1: Half resolution
|
# Validate dimensions
|
||||||
stage1_height = height // 2
|
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||||
stage1_width = width // 2
|
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||||
stage1_latent_height = stage1_height // 32
|
|
||||||
stage1_latent_width = stage1_width // 32
|
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||||
|
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||||
|
|
||||||
|
# Get model path
|
||||||
|
model_path = get_model_path(model_repo)
|
||||||
|
|
||||||
|
# Calculate latent dimensions
|
||||||
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
|
stage2_h, stage2_w = height // 32, width // 32
|
||||||
latent_frames = 1 + (num_frames - 1) // 8
|
latent_frames = 1 + (num_frames - 1) // 8
|
||||||
|
|
||||||
# Stage 2: Full resolution
|
|
||||||
latent_height = height // 32
|
|
||||||
latent_width = width // 32
|
|
||||||
|
|
||||||
print(f"\nConfig:")
|
|
||||||
print(f" Prompt: {prompt}")
|
|
||||||
print(f" Negative prompt: '{negative_prompt}'")
|
|
||||||
print(f" CFG scale: {cfg_scale}")
|
|
||||||
print(f" Final resolution: {width}x{height}, {num_frames} frames")
|
|
||||||
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
|
|
||||||
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
|
|
||||||
print(f" Seed: {seed}")
|
|
||||||
|
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Load text encoder
|
# Load text encoder
|
||||||
print("\nLoading text encoder...")
|
print("Loading text encoder...")
|
||||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||||
|
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||||
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
|
text_encoder.load(str(model_path))
|
||||||
text_encoder.load(str(LTX2_PATH))
|
|
||||||
mx.eval(text_encoder.parameters())
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
# Encode positive prompt
|
text_embeddings, _ = text_encoder(prompt)
|
||||||
print("Encoding text...")
|
|
||||||
text_embeddings, attention_mask = text_encoder(prompt)
|
|
||||||
mx.eval(text_embeddings)
|
mx.eval(text_embeddings)
|
||||||
print(f" Positive embeddings: {text_embeddings.shape}")
|
|
||||||
|
|
||||||
# Encode negative prompt for CFG
|
|
||||||
negative_embeddings, _ = text_encoder(negative_prompt)
|
|
||||||
mx.eval(negative_embeddings)
|
|
||||||
print(f" Negative embeddings: {negative_embeddings.shape}")
|
|
||||||
|
|
||||||
# Free text encoder memory
|
|
||||||
del text_encoder
|
del text_encoder
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Load transformer
|
# Load transformer
|
||||||
print("\nLoading transformer...")
|
print("Loading transformer...")
|
||||||
raw_weights = mx.load(MODEL_PATH)
|
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
sanitized = sanitize_transformer_weights(raw_weights)
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
|
||||||
config = LTXModelConfig(
|
config = LTXModelConfig(
|
||||||
@@ -231,164 +156,166 @@ def main():
|
|||||||
transformer = LTXModel(config)
|
transformer = LTXModel(config)
|
||||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
mx.eval(transformer.parameters())
|
mx.eval(transformer.parameters())
|
||||||
print(" Transformer loaded!")
|
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Stage 1: Generate at half resolution
|
# Stage 1: Generate at half resolution
|
||||||
# ========================================
|
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||||
print("\n" + "="*60)
|
|
||||||
print("Stage 1: Generating at half resolution")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
|
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
print(f" Initial latents: {latents.shape}")
|
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
latents = denoise_loop(
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||||
latents=latents,
|
|
||||||
positions=positions,
|
|
||||||
text_embeddings=text_embeddings,
|
|
||||||
transformer=transformer,
|
|
||||||
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
|
|
||||||
stage_name="Stage 1",
|
|
||||||
negative_embeddings=negative_embeddings,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\nStage 1 latents: {latents.shape}")
|
# Upsample latents
|
||||||
latents_np = np.array(latents)
|
print("Upsampling latents 2x...")
|
||||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Upsample latents 2x
|
|
||||||
# ========================================
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("Upsampling latents 2x")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
# Load upsampler
|
|
||||||
print(" Loading spatial upsampler...")
|
|
||||||
upsampler = load_upsampler(UPSAMPLER_PATH)
|
|
||||||
mx.eval(upsampler.parameters())
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
# Load latent statistics for normalization
|
vae_decoder = load_vae_decoder(
|
||||||
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
|
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||||
# EXPERIMENT: Disable VAE decode noise for sharper output
|
timestep_conditioning=True
|
||||||
# vae_decoder.decode_noise_scale = 0.0
|
)
|
||||||
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
|
|
||||||
latent_mean = vae_decoder.latents_mean
|
|
||||||
latent_std = vae_decoder.latents_std
|
|
||||||
|
|
||||||
# Upsample
|
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||||
print(" Upsampling...")
|
|
||||||
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
|
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
print(f" Upsampled latents: {latents.shape}")
|
|
||||||
|
|
||||||
# Free upsampler memory
|
|
||||||
del upsampler
|
del upsampler
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Stage 2: Refine at full resolution
|
# Stage 2: Refine at full resolution
|
||||||
# ========================================
|
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||||
print("\n" + "="*60)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
print("Stage 2: Refining at full resolution")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
# Debug: Print upsampled latent stats before adding noise
|
|
||||||
latents_np = np.array(latents)
|
|
||||||
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
|
||||||
|
|
||||||
# Create new position grid for full resolution
|
|
||||||
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
|
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
# Add noise at initial sigma for stage 2
|
# Add noise for refinement
|
||||||
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
# NOT addition: noisy = clean + scale * noise
|
|
||||||
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
|
|
||||||
noise = mx.random.normal(latents.shape)
|
noise = mx.random.normal(latents.shape)
|
||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
# Debug: Print latents after adding noise
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||||
latents_np = np.array(latents)
|
|
||||||
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
|
||||||
|
|
||||||
latents = denoise_loop(
|
|
||||||
latents=latents,
|
|
||||||
positions=positions,
|
|
||||||
text_embeddings=text_embeddings,
|
|
||||||
transformer=transformer,
|
|
||||||
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
|
|
||||||
stage_name="Stage 2",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\nFinal latents: {latents.shape}")
|
|
||||||
latents_np = np.array(latents)
|
|
||||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
|
||||||
|
|
||||||
# Save latents for PyTorch comparison
|
|
||||||
np.save("mlx_final_latents.npy", latents_np)
|
|
||||||
print(" Saved latents to mlx_final_latents.npy")
|
|
||||||
|
|
||||||
# Free transformer memory
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Decode to video
|
# Decode to video
|
||||||
# ========================================
|
print("Decoding video...")
|
||||||
print("\n" + "="*60)
|
video = vae_decoder(latents)
|
||||||
print("Decoding with VAE")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
# Decode latents to video
|
|
||||||
video = vae_decoder(latents, debug=True)
|
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f" Video shape: {video.shape}")
|
mx.clear_cache()
|
||||||
|
|
||||||
# Convert to frames
|
# Convert to uint8 frames
|
||||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||||
|
|
||||||
# Debug: check raw RGB values before conversion
|
|
||||||
video_raw = np.array(video)
|
|
||||||
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
|
|
||||||
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
|
|
||||||
|
|
||||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||||
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||||
video = mx.clip(video, 0.0, 1.0)
|
|
||||||
video = (video * 255).astype(mx.uint8)
|
video = (video * 255).astype(mx.uint8)
|
||||||
video_np = np.array(video)
|
video_np = np.array(video)
|
||||||
|
|
||||||
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
|
# Save outputs
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Save first frame
|
|
||||||
output_path = Path("mlx_output_frame0_2.png")
|
|
||||||
Image.fromarray(video_np[0]).save(output_path)
|
|
||||||
print(f"\nSaved first frame to {output_path}")
|
|
||||||
|
|
||||||
# Save video
|
|
||||||
try:
|
try:
|
||||||
import imageio
|
import imageio
|
||||||
video_path = "mlx_output_video_2.mp4"
|
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
|
||||||
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
|
print(f"Saved video to {output_path}")
|
||||||
print(f"Saved video to {video_path}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not save video: {e}")
|
print(f"Could not save video: {e}")
|
||||||
|
|
||||||
print("\nDone!")
|
if save_frames:
|
||||||
|
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||||
|
frames_dir.mkdir(exist_ok=True)
|
||||||
|
for i, frame in enumerate(video_np):
|
||||||
|
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||||
|
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||||
|
|
||||||
|
return video_np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate videos with MLX LTX-2",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
python main.py --prompt "A cat walking on grass"
|
||||||
|
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||||
|
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", "-p",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Text description of the video to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--height", "-H",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Output video height (default: 512, must be divisible by 32)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--width", "-W",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Output video width (default: 512, must be divisible by 32)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-frames", "-n",
|
||||||
|
type=int,
|
||||||
|
default=33,
|
||||||
|
help="Number of frames (default: 33, must be 1 + 8*k)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", "-s",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="Random seed for reproducibility (default: 42)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fps",
|
||||||
|
type=int,
|
||||||
|
default=24,
|
||||||
|
help="Frames per second for output video (default: 24)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o",
|
||||||
|
type=str,
|
||||||
|
default="output.mp4",
|
||||||
|
help="Output video path (default: output.mp4)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-frames",
|
||||||
|
action="store_true",
|
||||||
|
help="Save individual frames as images"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-repo",
|
||||||
|
type=str,
|
||||||
|
default="Lightricks/LTX-2",
|
||||||
|
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
generate_video(
|
||||||
|
model_repo=args.model_repo,
|
||||||
|
prompt=args.prompt,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
seed=args.seed,
|
||||||
|
fps=args.fps,
|
||||||
|
output_path=args.output,
|
||||||
|
save_frames=args.save_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
main()
|
main()
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
weights: Dictionary of weights with PyTorch naming
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with MLX-compatible naming for VAE
|
Dictionary with MLX-compatible naming for VAE decoder
|
||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
if "position_ids" in key:
|
if "position_ids" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Only process VAE decoder weights (skip audio_vae, etc.)
|
||||||
|
if not key.startswith("vae."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle per-channel statistics key mapping
|
||||||
|
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
|
||||||
|
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
|
||||||
|
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
|
||||||
|
if "vae.per_channel_statistics" in key:
|
||||||
|
if key == "vae.per_channel_statistics.mean-of-means":
|
||||||
|
new_key = "per_channel_statistics.mean"
|
||||||
|
elif key == "vae.per_channel_statistics.std-of-means":
|
||||||
|
new_key = "per_channel_statistics.std"
|
||||||
|
else:
|
||||||
|
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
|
||||||
|
continue
|
||||||
|
elif key.startswith("vae.decoder."):
|
||||||
|
# Strip the vae.decoder. prefix for decoder weights
|
||||||
|
new_key = key.replace("vae.decoder.", "")
|
||||||
|
else:
|
||||||
|
# Skip other vae.* keys that are not decoder weights
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle Conv3d weight shape conversion
|
# Handle Conv3d weight shape conversion
|
||||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||||
# MLX: (out_channels, D, H, W, in_channels)
|
# MLX: (out_channels, D, H, W, in_channels)
|
||||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
# Handle Conv2d weight shape conversion
|
# Handle Conv2d weight shape conversion
|
||||||
# PyTorch: (out_channels, in_channels, H, W)
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
# MLX: (out_channels, H, W, in_channels)
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
value = mx.transpose(value, (0, 2, 3, 1))
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
|||||||
@@ -381,8 +381,7 @@ def precompute_freqs_cis(
|
|||||||
if max_pos is None:
|
if max_pos is None:
|
||||||
max_pos = [20, 2048, 2048]
|
max_pos = [20, 2048, 2048]
|
||||||
|
|
||||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
|
||||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
|
||||||
if double_precision:
|
if double_precision:
|
||||||
return _precompute_freqs_cis_double_precision(
|
return _precompute_freqs_cis_double_precision(
|
||||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||||
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies in double precision using numpy.
|
|
||||||
|
|
||||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
|
||||||
"""
|
|
||||||
# Convert to numpy float64
|
# Convert to numpy float64
|
||||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm
|
||||||
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3Config:
|
class Gemma3Config:
|
||||||
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
pe: Optional[mx.array] = None,
|
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
batch_size, seq_len, _ = x.shape
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
# pe: (1, seq_len, num_heads, head_dim, 2)
|
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
||||||
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
||||||
q, k = apply_rotary_emb_1d(q, k, pe)
|
|
||||||
# Reshape back for attention computation
|
|
||||||
q = mx.reshape(q, (batch_size, seq_len, -1))
|
|
||||||
k = mx.reshape(k, (batch_size, seq_len, -1))
|
|
||||||
|
|
||||||
|
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
|
||||||
if attention_mask is not None:
|
|
||||||
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
|
||||||
|
|
||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
# No mask needed for connector - after register replacement, all positions are valid
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||||
|
|
||||||
return self.to_out[0](out)
|
return self.to_out[0](out)
|
||||||
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
if num_learnable_registers > 0:
|
if num_learnable_registers > 0:
|
||||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||||
|
|
||||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||||
|
|
||||||
|
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||||
|
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
dim = self.num_heads * self.head_dim
|
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||||
theta = self.positional_embedding_theta
|
theta = self.positional_embedding_theta
|
||||||
n_elem = 2
|
max_pos = [1] # Default for connector
|
||||||
|
n_elem = 2 * len(max_pos) # = 2
|
||||||
|
|
||||||
|
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
||||||
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
start = 1.0
|
||||||
indices = (theta ** linspace_vals) * (math.pi / 2)
|
end = theta
|
||||||
|
num_indices = dim // n_elem # 1920
|
||||||
|
|
||||||
|
log_start = math.log(start) / math.log(theta) # = 0
|
||||||
|
log_end = math.log(end) / math.log(theta) # = 1
|
||||||
|
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||||
|
indices = (theta ** lin_space) * (math.pi / 2)
|
||||||
|
|
||||||
|
# Generate positions and compute freqs (matches generate_freqs)
|
||||||
positions = mx.arange(seq_len).astype(mx.float32)
|
positions = mx.arange(seq_len).astype(mx.float32)
|
||||||
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||||
|
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||||
|
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||||
|
|
||||||
cos = mx.cos(freqs) # (seq_len, dim//2)
|
# freqs = indices * scaled_positions (outer product)
|
||||||
sin = mx.sin(freqs)
|
# Shape: (seq_len, num_indices)
|
||||||
|
freqs = scaled_positions[:, None] * indices[None, :]
|
||||||
|
|
||||||
|
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||||
|
cos_freq = mx.cos(freqs)
|
||||||
|
sin_freq = mx.sin(freqs)
|
||||||
|
|
||||||
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||||
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||||
|
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
||||||
|
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
# Add batch dimension: (1, seq_len, dim)
|
||||||
return freqs_cis.astype(dtype)
|
cos_full = cos_full[None, :, :]
|
||||||
|
sin_full = sin_full[None, :, :]
|
||||||
|
|
||||||
|
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||||
|
|
||||||
def _replace_padded_with_registers(
|
def _replace_padded_with_registers(
|
||||||
self,
|
self,
|
||||||
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
|
|||||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||||
|
|
||||||
# Compute masked min/max per layer
|
# Compute masked min/max per layer
|
||||||
large_val = 1e9
|
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
||||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
||||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
|
||||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||||
range_val = x_max - x_min
|
range_val = x_max - x_min
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
|
|
||||||
out_features = out_features or hidden_size
|
out_features = out_features or hidden_size
|
||||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||||
self.act = nn.GELU(approx="precise")
|
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
|
||||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
|||||||
@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
out_features: int | None = None,
|
out_features: int | None = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
|
act_fn: str = "gelu_tanh",
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
out_features = out_features or hidden_size
|
out_features = out_features or hidden_size
|
||||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||||
self.act = nn.GELU(approx="precise")
|
if act_fn == "gelu_tanh":
|
||||||
|
self.act = nn.GELU(approx="tanh")
|
||||||
|
elif act_fn == "silu":
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
|||||||
Reference in New Issue
Block a user