Update .gitignore to exclude additional configuration and model files. Modify generate.py to enhance console output with rescale parameter and adjust default values for inference steps and CFG scale. Refactor text encoder to align positional embedding max position with PyTorch defaults, improving compatibility and performance.
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,5 +1,9 @@
|
|||||||
.env
|
.env
|
||||||
claude.md
|
.claude/*
|
||||||
|
CLAUDE.md
|
||||||
|
config.json
|
||||||
|
*.safetensors
|
||||||
|
*.safetensors.index.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
**.pyc
|
**.pyc
|
||||||
__pycache__/*
|
__pycache__/*
|
||||||
|
|||||||
@@ -938,7 +938,7 @@ def generate_video(
|
|||||||
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
||||||
|
|
||||||
if pipeline == PipelineType.DEV:
|
if pipeline == PipelineType.DEV:
|
||||||
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}[/]")
|
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]")
|
||||||
|
|
||||||
if is_i2v:
|
if is_i2v:
|
||||||
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
|
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
|
||||||
@@ -1188,7 +1188,7 @@ def generate_video(
|
|||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||||
|
|
||||||
console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})")
|
console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
||||||
@@ -1432,8 +1432,8 @@ Examples:
|
|||||||
python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled
|
python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled
|
||||||
|
|
||||||
# Dev pipeline (single-stage, CFG, higher quality)
|
# Dev pipeline (single-stage, CFG, higher quality)
|
||||||
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 4.0
|
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0
|
||||||
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 50
|
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40
|
||||||
|
|
||||||
# Image-to-Video (works with both pipelines)
|
# Image-to-Video (works with both pipelines)
|
||||||
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
||||||
@@ -1453,9 +1453,9 @@ Examples:
|
|||||||
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
|
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
|
||||||
parser.add_argument("--width", "-W", type=int, default=512, help="Output video width")
|
parser.add_argument("--width", "-W", type=int, default=512, help="Output video width")
|
||||||
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
|
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
|
||||||
parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)")
|
parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)")
|
||||||
parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)")
|
parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale (dev pipeline only, default 3.0)")
|
||||||
parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)")
|
parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)")
|
||||||
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
|
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
|
||||||
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
||||||
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")
|
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ class ConnectorFeedForward(nn.Module):
|
|||||||
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
|
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
x = nn.gelu(self.proj_in(x))
|
x = nn.gelu_approx(self.proj_in(x))
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x
|
return x
|
||||||
@@ -385,7 +385,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_learnable_registers = num_learnable_registers
|
self.num_learnable_registers = num_learnable_registers
|
||||||
self.positional_embedding_theta = positional_embedding_theta
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
|
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
||||||
|
|
||||||
self.transformer_1d_blocks = {
|
self.transformer_1d_blocks = {
|
||||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||||
@@ -403,50 +403,54 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
dim = self.num_heads * self.head_dim # inner_dim
|
||||||
theta = self.positional_embedding_theta
|
theta = self.positional_embedding_theta
|
||||||
max_pos = self.positional_embedding_max_pos # [4096] from PyTorch
|
max_pos = self.positional_embedding_max_pos # [1] = PyTorch default
|
||||||
n_elem = 2 * len(max_pos) # = 2
|
n_elem = 2 * len(max_pos) # = 2
|
||||||
|
|
||||||
start = 1.0
|
start = 1.0
|
||||||
end = theta
|
end = theta
|
||||||
num_indices = dim // n_elem # 1920
|
num_indices = dim // n_elem
|
||||||
|
|
||||||
# Use numpy float64 for precision (double_precision_rope=True in PyTorch)
|
# generate_freq_grid_np: compute indices in float64 then cast to float32
|
||||||
|
# (matches PyTorch: double_precision_rope generates in np.float64,
|
||||||
|
# but returns torch.float32)
|
||||||
log_start = np.log(start) / np.log(theta) # = 0
|
log_start = np.log(start) / np.log(theta) # = 0
|
||||||
log_end = np.log(end) / np.log(theta) # = 1
|
log_end = np.log(end) / np.log(theta) # = 1
|
||||||
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||||
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64)
|
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
||||||
|
|
||||||
# Generate positions and compute freqs (matches generate_freqs)
|
# generate_freqs: positions and freqs in float32 (matching PyTorch)
|
||||||
positions = np.arange(seq_len, dtype=np.float64)
|
positions = np.arange(seq_len, dtype=np.float32)
|
||||||
# Scale positions by max_pos (PyTorch uses max_pos=[4096])
|
|
||||||
fractional_positions = positions / max_pos[0]
|
fractional_positions = positions / max_pos[0]
|
||||||
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
|
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
|
||||||
|
|
||||||
# freqs = indices * scaled_positions (outer product)
|
# freqs = scaled_positions * indices (outer product) in float32
|
||||||
# Shape: (seq_len, num_indices)
|
|
||||||
freqs = scaled_positions[:, None] * indices[None, :]
|
freqs = scaled_positions[:, None] * indices[None, :]
|
||||||
|
|
||||||
# Compute cos/sin
|
# split_freqs_cis: cos/sin in float32 (matching PyTorch)
|
||||||
cos_freq = np.cos(freqs) # (seq_len, 1920)
|
expected_freqs = dim // 2
|
||||||
|
pad_size = expected_freqs - freqs.shape[-1]
|
||||||
|
|
||||||
|
cos_freq = np.cos(freqs) # (seq_len, num_indices)
|
||||||
sin_freq = np.sin(freqs)
|
sin_freq = np.sin(freqs)
|
||||||
|
|
||||||
# For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2)
|
if pad_size > 0:
|
||||||
# Current: (T, 1920) -> need (1, 30, T, 64)
|
cos_padding = np.ones((seq_len, pad_size), dtype=np.float32)
|
||||||
# 30 heads * 64 = 1920, so no padding needed
|
sin_padding = np.zeros((seq_len, pad_size), dtype=np.float32)
|
||||||
|
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
# Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64)
|
# Reshape: (T, dim//2) -> (T, H, D//2) -> (1, H, T, D//2)
|
||||||
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||||
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||||
|
|
||||||
# Transpose to (1, H, T, D//2)
|
|
||||||
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
|
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
|
||||||
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
|
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
|
||||||
|
|
||||||
# Convert to MLX
|
# Convert to MLX
|
||||||
cos_full = mx.array(cos_freq.astype(np.float32))
|
cos_full = mx.array(cos_freq)
|
||||||
sin_full = mx.array(sin_freq.astype(np.float32))
|
sin_full = mx.array(sin_freq)
|
||||||
|
|
||||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||||
|
|
||||||
@@ -721,15 +725,17 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deeper connectors with matching dims and gate_logits
|
# Deeper connectors with matching dims and gate_logits
|
||||||
|
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
|
||||||
|
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||||
num_layers=8, num_learnable_registers=128,
|
num_layers=8, num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||||
)
|
)
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||||
num_layers=8, num_learnable_registers=128,
|
num_layers=8, num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
positional_embedding_max_pos=[1], has_gate_logits=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||||
@@ -738,12 +744,12 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||||
num_layers=2, num_learnable_registers=128,
|
num_layers=2, num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[4096],
|
positional_embedding_max_pos=[1],
|
||||||
)
|
)
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||||
num_layers=2, num_learnable_registers=128,
|
num_layers=2, num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[4096],
|
positional_embedding_max_pos=[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|||||||
Reference in New Issue
Block a user