Refactor Wan model structure by renaming and relocating model imports from model.py to wan2.py, enhancing code organization and clarity across the Wan2 module.
This commit is contained in:
724
mlx_video/models/ltx_2/ltx_2.py
Normal file
724
mlx_video/models/ltx_2/ltx_2.py
Normal file
@@ -0,0 +1,724 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXRopeType,
|
||||
)
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
BasicAVTransformerBlock,
|
||||
Modality,
|
||||
TransformerArgs,
|
||||
)
|
||||
from mlx_video.utils import to_denoised
|
||||
|
||||
|
||||
class TransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
use_middle_indices_grid: bool,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.prompt_adaln = prompt_adaln
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
|
||||
def _prepare_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_timestep_with_adaln(
|
||||
self,
|
||||
adaln: AdaLayerNormSingle,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
context: mx.array,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||
return context, attention_mask
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
attention_mask: Optional[mx.array],
|
||||
x_dtype: mx.Dtype,
|
||||
) -> Optional[mx.array]:
|
||||
if attention_mask is None:
|
||||
return None
|
||||
|
||||
# Check if already float
|
||||
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
|
||||
return attention_mask
|
||||
|
||||
# Convert boolean/int mask to float mask
|
||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||
mask = mx.reshape(
|
||||
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
)
|
||||
return mask
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
self,
|
||||
positions: mx.array,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
pe = precompute_freqs_cis(
|
||||
positions,
|
||||
dim=inner_dim,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=max_pos,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
double_precision=self.double_precision_rope,
|
||||
)
|
||||
return pe
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(
|
||||
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
|
||||
)
|
||||
context, attention_mask = self._prepare_context(
|
||||
modality.context, x, modality.context_mask
|
||||
)
|
||||
attention_mask = self._prepare_attention_mask(
|
||||
attention_mask, modality.latent.dtype
|
||||
)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
pe = modality.positional_embeddings
|
||||
else:
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
|
||||
prompt_timestep = None
|
||||
prompt_embedded_timestep = None
|
||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||
prompt_timestep, prompt_embedded_timestep = (
|
||||
self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln,
|
||||
modality.sigma,
|
||||
x.shape[0],
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
context_mask=attention_mask,
|
||||
timesteps=timestep,
|
||||
embedded_timestep=embedded_timestep,
|
||||
positional_embeddings=pe,
|
||||
cross_positional_embeddings=None,
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
prompt_timesteps=prompt_timestep,
|
||||
prompt_embedded_timestep=prompt_embedded_timestep,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
cross_pe_max_pos: int,
|
||||
use_middle_indices_grid: bool,
|
||||
audio_cross_attention_dim: int,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
adaln=adaln,
|
||||
caption_projection=caption_projection,
|
||||
inner_dim=inner_dim,
|
||||
max_pos=max_pos,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
double_precision_rope=double_precision_rope,
|
||||
prompt_adaln=prompt_adaln,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
self.cross_pe_max_pos = cross_pe_max_pos
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
from dataclasses import replace
|
||||
|
||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||
|
||||
# Prepare cross-modal positional embeddings
|
||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||
positions=modality.positions[:, 0:1, :],
|
||||
inner_dim=self.audio_cross_attention_dim,
|
||||
max_pos=[self.cross_pe_max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prepare cross-attention timestep embeddings
|
||||
cross_scale_shift_timestep, cross_gate_timestep = (
|
||||
self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
transformer_args,
|
||||
cross_positional_embeddings=cross_pe,
|
||||
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
||||
cross_gate_timestep=cross_gate_timestep,
|
||||
)
|
||||
|
||||
def _prepare_cross_attention_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
scale_shift_timestep = mx.reshape(
|
||||
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
|
||||
)
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(
|
||||
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
|
||||
)
|
||||
gate_timestep = mx.reshape(
|
||||
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
|
||||
class LTXModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LTXModelConfig):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.use_middle_indices_grid = config.use_middle_indices_grid
|
||||
self.rope_type = config.rope_type
|
||||
self.timestep_scale_multiplier = config.timestep_scale_multiplier
|
||||
self.positional_embedding_theta = config.positional_embedding_theta
|
||||
|
||||
cross_pe_max_pos = None
|
||||
|
||||
if config.model_type.is_video_enabled():
|
||||
self.positional_embedding_max_pos = config.positional_embedding_max_pos
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.inner_dim = config.inner_dim
|
||||
self._init_video(config)
|
||||
|
||||
if config.model_type.is_audio_enabled():
|
||||
self.audio_positional_embedding_max_pos = (
|
||||
config.audio_positional_embedding_max_pos
|
||||
)
|
||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||
self.audio_inner_dim = config.audio_inner_dim
|
||||
self._init_audio(config)
|
||||
|
||||
# Initialize cross-modal components
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
cross_pe_max_pos = max(
|
||||
config.positional_embedding_max_pos[0],
|
||||
config.audio_positional_embedding_max_pos[0],
|
||||
)
|
||||
self.av_ca_timestep_scale_multiplier = (
|
||||
config.av_ca_timestep_scale_multiplier
|
||||
)
|
||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||
self._init_audio_video(config)
|
||||
|
||||
self._init_preprocessors(config, cross_pe_max_pos)
|
||||
|
||||
self._init_transformer_blocks(config)
|
||||
|
||||
def _init_video(self, config: LTXModelConfig) -> None:
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
|
||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
self.scale_shift_table = mx.zeros((2, self.inner_dim))
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(
|
||||
config.audio_in_channels, self.audio_inner_dim, bias=True
|
||||
)
|
||||
|
||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
self.audio_norm_out = nn.LayerNorm(
|
||||
self.audio_inner_dim, eps=config.norm_eps, affine=False
|
||||
)
|
||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||
|
||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||
num_scale_shift_values = 4
|
||||
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
|
||||
def _init_preprocessors(
|
||||
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
|
||||
) -> None:
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
# Multi-modal preprocessors
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
self.transformer_blocks = {
|
||||
idx: BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
video=video_config,
|
||||
audio=audio_config,
|
||||
rope_type=config.rope_type,
|
||||
norm_eps=config.norm_eps,
|
||||
has_prompt_adaln=config.has_prompt_adaln,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
}
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self,
|
||||
video: Optional[TransformerArgs],
|
||||
audio: Optional[TransformerArgs],
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks.
|
||||
|
||||
Args:
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
|
||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||
for idx, block in self.transformer_blocks.items():
|
||||
video, audio = block(
|
||||
video=video,
|
||||
audio=audio,
|
||||
skip_video_self_attn=(idx in stg_v_set),
|
||||
skip_audio_self_attn=(idx in stg_a_set),
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
return video, audio
|
||||
|
||||
def _process_output(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
norm_out: nn.LayerNorm,
|
||||
proj_out: nn.Linear,
|
||||
x: mx.array,
|
||||
embedded_timestep: mx.array,
|
||||
) -> mx.array:
|
||||
|
||||
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
|
||||
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
|
||||
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
|
||||
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
|
||||
|
||||
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
|
||||
scale_shift_values = table_expanded + timestep_expanded
|
||||
|
||||
# Extract shift and scale (first index is shift, second is scale)
|
||||
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
|
||||
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
|
||||
|
||||
x = norm_out(x)
|
||||
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
|
||||
x = proj_out(x)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
video: Video modality input.
|
||||
audio: Audio modality input.
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
# Validate inputs
|
||||
if not self.model_type.is_video_enabled() and video is not None:
|
||||
raise ValueError("Video is not enabled for this model")
|
||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
# Preprocess arguments
|
||||
video_args = (
|
||||
self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
)
|
||||
audio_args = (
|
||||
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
# Process outputs
|
||||
vx = (
|
||||
self._process_output(
|
||||
self.scale_shift_table,
|
||||
self.norm_out,
|
||||
self.proj_out,
|
||||
video_out.x,
|
||||
video_out.embedded_timestep,
|
||||
)
|
||||
if video_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
ax = (
|
||||
self._process_output(
|
||||
self.audio_scale_shift_table,
|
||||
self.audio_norm_out,
|
||||
self.audio_proj_out,
|
||||
audio_out.x,
|
||||
audio_out.embedded_timestep,
|
||||
)
|
||||
if audio_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return vx, ax
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
|
||||
if not has_raw_prefix:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
if (
|
||||
"audio_embeddings_connector" in key
|
||||
or "video_embeddings_connector" in key
|
||||
):
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = new_key.replace("model.diffusion_model.", "")
|
||||
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel":
|
||||
import json
|
||||
|
||||
config_dict = {}
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config_dict = json.load(f)
|
||||
config = LTXModelConfig(**config_dict)
|
||||
model = cls(config)
|
||||
|
||||
weights = {}
|
||||
|
||||
for weight_file in model_path.glob("*.safetensors"):
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {
|
||||
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
|
||||
for k, v in sanitized.items()
|
||||
}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class X0Model(nn.Module):
|
||||
|
||||
def __init__(self, velocity_model: LTXModel):
|
||||
|
||||
super().__init__()
|
||||
self.velocity_model = velocity_model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(
|
||||
video,
|
||||
audio,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
denoised_video = (
|
||||
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
)
|
||||
denoised_audio = (
|
||||
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
)
|
||||
|
||||
return denoised_video, denoised_audio
|
||||
Reference in New Issue
Block a user