format
This commit is contained in:
@@ -1,3 +1,2 @@
|
||||
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
|
||||
@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||
self.linear = nn.Linear(
|
||||
embedding_dim, embedding_coefficient * embedding_dim, bias=True
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
use_additional_conditions: bool = False,
|
||||
timestep_proj_dim: int = 256,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.size_emb_dim = size_emb_dim
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||
self.time_proj = Timesteps(
|
||||
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
|
||||
)
|
||||
|
||||
if use_additional_conditions and size_emb_dim > 0:
|
||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
# Add additional conditions if enabled
|
||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||
if resolution is not None and aspect_ratio is not None:
|
||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||
additional_embeds = self.additional_embedder(
|
||||
resolution, aspect_ratio, hidden_dtype
|
||||
)
|
||||
timesteps_emb = timesteps_emb + additional_embeds
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
|
||||
@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -103,6 +105,8 @@ def make_attn(
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
raise NotImplementedError(
|
||||
f"Attention type {attn_type.value} is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_audio(
|
||||
@@ -99,14 +98,16 @@ def waveform_to_mel(
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
S = np.abs(
|
||||
librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
)
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
@@ -39,7 +39,9 @@ def build_mid_block(
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
||||
if add_attention
|
||||
else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
config.in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
config = AudioEncoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels,
|
||||
base_block_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.norm_out = build_normalization_layer(
|
||||
final_block_channels, normtype=self.norm_type
|
||||
)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels,
|
||||
config.out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
def _denormalize_latents(
|
||||
self, sample: mx.array
|
||||
) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
mel_bins=(
|
||||
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
||||
),
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
:,
|
||||
: min(current_time, target_time),
|
||||
: min(current_freq, target_freq),
|
||||
:target_channels,
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
def decode_audio(
|
||||
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
||||
) -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
|
||||
@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
|
||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# Non-causal: symmetric padding
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||
self.padding = (
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
)
|
||||
elif self.causality_axis in (
|
||||
CausalityAxis.WIDTH,
|
||||
CausalityAxis.WIDTH_COMPATIBILITY,
|
||||
):
|
||||
# Causal on width: pad left (before width axis)
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
|
||||
if any(p > 0 for p in self.padding):
|
||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
|
||||
)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
@@ -124,7 +135,14 @@ def make_conv2d(
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
causality_axis,
|
||||
)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -116,10 +118,14 @@ def build_downsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["downsample"] = Downsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
@@ -51,7 +51,9 @@ def build_normalization_layer(
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
return nn.GroupNorm(
|
||||
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
h = h + mx.expand_dims(
|
||||
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
|
||||
)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@@ -124,10 +128,14 @@ def build_upsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["upsample"] = Upsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
@@ -7,8 +7,8 @@ Supports:
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +32,9 @@ class Snake(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
self.beta = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
alpha = self.alpha
|
||||
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff: float, half_width: float, kernel_size: int
|
||||
) -> mx.array:
|
||||
"""Compute a Kaiser-windowed sinc filter."""
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
|
||||
# Kaiser window - compute using scipy-compatible formula
|
||||
import numpy as np
|
||||
|
||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||
|
||||
if even:
|
||||
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||
import numpy as np
|
||||
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
|
||||
self.kernel_size = filt.shape[2]
|
||||
self.filter = filt
|
||||
else:
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
self.pad_left = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
)
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
self.filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
|
||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
||||
x = self.ratio * mx.conv_transpose1d(
|
||||
x, filt, stride=self.stride
|
||||
) # (N*C, L', 1)
|
||||
|
||||
# Trim padding
|
||||
x = x[:, self.pad_left:-self.pad_right, :]
|
||||
x = x[:, self.pad_left : -self.pad_right, :]
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
|
||||
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=d, padding=get_padding(kernel_size, d),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=1, padding=get_padding(kernel_size, 1),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
|
||||
y = mx.concatenate([first, y], axis=1)
|
||||
|
||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
||||
basis = mx.transpose(
|
||||
self.forward_basis.astype(y.dtype), (0, 2, 1)
|
||||
) # (514, K, 1)
|
||||
|
||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
|
||||
real = spec[..., :n_freqs]
|
||||
imag = spec[..., n_freqs:]
|
||||
|
||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
||||
magnitude = mx.sqrt(real**2 + imag**2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
|
||||
real.dtype
|
||||
)
|
||||
|
||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||
return magnitude, phase
|
||||
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
||||
def __init__(
|
||||
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||
n_freqs = filter_length // 2 + 1
|
||||
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
# magnitude: (B, T_frames, n_freqs)
|
||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
||||
mel = (
|
||||
magnitude @ self.mel_basis.astype(magnitude.dtype).T
|
||||
) # (B, T_frames, n_mels)
|
||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||
return mx.transpose(log_mel, (0, 2, 1))
|
||||
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
|
||||
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels, config.upsample_initial_channel,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
in_channels,
|
||||
config.upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
|
||||
# Upsampling layers
|
||||
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
|
||||
for i, (stride, kernel_size) in enumerate(
|
||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||
):
|
||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
||||
in_ch = config.upsample_initial_channel // (2**i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch, out_ch,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
|
||||
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = AMPBlock1(
|
||||
ch, kernel_size, tuple(dilations),
|
||||
ch,
|
||||
kernel_size,
|
||||
tuple(dilations),
|
||||
activation=config.activation,
|
||||
)
|
||||
block_idx += 1
|
||||
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
self.resblocks[block_idx] = resblock_class(
|
||||
ch, kernel_size, tuple(dilations)
|
||||
)
|
||||
block_idx += 1
|
||||
|
||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
||||
final_channels = config.upsample_initial_channel // (
|
||||
2 ** len(config.upsample_rates)
|
||||
)
|
||||
|
||||
# Post-activation
|
||||
if self.is_amp:
|
||||
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
|
||||
# Final conv
|
||||
out_channels = 2 if config.stereo else 1
|
||||
self.conv_post = nn.Conv1d(
|
||||
final_channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
final_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=config.use_bias_at_final,
|
||||
)
|
||||
|
||||
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
|
||||
"""
|
||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
output_length = (
|
||||
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
)
|
||||
|
||||
# Pad to hop_length multiple
|
||||
remainder = length_low_rate % self.hop_length
|
||||
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.models.ltx_2.conditioning.latent import (
|
||||
VideoConditionByLatentIndex,
|
||||
apply_conditioning,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
@@ -41,6 +42,7 @@ class LatentState:
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
@@ -130,15 +132,15 @@ def apply_conditioning(
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
latent_list.append(state.latent[:, :, i : i + 1])
|
||||
clean_list.append(state.clean_latent[:, :, i : i + 1])
|
||||
mask_list.append(state.denoise_mask[:, :, i : i + 1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
|
||||
SPLIT = "split"
|
||||
TWO_D = "2d"
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
|
||||
@@ -46,7 +47,7 @@ class BaseModelConfig:
|
||||
if v is not None:
|
||||
if isinstance(v, Enum):
|
||||
result[k] = v.value
|
||||
elif hasattr(v, 'to_dict'):
|
||||
elif hasattr(v, "to_dict"):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
|
||||
out_channels: int = 128
|
||||
latent_channels: int = 128
|
||||
patch_size: int = 4
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
])
|
||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
decoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
|
||||
audio_in_channels: int = 128
|
||||
audio_out_channels: int = 128
|
||||
audio_cross_attention_dim: int = 2048
|
||||
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
||||
audio_caption_channels: int = (
|
||||
3840 # Input dim for audio text embeddings (same as video)
|
||||
)
|
||||
|
||||
# Positional embedding config
|
||||
positional_embedding_theta: float = 10000.0
|
||||
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
@@ -237,21 +243,22 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
|
||||
|
||||
# Convert norm_type string to enum
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
|
||||
|
||||
# Convert attn_type string to enum
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
||||
dropout: float = 0.0
|
||||
timestep_conditioning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
norm_layer: Enum = None
|
||||
latent_log_var: Enum = None
|
||||
encoder_spatial_padding_mode: Enum = None
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
@@ -371,10 +382,12 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(
|
||||
self.encoder_spatial_padding_mode
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.encoder_blocks is not None:
|
||||
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -49,7 +49,6 @@ from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
new_key = key[len(TRANSFORMER_PREFIX):]
|
||||
new_key = key[len(TRANSFORMER_PREFIX) :]
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
else:
|
||||
continue
|
||||
elif key.startswith(VAE_DECODER_PREFIX):
|
||||
new_key = key[len(VAE_DECODER_PREFIX):]
|
||||
new_key = key[len(VAE_DECODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if value.dtype != mx.float32:
|
||||
value = value.astype(mx.float32)
|
||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||
new_key = key[len(VAE_ENCODER_PREFIX):]
|
||||
new_key = key[len(VAE_ENCODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_DECODER_PREFIX):
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if not key.startswith(VOCODER_PREFIX):
|
||||
continue
|
||||
|
||||
new_key = key[len(VOCODER_PREFIX):]
|
||||
new_key = key[len(VOCODER_PREFIX) :]
|
||||
|
||||
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
||||
# aggregate_embed weights (text_embedding_projection.*)
|
||||
for key, value in weights.items():
|
||||
if key.startswith(TEXT_PROJ_PREFIX):
|
||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
||||
new_key = key[len(TEXT_PROJ_PREFIX) :]
|
||||
extracted[new_key] = value
|
||||
|
||||
# video_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
|
||||
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
# audio_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
|
||||
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
|
||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
||||
MONOLITHIC_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
|
||||
)
|
||||
|
||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
|
||||
UPSCALER_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
|
||||
)
|
||||
|
||||
|
||||
def resolve_source(source: str, variant: str) -> Path:
|
||||
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||
"""Infer VAE decoder config from weights."""
|
||||
# Check for timestep conditioning keys
|
||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
||||
has_timestep = any(
|
||||
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
|
||||
)
|
||||
|
||||
# Count channel multipliers from up_blocks
|
||||
max_block = -1
|
||||
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
save_config(config, output_path / "transformer")
|
||||
t_params = sum(v.size for v in transformer_weights.values())
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
print(
|
||||
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
|
||||
)
|
||||
|
||||
# 2. VAE Decoder
|
||||
print(" [2/7] VAE Decoder...")
|
||||
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
]
|
||||
else:
|
||||
upscaler_files = [
|
||||
f.name for f in source_dir.iterdir()
|
||||
f.name
|
||||
for f in source_dir.iterdir()
|
||||
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
||||
]
|
||||
|
||||
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||
if all_converted < total_keys:
|
||||
known_prefixes = (
|
||||
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
|
||||
TRANSFORMER_PREFIX,
|
||||
VAE_DECODER_PREFIX,
|
||||
VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX,
|
||||
AUDIO_DECODER_PREFIX,
|
||||
AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX,
|
||||
VOCODER_PREFIX,
|
||||
TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX,
|
||||
AUDIO_CONNECTOR_PREFIX,
|
||||
)
|
||||
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
|
||||
skipped = [
|
||||
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
|
||||
]
|
||||
if skipped:
|
||||
print(f" Skipped {len(skipped)} keys:")
|
||||
for k in sorted(skipped)[:20]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,14 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb, embedded_timestep = self.adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
timestep_emb, embedded_timestep = adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
|
||||
# Convert boolean/int mask to float mask
|
||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
mask = mx.reshape(
|
||||
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
)
|
||||
return mask
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
timestep, embedded_timestep = self._prepare_timestep(
|
||||
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
|
||||
)
|
||||
context, attention_mask = self._prepare_context(
|
||||
modality.context, x, modality.context_mask
|
||||
)
|
||||
attention_mask = self._prepare_attention_mask(
|
||||
attention_mask, modality.latent.dtype
|
||||
)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
|
||||
prompt_timestep = None
|
||||
prompt_embedded_timestep = None
|
||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
|
||||
prompt_timestep, prompt_embedded_timestep = (
|
||||
self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln,
|
||||
modality.sigma,
|
||||
x.shape[0],
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
)
|
||||
|
||||
# Prepare cross-attention timestep embeddings
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
cross_scale_shift_timestep, cross_gate_timestep = (
|
||||
self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
@@ -254,17 +280,25 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
scale_shift_timestep = mx.reshape(
|
||||
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
|
||||
)
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
gate_timestep, _ = self.cross_gate_adaln(
|
||||
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
|
||||
)
|
||||
gate_timestep = mx.reshape(
|
||||
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
|
||||
class LTXModel(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, config: LTXModelConfig):
|
||||
|
||||
super().__init__()
|
||||
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
|
||||
self._init_video(config)
|
||||
|
||||
if config.model_type.is_audio_enabled():
|
||||
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
||||
self.audio_positional_embedding_max_pos = (
|
||||
config.audio_positional_embedding_max_pos
|
||||
)
|
||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||
self.audio_inner_dim = config.audio_inner_dim
|
||||
self._init_audio(config)
|
||||
|
||||
# Initialize cross-modal components
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
cross_pe_max_pos = max(
|
||||
config.positional_embedding_max_pos[0],
|
||||
config.audio_positional_embedding_max_pos[0],
|
||||
)
|
||||
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
||||
self.av_ca_timestep_scale_multiplier = (
|
||||
config.av_ca_timestep_scale_multiplier
|
||||
)
|
||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||
self._init_audio_video(config)
|
||||
|
||||
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
|
||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
|
||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
self.audio_patchify_proj = nn.Linear(
|
||||
config.audio_in_channels, self.audio_inner_dim, bias=True
|
||||
)
|
||||
|
||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.audio_norm_out = nn.LayerNorm(
|
||||
self.audio_inner_dim, eps=config.norm_eps, affine=False
|
||||
)
|
||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||
|
||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
|
||||
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
def _init_preprocessors(
|
||||
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
|
||||
) -> None:
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
# Multi-modal preprocessors
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
|
||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||
for idx, block in self.transformer_blocks.items():
|
||||
video, audio = block(
|
||||
video=video, audio=audio,
|
||||
video=video,
|
||||
audio=audio,
|
||||
skip_video_self_attn=(idx in stg_v_set),
|
||||
skip_audio_self_attn=(idx in stg_a_set),
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
@@ -483,7 +542,7 @@ class LTXModel(nn.Module):
|
||||
x: mx.array,
|
||||
embedded_timestep: mx.array,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
|
||||
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
|
||||
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
|
||||
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
# Preprocess arguments
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
video_args = (
|
||||
self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
)
|
||||
audio_args = (
|
||||
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
@@ -567,7 +630,7 @@ class LTXModel(nn.Module):
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
|
||||
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
|
||||
if not has_raw_prefix:
|
||||
return weights
|
||||
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
|
||||
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
if (
|
||||
"audio_embeddings_connector" in key
|
||||
or "video_embeddings_connector" in key
|
||||
):
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
|
||||
for weight_file in model_path.glob("*.safetensors"):
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
sanitized = {
|
||||
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
|
||||
for k, v in sanitized.items()
|
||||
}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
@@ -625,7 +693,7 @@ class LTXModel(nn.Module):
|
||||
class X0Model(nn.Module):
|
||||
|
||||
def __init__(self, velocity_model: LTXModel):
|
||||
|
||||
|
||||
super().__init__()
|
||||
self.velocity_model = velocity_model
|
||||
|
||||
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(
|
||||
video, audio,
|
||||
video,
|
||||
audio,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
denoised_video = (
|
||||
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
)
|
||||
denoised_audio = (
|
||||
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
)
|
||||
|
||||
return denoised_video, denoised_audio
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||
def bilateral_filter(
|
||||
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
|
||||
) -> np.ndarray:
|
||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||
|
||||
Args:
|
||||
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||
except ImportError:
|
||||
# Fallback to simple Gaussian blur if cv2 not available
|
||||
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||
except ImportError:
|
||||
# Simple box blur fallback
|
||||
from scipy.ndimage import uniform_filter
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
||||
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
|
||||
np.uint8
|
||||
)
|
||||
|
||||
|
||||
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
||||
def unsharp_mask(
|
||||
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Apply unsharp masking to enhance edges after blur.
|
||||
|
||||
Args:
|
||||
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
|
||||
if method == "bilateral":
|
||||
d = max(3, int(5 * strength))
|
||||
sigma = 50 + 50 * strength
|
||||
processed = np.stack([
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
]
|
||||
)
|
||||
elif method == "gaussian":
|
||||
kernel_size = max(3, int(3 + 4 * strength))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
processed = np.stack([
|
||||
gaussian_blur(frame, kernel_size=kernel_size)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
|
||||
)
|
||||
elif method == "frequency":
|
||||
processed = np.stack([
|
||||
remove_grid_frequency(frame, grid_size=8)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[remove_grid_frequency(frame, grid_size=8) for frame in video]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
||||
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
@@ -228,9 +228,9 @@ def get_fractional_positions(
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
assert n_pos_dims == len(
|
||||
max_pos
|
||||
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
indices_grid,
|
||||
dim,
|
||||
theta,
|
||||
max_pos,
|
||||
use_middle_indices_grid,
|
||||
num_attention_heads,
|
||||
rope_type,
|
||||
)
|
||||
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
|
||||
# Compute frequencies: outer product
|
||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
|
||||
freq_indices, (1, 1, 1, -1)
|
||||
)
|
||||
# freqs: (B, T, n_dims, num_indices)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||
|
||||
@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
@@ -139,7 +140,9 @@ def sde_noise_step(
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
x_noised = (
|
||||
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
)
|
||||
|
||||
return x_noised
|
||||
|
||||
@@ -148,6 +151,7 @@ def sde_noise_step(
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
from mlx_video.utils import apply_quantization, rms_norm
|
||||
|
||||
# Path to system prompts
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
@@ -36,11 +36,10 @@ def _load_system_prompt(prompt_name: str) -> str:
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
# Create config matching LTX-2 text encoder requirements
|
||||
self.config = config
|
||||
self.config = config
|
||||
|
||||
# Create the Gemma3Model from mlx-vlm
|
||||
self.model = Gemma3Model(self.config)
|
||||
@@ -51,7 +50,7 @@ class LanguageModel(nn.Module):
|
||||
attention_mask: Optional[mx.array],
|
||||
dtype: mx.Dtype,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
|
||||
|
||||
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
combined,
|
||||
mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype),
|
||||
)
|
||||
return mask[:, None, :, :]
|
||||
else:
|
||||
# No padding mask, just causal
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
causal_mask,
|
||||
mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype),
|
||||
)
|
||||
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||
|
||||
def __call__(
|
||||
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
|
||||
batch_size, seq_len = inputs.shape
|
||||
|
||||
# Get embeddings
|
||||
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||
h = (
|
||||
input_embeddings
|
||||
if input_embeddings is not None
|
||||
else self.model.embed_tokens(inputs)
|
||||
)
|
||||
|
||||
# Apply Gemma scaling
|
||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||
full_causal_mask = self._create_causal_mask_with_padding(
|
||||
seq_len, attention_mask, h.dtype
|
||||
)
|
||||
|
||||
sliding_mask = full_causal_mask
|
||||
|
||||
|
||||
num_layers = len(self.model.layers)
|
||||
for i, layer in enumerate(self.model.layers):
|
||||
is_global = (
|
||||
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
|
||||
else:
|
||||
sanitized[key[len(prefix):]] = value
|
||||
sanitized[key[len(prefix) :]] = value
|
||||
return sanitized
|
||||
|
||||
@property
|
||||
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
|
||||
|
||||
def make_cache(self):
|
||||
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||
|
||||
caches = []
|
||||
for i in range(len(self.layers)):
|
||||
if (
|
||||
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
import json
|
||||
|
||||
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||
config_file = Path(model_path) / "config.json"
|
||||
config_dict = {}
|
||||
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||
language_model = cls(
|
||||
config=TextConfig.from_dict(config_dict["text_config"])
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Config file not found at {model_path}")
|
||||
|
||||
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
|
||||
for i, wf in enumerate(weight_files):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
if hasattr(language_model, "sanitize"):
|
||||
weights = language_model.sanitize(weights=weights)
|
||||
|
||||
|
||||
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||
apply_quantization(
|
||||
model=language_model, weights=weights, quantization=quantization
|
||||
)
|
||||
|
||||
language_model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
return language_model
|
||||
|
||||
|
||||
|
||||
class ConnectorAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
q = mx.reshape(
|
||||
q, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(
|
||||
k, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(
|
||||
v, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if pe is not None:
|
||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||
@@ -304,7 +327,7 @@ class ConnectorAttention(nn.Module):
|
||||
out2 = x2 * cos_freq + x1 * sin_freq
|
||||
|
||||
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
|
||||
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
"""GELU-gated linear unit."""
|
||||
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
|
||||
|
||||
class ConnectorTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
self.attn1 = ConnectorAttention(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
self.ff = ConnectorFeedForward(dim)
|
||||
|
||||
def __call__(
|
||||
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
||||
|
||||
self.transformer_1d_blocks = {
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
i: ConnectorTransformerBlock(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
def _precompute_freqs_cis(
|
||||
self, seq_len: int, dtype: mx.Dtype
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||
|
||||
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Binary mask: 1 for valid tokens, 0 for padded
|
||||
# attention_mask is additive: 0 for valid, large negative for padded
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
|
||||
mx.int32
|
||||
) # (batch, seq)
|
||||
|
||||
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||
num_tiles = seq_len // self.num_learnable_registers
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
|
||||
dtype
|
||||
) # (seq_len, dim)
|
||||
|
||||
# Process each batch item (PyTorch uses advanced indexing)
|
||||
result_list = []
|
||||
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Extract valid tokens (where mask is 1)
|
||||
# Since we have left-padded input, valid tokens are at the end
|
||||
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
||||
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
|
||||
|
||||
# Pad with zeros on the right to get back to seq_len
|
||||
pad_length = seq_len - num_valid
|
||||
if pad_length > 0:
|
||||
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||
adjusted = mx.concatenate(
|
||||
[valid_tokens, padding], axis=0
|
||||
) # (seq_len, dim)
|
||||
else:
|
||||
adjusted = valid_tokens
|
||||
|
||||
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||
flipped_mask = mx.concatenate([
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32)
|
||||
], axis=0) # (seq,)
|
||||
flipped_mask = mx.concatenate(
|
||||
[
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32),
|
||||
],
|
||||
axis=0,
|
||||
) # (seq,)
|
||||
|
||||
# Combine: valid tokens at front, registers at back
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||
combined = (
|
||||
flipped_mask_expanded * adjusted
|
||||
+ (1 - flipped_mask_expanded) * registers
|
||||
)
|
||||
result_list.append(combined)
|
||||
|
||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Process through transformer blocks
|
||||
for i in range(len(self.transformer_1d_blocks)):
|
||||
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
||||
hidden_states = self.transformer_1d_blocks[i](
|
||||
hidden_states, attention_mask, freqs_cis
|
||||
)
|
||||
|
||||
# Final RMS norm
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
|
||||
def norm_and_concat_hidden_states(
|
||||
hidden_states: List[mx.array],
|
||||
attention_mask: mx.array,
|
||||
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
||||
x_for_min = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
|
||||
)
|
||||
x_for_max = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
|
||||
)
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
|
||||
dtype = encoded_text.dtype
|
||||
|
||||
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
||||
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
|
||||
variance = mx.mean(
|
||||
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
|
||||
) # (B, T, 1, L)
|
||||
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
||||
normed = normed.astype(dtype)
|
||||
|
||||
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
|
||||
def __init__(
|
||||
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
||||
|
||||
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
|
||||
|
||||
if mode == "video":
|
||||
target_dim = self.video_aggregate_embed.weight.shape[0]
|
||||
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
return self.video_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
else:
|
||||
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
||||
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
|
||||
|
||||
|
||||
return self.audio_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
|
||||
|
||||
class AudioEmbeddingsConnector(nn.Module):
|
||||
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
|
||||
video_output_dim = 4096
|
||||
audio_output_dim = 2048
|
||||
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
video_output_dim=video_output_dim,
|
||||
audio_output_dim=audio_output_dim,
|
||||
bias=True,
|
||||
@@ -728,37 +785,57 @@ class LTX2TextEncoder(nn.Module):
|
||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=video_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=128,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=audio_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=64,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
feature_input_dim, hidden_dim
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||
def load(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
|
||||
):
|
||||
|
||||
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||
|
||||
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
|
||||
# Load transformer weights for feature extractor and connector.
|
||||
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
if transformer_weights:
|
||||
self._load_feature_extractors(transformer_weights, is_reformatted)
|
||||
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector(
|
||||
"video_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
self._load_connector(
|
||||
"audio_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
else:
|
||||
print("WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!")
|
||||
print(
|
||||
"WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!"
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer_path = model_path / "tokenizer"
|
||||
if tokenizer_path.exists():
|
||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
str(tokenizer_path), trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
text_encoder_path, trust_remote_code=True
|
||||
)
|
||||
except Exception:
|
||||
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
"google/gemma-3-12b-it", trust_remote_code=True
|
||||
)
|
||||
# Set left padding to match official LTX-2 text encoder
|
||||
self.processor.padding_side = "left"
|
||||
|
||||
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
submodule.bias = weights[b_key]
|
||||
else:
|
||||
# LTX-2: single aggregate_embed
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
agg_key = (
|
||||
"aggregate_embed.weight"
|
||||
if is_reformatted
|
||||
else "text_embedding_projection.aggregate_embed.weight"
|
||||
)
|
||||
if agg_key in weights:
|
||||
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
||||
|
||||
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
|
||||
prefix = f"{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
connector_weights[key[len(prefix):]] = value
|
||||
connector_weights[key[len(prefix) :]] = value
|
||||
else:
|
||||
mono_prefix = f"model.diffusion_model.{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mono_prefix):
|
||||
connector_weights[key[len(mono_prefix):]] = value
|
||||
connector_weights[key[len(mono_prefix) :]] = value
|
||||
|
||||
if not connector_weights:
|
||||
return
|
||||
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
_, all_hidden_states = self.language_model(
|
||||
inputs=input_ids,
|
||||
input_embeddings=None,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
||||
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
|
||||
video_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="video"
|
||||
)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
|
||||
audio_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="audio"
|
||||
)
|
||||
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
||||
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
audio_features, audio_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
video_features = self.feature_extractor(concat_hidden)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
# Remove leading/trailing whitespace
|
||||
response = response.strip()
|
||||
# Remove any leading punctuation
|
||||
response = re.sub(r'^[^\w\s]+', '', response)
|
||||
response = re.sub(r"^[^\w\s]+", "", response)
|
||||
return response
|
||||
|
||||
def _apply_chat_template(
|
||||
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (image + text)
|
||||
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
formatted += (
|
||||
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||
# Add generation prompt
|
||||
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
except ImportError:
|
||||
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||
logging.warning(
|
||||
"mlx-lm not available for prompt enhancement. Using original prompt."
|
||||
)
|
||||
return prompt
|
||||
|
||||
if self.processor is None:
|
||||
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
)
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
|
||||
sampler = make_sampler(
|
||||
kwargs.get("temperature", 0.7),
|
||||
kwargs.get("top_p", 1.0),
|
||||
top_k=kwargs.get("top_k", -1),
|
||||
)
|
||||
logits_processors = make_logits_processors(
|
||||
kwargs.get("logit_bias", None),
|
||||
kwargs.get("repetition_penalty", 1.3),
|
||||
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode only the new tokens
|
||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||
enhanced_prompt = self.processor.decode(
|
||||
generated_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
|
||||
def enhance_i2v(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder()
|
||||
encoder.load(model_path=model_path)
|
||||
return encoder
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||
timestep_reshaped = mx.reshape(
|
||||
timestep,
|
||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
)
|
||||
|
||||
# Extract the relevant indices
|
||||
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Squeeze the sequence dimension if it's 1
|
||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||
scale_shift_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
|
||||
)
|
||||
gate_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
|
||||
)
|
||||
|
||||
return (*scale_shift_squeezed, *gate_squeezed)
|
||||
|
||||
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
run_a2v = (
|
||||
run_vx
|
||||
and (audio is not None and audio.enabled and ax.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
run_v2a = (
|
||||
run_ax
|
||||
and (video is not None and video.enabled and vx.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn1(
|
||||
norm_vx,
|
||||
pe=video.positional_embeddings,
|
||||
skip_attention=skip_video_self_attn,
|
||||
)
|
||||
* vgate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||
)
|
||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
||||
self.prompt_scale_shift_table,
|
||||
vx.shape[0],
|
||||
video.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
||||
encoder_hidden_states = (
|
||||
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
)
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn2(
|
||||
attn_input,
|
||||
context=encoder_hidden_states,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
* vgate_q
|
||||
)
|
||||
else:
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn1(
|
||||
norm_ax,
|
||||
pe=audio.positional_embeddings,
|
||||
skip_attention=skip_audio_self_attn,
|
||||
)
|
||||
* agate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
||||
self.audio_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.timesteps,
|
||||
slice(6, 9),
|
||||
)
|
||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
||||
self.audio_prompt_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input_a = (
|
||||
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
)
|
||||
encoder_hidden_states_a = (
|
||||
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
)
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn2(
|
||||
attn_input_a,
|
||||
context=encoder_hidden_states_a,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
* agate_q
|
||||
)
|
||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
||||
else:
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
|
||||
self.groups = groups
|
||||
|
||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
scale = (
|
||||
1.0
|
||||
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||
shape=(
|
||||
out_channels,
|
||||
kernel_size[0],
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
in_channels,
|
||||
),
|
||||
)
|
||||
|
||||
if bias:
|
||||
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
|
||||
n, d, h, w, c = x.shape
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
|
||||
self.den = den
|
||||
|
||||
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
||||
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffle2D(num, num)
|
||||
self.blur_down = BlurDownsample(stride=den)
|
||||
|
||||
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x) # H*num, W*num
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
|
||||
_, h_out, w_out, _ = x.shape
|
||||
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
||||
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
||||
from fractions import Fraction
|
||||
|
||||
frac = Fraction(scale).limit_denominator(10)
|
||||
return frac.numerator, frac.denominator
|
||||
|
||||
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
|
||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||
|
||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
||||
|
||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.post_upsample_res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Final projection
|
||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
|
||||
Returns:
|
||||
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
||||
"""
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
print(
|
||||
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
||||
conv_key = (
|
||||
"upsampler.conv.weight"
|
||||
if "upsampler.conv.weight" in raw_weights
|
||||
else "upsampler.0.weight"
|
||||
)
|
||||
if conv_key in raw_weights:
|
||||
out_channels = raw_weights[conv_key].shape[0]
|
||||
ratio = out_channels // mid_channels
|
||||
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
rational_resampler = False
|
||||
spatial_scale = 2.0
|
||||
|
||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||
print(
|
||||
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
|
||||
)
|
||||
|
||||
# Create model
|
||||
upsampler = LatentUpsampler(
|
||||
|
||||
@@ -109,6 +109,7 @@ def convert_audio_encoder(
|
||||
return encoder_dir
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
|
||||
:, :, ::-1, :, :
|
||||
] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
|
||||
:, :, :, ::-1, :
|
||||
] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
@@ -50,7 +54,7 @@ def make_conv_nd(
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
@@ -118,15 +122,17 @@ class CausalConv3d(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
first_frame_pad = mx.repeat(
|
||||
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
|
||||
)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
@@ -191,11 +196,10 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
|
||||
@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
in_channels=256, time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
def __call__(
|
||||
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
ada_values = self.scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
num_layers = (
|
||||
block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
)
|
||||
self.up_blocks[idx] = ResBlockGroup(
|
||||
ch, num_layers, spatial_padding_mode, timestep_conditioning
|
||||
)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
if (
|
||||
".conv.conv.weight" not in new_key
|
||||
and ".conv.conv.bias" not in new_key
|
||||
):
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
def from_pretrained(
|
||||
cls, model_path: Path, strict: bool = True
|
||||
) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
scaled_timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
use_chunked_conv = tiling_mode in (
|
||||
"conservative",
|
||||
"none",
|
||||
"auto",
|
||||
"default",
|
||||
"spatial",
|
||||
)
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
return self(
|
||||
sample,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
|
||||
@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
x = mx.reshape(
|
||||
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
|
||||
)
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
@@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module):
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
@@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module):
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
dtype = x.dtype
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
def __call__(
|
||||
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
|
||||
) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
def _chunked_conv_depth_to_space(
|
||||
self, x: mx.array, causal: bool = True
|
||||
) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
|
||||
@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
fade_out = [
|
||||
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
|
||||
]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
@@ -113,15 +127,21 @@ class TilingConfig:
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
|
||||
),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@@ -130,23 +150,33 @@ class TilingConfig:
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=256, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=32, tile_overlap_in_frames=8
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=768, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=96, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -186,10 +216,14 @@ class TilingConfig:
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
temporal_config = TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
@@ -197,16 +231,21 @@ class TilingConfig:
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_spatial(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
|
||||
)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_temporal(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts,
|
||||
ends=intervals.ends,
|
||||
left_ramps=left_ramps,
|
||||
right_ramps=intervals.right_ramps,
|
||||
)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_temporal_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, True
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_spatial_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, False
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
@@ -315,7 +373,9 @@ def decode_with_tiling(
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -338,7 +398,9 @@ def decode_with_tiling(
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -347,7 +409,9 @@ def decode_with_tiling(
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
@@ -356,13 +420,23 @@ def decode_with_tiling(
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -385,13 +459,15 @@ def decode_with_tiling(
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
@@ -409,11 +485,37 @@ def decode_with_tiling(
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
@@ -445,10 +547,12 @@ def decode_with_tiling(
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
@@ -456,7 +560,10 @@ def decode_with_tiling(
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
@@ -473,7 +580,7 @@ def decode_with_tiling(
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
@@ -481,7 +588,7 @@ def decode_with_tiling(
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
|
||||
@@ -8,12 +8,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.ops import (
|
||||
PerChannelStatistics,
|
||||
patchify,
|
||||
unpatchify,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
self.per_channel_statistics = PerChannelStatistics(
|
||||
latent_channels=config.out_channels
|
||||
)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
in_channels = config.in_channels * config.patch_size**2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
elif config.latent_log_var in {
|
||||
LogVarianceType.UNIFORM,
|
||||
LogVarianceType.CONSTANT,
|
||||
}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
sample = mx.concatenate(
|
||||
[
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
means = sample[:, : self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
|
||||
@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
q = rope_apply(
|
||||
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
k = rope_apply(
|
||||
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||
return weights
|
||||
|
||||
|
||||
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
def sanitize_wan_transformer_weights(
|
||||
weights: Dict[str, mx.array]
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||
|
||||
Wan2.2 keys follow the pattern:
|
||||
@@ -246,8 +247,8 @@ def _load_lora_configs(
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
@@ -264,7 +265,9 @@ def _load_lora_configs(
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
|
||||
)
|
||||
|
||||
return module_to_loras
|
||||
|
||||
@@ -279,8 +282,8 @@ def load_and_apply_loras(
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
@@ -289,12 +292,17 @@ def load_and_apply_loras(
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
|
||||
)
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
||||
model_weights,
|
||||
module_to_loras,
|
||||
verbose=verbose,
|
||||
quantization_bits=quantization_bits,
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
src_text_len = src_config.get("text_len", 512)
|
||||
|
||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}")
|
||||
print(
|
||||
f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}"
|
||||
)
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
weights, include_encoder=include_encoder
|
||||
)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
print(
|
||||
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
|
||||
)
|
||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||
|
||||
print(f"\nConversion complete! Output: {output_dir}")
|
||||
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
|
||||
return False
|
||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||
quantize_patterns = (
|
||||
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
||||
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
||||
".ffn.fc1", ".ffn.fc2",
|
||||
".self_attn.q",
|
||||
".self_attn.k",
|
||||
".self_attn.v",
|
||||
".self_attn.o",
|
||||
".cross_attn.q",
|
||||
".cross_attn.k",
|
||||
".cross_attn.v",
|
||||
".cross_attn.o",
|
||||
".ffn.fc1",
|
||||
".ffn.fc2",
|
||||
)
|
||||
return any(path.endswith(p) for p in quantize_patterns)
|
||||
|
||||
@@ -684,14 +706,20 @@ def quantize_mlx_model(
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
||||
transformer_files = {
|
||||
"low_noise_model.safetensors",
|
||||
"high_noise_model.safetensors",
|
||||
"model.safetensors",
|
||||
}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
@@ -763,11 +791,18 @@ if __name__ == "__main__":
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir, args.output_dir,
|
||||
bits=args.bits, group_size=args.group_size,
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
args.dtype,
|
||||
args.model_version,
|
||||
quantize=args.quantize,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
|
||||
@@ -4,18 +4,15 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan.loading import (
|
||||
_clean_text,
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
@@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import (
|
||||
)
|
||||
from mlx_video.models.wan.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
@@ -37,6 +35,7 @@ class Colors:
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
@@ -143,10 +142,13 @@ def generate_video(
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**{
|
||||
k: v for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
})
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||
if (model_dir / "low_noise_model.safetensors").exists():
|
||||
@@ -182,7 +184,9 @@ def generate_video(
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
actual_dim = v.shape[0]
|
||||
if actual_dim != config.dim:
|
||||
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
|
||||
)
|
||||
if actual_dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
@@ -192,13 +196,20 @@ def generate_video(
|
||||
|
||||
# Auto-correct Wan2.2 VAE params from stale configs
|
||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
|
||||
config = WanModelConfig(**{
|
||||
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
})
|
||||
print(
|
||||
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
|
||||
)
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
}
|
||||
)
|
||||
|
||||
# Apply defaults from config if not overridden
|
||||
if steps is None:
|
||||
@@ -227,7 +238,9 @@ def generate_video(
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
|
||||
)
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
@@ -247,10 +260,16 @@ def generate_video(
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
neg_display = (
|
||||
neg_prompt_resolved[:60] + "..."
|
||||
if len(neg_prompt_resolved) > 60
|
||||
else neg_prompt_resolved
|
||||
)
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
print(
|
||||
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
|
||||
)
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
@@ -275,12 +294,16 @@ def generate_video(
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
|
||||
)
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
||||
width, height = _best_output_size(
|
||||
width, height, align_w, align_h, config.max_area
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
@@ -309,6 +332,7 @@ def generate_video(
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
|
||||
# Encode prompts
|
||||
@@ -318,12 +342,15 @@ def generate_video(
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
context_null = encode_text(
|
||||
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
|
||||
)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
@@ -346,18 +373,25 @@ def generate_video(
|
||||
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
|
||||
img_arr = mx.array(
|
||||
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
|
||||
) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate([
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
], axis=1)
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
@@ -367,12 +401,17 @@ def generate_video(
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate([
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -395,13 +434,16 @@ def generate_video(
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
|
||||
)
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
@@ -412,10 +454,16 @@ def generate_video(
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
||||
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
||||
low_noise_model = load_wan_model(
|
||||
low_noise_path, config, quantization, loras=_loras_low
|
||||
)
|
||||
high_noise_model = load_wan_model(
|
||||
high_noise_path, config, quantization, loras=_loras_high
|
||||
)
|
||||
else:
|
||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
||||
single_model = load_wan_model(
|
||||
model_dir / "model.safetensors", config, quantization, loras=_loras_single
|
||||
)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
@@ -437,8 +485,12 @@ def generate_video(
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
context_cfg_low = mx.concatenate(
|
||||
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
|
||||
)
|
||||
context_cfg_high = mx.concatenate(
|
||||
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
|
||||
)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
@@ -534,7 +586,7 @@ def generate_video(
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Use compiled forward when available (faster after first trace)
|
||||
_call = getattr(model, '_compiled', model)
|
||||
_call = getattr(model, "_compiled", model)
|
||||
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
@@ -552,7 +604,9 @@ def generate_video(
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
ctx = (
|
||||
context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
)
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = _call(
|
||||
@@ -571,7 +625,11 @@ def generate_video(
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
gs = (
|
||||
guide_scale
|
||||
if isinstance(guide_scale, (int, float))
|
||||
else guide_scale[0]
|
||||
)
|
||||
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
@@ -586,8 +644,10 @@ def generate_video(
|
||||
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
ctx = (
|
||||
context_cfg
|
||||
if not is_dual
|
||||
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
|
||||
)
|
||||
preds = _call(
|
||||
[latents, latents],
|
||||
@@ -618,16 +678,24 @@ def generate_video(
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
|
||||
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
|
||||
print(
|
||||
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
|
||||
)
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
|
||||
print(
|
||||
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
|
||||
)
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
|
||||
print(
|
||||
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
@@ -646,7 +714,8 @@ def generate_video(
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||
@@ -677,13 +746,25 @@ def generate_video(
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
|
||||
)
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
spatial_info = (
|
||||
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
|
||||
if tiling_config.spatial_config
|
||||
else "none"
|
||||
)
|
||||
temporal_info = (
|
||||
f"{tiling_config.temporal_config.tile_size_in_frames}f"
|
||||
if tiling_config.temporal_config
|
||||
else "none"
|
||||
)
|
||||
print(
|
||||
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
@@ -718,7 +799,9 @@ def generate_video(
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
|
||||
)
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
@@ -727,58 +810,124 @@ def generate_video(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--image", type=str, default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
||||
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||
parser.add_argument(
|
||||
"--scheduler", type=str, default="unipc",
|
||||
"--model-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to converted MLX model directory",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-negative-prompt",
|
||||
action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", type=int, default=1280, help="Video width (default: 1280)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=704,
|
||||
help="Video height (default: 704; 720p models use 704)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of diffusion steps (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide-scale",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Guidance scale: single float or low,high pair",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Noise schedule shift (default: from config)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="output.mp4", help="Output video path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora-high",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora-low",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
choices=[
|
||||
"auto",
|
||||
"none",
|
||||
"default",
|
||||
"aggressive",
|
||||
"conservative",
|
||||
"spatial",
|
||||
"temporal",
|
||||
],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile", action="store_true",
|
||||
"--no-compile",
|
||||
action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames", type=int, default=0, metavar="N",
|
||||
"--trim-first-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents", action="store_true",
|
||||
"--debug-latents",
|
||||
action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||
|
||||
# Resize so that the image covers the target size (LANCZOS)
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
|
||||
# Center crop
|
||||
x1 = (img.width - width) // 2
|
||||
|
||||
@@ -6,7 +6,12 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
||||
def load_wan_model(
|
||||
model_path: Path,
|
||||
config,
|
||||
quantization: dict | None = None,
|
||||
loras: list | None = None,
|
||||
):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
@@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
@@ -37,7 +38,9 @@ class Head(nn.Module):
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
|
||||
# Reference computes three rope_params with different dim normalizations
|
||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
self.freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
).astype(np.float32)
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
|
||||
np.float32
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
|
||||
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
(
|
||||
mx.concatenate(
|
||||
[
|
||||
p,
|
||||
mx.zeros(
|
||||
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
|
||||
),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
|
||||
t = t[None]
|
||||
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate(
|
||||
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
|
||||
)
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
"""Save video frames to MP4.
|
||||
|
||||
@@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
"""
|
||||
try:
|
||||
import imageio
|
||||
|
||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
@@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
except ImportError:
|
||||
try:
|
||||
import cv2
|
||||
|
||||
h, w = frames.shape[1], frames.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||
@@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
except (ImportError, Exception):
|
||||
# Last resort: save as individual PNGs
|
||||
from PIL import Image
|
||||
|
||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
||||
|
||||
print(
|
||||
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
freqs = (
|
||||
np.arange(max_seq_len, dtype=np.float64)[:, None]
|
||||
* (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
)
|
||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||
@@ -46,9 +48,9 @@ def rope_apply(
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = all(
|
||||
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
|
||||
) if b > 1 else True
|
||||
all_same_grid = (
|
||||
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
|
||||
)
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
@@ -57,7 +59,9 @@ def rope_apply(
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
|
||||
b, seq_len, n, d
|
||||
)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
@@ -102,17 +106,11 @@ def rope_apply(
|
||||
|
||||
# Build per-position frequencies by expanding along grid dims
|
||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||
ft = mx.broadcast_to(
|
||||
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
|
||||
)
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||
fh = mx.broadcast_to(
|
||||
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
|
||||
)
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||
fw = mx.broadcast_to(
|
||||
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
|
||||
)
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
# Concatenate: [f*h*w, half_d, 2]
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
|
||||
@@ -7,9 +7,8 @@ for the same quality as Euler.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_sigmas(
|
||||
@@ -25,9 +24,7 @@ def _compute_sigmas(
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
|
||||
::-1
|
||||
]
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
|
||||
dt = (
|
||||
self._sigmas_float[self._step_index + 1]
|
||||
- self._sigmas_float[self._step_index]
|
||||
)
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
|
||||
|
||||
# Decide order: 1st for first step, last step (if lower_order_final
|
||||
# and few steps), otherwise 2nd
|
||||
use_first_order = (
|
||||
self._prev_x0 is None
|
||||
or (
|
||||
self.lower_order_final
|
||||
and i == self._num_steps - 1
|
||||
and self._num_steps < 15
|
||||
)
|
||||
use_first_order = self._prev_x0 is None or (
|
||||
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
|
||||
)
|
||||
|
||||
if use_first_order or sigma_next == 0.0:
|
||||
|
||||
@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = (
|
||||
max_exact
|
||||
+ (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
)
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
|
||||
rel_buckets = rel_buckets + mx.where(
|
||||
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||
)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
|
||||
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
|
||||
@@ -75,7 +75,11 @@ def decode_with_tiling(
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
||||
out_f = (
|
||||
(1 + (f_latent - 1) * temporal_scale)
|
||||
if causal_temporal
|
||||
else (f_latent * temporal_scale)
|
||||
)
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
@@ -98,9 +102,13 @@ def decode_with_tiling(
|
||||
|
||||
# Compute intervals for each dimension
|
||||
if causal_temporal:
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
else:
|
||||
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_spatial(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -124,9 +132,13 @@ def decode_with_tiling(
|
||||
|
||||
# Map temporal coordinates
|
||||
if causal_temporal:
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
else:
|
||||
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_spatial_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -135,7 +147,9 @@ def decode_with_tiling(
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
@@ -144,13 +158,23 @@ def decode_with_tiling(
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -173,13 +197,15 @@ def decode_with_tiling(
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
@@ -196,11 +222,37 @@ def decode_with_tiling(
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
@@ -232,12 +284,14 @@ def decode_with_tiling(
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
elif causal_temporal:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
else:
|
||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
@@ -245,7 +299,10 @@ def decode_with_tiling(
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
@@ -262,7 +319,7 @@ def decode_with_tiling(
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
@@ -270,7 +327,7 @@ def decode_with_tiling(
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
|
||||
@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True)
|
||||
if cross_attn_norm
|
||||
else None
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
y = self.self_attn(
|
||||
x_mod,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
|
||||
@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization statistics for z_dim=16
|
||||
VAE_MEAN = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
VAE_STD = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
|
||||
|
||||
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
|
||||
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
],
|
||||
)
|
||||
|
||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||
out = self._conv3d(x)
|
||||
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
norm_dim = 1 if self.channel_first else -1
|
||||
# L2 normalize along channel dim (matches F.normalize)
|
||||
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
|
||||
norm = mx.sqrt(
|
||||
mx.clip(
|
||||
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
|
||||
)
|
||||
)
|
||||
return (x / norm) * self.scale * self.gamma
|
||||
|
||||
|
||||
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.residual = [
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
@@ -226,13 +266,16 @@ class Resample(nn.Module):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
|
||||
)
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
@@ -272,8 +315,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(
|
||||
x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||
]
|
||||
|
||||
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization for z_dim=48 latent space
|
||||
VAE22_MEAN = mx.array([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
])
|
||||
VAE22_MEAN = mx.array(
|
||||
[
|
||||
-0.2289,
|
||||
-0.0052,
|
||||
-0.1323,
|
||||
-0.2339,
|
||||
-0.2799,
|
||||
0.0174,
|
||||
0.1838,
|
||||
0.1557,
|
||||
-0.1382,
|
||||
0.0542,
|
||||
0.2813,
|
||||
0.0891,
|
||||
0.1570,
|
||||
-0.0098,
|
||||
0.0375,
|
||||
-0.1825,
|
||||
-0.2246,
|
||||
-0.1207,
|
||||
-0.0698,
|
||||
0.5109,
|
||||
0.2665,
|
||||
-0.2108,
|
||||
-0.2158,
|
||||
0.2502,
|
||||
-0.2055,
|
||||
-0.0322,
|
||||
0.1109,
|
||||
0.1567,
|
||||
-0.0729,
|
||||
0.0899,
|
||||
-0.2799,
|
||||
-0.1230,
|
||||
-0.0313,
|
||||
-0.1649,
|
||||
0.0117,
|
||||
0.0723,
|
||||
-0.2839,
|
||||
-0.2083,
|
||||
-0.0520,
|
||||
0.3748,
|
||||
0.0152,
|
||||
0.1957,
|
||||
0.1433,
|
||||
-0.2944,
|
||||
0.3573,
|
||||
-0.0548,
|
||||
-0.1681,
|
||||
-0.0667,
|
||||
]
|
||||
)
|
||||
|
||||
VAE22_STD = mx.array([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
|
||||
])
|
||||
VAE22_STD = mx.array(
|
||||
[
|
||||
0.4765,
|
||||
1.0364,
|
||||
0.4514,
|
||||
1.1677,
|
||||
0.5313,
|
||||
0.4990,
|
||||
0.4818,
|
||||
0.5013,
|
||||
0.8158,
|
||||
1.0344,
|
||||
0.5894,
|
||||
1.0901,
|
||||
0.6885,
|
||||
0.6165,
|
||||
0.8454,
|
||||
0.4978,
|
||||
0.5759,
|
||||
0.3523,
|
||||
0.7135,
|
||||
0.6804,
|
||||
0.5833,
|
||||
1.4146,
|
||||
0.8986,
|
||||
0.5659,
|
||||
0.7069,
|
||||
0.5338,
|
||||
0.4889,
|
||||
0.4917,
|
||||
0.4069,
|
||||
0.4999,
|
||||
0.6866,
|
||||
0.4093,
|
||||
0.5709,
|
||||
0.6065,
|
||||
0.6415,
|
||||
0.4944,
|
||||
0.5726,
|
||||
1.2042,
|
||||
0.5458,
|
||||
1.6887,
|
||||
0.3971,
|
||||
1.0600,
|
||||
0.3943,
|
||||
0.5537,
|
||||
0.5444,
|
||||
0.4089,
|
||||
0.7468,
|
||||
0.7744,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# Weight: [O, D, H, W, I] for MLX
|
||||
self.weight = mx.zeros((
|
||||
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
|
||||
))
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
# Spatial padding
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w), (0, 0)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
(0, 0),
|
||||
],
|
||||
)
|
||||
|
||||
T_padded = x.shape[1]
|
||||
H_padded, W_padded = x.shape[2], x.shape[3]
|
||||
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
|
||||
for d in range(kd):
|
||||
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
||||
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
||||
conv_out = mx.conv_general(frame, w2d,
|
||||
stride=(self.stride[1], self.stride[2]))
|
||||
conv_out = mx.conv_general(
|
||||
frame, w2d, stride=(self.stride[1], self.stride[2])
|
||||
)
|
||||
accum = conv_out if accum is None else accum + conv_out
|
||||
outputs.append(accum + self.bias)
|
||||
|
||||
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.scale = dim**0.5
|
||||
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
|
||||
# x: [..., C] (channels-last)
|
||||
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
||||
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
||||
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||
return (
|
||||
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
|
||||
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
||||
# We store as named layers matching PyTorch's indices
|
||||
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
||||
self.shortcut = (
|
||||
CausalConv3d(in_dim, out_dim, 1)
|
||||
if in_dim != out_dim
|
||||
else None
|
||||
)
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||
h = self.shortcut(x) if self.shortcut is not None else x
|
||||
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
|
||||
# Save last CACHE_T frames before conv (for next chunk's context)
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
out = conv(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
|
||||
x = self.norm(x)
|
||||
|
||||
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
||||
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
|
||||
qkv = (
|
||||
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
|
||||
) # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
||||
|
||||
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
|
||||
scale = C ** -0.5
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
|
||||
scale = C**-0.5
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale
|
||||
) # [BT, 1, HW, C]
|
||||
out = out.squeeze(1).reshape(B * T, H, W, C)
|
||||
|
||||
# Project output
|
||||
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
|
||||
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
||||
|
||||
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
||||
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
|
||||
x = x.reshape(
|
||||
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
|
||||
)
|
||||
|
||||
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
||||
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
|
||||
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
||||
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
|
||||
x = x.reshape(
|
||||
B,
|
||||
T * self.factor_t,
|
||||
H * self.factor_s,
|
||||
W * self.factor_s,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
if first_chunk:
|
||||
x = x[:, self.factor_t - 1:, :, :, :]
|
||||
x = x[:, self.factor_t - 1 :, :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
@@ -348,7 +452,9 @@ class Resample(nn.Module):
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
@@ -369,7 +475,9 @@ class Resample(nn.Module):
|
||||
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
return (
|
||||
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
)
|
||||
|
||||
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
||||
# x: [B, T, H, W, C]
|
||||
@@ -444,14 +552,17 @@ class Resample(nn.Module):
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
"""Upsampling residual block with optional DupUp3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
|
||||
):
|
||||
super().__init__()
|
||||
self.up_flag = up_flag
|
||||
|
||||
# DupUp3D shortcut (no learnable params)
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim, out_dim,
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
num_res_blocks,
|
||||
temperal_downsample=False,
|
||||
down_flag=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.down_flag = down_flag
|
||||
|
||||
# AvgDown3D shortcut (no learnable params, always present)
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim, out_dim,
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
|
||||
self.upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||
self.upsamples.append(Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks + 1,
|
||||
temperal_upsample=t_up,
|
||||
up_flag=(i != len(dim_mult) - 1),
|
||||
))
|
||||
self.upsamples.append(
|
||||
Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks + 1,
|
||||
temperal_upsample=t_up,
|
||||
up_flag=(i != len(dim_mult) - 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
||||
self.head = Head22(dims[-1])
|
||||
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
|
||||
for i in range(len(dim_mult)):
|
||||
in_d, out_d = dims[i], dims[i + 1]
|
||||
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
self.downsamples.append(Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
))
|
||||
self.downsamples.append(
|
||||
Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Middle blocks (same as decoder)
|
||||
out_dim = dims[-1]
|
||||
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -700,9 +821,7 @@ class Head22(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
||||
)
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||
x = self.layer_2(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
|
||||
if i == 0:
|
||||
chunk = x[:, :1]
|
||||
else:
|
||||
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
|
||||
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
|
||||
|
||||
# conv1 (pointwise) + split into mu, log_var
|
||||
out = self.conv1(out)
|
||||
mu = out[:, :, :, :, :self.z_dim]
|
||||
mu = out[:, :, :, :, : self.z_dim]
|
||||
|
||||
# Normalize
|
||||
mu = normalize_latents(mu)
|
||||
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_cf,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
causal_temporal=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user