Update .gitignore to exclude additional configuration and model files. Modify generate.py to enhance console output with rescale parameter and adjust default values for inference steps and CFG scale. Refactor text encoder to align positional embedding max position with PyTorch defaults, improving compatibility and performance.

This commit is contained in:
Prince Canuma
2026-03-12 17:13:43 +01:00
parent d1fa47722b
commit b07b1e3213
3 changed files with 43 additions and 33 deletions

View File

@@ -328,7 +328,7 @@ class ConnectorFeedForward(nn.Module):
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
x = nn.gelu(self.proj_in(x))
x = nn.gelu_approx(self.proj_in(x))
x = self.dropout(x)
x = self.proj_out(x)
return x
@@ -385,7 +385,7 @@ class Embeddings1DConnector(nn.Module):
self.head_dim = head_dim
self.num_learnable_registers = num_learnable_registers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
@@ -403,50 +403,54 @@ class Embeddings1DConnector(nn.Module):
import numpy as np
dim = self.num_heads * self.head_dim # inner_dim = 3840
dim = self.num_heads * self.head_dim # inner_dim
theta = self.positional_embedding_theta
max_pos = self.positional_embedding_max_pos # [4096] from PyTorch
max_pos = self.positional_embedding_max_pos # [1] = PyTorch default
n_elem = 2 * len(max_pos) # = 2
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
num_indices = dim // n_elem
# Use numpy float64 for precision (double_precision_rope=True in PyTorch)
# generate_freq_grid_np: compute indices in float64 then cast to float32
# (matches PyTorch: double_precision_rope generates in np.float64,
# but returns torch.float32)
log_start = np.log(start) / np.log(theta) # = 0
log_end = np.log(end) / np.log(theta) # = 1
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
# Generate positions and compute freqs (matches generate_freqs)
positions = np.arange(seq_len, dtype=np.float64)
# Scale positions by max_pos (PyTorch uses max_pos=[4096])
# generate_freqs: positions and freqs in float32 (matching PyTorch)
positions = np.arange(seq_len, dtype=np.float32)
fractional_positions = positions / max_pos[0]
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
# freqs = indices * scaled_positions (outer product)
# Shape: (seq_len, num_indices)
# freqs = scaled_positions * indices (outer product) in float32
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin
cos_freq = np.cos(freqs) # (seq_len, 1920)
# split_freqs_cis: cos/sin in float32 (matching PyTorch)
expected_freqs = dim // 2
pad_size = expected_freqs - freqs.shape[-1]
cos_freq = np.cos(freqs) # (seq_len, num_indices)
sin_freq = np.sin(freqs)
# For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2)
# Current: (T, 1920) -> need (1, 30, T, 64)
# 30 heads * 64 = 1920, so no padding needed
if pad_size > 0:
cos_padding = np.ones((seq_len, pad_size), dtype=np.float32)
sin_padding = np.zeros((seq_len, pad_size), dtype=np.float32)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64)
# Reshape: (T, dim//2) -> (T, H, D//2) -> (1, H, T, D//2)
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
# Transpose to (1, H, T, D//2)
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
# Convert to MLX
cos_full = mx.array(cos_freq.astype(np.float32))
sin_full = mx.array(sin_freq.astype(np.float32))
cos_full = mx.array(cos_freq)
sin_full = mx.array(sin_freq)
return cos_full.astype(dtype), sin_full.astype(dtype)
@@ -721,15 +725,17 @@ class LTX2TextEncoder(nn.Module):
)
# Deeper connectors with matching dims and gate_logits
# NOTE: positional_embedding_max_pos=[1] matches PyTorch default
# (connector_positional_embedding_max_pos not in LTX-2.3 config)
self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
positional_embedding_max_pos=[1], has_gate_logits=True,
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
positional_embedding_max_pos=[1], has_gate_logits=True,
)
else:
# LTX-2: shared feature extractor, 3840-dim connectors
@@ -738,12 +744,12 @@ class LTX2TextEncoder(nn.Module):
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
positional_embedding_max_pos=[4096],
positional_embedding_max_pos=[1],
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
positional_embedding_max_pos=[4096],
positional_embedding_max_pos=[1],
)
self.processor = None