Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.
This commit is contained in:
@@ -101,6 +101,7 @@ class Attention(nn.Module):
|
||||
mask: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
skip_attention: bool = False,
|
||||
) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
@@ -110,6 +111,8 @@ class Attention(nn.Module):
|
||||
mask: Attention mask
|
||||
pe: Position embeddings for query (and key if k_pe is None)
|
||||
k_pe: Position embeddings for key (optional, uses pe if None)
|
||||
skip_attention: If True, bypass Q*K*V attention and use value projection
|
||||
only (for STG perturbation). Matches PyTorch all_perturbed=True.
|
||||
|
||||
Returns:
|
||||
Attention output of shape (B, seq_len, query_dim)
|
||||
@@ -119,24 +122,26 @@ class Attention(nn.Module):
|
||||
if hasattr(self, "to_gate_logits"):
|
||||
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
|
||||
|
||||
# Compute Q, K, V
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
# Apply normalization
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if skip_attention:
|
||||
# STG: bypass Q*K*V attention, use value projection only
|
||||
out = v
|
||||
else:
|
||||
# Standard attention
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Compute attention
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
|
||||
# Apply per-head gating
|
||||
if gate is not None:
|
||||
|
||||
Reference in New Issue
Block a user