format
This commit is contained in:
@@ -4,8 +4,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||
timestep_reshaped = mx.reshape(
|
||||
timestep,
|
||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
)
|
||||
|
||||
# Extract the relevant indices
|
||||
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Squeeze the sequence dimension if it's 1
|
||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||
scale_shift_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
|
||||
)
|
||||
gate_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
|
||||
)
|
||||
|
||||
return (*scale_shift_squeezed, *gate_squeezed)
|
||||
|
||||
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
run_a2v = (
|
||||
run_vx
|
||||
and (audio is not None and audio.enabled and ax.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
run_v2a = (
|
||||
run_ax
|
||||
and (video is not None and video.enabled and vx.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn1(
|
||||
norm_vx,
|
||||
pe=video.positional_embeddings,
|
||||
skip_attention=skip_video_self_attn,
|
||||
)
|
||||
* vgate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||
)
|
||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
||||
self.prompt_scale_shift_table,
|
||||
vx.shape[0],
|
||||
video.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
||||
encoder_hidden_states = (
|
||||
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
)
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn2(
|
||||
attn_input,
|
||||
context=encoder_hidden_states,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
* vgate_q
|
||||
)
|
||||
else:
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn1(
|
||||
norm_ax,
|
||||
pe=audio.positional_embeddings,
|
||||
skip_attention=skip_audio_self_attn,
|
||||
)
|
||||
* agate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
||||
self.audio_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.timesteps,
|
||||
slice(6, 9),
|
||||
)
|
||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
||||
self.audio_prompt_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input_a = (
|
||||
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
)
|
||||
encoder_hidden_states_a = (
|
||||
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
)
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn2(
|
||||
attn_input_a,
|
||||
context=encoder_hidden_states_a,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
* agate_q
|
||||
)
|
||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
||||
else:
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
|
||||
Reference in New Issue
Block a user