diff --git a/.gitignore b/.gitignore index 04c7330..3c2f021 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ .env -claude.md +.claude/* +CLAUDE.md +config.json +*.safetensors +*.safetensors.index.json .DS_Store **.pyc __pycache__/* diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 21790a7..37f0824 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -938,7 +938,7 @@ def generate_video( console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") if pipeline == PipelineType.DEV: - console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}[/]") + console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") if is_i2v: console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") @@ -1188,7 +1188,7 @@ def generate_video( mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") - console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})") + console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})") mx.random.seed(seed) video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) @@ -1432,8 +1432,8 @@ Examples: python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled # Dev pipeline (single-stage, CFG, higher quality) - python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 4.0 - python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 50 + python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0 + python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40 # Image-to-Video (works with both pipelines) python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg @@ -1453,9 +1453,9 @@ Examples: parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") parser.add_argument("--width", "-W", type=int, default=512, help="Output video width") parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") - parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)") - parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)") - parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)") + parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)") + parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale (dev pipeline only, default 3.0)") + parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)") parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") parser.add_argument("--fps", type=int, default=24, help="Frames per second") parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path") diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 1c16524..90c061b 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -328,7 +328,7 @@ class ConnectorFeedForward(nn.Module): self.proj_out = nn.Linear(inner_dim, dim, bias=True) def __call__(self, x: mx.array) -> mx.array: - x = nn.gelu(self.proj_in(x)) + x = nn.gelu_approx(self.proj_in(x)) x = self.dropout(x) x = self.proj_out(x) return x @@ -385,7 +385,7 @@ class Embeddings1DConnector(nn.Module): self.head_dim = head_dim self.num_learnable_registers = num_learnable_registers self.positional_embedding_theta = positional_embedding_theta - self.positional_embedding_max_pos = positional_embedding_max_pos or [4096] + self.positional_embedding_max_pos = positional_embedding_max_pos or [1] self.transformer_1d_blocks = { i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) @@ -403,50 +403,54 @@ class Embeddings1DConnector(nn.Module): import numpy as np - dim = self.num_heads * self.head_dim # inner_dim = 3840 + dim = self.num_heads * self.head_dim # inner_dim theta = self.positional_embedding_theta - max_pos = self.positional_embedding_max_pos # [4096] from PyTorch + max_pos = self.positional_embedding_max_pos # [1] = PyTorch default n_elem = 2 * len(max_pos) # = 2 start = 1.0 end = theta - num_indices = dim // n_elem # 1920 + num_indices = dim // n_elem - # Use numpy float64 for precision (double_precision_rope=True in PyTorch) + # generate_freq_grid_np: compute indices in float64 then cast to float32 + # (matches PyTorch: double_precision_rope generates in np.float64, + # but returns torch.float32) log_start = np.log(start) / np.log(theta) # = 0 log_end = np.log(end) / np.log(theta) # = 1 lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64) - indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64) + indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32) - # Generate positions and compute freqs (matches generate_freqs) - positions = np.arange(seq_len, dtype=np.float64) - # Scale positions by max_pos (PyTorch uses max_pos=[4096]) + # generate_freqs: positions and freqs in float32 (matching PyTorch) + positions = np.arange(seq_len, dtype=np.float32) fractional_positions = positions / max_pos[0] scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,) - # freqs = indices * scaled_positions (outer product) - # Shape: (seq_len, num_indices) + # freqs = scaled_positions * indices (outer product) in float32 freqs = scaled_positions[:, None] * indices[None, :] - # Compute cos/sin - cos_freq = np.cos(freqs) # (seq_len, 1920) + # split_freqs_cis: cos/sin in float32 (matching PyTorch) + expected_freqs = dim // 2 + pad_size = expected_freqs - freqs.shape[-1] + + cos_freq = np.cos(freqs) # (seq_len, num_indices) sin_freq = np.sin(freqs) - # For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2) - # Current: (T, 1920) -> need (1, 30, T, 64) - # 30 heads * 64 = 1920, so no padding needed + if pad_size > 0: + cos_padding = np.ones((seq_len, pad_size), dtype=np.float32) + sin_padding = np.zeros((seq_len, pad_size), dtype=np.float32) + cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) - # Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64) + # Reshape: (T, dim//2) -> (T, H, D//2) -> (1, H, T, D//2) cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) - # Transpose to (1, H, T, D//2) cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...] sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...] # Convert to MLX - cos_full = mx.array(cos_freq.astype(np.float32)) - sin_full = mx.array(sin_freq.astype(np.float32)) + cos_full = mx.array(cos_freq) + sin_full = mx.array(sin_freq) return cos_full.astype(dtype), sin_full.astype(dtype) @@ -721,15 +725,17 @@ class LTX2TextEncoder(nn.Module): ) # Deeper connectors with matching dims and gate_logits + # NOTE: positional_embedding_max_pos=[1] matches PyTorch default + # (connector_positional_embedding_max_pos not in LTX-2.3 config) self.video_embeddings_connector = Embeddings1DConnector( dim=video_output_dim, num_heads=32, head_dim=128, num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[4096], has_gate_logits=True, + positional_embedding_max_pos=[1], has_gate_logits=True, ) self.audio_embeddings_connector = Embeddings1DConnector( dim=audio_output_dim, num_heads=32, head_dim=64, num_layers=8, num_learnable_registers=128, - positional_embedding_max_pos=[4096], has_gate_logits=True, + positional_embedding_max_pos=[1], has_gate_logits=True, ) else: # LTX-2: shared feature extractor, 3840-dim connectors @@ -738,12 +744,12 @@ class LTX2TextEncoder(nn.Module): self.video_embeddings_connector = Embeddings1DConnector( dim=hidden_dim, num_heads=30, head_dim=128, num_layers=2, num_learnable_registers=128, - positional_embedding_max_pos=[4096], + positional_embedding_max_pos=[1], ) self.audio_embeddings_connector = Embeddings1DConnector( dim=hidden_dim, num_heads=30, head_dim=128, num_layers=2, num_learnable_registers=128, - positional_embedding_max_pos=[4096], + positional_embedding_max_pos=[1], ) self.processor = None