Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -208,6 +208,7 @@ class ConnectorAttention(nn.Module):
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
@@ -218,13 +219,14 @@ class ConnectorAttention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
||||
# Direct attribute for MLX parameter tracking (not a list)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=True)
|
||||
|
||||
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
||||
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||
|
||||
if has_gate_logits:
|
||||
self.to_gate_logits = nn.Linear(dim, num_heads, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
@@ -233,12 +235,17 @@ class ConnectorAttention(nn.Module):
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# Compute per-head gate early (from original input)
|
||||
gate = None
|
||||
if hasattr(self, "to_gate_logits"):
|
||||
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.to_q(x) # (B, seq, inner_dim)
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(x)
|
||||
v = self.to_v(x)
|
||||
|
||||
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch)
|
||||
# QK normalization on full inner_dim BEFORE reshape
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
@@ -248,15 +255,18 @@ class ConnectorAttention(nn.Module):
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
if pe is not None:
|
||||
# pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2)
|
||||
# Apply SPLIT RoPE: operates on first half of head dimensions
|
||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||
k = self._apply_split_rope(k, pe[0], pe[1])
|
||||
|
||||
# No mask needed for connector - after register replacement, all positions are valid
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
# Apply per-head gating
|
||||
if gate is not None:
|
||||
out = mx.reshape(out, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
out = out * gate[..., None]
|
||||
out = mx.reshape(out, (batch_size, seq_len, -1))
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _apply_split_rope(
|
||||
@@ -326,9 +336,9 @@ class ConnectorFeedForward(nn.Module):
|
||||
|
||||
class ConnectorTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128):
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
|
||||
super().__init__()
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim)
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
self.ff = ConnectorFeedForward(dim)
|
||||
|
||||
def __call__(
|
||||
@@ -367,6 +377,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_learnable_registers: int = 128,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = None,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -376,9 +387,8 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
|
||||
|
||||
# Use dict with int keys for MLX to track parameters (lists are not tracked)
|
||||
self.transformer_1d_blocks = {
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
@@ -572,16 +582,100 @@ def norm_and_concat_hidden_states(
|
||||
return normed
|
||||
|
||||
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
def norm_and_concat_per_token_rms(
|
||||
encoded_text: mx.array,
|
||||
attention_mask: mx.array,
|
||||
) -> mx.array:
|
||||
"""Per-token RMSNorm normalization for V2 feature extraction (LTX-2.3).
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840):
|
||||
Args:
|
||||
encoded_text: (B, T, D, L) stacked hidden states
|
||||
attention_mask: (B, T) binary mask (1=valid, 0=padding)
|
||||
|
||||
Returns:
|
||||
(B, T, D*L) normalized tensor with padding zeroed out.
|
||||
"""
|
||||
b, t, d, num_layers = encoded_text.shape
|
||||
dtype = encoded_text.dtype
|
||||
|
||||
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
||||
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
|
||||
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
||||
normed = normed.astype(dtype)
|
||||
|
||||
# Flatten layers: (B, T, D*L)
|
||||
normed = mx.reshape(normed, (b, t, d * num_layers))
|
||||
|
||||
# Zero out padded positions
|
||||
mask_3d = attention_mask[:, :, None].astype(mx.bool_) # (B, T, 1)
|
||||
normed = mx.where(mask_3d, normed, mx.zeros_like(normed))
|
||||
|
||||
return normed
|
||||
|
||||
|
||||
def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
|
||||
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
|
||||
return x * math.sqrt(target_dim / source_dim)
|
||||
|
||||
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
|
||||
super().__init__()
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.aggregate_embed(x)
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorV2(nn.Module):
|
||||
"""V2 feature extractor (LTX-2.3): per-token RMSNorm + rescale normalization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flat_dim: int,
|
||||
embedding_dim: int,
|
||||
video_output_dim: int,
|
||||
audio_output_dim: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim # Gemma hidden_dim (3840), used for rescale
|
||||
self.video_aggregate_embed = nn.Linear(flat_dim, video_output_dim, bias=bias)
|
||||
self.audio_aggregate_embed = nn.Linear(flat_dim, audio_output_dim, bias=bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: List[mx.array],
|
||||
attention_mask: mx.array,
|
||||
mode: str = "video",
|
||||
) -> mx.array:
|
||||
"""Extract features with per-token RMSNorm + rescale.
|
||||
|
||||
Args:
|
||||
hidden_states: List of hidden states from all Gemma layers
|
||||
attention_mask: Binary attention mask (B, T)
|
||||
mode: "video" or "audio" to select which aggregate embed to use
|
||||
|
||||
Returns:
|
||||
Projected features
|
||||
"""
|
||||
# Stack hidden states: (B, T, D, L)
|
||||
encoded = mx.stack(hidden_states, axis=-1)
|
||||
|
||||
# Per-token RMSNorm + flatten
|
||||
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
|
||||
normed = normed.astype(encoded.dtype)
|
||||
|
||||
if mode == "video":
|
||||
target_dim = self.video_aggregate_embed.weight.shape[0]
|
||||
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
else:
|
||||
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
||||
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -603,39 +697,54 @@ class LTX2TextEncoder(nn.Module):
|
||||
hidden_dim: int = 3840,
|
||||
audio_dim: int = 2048,
|
||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||
has_prompt_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.audio_dim = audio_dim
|
||||
self.num_layers = num_layers
|
||||
self.has_prompt_adaln = has_prompt_adaln
|
||||
self.language_model = None
|
||||
|
||||
# Feature extractor: 3840*49 -> 3840
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
input_dim=hidden_dim * num_layers,
|
||||
output_dim=hidden_dim,
|
||||
)
|
||||
feature_input_dim = hidden_dim * num_layers
|
||||
|
||||
# Video embeddings connector: 2-layer transformer
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||
)
|
||||
if has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extractor with per-token RMSNorm + rescale
|
||||
video_output_dim = 4096
|
||||
audio_output_dim = 2048
|
||||
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
video_output_dim=video_output_dim,
|
||||
audio_output_dim=audio_output_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Audio embeddings connector: separate 2-layer transformer (same architecture as video)
|
||||
# Both connectors process the feature extractor output independently
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||
)
|
||||
# Deeper connectors with matching dims and gate_logits
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
@@ -669,81 +778,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
|
||||
if transformer_weights:
|
||||
# Load feature extractor (aggregate_embed)
|
||||
# Reformatted key: "aggregate_embed.weight"
|
||||
# Monolithic key: "text_embedding_projection.aggregate_embed.weight"
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
if agg_key in transformer_weights:
|
||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key]
|
||||
|
||||
# Load video_embeddings_connector weights
|
||||
connector_weights = {}
|
||||
if is_reformatted:
|
||||
# Reformatted: keys are already sanitized with "video_embeddings_connector." prefix
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("video_embeddings_connector."):
|
||||
new_key = key.replace("video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
else:
|
||||
# Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
|
||||
if connector_weights:
|
||||
# Map weight names to our structure (only needed for monolithic/raw PyTorch keys)
|
||||
mapped_weights = {}
|
||||
for key, value in connector_weights.items():
|
||||
new_key = key
|
||||
if not is_reformatted:
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_weights[new_key] = value
|
||||
|
||||
self.video_embeddings_connector.load_weights(
|
||||
list(mapped_weights.items()), strict=False
|
||||
)
|
||||
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in connector_weights:
|
||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||
|
||||
# Load audio_embeddings_connector weights (same structure as video connector)
|
||||
audio_connector_weights = {}
|
||||
if is_reformatted:
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("audio_embeddings_connector."):
|
||||
new_key = key.replace("audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
else:
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
|
||||
if audio_connector_weights:
|
||||
# Map weight names to our structure (same as video connector)
|
||||
mapped_audio_weights = {}
|
||||
for key, value in audio_connector_weights.items():
|
||||
new_key = key
|
||||
if not is_reformatted:
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_audio_weights[new_key] = value
|
||||
|
||||
self.audio_embeddings_connector.load_weights(
|
||||
list(mapped_audio_weights.items()), strict=False
|
||||
)
|
||||
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in audio_connector_weights:
|
||||
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
||||
self._load_feature_extractors(transformer_weights, is_reformatted)
|
||||
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
|
||||
else:
|
||||
print("WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!")
|
||||
@@ -763,6 +800,63 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
print("Text encoder loaded successfully")
|
||||
|
||||
def _load_feature_extractors(self, weights: dict, is_reformatted: bool):
|
||||
"""Load feature extractor weights for both LTX-2 and LTX-2.3."""
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extractor with separate video/audio aggregate embeds
|
||||
for attr, prefix in [
|
||||
("video_aggregate_embed", "video_aggregate_embed"),
|
||||
("audio_aggregate_embed", "audio_aggregate_embed"),
|
||||
]:
|
||||
w_key = f"{prefix}.weight"
|
||||
b_key = f"{prefix}.bias"
|
||||
if w_key in weights:
|
||||
submodule = getattr(self.feature_extractor_v2, attr)
|
||||
submodule.weight = weights[w_key]
|
||||
if b_key in weights:
|
||||
submodule.bias = weights[b_key]
|
||||
else:
|
||||
# LTX-2: single aggregate_embed
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
if agg_key in weights:
|
||||
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
||||
|
||||
def _load_connector(self, name: str, weights: dict, is_reformatted: bool):
|
||||
"""Load a connector's weights (video or audio)."""
|
||||
connector = getattr(self, name)
|
||||
|
||||
# Extract connector-specific weights
|
||||
connector_weights = {}
|
||||
if is_reformatted:
|
||||
prefix = f"{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
connector_weights[key[len(prefix):]] = value
|
||||
else:
|
||||
mono_prefix = f"model.diffusion_model.{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mono_prefix):
|
||||
connector_weights[key[len(mono_prefix):]] = value
|
||||
|
||||
if not connector_weights:
|
||||
return
|
||||
|
||||
# Sanitize key names (only needed for monolithic/raw PyTorch keys)
|
||||
mapped = {}
|
||||
for key, value in connector_weights.items():
|
||||
new_key = key
|
||||
if not is_reformatted:
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped[new_key] = value
|
||||
|
||||
connector.load_weights(list(mapped.items()), strict=False)
|
||||
|
||||
# Manually load learnable_registers (plain mx.array, not tracked as parameter)
|
||||
if "learnable_registers" in connector_weights:
|
||||
connector.learnable_registers = connector_weights["learnable_registers"]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -795,21 +889,40 @@ class LTX2TextEncoder(nn.Module):
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
)
|
||||
features = self.feature_extractor(concat_hidden)
|
||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
||||
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
if return_audio_embeddings:
|
||||
# Process features through audio connector independently (same input as video)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask)
|
||||
return video_embeddings, audio_embeddings
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
|
||||
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
||||
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
# LTX-2: V1 feature extraction (8 * (x - mean) / range)
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
)
|
||||
|
||||
video_features = self.feature_extractor(concat_hidden)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user