initial commit (LTX-2)
This commit is contained in:
7
mlx_video/models/ltx/__init__.py
Normal file
7
mlx_video/models/ltx/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||
161
mlx_video/models/ltx/adaln.py
Normal file
161
mlx_video/models/ltx/adaln.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import get_timestep_embedding
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
embedding_coefficient: int = 6,
|
||||
use_additional_conditions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=embedding_dim,
|
||||
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
added_cond_kwargs: dict | None = None,
|
||||
batch_size: int | None = None,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
added_cond_kwargs = added_cond_kwargs or {}
|
||||
|
||||
embedded_timestep = self.emb(
|
||||
timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
**added_cond_kwargs,
|
||||
)
|
||||
|
||||
scale_shift_params = self.linear(self.silu(embedded_timestep))
|
||||
return scale_shift_params, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
size_emb_dim: int = 0,
|
||||
use_additional_conditions: bool = False,
|
||||
timestep_proj_dim: int = 256,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.size_emb_dim = size_emb_dim
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||
|
||||
if use_additional_conditions and size_emb_dim > 0:
|
||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
resolution: mx.array | None = None,
|
||||
aspect_ratio: mx.array | None = None,
|
||||
batch_size: int | None = None,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> mx.array:
|
||||
# Project timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
if hidden_dtype is not None:
|
||||
timesteps_proj = timesteps_proj.astype(hidden_dtype)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj)
|
||||
|
||||
# Add additional conditions if enabled
|
||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||
if resolution is not None and aspect_ratio is not None:
|
||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||
timesteps_emb = timesteps_emb + additional_embeds
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def __call__(self, timesteps: mx.array) -> mx.array:
|
||||
return get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_dim = out_dim or time_embed_dim
|
||||
self.linear1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||
self.linear2 = nn.Linear(time_embed_dim, out_dim)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class ConditionEmbedding(nn.Module):
|
||||
def __init__(self, size_emb_dim: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
resolution: mx.array,
|
||||
aspect_ratio: mx.array,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> mx.array:
|
||||
resolution_emb = self.resolution_embedder(resolution)
|
||||
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
|
||||
|
||||
if hidden_dtype is not None:
|
||||
resolution_emb = resolution_emb.astype(hidden_dtype)
|
||||
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
|
||||
|
||||
return resolution_emb + aspect_ratio_emb
|
||||
142
mlx_video/models/ltx/attention.py
Normal file
142
mlx_video/models/ltx/attention.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Attention module for LTX-2."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
v: mx.array,
|
||||
heads: int,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
|
||||
b, q_seq_len, dim = q.shape
|
||||
_, kv_seq_len, _ = k.shape
|
||||
dim_head = dim // heads
|
||||
|
||||
# Reshape to (B, seq_len, heads, dim_head)
|
||||
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
|
||||
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
|
||||
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
|
||||
|
||||
# Transpose to (B, heads, seq_len, dim_head)
|
||||
q = mx.swapaxes(q, 1, 2)
|
||||
k = mx.swapaxes(k, 1, 2)
|
||||
v = mx.swapaxes(v, 1, 2)
|
||||
|
||||
# Handle mask dimensions
|
||||
if mask is not None:
|
||||
# Add batch dimension if needed
|
||||
if mask.ndim == 2:
|
||||
mask = mx.expand_dims(mask, axis=0)
|
||||
# Add heads dimension if needed
|
||||
if mask.ndim == 3:
|
||||
mask = mx.expand_dims(mask, axis=1)
|
||||
|
||||
# Compute scaled dot-product attention
|
||||
scale = 1.0 / math.sqrt(dim_head)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
# Reshape back to (B, q_seq_len, heads * dim_head)
|
||||
out = mx.swapaxes(out, 1, 2)
|
||||
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head attention with rotary position embeddings.
|
||||
|
||||
Supports both self-attention and cross-attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
norm_eps: float = 1e-6,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
):
|
||||
"""Initialize attention module.
|
||||
|
||||
Args:
|
||||
query_dim: Dimension of query input
|
||||
context_dim: Dimension of context (key/value) input. If None, same as query_dim
|
||||
heads: Number of attention heads
|
||||
dim_head: Dimension per head
|
||||
norm_eps: Epsilon for RMS normalization
|
||||
rope_type: Type of rotary position embedding
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.rope_type = rope_type
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
|
||||
# Q, K, V projections
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
|
||||
|
||||
# Q and K normalization
|
||||
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||
|
||||
# Output projection
|
||||
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: Optional[mx.array] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Query input of shape (B, seq_len, query_dim)
|
||||
context: Context for cross-attention. If None, uses x (self-attention)
|
||||
mask: Attention mask
|
||||
pe: Position embeddings for query (and key if k_pe is None)
|
||||
k_pe: Position embeddings for key (optional, uses pe if None)
|
||||
|
||||
Returns:
|
||||
Attention output of shape (B, seq_len, query_dim)
|
||||
"""
|
||||
# Compute Q, K, V
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
# Apply normalization
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
|
||||
# Compute attention
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
|
||||
# Project output
|
||||
return self.to_out(out)
|
||||
181
mlx_video/models/ltx/config.py
Normal file
181
mlx_video/models/ltx/config.py
Normal file
@@ -0,0 +1,181 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
AudioVideo = "ltx av model"
|
||||
VideoOnly = "ltx video only model"
|
||||
AudioOnly = "ltx audio only model"
|
||||
|
||||
def is_video_enabled(self) -> bool:
|
||||
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
||||
|
||||
def is_audio_enabled(self) -> bool:
|
||||
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
||||
|
||||
|
||||
class LTXRopeType(Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
TWO_D = "2d"
|
||||
|
||||
class AttentionType(Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
|
||||
"""Create config from dictionary, filtering only valid parameters."""
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Export config to dictionary."""
|
||||
result = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if v is not None:
|
||||
if isinstance(v, Enum):
|
||||
result[k] = v.value
|
||||
elif hasattr(v, 'to_dict'):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig(BaseModelConfig):
|
||||
dim: int
|
||||
heads: int
|
||||
d_head: int
|
||||
context_dim: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoVAEConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
in_channels: int = 3
|
||||
out_channels: int = 128
|
||||
latent_channels: int = 128
|
||||
patch_size: int = 4
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
])
|
||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTXModelConfig(BaseModelConfig):
|
||||
|
||||
# Model type
|
||||
model_type: LTXModelType = LTXModelType.AudioVideo
|
||||
|
||||
# Video transformer config
|
||||
num_attention_heads: int = 32
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 128
|
||||
out_channels: int = 128
|
||||
num_layers: int = 48
|
||||
cross_attention_dim: int = 4096
|
||||
caption_channels: int = 3840
|
||||
|
||||
# Audio transformer config
|
||||
audio_num_attention_heads: int = 32
|
||||
audio_attention_head_dim: int = 64
|
||||
audio_in_channels: int = 128
|
||||
audio_out_channels: int = 128
|
||||
audio_cross_attention_dim: int = 2048
|
||||
|
||||
# Positional embedding config
|
||||
positional_embedding_theta: float = 10000.0
|
||||
positional_embedding_max_pos: Optional[List[int]] = None
|
||||
audio_positional_embedding_max_pos: Optional[List[int]] = None
|
||||
use_middle_indices_grid: bool = True
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
|
||||
double_precision_rope: bool = False
|
||||
|
||||
# Timestep config
|
||||
timestep_scale_multiplier: int = 1000
|
||||
av_ca_timestep_scale_multiplier: int = 1
|
||||
|
||||
# Normalization
|
||||
norm_eps: float = 1e-6
|
||||
|
||||
# Attention type
|
||||
attention_type: AttentionType = AttentionType.DEFAULT
|
||||
|
||||
# VAE config
|
||||
vae_config: Optional[VideoVAEConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default values after initialization."""
|
||||
if self.positional_embedding_max_pos is None:
|
||||
self.positional_embedding_max_pos = [20, 2048, 2048]
|
||||
if self.audio_positional_embedding_max_pos is None:
|
||||
self.audio_positional_embedding_max_pos = [20]
|
||||
|
||||
# Convert string enum values if loading from dict
|
||||
if isinstance(self.model_type, str):
|
||||
self.model_type = LTXModelType(self.model_type)
|
||||
if isinstance(self.rope_type, str):
|
||||
self.rope_type = LTXRopeType(self.rope_type)
|
||||
if isinstance(self.attention_type, str):
|
||||
self.attention_type = AttentionType(self.attention_type)
|
||||
|
||||
@property
|
||||
def inner_dim(self) -> int:
|
||||
"""Video inner dimension."""
|
||||
return self.num_attention_heads * self.attention_head_dim
|
||||
|
||||
@property
|
||||
def audio_inner_dim(self) -> int:
|
||||
"""Audio inner dimension."""
|
||||
return self.audio_num_attention_heads * self.audio_attention_head_dim
|
||||
|
||||
def get_video_config(self) -> Optional[TransformerConfig]:
|
||||
"""Get video transformer configuration."""
|
||||
if not self.model_type.is_video_enabled():
|
||||
return None
|
||||
return TransformerConfig(
|
||||
dim=self.inner_dim,
|
||||
heads=self.num_attention_heads,
|
||||
d_head=self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
)
|
||||
|
||||
def get_audio_config(self) -> Optional[TransformerConfig]:
|
||||
"""Get audio transformer configuration."""
|
||||
if not self.model_type.is_audio_enabled():
|
||||
return None
|
||||
return TransformerConfig(
|
||||
dim=self.audio_inner_dim,
|
||||
heads=self.audio_num_attention_heads,
|
||||
d_head=self.audio_attention_head_dim,
|
||||
context_dim=self.audio_cross_attention_dim,
|
||||
)
|
||||
40
mlx_video/models/ltx/feed_forward.py
Normal file
40
mlx_video/models/ltx/feed_forward.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, approximate: str = "tanh"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.approximate == "tanh":
|
||||
return nn.gelu_approx(x)
|
||||
else:
|
||||
return nn.gelu(x)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int | None = None,
|
||||
mult: int = 4,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_out = dim_out or dim
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
|
||||
self.act = GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
x = self.proj_in(x)
|
||||
x = self.act(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
518
mlx_video/models/ltx/ltx.py
Normal file
518
mlx_video/models/ltx/ltx.py
Normal file
@@ -0,0 +1,518 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx.transformer import (
|
||||
BasicAVTransformerBlock,
|
||||
Modality,
|
||||
TransformerArgs,
|
||||
)
|
||||
from mlx_video.utils import to_denoised
|
||||
|
||||
|
||||
class TransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
use_middle_indices_grid: bool,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
double_precision_rope: bool = False,
|
||||
):
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
|
||||
def _prepare_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
context: mx.array,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||
return context, attention_mask
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
attention_mask: Optional[mx.array],
|
||||
x_dtype: mx.Dtype,
|
||||
) -> Optional[mx.array]:
|
||||
if attention_mask is None:
|
||||
return None
|
||||
|
||||
# Check if already float
|
||||
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
|
||||
return attention_mask
|
||||
|
||||
# Convert boolean/int mask to float mask
|
||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
return mask
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
self,
|
||||
positions: mx.array,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
pe = precompute_freqs_cis(
|
||||
positions,
|
||||
dim=inner_dim,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=max_pos,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
double_precision=self.double_precision_rope,
|
||||
)
|
||||
return pe
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
context_mask=attention_mask,
|
||||
timesteps=timestep,
|
||||
embedded_timestep=embedded_timestep,
|
||||
positional_embeddings=pe,
|
||||
cross_positional_embeddings=None,
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
cross_pe_max_pos: int,
|
||||
use_middle_indices_grid: bool,
|
||||
audio_cross_attention_dim: int,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
double_precision_rope: bool = False,
|
||||
):
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
adaln=adaln,
|
||||
caption_projection=caption_projection,
|
||||
inner_dim=inner_dim,
|
||||
max_pos=max_pos,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
double_precision_rope=double_precision_rope,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
self.cross_pe_max_pos = cross_pe_max_pos
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
from dataclasses import replace
|
||||
|
||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||
|
||||
# Prepare cross-modal positional embeddings
|
||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||
positions=modality.positions[:, 0:1, :],
|
||||
inner_dim=self.audio_cross_attention_dim,
|
||||
max_pos=[self.cross_pe_max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prepare cross-attention timestep embeddings
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
)
|
||||
|
||||
return replace(
|
||||
transformer_args,
|
||||
cross_positional_embeddings=cross_pe,
|
||||
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
||||
cross_gate_timestep=cross_gate_timestep,
|
||||
)
|
||||
|
||||
def _prepare_cross_attention_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
|
||||
class LTXModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LTXModelConfig):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.use_middle_indices_grid = config.use_middle_indices_grid
|
||||
self.rope_type = config.rope_type
|
||||
self.timestep_scale_multiplier = config.timestep_scale_multiplier
|
||||
self.positional_embedding_theta = config.positional_embedding_theta
|
||||
|
||||
cross_pe_max_pos = None
|
||||
|
||||
if config.model_type.is_video_enabled():
|
||||
self.positional_embedding_max_pos = config.positional_embedding_max_pos
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.inner_dim = config.inner_dim
|
||||
self._init_video(config)
|
||||
|
||||
if config.model_type.is_audio_enabled():
|
||||
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||
self.audio_inner_dim = config.audio_inner_dim
|
||||
self._init_audio(config)
|
||||
|
||||
# Initialize cross-modal components
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(
|
||||
config.positional_embedding_max_pos[0],
|
||||
config.audio_positional_embedding_max_pos[0],
|
||||
)
|
||||
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||
self._init_audio_video(config)
|
||||
|
||||
self._init_preprocessors(config, cross_pe_max_pos)
|
||||
|
||||
self._init_transformer_blocks(config)
|
||||
|
||||
def _init_video(self, config: LTXModelConfig) -> None:
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
self.scale_shift_table = mx.zeros((2, self.inner_dim))
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||
|
||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||
num_scale_shift_values = 4
|
||||
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
|
||||
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
# Multi-modal preprocessors
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
)
|
||||
elif config.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
)
|
||||
elif config.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
self.transformer_blocks = [
|
||||
BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
video=video_config,
|
||||
audio=audio_config,
|
||||
rope_type=config.rope_type,
|
||||
norm_eps=config.norm_eps,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self,
|
||||
video: Optional[TransformerArgs],
|
||||
audio: Optional[TransformerArgs],
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks."""
|
||||
for block in self.transformer_blocks:
|
||||
video, audio = block(video=video, audio=audio)
|
||||
return video, audio
|
||||
|
||||
def _process_output(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
norm_out: nn.LayerNorm,
|
||||
proj_out: nn.Linear,
|
||||
x: mx.array,
|
||||
embedded_timestep: mx.array,
|
||||
) -> mx.array:
|
||||
|
||||
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
|
||||
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
|
||||
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
|
||||
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
|
||||
|
||||
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
|
||||
scale_shift_values = table_expanded + timestep_expanded
|
||||
|
||||
# Extract shift and scale (first index is shift, second is scale)
|
||||
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
|
||||
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
|
||||
|
||||
x = norm_out(x)
|
||||
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
|
||||
x = proj_out(x)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
# Validate inputs
|
||||
if not self.model_type.is_video_enabled() and video is not None:
|
||||
raise ValueError("Video is not enabled for this model")
|
||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
# Preprocess arguments
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
)
|
||||
|
||||
# Process outputs
|
||||
vx = (
|
||||
self._process_output(
|
||||
self.scale_shift_table,
|
||||
self.norm_out,
|
||||
self.proj_out,
|
||||
video_out.x,
|
||||
video_out.embedded_timestep,
|
||||
)
|
||||
if video_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
ax = (
|
||||
self._process_output(
|
||||
self.audio_scale_shift_table,
|
||||
self.audio_norm_out,
|
||||
self.audio_proj_out,
|
||||
audio_out.x,
|
||||
audio_out.embedded_timestep,
|
||||
)
|
||||
if audio_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return vx, ax
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle common remappings
|
||||
# transformer_blocks.X -> transformer_blocks[X]
|
||||
if "transformer_blocks." in new_key:
|
||||
# Keep as-is for now, MLX handles this
|
||||
pass
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
class X0Model(nn.Module):
|
||||
|
||||
def __init__(self, velocity_model: LTXModel):
|
||||
|
||||
super().__init__()
|
||||
self.velocity_model = velocity_model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(video, audio)
|
||||
|
||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
|
||||
return denoised_video, denoised_audio
|
||||
508
mlx_video/models/ltx/rope.py
Normal file
508
mlx_video/models/ltx/rope.py
Normal file
@@ -0,0 +1,508 @@
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
freqs_cis: Tuple[mx.array, mx.array],
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
) -> mx.array:
|
||||
"""Apply rotary position embeddings to input tensor.
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor to apply RoPE to
|
||||
freqs_cis: Tuple of (cos_freqs, sin_freqs)
|
||||
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
|
||||
|
||||
Returns:
|
||||
Tensor with rotary embeddings applied
|
||||
"""
|
||||
if rope_type == LTXRopeType.INTERLEAVED:
|
||||
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
elif rope_type == LTXRopeType.SPLIT:
|
||||
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
else:
|
||||
raise ValueError(f"Invalid rope type: {rope_type}")
|
||||
|
||||
|
||||
def apply_interleaved_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply interleaved rotary embeddings.
|
||||
|
||||
Pairs adjacent dimensions and applies rotation.
|
||||
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor of shape (..., dim)
|
||||
cos_freqs: Cosine frequencies
|
||||
sin_freqs: Sine frequencies
|
||||
|
||||
Returns:
|
||||
Tensor with interleaved rotary embeddings applied
|
||||
"""
|
||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||
shape = input_tensor.shape
|
||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||
|
||||
# Extract pairs
|
||||
t1 = input_tensor[..., 0] # Even indices
|
||||
t2 = input_tensor[..., 1] # Odd indices
|
||||
|
||||
# Apply rotation: (-t2, t1) pattern
|
||||
t_rot = mx.stack([-t2, t1], axis=-1)
|
||||
|
||||
# Flatten back: (..., dim/2, 2) -> (..., dim)
|
||||
input_tensor = mx.reshape(input_tensor, shape)
|
||||
t_rot = mx.reshape(t_rot, shape)
|
||||
|
||||
# Apply rotary embeddings
|
||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
|
||||
|
||||
PyTorch equivalent:
|
||||
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
return rearrange(t_dup, "... d r -> ... (d r)")
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
freqs_cis: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
|
||||
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
|
||||
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
|
||||
sin = freqs_cis[..., 1]
|
||||
|
||||
# q, k: (batch, seq_len, num_heads, head_dim)
|
||||
# Interleaved RoPE: pairs of adjacent dims rotate together
|
||||
q_r = q * cos + rotate_half_interleaved(q) * sin
|
||||
k_r = k * cos + rotate_half_interleaved(k) * sin
|
||||
|
||||
return q_r, k_r
|
||||
|
||||
|
||||
def apply_split_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply split rotary embeddings.
|
||||
|
||||
Splits dimensions into two halves and applies rotation.
|
||||
Pattern: split into first half and second half
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor
|
||||
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
|
||||
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
|
||||
|
||||
Returns:
|
||||
Tensor with split rotary embeddings applied
|
||||
"""
|
||||
needs_reshape = False
|
||||
original_shape = input_tensor.shape
|
||||
|
||||
# Handle dimension mismatch
|
||||
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
||||
b, h, t, _ = cos_freqs.shape
|
||||
# Reshape from (B, T, H*D) to (B, H, T, D)
|
||||
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
|
||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||
dim = input_tensor.shape[-1]
|
||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||
|
||||
# Get first and second halves
|
||||
first_half = split_input[..., 0, :] # (..., dim//2)
|
||||
second_half = split_input[..., 1, :] # (..., dim//2)
|
||||
|
||||
# Apply cosine to both halves
|
||||
output_first = first_half * cos_freqs
|
||||
output_second = second_half * cos_freqs
|
||||
|
||||
# Apply sine cross-terms (addcmul pattern)
|
||||
output_first = output_first - sin_freqs * second_half
|
||||
output_second = output_second + sin_freqs * first_half
|
||||
|
||||
# Stack back together
|
||||
output = mx.stack([output_first, output_second], axis=-2)
|
||||
|
||||
# Flatten: (..., 2, dim//2) -> (..., dim)
|
||||
output = mx.reshape(output, input_tensor.shape)
|
||||
|
||||
if needs_reshape:
|
||||
# Reshape back: (B, H, T, D) -> (B, T, H*D)
|
||||
b, h, t, d = output.shape
|
||||
output = mx.swapaxes(output, 1, 2)
|
||||
output = mx.reshape(output, (b, t, h * d))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_freq_grid(
|
||||
positional_embedding_theta: float,
|
||||
positional_embedding_max_pos_count: int,
|
||||
inner_dim: int,
|
||||
) -> mx.array:
|
||||
"""Generate frequency grid for RoPE.
|
||||
|
||||
Args:
|
||||
positional_embedding_theta: Base theta value
|
||||
positional_embedding_max_pos_count: Number of position dimensions
|
||||
inner_dim: Inner dimension of the model
|
||||
|
||||
Returns:
|
||||
Frequency indices tensor
|
||||
"""
|
||||
theta = positional_embedding_theta
|
||||
start = 1.0
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
# Compute logarithmic spacing
|
||||
log_start = math.log(start) / math.log(theta)
|
||||
log_end = math.log(end) / math.log(theta)
|
||||
|
||||
num_indices = inner_dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
# Create linearly spaced values in log space
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
|
||||
# Compute power indices
|
||||
pow_indices = mx.power(theta, lin_space)
|
||||
|
||||
# Scale by pi/2
|
||||
return pow_indices * (math.pi / 2)
|
||||
|
||||
|
||||
def get_fractional_positions(
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
) -> mx.array:
|
||||
"""Convert indices to fractional positions.
|
||||
|
||||
Args:
|
||||
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
|
||||
max_pos: Maximum position for each dimension
|
||||
|
||||
Returns:
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
for i in range(n_pos_dims):
|
||||
frac = indices_grid[:, i] / max_pos[i]
|
||||
fractional_positions.append(frac)
|
||||
|
||||
return mx.stack(fractional_positions, axis=-1)
|
||||
|
||||
|
||||
def generate_freqs(
|
||||
indices: mx.array,
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
) -> mx.array:
|
||||
"""Generate frequencies from indices and position grid.
|
||||
|
||||
Args:
|
||||
indices: Frequency indices
|
||||
indices_grid: Position indices grid
|
||||
max_pos: Maximum positions per dimension
|
||||
use_middle_indices_grid: Whether to use middle of index ranges
|
||||
|
||||
Returns:
|
||||
Frequency tensor
|
||||
"""
|
||||
# Handle middle indices grid
|
||||
if use_middle_indices_grid:
|
||||
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
|
||||
assert len(indices_grid.shape) == 4
|
||||
assert indices_grid.shape[-1] == 2
|
||||
indices_grid_start = indices_grid[..., 0]
|
||||
indices_grid_end = indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
# Compute frequencies
|
||||
# fractional_positions: (B, T, n_dims)
|
||||
# indices: (inner_dim // n_elem,)
|
||||
# Result: (B, T, inner_dim // n_elem * n_dims)
|
||||
|
||||
# Scale fractional positions to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
|
||||
|
||||
# Outer product with indices
|
||||
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
|
||||
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
|
||||
)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
|
||||
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
|
||||
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def split_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for split RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor
|
||||
pad_size: Padding size for dimension alignment
|
||||
num_attention_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
|
||||
"""
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
|
||||
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||
|
||||
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
|
||||
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def interleaved_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for interleaved RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor of shape (B, T, dim//2)
|
||||
pad_size: Padding size for dimension alignment
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
|
||||
"""
|
||||
# Compute cos and sin
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Repeat interleave: each element repeated twice
|
||||
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
|
||||
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float = 10000.0,
|
||||
max_pos: Optional[List[int]] = None,
|
||||
use_middle_indices_grid: bool = False,
|
||||
num_attention_heads: int = 32,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
double_precision: bool = False,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Precompute RoPE frequencies.
|
||||
|
||||
Args:
|
||||
indices_grid: Position indices grid
|
||||
dim: Dimension for RoPE
|
||||
theta: Base theta value for frequency computation
|
||||
max_pos: Maximum position per dimension
|
||||
use_middle_indices_grid: Whether to use middle indices
|
||||
num_attention_heads: Number of attention heads
|
||||
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
|
||||
double_precision: If True, compute frequencies in float64 for higher precision
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) tensors
|
||||
"""
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
|
||||
# Generate frequencies
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
# Prepare cos/sin based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# Interleaved
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def _precompute_freqs_cis_double_precision(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies in double precision using numpy.
|
||||
|
||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||
"""
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
n_pos_dims = indices_grid_np.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies
|
||||
log_start = math.log(1.0) / math.log(theta)
|
||||
log_end = math.log(theta) / math.log(theta)
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
lin_space = np.linspace(log_start, log_end, num_indices)
|
||||
indices_np = np.power(theta, lin_space) * (math.pi / 2)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
if use_middle_indices_grid:
|
||||
assert len(indices_grid_np.shape) == 4
|
||||
assert indices_grid_np.shape[-1] == 2
|
||||
indices_grid_start = indices_grid_np[..., 0]
|
||||
indices_grid_end = indices_grid_np[..., 1]
|
||||
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid_np.shape) == 4:
|
||||
indices_grid_np = indices_grid_np[..., 0]
|
||||
# After handling: indices_grid_np shape is (B, n_dims, T)
|
||||
|
||||
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||
batch_size = indices_grid_np.shape[0]
|
||||
seq_len = indices_grid_np.shape[2]
|
||||
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
|
||||
for i in range(n_pos_dims):
|
||||
# indices_grid_np[:, i, :] has shape (B, T)
|
||||
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
|
||||
|
||||
# Scale to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1
|
||||
|
||||
# Compute frequencies: outer product
|
||||
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
|
||||
|
||||
# Compute cos/sin in float64
|
||||
cos_freq = np.cos(freqs)
|
||||
sin_freq = np.sin(freqs)
|
||||
|
||||
# Prepare based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = cos_freq.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
|
||||
# Add padding
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
||||
cos_freq = np.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = np.swapaxes(sin_freq, 1, 2)
|
||||
else:
|
||||
# Interleaved
|
||||
cos_freq = np.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = np.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
pad_size = dim % n_elem
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Convert back to MLX (float32 for GPU compatibility)
|
||||
cos_freq = mx.array(cos_freq.astype(np.float32))
|
||||
sin_freq = mx.array(sin_freq.astype(np.float32))
|
||||
|
||||
return cos_freq, sin_freq
|
||||
727
mlx_video/models/ltx/text_encoder.py
Normal file
727
mlx_video/models/ltx/text_encoder.py
Normal file
@@ -0,0 +1,727 @@
|
||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
||||
|
||||
@dataclass
|
||||
class Gemma3Config:
|
||||
"""Configuration for Gemma 3 text model."""
|
||||
hidden_size: int = 3840
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 256
|
||||
intermediate_size: int = 15360
|
||||
num_hidden_layers: int = 48
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
vocab_size: int = 262208
|
||||
max_position_embeddings: int = 131072
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""RMS Normalization (Gemma style with 1+weight scaling)."""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
# Gemma initializes to ones, but uses (1+weight) scaling
|
||||
# After loading weights, weight will have the actual learned values
|
||||
self.weight = mx.ones((dims,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
positions: mx.array,
|
||||
head_dim: int,
|
||||
rope_theta: float = 1000000.0,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply rotary position embeddings to Q and K."""
|
||||
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
||||
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
|
||||
cos = mx.cos(freqs)
|
||||
sin = mx.sin(freqs)
|
||||
cos = cos[:, :, None, :]
|
||||
sin = sin[:, :, None, :]
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return mx.concatenate([-x2, x1], axis=-1)
|
||||
|
||||
cos_full = mx.concatenate([cos, cos], axis=-1)
|
||||
sin_full = mx.concatenate([sin, sin], axis=-1)
|
||||
q_embed = q * cos_full + rotate_half(q) * sin_full
|
||||
k_embed = k * cos_full + rotate_half(k) * sin_full
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
"""Gemma 3 MLP with gated activation."""
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gate = nn.gelu_approx(self.gate_proj(x))
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class Gemma3Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.scale = 1.0 / math.sqrt(config.head_dim)
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
positions: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
|
||||
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# Create causal mask (lower triangular)
|
||||
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
|
||||
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
|
||||
out = mx.transpose(out, (0, 2, 1, 3))
|
||||
out = mx.reshape(out, (batch_size, seq_len, -1))
|
||||
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.self_attn = Gemma3Attention(config)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
positions: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Gemma3TextModel(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
# Gemma scales embeddings by sqrt(hidden_size)
|
||||
self.embed_scale = config.hidden_size ** 0.5
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
output_hidden_states: bool = True,
|
||||
) -> Tuple[mx.array, List[mx.array]]:
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Gemma scales embeddings by sqrt(hidden_size)
|
||||
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
all_hidden_states = [hidden_states] if output_hidden_states else []
|
||||
|
||||
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
|
||||
positions = mx.broadcast_to(positions, (batch_size, seq_len))
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, positions, attention_mask)
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states, all_hidden_states
|
||||
|
||||
|
||||
|
||||
class ConnectorAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
inner_dim = num_heads * head_dim
|
||||
self.scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
|
||||
|
||||
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
||||
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
pe: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.to_q(x) # (B, seq, inner_dim)
|
||||
k = self.to_k(x)
|
||||
v = self.to_v(x)
|
||||
|
||||
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
|
||||
if pe is not None:
|
||||
# pe: (1, seq_len, num_heads, head_dim, 2)
|
||||
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
q, k = apply_rotary_emb_1d(q, k, pe)
|
||||
# Reshape back for attention computation
|
||||
q = mx.reshape(q, (batch_size, seq_len, -1))
|
||||
k = mx.reshape(k, (batch_size, seq_len, -1))
|
||||
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
||||
if attention_mask is not None:
|
||||
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
return self.to_out[0](out)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
"""GELU-gated linear unit."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return nn.gelu_approx(self.proj(x))
|
||||
|
||||
|
||||
class ConnectorFeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim * mult
|
||||
self.net = [
|
||||
GEGLU(dim, inner_dim),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias=True),
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConnectorTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128):
|
||||
super().__init__()
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim)
|
||||
self.ff = ConnectorFeedForward(dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
pe: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
# Pre-norm + attention + residual
|
||||
norm_x = rms_norm(x)
|
||||
if norm_x.ndim == 4:
|
||||
norm_x = mx.squeeze(norm_x, axis=1)
|
||||
attn_out = self.attn1(norm_x, attention_mask, pe)
|
||||
x = x + attn_out
|
||||
if x.ndim == 4:
|
||||
x = mx.squeeze(x, axis=1)
|
||||
|
||||
# Pre-norm + FFN + residual
|
||||
norm_x = rms_norm(x)
|
||||
ff_out = self.ff(norm_x)
|
||||
x = x + ff_out
|
||||
if x.ndim == 4:
|
||||
x = mx.squeeze(x, axis=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int = 128,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
|
||||
self.transformer_1d_blocks = [
|
||||
ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||
import math
|
||||
|
||||
dim = self.num_heads * self.head_dim
|
||||
theta = self.positional_embedding_theta
|
||||
n_elem = 2
|
||||
|
||||
|
||||
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
||||
indices = (theta ** linspace_vals) * (math.pi / 2)
|
||||
|
||||
positions = mx.arange(seq_len).astype(mx.float32)
|
||||
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
||||
|
||||
cos = mx.cos(freqs) # (seq_len, dim//2)
|
||||
sin = mx.sin(freqs)
|
||||
|
||||
|
||||
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
||||
return freqs_cis.astype(dtype)
|
||||
|
||||
def _replace_padded_with_registers(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
batch_size, seq_len, dim = hidden_states.shape
|
||||
|
||||
# Binary mask: 1 for valid tokens, 0 for padded
|
||||
# attention_mask is additive: 0 for valid, large negative for padded
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||
|
||||
# Tile registers to match sequence length
|
||||
num_tiles = seq_len // self.num_learnable_registers
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
|
||||
|
||||
# Process each batch item (PyTorch uses advanced indexing)
|
||||
result_list = []
|
||||
for b in range(batch_size):
|
||||
mask_b = mask_binary[b] # (seq,)
|
||||
hs_b = hidden_states[b] # (seq, dim)
|
||||
|
||||
# Count valid tokens
|
||||
num_valid = int(mx.sum(mask_b))
|
||||
|
||||
# Extract valid tokens (where mask is 1)
|
||||
# Since we have left-padded input, valid tokens are at the end
|
||||
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
||||
|
||||
# Pad with zeros on the right to get back to seq_len
|
||||
pad_length = seq_len - num_valid
|
||||
if pad_length > 0:
|
||||
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
|
||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||
else:
|
||||
adjusted = valid_tokens
|
||||
|
||||
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||
flipped_mask = mx.concatenate([
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32)
|
||||
], axis=0) # (seq,)
|
||||
|
||||
# Combine: valid tokens at front, registers at back
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
|
||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||
|
||||
result_list.append(combined)
|
||||
|
||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||
|
||||
# Reset attention mask to all zeros (no masking after register replacement)
|
||||
attention_mask = mx.zeros_like(attention_mask)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
# Replace padded tokens with learnable registers
|
||||
if self.num_learnable_registers > 0 and attention_mask is not None:
|
||||
hidden_states, attention_mask = self._replace_padded_with_registers(
|
||||
hidden_states, attention_mask
|
||||
)
|
||||
|
||||
# Compute RoPE frequencies
|
||||
seq_len = hidden_states.shape[1]
|
||||
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
|
||||
|
||||
# Process through transformer blocks
|
||||
for block in self.transformer_1d_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask, freqs_cis)
|
||||
|
||||
# Final RMS norm
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
|
||||
def norm_and_concat_hidden_states(
|
||||
hidden_states: List[mx.array],
|
||||
attention_mask: mx.array,
|
||||
padding_side: str = "left",
|
||||
) -> mx.array:
|
||||
|
||||
# Stack hidden states: (batch, seq, dim, num_layers)
|
||||
stacked = mx.stack(hidden_states, axis=-1)
|
||||
b, t, d, num_layers = stacked.shape
|
||||
|
||||
# Compute sequence lengths from attention mask
|
||||
sequence_lengths = mx.sum(attention_mask, axis=-1) # (batch,)
|
||||
|
||||
# Build mask based on padding side
|
||||
token_indices = mx.arange(t)[None, :] # (1, T)
|
||||
|
||||
if padding_side == "right":
|
||||
mask = token_indices < sequence_lengths[:, None] # (B, T)
|
||||
else: # left padding
|
||||
start_indices = t - sequence_lengths[:, None] # (B, 1)
|
||||
mask = token_indices >= start_indices # (B, T)
|
||||
|
||||
mask = mask[:, :, None, None] # (B, T, 1, 1)
|
||||
eps = 1e-6
|
||||
|
||||
# Compute masked mean per layer
|
||||
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
|
||||
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
large_val = 1e9
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
|
||||
# Normalize: 8 * (x - mean) / range
|
||||
normed = 8 * (stacked - mean) / (range_val + eps)
|
||||
|
||||
# Flatten layers into feature dimension: (B, T, D*L)
|
||||
normed = mx.reshape(normed, (b, t, -1))
|
||||
|
||||
# Zero out padded positions
|
||||
mask_flat = mx.broadcast_to(mask[:, :, :, 0], (b, t, d * num_layers))
|
||||
normed = mx.where(mask_flat, normed, mx.zeros_like(normed))
|
||||
|
||||
return normed
|
||||
|
||||
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840):
|
||||
super().__init__()
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.aggregate_embed(x)
|
||||
|
||||
|
||||
|
||||
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if key.startswith("base_text_encoder.language_model."):
|
||||
new_key = key.replace("base_text_encoder.language_model.", "")
|
||||
elif key.startswith("language_model.model."):
|
||||
new_key = key.replace("language_model.model.", "")
|
||||
elif key.startswith("language_model."):
|
||||
new_key = key.replace("language_model.", "")
|
||||
else:
|
||||
continue
|
||||
|
||||
if new_key is None:
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
class LTX2TextEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "Lightricks/LTX-2",
|
||||
hidden_dim: int = 3840,
|
||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||
):
|
||||
super().__init__()
|
||||
self._model_path = model_path
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Gemma 3 model
|
||||
self.config = Gemma3Config()
|
||||
self.model = Gemma3TextModel(self.config)
|
||||
|
||||
# Feature extractor: 3840*49 -> 3840
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
input_dim=hidden_dim * num_layers,
|
||||
output_dim=hidden_dim,
|
||||
)
|
||||
|
||||
# Video embeddings connector: 2-layer transformer
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None):
|
||||
path = model_path or self._model_path
|
||||
|
||||
# Load Gemma weights from text_encoder subdirectory
|
||||
if Path(path).is_dir():
|
||||
text_encoder_path = Path(path) / "text_encoder"
|
||||
if text_encoder_path.exists():
|
||||
gemma_path = str(text_encoder_path)
|
||||
else:
|
||||
gemma_path = path
|
||||
else:
|
||||
gemma_path = path
|
||||
|
||||
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
|
||||
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
|
||||
all_weights = {}
|
||||
for i, wf in enumerate(weight_files):
|
||||
print(f" Loading weight file {i+1}/{len(weight_files)}...")
|
||||
weights = mx.load(str(wf))
|
||||
all_weights.update(weights)
|
||||
|
||||
# Sanitize and load Gemma weights
|
||||
sanitized = sanitize_gemma3_weights(all_weights)
|
||||
print(f" Sanitized Gemma weights: {len(sanitized)}")
|
||||
self.model.load_weights(list(sanitized.items()), strict=False)
|
||||
|
||||
# Load transformer weights for feature extractor and connector
|
||||
transformer_path = Path(model_path or self._model_path)
|
||||
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
|
||||
if transformer_files:
|
||||
print(f"Loading transformer weights for text pipeline...")
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
|
||||
# Load feature extractor (aggregate_embed)
|
||||
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
|
||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||
"text_embedding_projection.aggregate_embed.weight"
|
||||
]
|
||||
print(" Loaded aggregate_embed weights")
|
||||
|
||||
# Load video_embeddings_connector weights
|
||||
connector_weights = {}
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
|
||||
if connector_weights:
|
||||
# Map weight names to our structure
|
||||
mapped_weights = {}
|
||||
for key, value in connector_weights.items():
|
||||
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
|
||||
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
|
||||
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
|
||||
mapped_weights[key] = value
|
||||
|
||||
self.video_embeddings_connector.load_weights(
|
||||
list(mapped_weights.items()), strict=False
|
||||
)
|
||||
print(f" Loaded {len(connector_weights)} connector weights")
|
||||
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in connector_weights:
|
||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
|
||||
if tokenizer_path.exists():
|
||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||
else:
|
||||
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
|
||||
# Set left padding to match official LTX-2 text encoder
|
||||
self.processor.padding_side = "left"
|
||||
|
||||
print("Text encoder loaded successfully")
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int = 1024,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
if self.processor is None:
|
||||
raise RuntimeError("Model not loaded. Call load() first.")
|
||||
|
||||
# Tokenize with left padding (as in PyTorch version)
|
||||
inputs = self.processor(
|
||||
prompt,
|
||||
return_tensors="np",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
# Get all hidden states from Gemma
|
||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
||||
|
||||
# Normalize and concatenate all hidden states
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
)
|
||||
|
||||
# Project through feature extractor
|
||||
features = self.feature_extractor(concat_hidden)
|
||||
|
||||
# Convert attention mask to additive format for connector
|
||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
# Process through connector
|
||||
# Note: connector replaces padding with learnable registers and resets mask to zeros
|
||||
# This means all positions now have valid embeddings (no need for final masking)
|
||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||
|
||||
# Return embeddings without zeroing - the connector's register replacement
|
||||
# means all positions have meaningful values now
|
||||
return embeddings, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int = 1024,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
return self.encode(prompt, max_length)
|
||||
|
||||
|
||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder(model_path=model_path)
|
||||
encoder.load()
|
||||
return encoder
|
||||
|
||||
26
mlx_video/models/ltx/text_projection.py
Normal file
26
mlx_video/models/ltx/text_projection.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
359
mlx_video/models/ltx/transformer.py
Normal file
359
mlx_video/models/ltx/transformer.py
Normal file
@@ -0,0 +1,359 @@
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx.attention import Attention
|
||||
from mlx_video.models.ltx.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
latent: mx.array
|
||||
timesteps: mx.array
|
||||
positions: mx.array
|
||||
context: mx.array
|
||||
enabled: bool = True
|
||||
context_mask: Optional[mx.array] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerArgs:
|
||||
x: mx.array
|
||||
context: mx.array
|
||||
context_mask: Optional[mx.array]
|
||||
timesteps: mx.array
|
||||
embedded_timestep: mx.array
|
||||
positional_embeddings: Tuple[mx.array, mx.array]
|
||||
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
||||
cross_scale_shift_timestep: Optional[mx.array]
|
||||
cross_gate_timestep: Optional[mx.array]
|
||||
enabled: bool
|
||||
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
"""Audio-Video transformer block with cross-modal attention.
|
||||
|
||||
Supports video-only, audio-only, or combined audio-video processing
|
||||
with bidirectional cross-attention between modalities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idx: int,
|
||||
video: Optional[TransformerConfig] = None,
|
||||
audio: Optional[TransformerConfig] = None,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
norm_eps: float = 1e-6,
|
||||
):
|
||||
"""Initialize transformer block.
|
||||
|
||||
Args:
|
||||
idx: Block index
|
||||
video: Video modality configuration
|
||||
audio: Audio modality configuration
|
||||
rope_type: Type of rotary position embedding
|
||||
norm_eps: Epsilon for normalization
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.idx = idx
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# Video components
|
||||
if video is not None:
|
||||
self.attn1 = Attention(
|
||||
query_dim=video.dim,
|
||||
heads=video.heads,
|
||||
dim_head=video.d_head,
|
||||
context_dim=None, # Self-attention
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=video.dim,
|
||||
context_dim=video.context_dim,
|
||||
heads=video.heads,
|
||||
dim_head=video.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||
# 6 scale-shift parameters: 3 for attention, 3 for MLP
|
||||
self.scale_shift_table = mx.zeros((6, video.dim))
|
||||
|
||||
# Audio components
|
||||
if audio is not None:
|
||||
self.audio_attn1 = Attention(
|
||||
query_dim=audio.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
self.audio_attn2 = Attention(
|
||||
query_dim=audio.dim,
|
||||
context_dim=audio.context_dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
|
||||
|
||||
# Cross-modal attention (when both video and audio are enabled)
|
||||
if audio is not None and video is not None:
|
||||
# Audio-to-Video: Q from video, K/V from audio
|
||||
self.audio_to_video_attn = Attention(
|
||||
query_dim=video.dim,
|
||||
context_dim=audio.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
# Video-to-Audio: Q from audio, K/V from video
|
||||
self.video_to_audio_attn = Attention(
|
||||
query_dim=audio.dim,
|
||||
context_dim=video.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
# Scale-shift tables for cross-attention
|
||||
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
|
||||
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
|
||||
|
||||
def get_ada_values(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
batch_size: int,
|
||||
timestep: mx.array,
|
||||
indices: slice,
|
||||
) -> Tuple[mx.array, ...]:
|
||||
"""Get adaptive normalization values from scale-shift table.
|
||||
|
||||
Args:
|
||||
scale_shift_table: Table of shape (num_params, dim)
|
||||
batch_size: Batch size
|
||||
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
|
||||
indices: Slice for which parameters to extract
|
||||
|
||||
Returns:
|
||||
Tuple of scale-shift values
|
||||
"""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# scale_shift_table[indices]: (num_selected, dim)
|
||||
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
|
||||
table_slice = scale_shift_table[indices]
|
||||
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
|
||||
|
||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||
timestep_reshaped = mx.reshape(
|
||||
timestep,
|
||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
)
|
||||
|
||||
# Extract the relevant indices
|
||||
timestep_slice = timestep_reshaped[:, :, indices, :]
|
||||
|
||||
# Add table values to timestep
|
||||
ada_values = table_expanded + timestep_slice
|
||||
|
||||
# Unbind along the parameter dimension
|
||||
# Result: tuple of tensors, each of shape (B, seq, dim)
|
||||
num_sliced = ada_values.shape[2]
|
||||
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
|
||||
|
||||
return result
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: mx.array,
|
||||
gate_timestep: mx.array,
|
||||
num_scale_shift_values: int = 4,
|
||||
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
|
||||
"""Get adaptive values for cross-modal attention.
|
||||
|
||||
Args:
|
||||
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
|
||||
batch_size: Batch size
|
||||
scale_shift_timestep: Timestep for scale-shift
|
||||
gate_timestep: Timestep for gating
|
||||
num_scale_shift_values: Number of scale-shift values (default 4)
|
||||
|
||||
Returns:
|
||||
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
|
||||
"""
|
||||
# Get scale-shift values
|
||||
scale_shift_ada = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
slice(None, None),
|
||||
)
|
||||
|
||||
# Get gate values
|
||||
gate_ada = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
slice(None, None),
|
||||
)
|
||||
|
||||
# Squeeze the sequence dimension if it's 1
|
||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||
|
||||
return (*scale_shift_squeezed, *gate_squeezed)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[TransformerArgs] = None,
|
||||
audio: Optional[TransformerArgs] = None,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Forward pass through transformer block.
|
||||
|
||||
Args:
|
||||
video: Video modality arguments
|
||||
audio: Audio modality arguments
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_video, updated_audio) TransformerArgs
|
||||
"""
|
||||
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
|
||||
|
||||
vx = video.x if video is not None else None
|
||||
ax = audio.x if audio is not None else None
|
||||
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
context=video.context,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
|
||||
# Process audio self-attention and cross-attention with text
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
context=audio.context,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
|
||||
# Audio-Video cross-modal attention
|
||||
if run_a2v or run_v2a:
|
||||
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||
|
||||
# Get adaptive values for audio cross-attention
|
||||
(
|
||||
scale_ca_audio_a2v,
|
||||
shift_ca_audio_a2v,
|
||||
scale_ca_audio_v2a,
|
||||
shift_ca_audio_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
)
|
||||
|
||||
# Get adaptive values for video cross-attention
|
||||
(
|
||||
scale_ca_video_a2v,
|
||||
shift_ca_video_a2v,
|
||||
scale_ca_video_v2a,
|
||||
shift_ca_video_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
)
|
||||
|
||||
# Audio-to-Video cross-attention
|
||||
if run_a2v:
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||
vx = vx + (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=video.cross_positional_embeddings,
|
||||
k_pe=audio.cross_positional_embeddings,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
# Video-to-Audio cross-attention
|
||||
if run_v2a:
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||
ax = ax + (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=audio.cross_positional_embeddings,
|
||||
k_pe=video.cross_positional_embeddings,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
# Process video feed-forward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
||||
)
|
||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||
|
||||
# Process audio feed-forward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
||||
)
|
||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
# Return updated TransformerArgs
|
||||
video_out = replace(video, x=vx) if video is not None else None
|
||||
audio_out = replace(audio, x=ax) if audio is not None else None
|
||||
|
||||
return video_out, audio_out
|
||||
364
mlx_video/models/ltx/upsampler.py
Normal file
364
mlx_video/models/ltx/upsampler.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Conv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
|
||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (N, D, H, W, C_in)
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (N, D', H', W', C_out)
|
||||
"""
|
||||
y = mx.conv3d(
|
||||
x,
|
||||
self.weight,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class GroupNorm3d(nn.Module):
|
||||
|
||||
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.num_channels = num_channels
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((num_channels,))
|
||||
self.bias = mx.zeros((num_channels,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, D, H, W, C)
|
||||
n, d, h, w, c = x.shape
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||
|
||||
# Compute mean and var over spatial and channel group dims
|
||||
mean = mx.mean(x, axis=(1, 3), keepdims=True)
|
||||
var = mx.var(x, axis=(1, 3), keepdims=True)
|
||||
|
||||
# Normalize
|
||||
x = (x - mean) / mx.sqrt(var + self.eps)
|
||||
|
||||
# Reshape back
|
||||
x = mx.reshape(x, (n, d, h, w, c))
|
||||
|
||||
# Apply weight and bias
|
||||
x = x * self.weight + self.bias
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
"""Pixel shuffle for 2D spatial upsampling."""
|
||||
|
||||
def __init__(self, upscale_factor: int = 2):
|
||||
super().__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
|
||||
n, h, w, c = x.shape
|
||||
r = self.upscale_factor
|
||||
out_c = c // (r * r)
|
||||
|
||||
# Reshape: (N, H, W, out_c, r, r)
|
||||
x = mx.reshape(x, (n, h, w, out_c, r, r))
|
||||
|
||||
# Permute: (N, H, r, W, r, out_c)
|
||||
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
|
||||
|
||||
# Reshape: (N, H*r, W*r, out_c)
|
||||
x = mx.reshape(x, (n, h * r, w * r, out_c))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
|
||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
|
||||
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
|
||||
|
||||
# Blur kernel for antialiasing
|
||||
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
|
||||
|
||||
self.pixel_shuffle = PixelShuffle2D(2)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, D, H, W, C) - channels last 3D format
|
||||
|
||||
n, d, h, w, c = x.shape
|
||||
|
||||
# Process frame by frame
|
||||
# Reshape to (N*D, H, W, C) for 2D operations
|
||||
x = mx.reshape(x, (n * d, h, w, c))
|
||||
|
||||
# Apply 2D conv
|
||||
x = self.conv(x)
|
||||
|
||||
# Pixel shuffle for 2x upscaling
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
# Reshape back to (N, D, H*2, W*2, C)
|
||||
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock3D(nn.Module):
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.norm1 = GroupNorm3d(32, channels)
|
||||
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = GroupNorm3d(32, channels)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
residual = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = nn.silu(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
|
||||
# Activation AFTER residual addition
|
||||
x = nn.silu(x + residual)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
# Initial projection
|
||||
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||
|
||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
|
||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
|
||||
|
||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
|
||||
# Final projection
|
||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
|
||||
"""Upsample latents by 2x spatially.
|
||||
|
||||
Args:
|
||||
latent: Input tensor of shape (B, C, F, H, W) - channels first
|
||||
debug: If True, print intermediate values for debugging
|
||||
|
||||
Returns:
|
||||
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
|
||||
"""
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
if debug:
|
||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||
debug_stats("Input (channels first)", latent)
|
||||
|
||||
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
|
||||
x = mx.transpose(latent, (0, 2, 3, 4, 1))
|
||||
if debug:
|
||||
debug_stats("After transpose to channels-last", x)
|
||||
|
||||
# Initial conv
|
||||
x = self.initial_conv(x)
|
||||
if debug:
|
||||
debug_stats("After initial_conv", x)
|
||||
x = self.initial_norm(x)
|
||||
if debug:
|
||||
debug_stats("After initial_norm", x)
|
||||
x = nn.silu(x)
|
||||
if debug:
|
||||
debug_stats("After silu", x)
|
||||
|
||||
# Pre-upsample blocks
|
||||
for i in sorted(self.res_blocks.keys()):
|
||||
x = self.res_blocks[i](x)
|
||||
if debug:
|
||||
debug_stats(f"After res_blocks[{i}]", x)
|
||||
|
||||
# Upsample (2D spatial, frame-by-frame)
|
||||
x = self.upsampler(x)
|
||||
if debug:
|
||||
debug_stats("After upsampler (spatial 2x)", x)
|
||||
|
||||
# Post-upsample blocks
|
||||
for i in sorted(self.post_upsample_res_blocks.keys()):
|
||||
x = self.post_upsample_res_blocks[i](x)
|
||||
if debug:
|
||||
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
|
||||
|
||||
# Final conv
|
||||
x = self.final_conv(x)
|
||||
if debug:
|
||||
debug_stats("After final_conv", x)
|
||||
|
||||
# Convert back to channels first (B, C, F, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
if debug:
|
||||
debug_stats("Output (channels first)", x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample_latents(
|
||||
latent: mx.array,
|
||||
upsampler: LatentUpsampler,
|
||||
latent_mean: mx.array,
|
||||
latent_std: mx.array,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
# Un-normalize: latent * std + mean
|
||||
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
|
||||
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
|
||||
latent = latent * latent_std + latent_mean
|
||||
|
||||
# Upsample
|
||||
latent = upsampler(latent, debug=debug)
|
||||
|
||||
# Re-normalize: (latent - mean) / std
|
||||
latent = (latent - latent_mean) / latent_std
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
def load_upsampler(weights_path: str) -> LatentUpsampler:
|
||||
"""Load upsampler from safetensors weights.
|
||||
|
||||
Args:
|
||||
weights_path: Path to upsampler weights file
|
||||
|
||||
Returns:
|
||||
Loaded LatentUpsampler model
|
||||
"""
|
||||
print(f"Loading spatial upsampler from {weights_path}...")
|
||||
raw_weights = mx.load(weights_path)
|
||||
|
||||
# Check weight shapes to determine mid_channels
|
||||
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
|
||||
sample_key = "res_blocks.0.conv1.weight"
|
||||
if sample_key in raw_weights:
|
||||
mid_channels = raw_weights[sample_key].shape[0]
|
||||
else:
|
||||
mid_channels = 1024 # default
|
||||
|
||||
print(f" Detected mid_channels: {mid_channels}")
|
||||
|
||||
# Create model
|
||||
upsampler = LatentUpsampler(
|
||||
in_channels=128,
|
||||
mid_channels=mid_channels,
|
||||
num_blocks_per_stage=4,
|
||||
)
|
||||
|
||||
# Sanitize weights - convert from PyTorch to MLX format
|
||||
sanitized = {}
|
||||
for key, value in raw_weights.items():
|
||||
new_key = key
|
||||
|
||||
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in key and "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in key and "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
|
||||
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
|
||||
if key.startswith("upsampler."):
|
||||
new_key = key # Keep as is for SpatialRationalResampler
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
# Load weights
|
||||
upsampler.load_weights(list(sanitized.items()), strict=False)
|
||||
|
||||
print(f" Loaded {len(sanitized)} weights")
|
||||
|
||||
return upsampler
|
||||
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PaddingModeType(Enum):
|
||||
ZEROS = "zeros"
|
||||
REFLECT = "reflect"
|
||||
|
||||
|
||||
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||
pad_h: Padding for height dimension
|
||||
pad_w: Padding for width dimension
|
||||
|
||||
Returns:
|
||||
Padded tensor
|
||||
"""
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride to tuples
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
self.spatial_padding = (height_pad, width_pad)
|
||||
|
||||
# Create the base convolution (without padding, we'll handle it manually)
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0, # We handle padding manually
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
pad_size = (self.time_kernel_size - 1) // 2
|
||||
if pad_size > 0:
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||
|
||||
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||
|
||||
# Apply spatial padding
|
||||
pad_h, pad_w = self.spatial_padding
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||
# Use reflect padding for spatial dimensions
|
||||
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||
else:
|
||||
# Use zero padding for spatial dimensions
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(0, 0), # D (temporal - already padded)
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
# Apply convolution with chunking for large tensors
|
||||
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||
x = self._chunked_conv3d(x)
|
||||
|
||||
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||
|
||||
Returns:
|
||||
Output tensor after conv3d
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
if total_elements <= max_safe_elements:
|
||||
return self.conv(x)
|
||||
|
||||
elements_per_frame = h * w * c
|
||||
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||
|
||||
kernel_t = self.time_kernel_size
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
while out_idx < expected_output_frames:
|
||||
remaining = expected_output_frames - out_idx
|
||||
out_frames_this_chunk = min(chunk_size, remaining)
|
||||
|
||||
in_frames_needed = out_frames_this_chunk + overlap
|
||||
in_end = min(in_start + in_frames_needed, d)
|
||||
|
||||
chunk = x[:, in_start:in_end, :, :, :]
|
||||
|
||||
chunk_out = self.conv(chunk)
|
||||
mx.eval(chunk_out)
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
out_idx += chunk_out.shape[1]
|
||||
in_start += chunk_out.shape[1]
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""2D convolution with optional causal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize CausalConv2d."""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
# Calculate padding
|
||||
if isinstance(padding, str) and padding == "same":
|
||||
self.padding = (
|
||||
(kernel_size[0] - 1) // 2,
|
||||
(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
"""Forward pass."""
|
||||
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 1))
|
||||
|
||||
# Apply padding
|
||||
pad_h, pad_w = self.padding
|
||||
if pad_h != 0 or pad_w != 0:
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||
x = mx.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""Video VAE Decoder for LTX-2 with timestep conditioning.
|
||||
|
||||
Architecture (from PyTorch weights):
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- pixel_norm + timestep modulation (last_scale_shift_table)
|
||||
- conv_out: 128 -> 48
|
||||
- unpatchify: 48 -> 3 with patch_size=4
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: mx.array,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = True,
|
||||
downscale_freq_shift: float = 0,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings."""
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = mx.exp(exponent)
|
||||
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
|
||||
emb = scale * emb
|
||||
|
||||
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||
|
||||
if flip_sin_to_cos:
|
||||
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
||||
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
"""MLP for timestep embedding."""
|
||||
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
"""Combined timestep embedding (sinusoidal + MLP)."""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class ResnetBlock3DSimple(nn.Module):
|
||||
"""ResNet block with optional timestep conditioning.
|
||||
|
||||
Weight keys: conv1.conv, conv2.conv, scale_shift_table
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Nested conv structure to match PyTorch naming: conv1.conv.weight
|
||||
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = mx.zeros((4, channels))
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep_embed: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Block 1 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
|
||||
scale1 = ada_values[:, 1]
|
||||
shift2 = ada_values[:, 2]
|
||||
scale2 = ada_values[:, 3]
|
||||
|
||||
x = x * (1 + scale1) + shift1
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Block 2 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
x = x * (1 + scale2) + shift2
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ResBlockGroup(nn.Module):
|
||||
"""Group of ResNet blocks with shared timestep embedding.
|
||||
|
||||
PyTorch naming: res_blocks.0, res_blocks.1, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_layers: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
|
||||
self.res_blocks = [
|
||||
ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
timestep_embed = None
|
||||
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
for res_block in self.res_blocks:
|
||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2VideoDecoder(nn.Module):
|
||||
"""LTX-2 Video VAE Decoder with timestep conditioning.
|
||||
|
||||
Architecture:
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Upsampler 1024 -> 512
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Upsampler 512 -> 256
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Upsampler 256 -> 128
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
patch_size: int = 4,
|
||||
num_layers_per_block: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Decode parameters (configurable via constructor)
|
||||
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.latents_mean = mx.zeros((in_channels,))
|
||||
self.latents_std = mx.ones((in_channels,))
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
class ConvInWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=1024,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||
|
||||
self.up_blocks = [
|
||||
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=1024,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=512,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
]
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=128,
|
||||
out_channels=final_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=128 * 2 # 256, matches (2, 128) table
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
||||
return x * std + mean
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
if debug:
|
||||
debug_stats("Input", sample)
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
if debug:
|
||||
debug_stats("After noise", sample)
|
||||
|
||||
if debug:
|
||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||
sample = self.denormalize(sample)
|
||||
if debug:
|
||||
debug_stats("After denormalize", sample)
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
block_type = type(block).__name__
|
||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
if debug:
|
||||
debug_stats("After pixel_norm", x)
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
if debug:
|
||||
debug_stats("After timestep modulation", x)
|
||||
|
||||
x = self.act(x)
|
||||
if debug:
|
||||
debug_stats("After activation", x)
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_out", x)
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
if debug:
|
||||
debug_stats("After unpatchify", x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.decoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
elif has_decoder_prefix:
|
||||
prefix = "decoder."
|
||||
stats_prefix = ""
|
||||
else:
|
||||
prefix = ""
|
||||
stats_prefix = ""
|
||||
|
||||
# Load per-channel statistics for denormalization
|
||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||
|
||||
if mean_key in weights:
|
||||
decoder.latents_mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||
if std_key in weights:
|
||||
decoder.latents_std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||
|
||||
# Build decoder weights dict with key remapping
|
||||
decoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
decoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(decoder_weights)} decoder weights")
|
||||
|
||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||
|
||||
# Load weights
|
||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE decoder loaded successfully")
|
||||
return decoder
|
||||
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert video to patches.
|
||||
|
||||
Moves spatial pixels from H, W dimensions to channel dimension.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
||||
"""
|
||||
b, c, f, h, w = x.shape
|
||||
|
||||
# Check dimensions are divisible
|
||||
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
||||
assert f % patch_size_t == 0
|
||||
|
||||
# New dimensions
|
||||
new_h = h // patch_size_hw
|
||||
new_w = w // patch_size_hw
|
||||
new_f = f // patch_size_t
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
|
||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert patches back to video.
|
||||
|
||||
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
||||
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
||||
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
||||
|
||||
Args:
|
||||
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
||||
"""
|
||||
b, c_packed, f, h, w = x.shape
|
||||
|
||||
# Calculate original channel count
|
||||
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
||||
|
||||
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
||||
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
||||
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
||||
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
||||
|
||||
# Permute to interleave patches with spatial dims:
|
||||
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
||||
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
||||
|
||||
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
||||
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
|
||||
def __init__(self, latent_channels: int = 128):
|
||||
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
|
||||
# Learnable per-channel mean and std
|
||||
self.mean = mx.zeros((latent_channels,))
|
||||
self.std = mx.ones((latent_channels,))
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x - mean) / std
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Normalized tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return x * std + mean
|
||||
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""ResNet blocks for Video VAE."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
PIXEL_NORM = "pixel_norm"
|
||||
|
||||
|
||||
def get_norm_layer(
|
||||
norm_type: NormLayerType,
|
||||
num_channels: int,
|
||||
num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
) -> nn.Module:
|
||||
|
||||
if norm_type == NormLayerType.GROUP_NORM:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
|
||||
elif norm_type == NormLayerType.PIXEL_NORM:
|
||||
return PixelNorm(eps=eps)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
eps: float = 1e-6,
|
||||
groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
# First normalization and convolution
|
||||
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
|
||||
self.conv1 = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Second normalization and convolution
|
||||
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
|
||||
self.conv2 = CausalConv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Shortcut connection if channels change
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
# Activation
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Inject noise if enabled
|
||||
if self.inject_noise and generator is not None:
|
||||
noise = mx.random.normal(x.shape)
|
||||
x = x + noise * 0.01
|
||||
|
||||
# Second block
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
# Shortcut
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Sampling operations for Video VAE (upsampling/downsampling)."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
|
||||
# Calculate the multiplier for channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
intermediate_channels = in_channels * multiplier
|
||||
|
||||
# 1x1x1 convolution to adjust channels
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=intermediate_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Pad if necessary to make dimensions divisible by stride
|
||||
pad_d = (st - d % st) % st
|
||||
pad_h = (sh - h % sh) % sh
|
||||
pad_w = (sw - w % sw) % sw
|
||||
|
||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||
# For causal, pad at the end of temporal dimension
|
||||
if causal:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
|
||||
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# Reshape to group spatial elements
|
||||
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
|
||||
# Permute to move stride elements to channel dim
|
||||
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
# Apply 1x1 conv to adjust channels
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
residual: bool = False,
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
# Calculate output channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
out_channels = in_channels // out_channels_reduction_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * multiplier,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _depth_to_space(self, x: mx.array) -> mx.array:
|
||||
b, c_packed, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
c = c_packed // (st * sh * sw)
|
||||
|
||||
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
|
||||
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
|
||||
|
||||
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
|
||||
|
||||
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
|
||||
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Compute residual path if enabled
|
||||
x_residual = None
|
||||
if self.residual:
|
||||
# Reshape input: treat channels as spatial factors
|
||||
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
|
||||
x_residual = self._depth_to_space(x)
|
||||
|
||||
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
|
||||
# num_repeat = prod(stride) / out_channels_reduction_factor
|
||||
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
|
||||
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
|
||||
|
||||
# Remove first temporal frame if temporal upsampling
|
||||
if st > 1:
|
||||
x_residual = x_residual[:, :, 1:, :, :]
|
||||
|
||||
# Apply conv
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
# Depth to space rearrangement
|
||||
x = self._depth_to_space(x)
|
||||
|
||||
# Remove first frame for causal temporal upsampling
|
||||
if st > 1:
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
# Add residual
|
||||
if self.residual and x_residual is not None:
|
||||
x = x + x_residual
|
||||
|
||||
return x
|
||||
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def _make_encoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create an encoder block.
|
||||
|
||||
Args:
|
||||
block_name: Type of block
|
||||
block_config: Block configuration
|
||||
in_channels: Input channels
|
||||
convolution_dimensions: Number of dimensions
|
||||
norm_layer: Normalization layer type
|
||||
norm_num_groups: Number of groups for group norm
|
||||
spatial_padding_mode: Padding mode
|
||||
|
||||
Returns:
|
||||
Tuple of (block, output_channels)
|
||||
"""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
timestep_conditioning: bool,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create a decoder block."""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels // block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_config.get("residual", False),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
class VideoEncoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
encoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize VideoEncoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions (3 for video)
|
||||
in_channels: Input channels (3 for RGB)
|
||||
out_channels: Output latent channels
|
||||
encoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
latent_log_var: Log variance mode
|
||||
encoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if encoder_blocks is None:
|
||||
encoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = in_channels * patch_size ** 2
|
||||
feature_channels = out_channels
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks
|
||||
self.down_blocks = []
|
||||
for block_name, block_params in encoder_blocks:
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
# Calculate output convolution channels
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""Encode video to latent representation.
|
||||
|
||||
Args:
|
||||
sample: Input video of shape (B, C, F, H, W).
|
||||
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
||||
|
||||
Returns:
|
||||
Normalized latent means of shape (B, 128, F', H', W')
|
||||
"""
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError(
|
||||
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
||||
)
|
||||
|
||||
# Initial patchify
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks:
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
sample = down_block(sample, causal=True)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=True)
|
||||
|
||||
# Handle log variance modes
|
||||
if self.latent_log_var == LogVarianceType.UNIFORM:
|
||||
means = sample[:, :-1, ...]
|
||||
logvar = sample[:, -1:, ...]
|
||||
num_channels = means.shape[1]
|
||||
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
||||
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
class VideoDecoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
decoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
):
|
||||
"""Initialize VideoDecoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions
|
||||
in_channels: Input latent channels
|
||||
out_channels: Output channels (3 for RGB)
|
||||
decoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
causal: Whether to use causal convolutions
|
||||
timestep_conditioning: Whether to use timestep conditioning
|
||||
decoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if decoder_blocks is None:
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Noise and timestep parameters
|
||||
self.decode_noise_scale = 0.025
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Compute initial feature channels
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
self.up_blocks = []
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks.append(block)
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
sample: Latent tensor of shape (B, 128, F', H', W')
|
||||
timestep: Optional timestep for conditioning
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F, H, W)
|
||||
"""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
# Denormalize latents
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
# Use default timestep if not provided
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
# Initial convolution
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for up_block in self.up_blocks:
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
else:
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
# Unpatchify to restore spatial resolution
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
Reference in New Issue
Block a user