Files
mlx-video/mlx_video/models/ltx/adaln.py
2026-01-11 23:48:33 +01:00

162 lines
4.9 KiB
Python

from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
embedding_coefficient: int = 6,
use_additional_conditions: bool = False,
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim=embedding_dim,
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
def __call__(
self,
timestep: mx.array,
added_cond_kwargs: dict | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> Tuple[mx.array, mx.array]:
added_cond_kwargs = added_cond_kwargs or {}
embedded_timestep = self.emb(
timestep,
batch_size=batch_size,
hidden_dtype=hidden_dtype,
**added_cond_kwargs,
)
scale_shift_params = self.linear(self.silu(embedded_timestep))
return scale_shift_params, embedded_timestep
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(
self,
embedding_dim: int,
size_emb_dim: int = 0,
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
timestep: mx.array,
resolution: mx.array | None = None,
aspect_ratio: mx.array | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
# Project timestep
timesteps_proj = self.time_proj(timestep)
if hidden_dtype is not None:
timesteps_proj = timesteps_proj.astype(hidden_dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj)
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb
class Timesteps(nn.Module):
def __init__(
self,
num_channels: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def __call__(self, timesteps: mx.array) -> mx.array:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int | None = None,
):
super().__init__()
out_dim = out_dim or time_embed_dim
self.linear1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
self.linear2 = nn.Linear(time_embed_dim, out_dim)
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear1(sample)
sample = self.act(sample)
sample = self.linear2(sample)
return sample
class ConditionEmbedding(nn.Module):
def __init__(self, size_emb_dim: int, embedding_dim: int):
super().__init__()
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
resolution: mx.array,
aspect_ratio: mx.array,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
resolution_emb = self.resolution_embedder(resolution)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
if hidden_dtype is not None:
resolution_emb = resolution_emb.astype(hidden_dtype)
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
return resolution_emb + aspect_ratio_emb