Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
239
mlx_video/models/wan_2/text_encoder.py
Normal file
239
mlx_video/models/wan_2/text_encoder.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
"""RMS-based layer normalization (T5 style)."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
"""T5-style relative position bias with bucketing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_buckets: int,
|
||||
num_heads: int,
|
||||
bidirectional: bool = True,
|
||||
max_dist: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(
|
||||
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||
)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
positions_k = mx.arange(lk)[None, :] # [1, lk]
|
||||
positions_q = mx.arange(lq)[:, None] # [lq, 1]
|
||||
rel_pos = positions_k - positions_q # [lq, lk]
|
||||
|
||||
buckets = self._relative_position_bucket(rel_pos)
|
||||
embeds = self.embedding(buckets) # [lq, lk, num_heads]
|
||||
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
|
||||
return embeds
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
"""T5-style multi-head attention (no scaling)."""
|
||||
|
||||
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
assert dim_attn % num_heads == 0
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array | None = None,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
context = x if context is None else context
|
||||
b, n, c = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 uses no scaling — compute attention manually with float32 softmax
|
||||
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
# Using SDPA with bfloat16 inputs causes precision loss in softmax
|
||||
# since unscaled logits can be very large (no 1/sqrt(d) division).
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias.astype(mx.float32)
|
||||
|
||||
# Apply attention mask (use dtype min like official, not -1e9)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
|
||||
attn = attn + additive_mask
|
||||
|
||||
# Softmax in float32 (matches official), then cast back
|
||||
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
|
||||
|
||||
# Attention @ V
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
|
||||
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = nn.GELU(approx="tanh")
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
|
||||
|
||||
|
||||
class T5SelfAttentionBlock(nn.Module):
|
||||
"""T5 encoder block: self-attention + FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_attn: int,
|
||||
dim_ffn: int,
|
||||
num_heads: int,
|
||||
num_buckets: int,
|
||||
shared_pos: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_pos = shared_pos
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn)
|
||||
self.pos_embedding = (
|
||||
None
|
||||
if shared_pos
|
||||
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
|
||||
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
"""T5 Encoder (UMT5-XXL configuration)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 256384,
|
||||
dim: int = 4096,
|
||||
dim_attn: int = 4096,
|
||||
dim_ffn: int = 10240,
|
||||
num_heads: int = 64,
|
||||
num_layers: int = 24,
|
||||
num_buckets: int = 32,
|
||||
shared_pos: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
if shared_pos
|
||||
else None
|
||||
)
|
||||
self.blocks = [
|
||||
T5SelfAttentionBlock(
|
||||
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
ids: Token IDs [B, L]
|
||||
mask: Attention mask [B, L]
|
||||
|
||||
Returns:
|
||||
Hidden states [B, L, dim]
|
||||
"""
|
||||
x = self.token_embedding(ids)
|
||||
|
||||
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask=mask, pos_bias=e)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user