format
This commit is contained in:
@@ -1,36 +1,33 @@
|
|||||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||||
from mlx_video.models.wan import WanModel, WanModelConfig
|
|
||||||
|
|
||||||
# Audio VAE components
|
# Audio VAE components
|
||||||
from mlx_video.models.ltx_2.audio_vae import (
|
from mlx_video.models.ltx_2.audio_vae import (
|
||||||
AudioDecoder,
|
AudioDecoder,
|
||||||
AudioEncoder,
|
AudioEncoder,
|
||||||
|
AudioLatentShape,
|
||||||
|
AudioPatchifier,
|
||||||
|
PerChannelStatistics,
|
||||||
Vocoder,
|
Vocoder,
|
||||||
decode_audio,
|
decode_audio,
|
||||||
AudioPatchifier,
|
|
||||||
AudioLatentShape,
|
|
||||||
PerChannelStatistics,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conditioning
|
# Conditioning
|
||||||
from mlx_video.models.ltx_2.conditioning import (
|
from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex
|
||||||
VideoConditionByLatentIndex,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
from mlx_video.models.ltx_2.utils import (
|
from mlx_video.models.ltx_2.utils import (
|
||||||
convert_audio_encoder,
|
convert_audio_encoder,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
load_safetensors,
|
|
||||||
load_config,
|
load_config,
|
||||||
|
load_safetensors,
|
||||||
save_weights,
|
save_weights,
|
||||||
)
|
)
|
||||||
|
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Models
|
# Models
|
||||||
"LTXModel",
|
"LTXModel",
|
||||||
"LTXModelConfig",
|
"LTXModelConfig",
|
||||||
|
|
||||||
# Audio VAE
|
# Audio VAE
|
||||||
"AudioDecoder",
|
"AudioDecoder",
|
||||||
"AudioEncoder",
|
"AudioEncoder",
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ from mlx_video.lora.apply import (
|
|||||||
apply_loras_to_model,
|
apply_loras_to_model,
|
||||||
apply_loras_to_weights,
|
apply_loras_to_weights,
|
||||||
)
|
)
|
||||||
from mlx_video.lora.loader import (
|
from mlx_video.lora.loader import load_lora_weights, load_multiple_loras
|
||||||
load_lora_weights,
|
|
||||||
load_multiple_loras,
|
|
||||||
)
|
|
||||||
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
|
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
candidates = [lora_key]
|
candidates = [lora_key]
|
||||||
for prefix in prefixes_to_strip:
|
for prefix in prefixes_to_strip:
|
||||||
if lora_key.startswith(prefix):
|
if lora_key.startswith(prefix):
|
||||||
candidates.append(lora_key[len(prefix):])
|
candidates.append(lora_key[len(prefix) :])
|
||||||
|
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
# Try as-is
|
# Try as-is
|
||||||
@@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
|
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
|
||||||
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
|
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
|
||||||
if transformed.endswith(".ffn.0"):
|
if transformed.endswith(".ffn.0"):
|
||||||
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
|
transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
|
||||||
if transformed.endswith(".ffn.2"):
|
if transformed.endswith(".ffn.2"):
|
||||||
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
|
transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
|
||||||
|
|
||||||
# Text embedding: text_embedding.0 → text_embedding_0
|
# Text embedding: text_embedding.0 → text_embedding_0
|
||||||
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
|
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
|
||||||
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
|
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
|
||||||
if transformed.endswith("text_embedding.0"):
|
if transformed.endswith("text_embedding.0"):
|
||||||
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
|
transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
|
||||||
if transformed.endswith("text_embedding.2"):
|
if transformed.endswith("text_embedding.2"):
|
||||||
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
|
transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
|
||||||
|
|
||||||
# Time embedding: time_embedding.0 → time_embedding_0
|
# Time embedding: time_embedding.0 → time_embedding_0
|
||||||
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
|
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
|
||||||
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
|
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
|
||||||
if transformed.endswith("time_embedding.0"):
|
if transformed.endswith("time_embedding.0"):
|
||||||
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
|
transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
|
||||||
if transformed.endswith("time_embedding.2"):
|
if transformed.endswith("time_embedding.2"):
|
||||||
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
|
transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
|
||||||
|
|
||||||
# Time projection: time_projection.1 → time_projection
|
# Time projection: time_projection.1 → time_projection
|
||||||
transformed = transformed.replace("time_projection.1.", "time_projection.")
|
transformed = transformed.replace("time_projection.1.", "time_projection.")
|
||||||
if transformed.endswith("time_projection.1"):
|
if transformed.endswith("time_projection.1"):
|
||||||
transformed = transformed[:-len("time_projection.1")] + "time_projection"
|
transformed = transformed[: -len("time_projection.1")] + "time_projection"
|
||||||
|
|
||||||
# Patch embedding: patch_embedding → patch_embedding_proj
|
# Patch embedding: patch_embedding → patch_embedding_proj
|
||||||
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
|
if (
|
||||||
|
"patch_embedding" in transformed
|
||||||
|
and "patch_embedding_proj" not in transformed
|
||||||
|
):
|
||||||
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
|
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
|
||||||
|
|
||||||
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||||
@@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
# Return best attempt with prefix stripped
|
# Return best attempt with prefix stripped
|
||||||
for prefix in prefixes_to_strip:
|
for prefix in prefixes_to_strip:
|
||||||
if lora_key.startswith(prefix):
|
if lora_key.startswith(prefix):
|
||||||
return lora_key[len(prefix):]
|
return lora_key[len(prefix) :]
|
||||||
|
|
||||||
return lora_key
|
return lora_key
|
||||||
|
|
||||||
@@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
|
|
||||||
for prefix in prefixes_to_strip:
|
for prefix in prefixes_to_strip:
|
||||||
if lora_key.startswith(prefix):
|
if lora_key.startswith(prefix):
|
||||||
normalized = lora_key[len(prefix):]
|
normalized = lora_key[len(prefix) :]
|
||||||
|
|
||||||
if f"{normalized}.weight" in model_keys or normalized in model_keys:
|
if f"{normalized}.weight" in model_keys or normalized in model_keys:
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
transformed = normalized
|
transformed = normalized
|
||||||
if transformed.endswith(".to_out.0"):
|
if transformed.endswith(".to_out.0"):
|
||||||
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||||
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||||
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
transformed = transformed.replace(
|
||||||
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
|
".audio_ff.net.0.proj.", ".audio_ff.proj_in."
|
||||||
|
)
|
||||||
|
transformed = transformed.replace(
|
||||||
|
".audio_ff.net.0.proj", ".audio_ff.proj_in"
|
||||||
|
)
|
||||||
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||||
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
|
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
|
||||||
|
|
||||||
@@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
# Try transformations on the original key
|
# Try transformations on the original key
|
||||||
transformed = lora_key
|
transformed = lora_key
|
||||||
if transformed.endswith(".to_out.0"):
|
if transformed.endswith(".to_out.0"):
|
||||||
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||||
@@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
|||||||
|
|
||||||
for prefix in prefixes_to_strip:
|
for prefix in prefixes_to_strip:
|
||||||
if lora_key.startswith(prefix):
|
if lora_key.startswith(prefix):
|
||||||
return lora_key[len(prefix):]
|
return lora_key[len(prefix) :]
|
||||||
|
|
||||||
return lora_key
|
return lora_key
|
||||||
|
|
||||||
@@ -226,7 +233,9 @@ def apply_loras_to_weights(
|
|||||||
skipped_count += 1
|
skipped_count += 1
|
||||||
skipped_modules.append(module_name)
|
skipped_modules.append(module_name)
|
||||||
if verbose and skipped_count <= 5:
|
if verbose and skipped_count <= 5:
|
||||||
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
|
print(
|
||||||
|
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
|
||||||
|
)
|
||||||
similar = [
|
similar = [
|
||||||
k
|
k
|
||||||
for k in list(model_keys)[:1000]
|
for k in list(model_keys)[:1000]
|
||||||
@@ -251,13 +260,21 @@ def apply_loras_to_weights(
|
|||||||
if is_quantized:
|
if is_quantized:
|
||||||
scales = modified_weights[scales_key]
|
scales = modified_weights[scales_key]
|
||||||
biases = modified_weights[biases_key]
|
biases = modified_weights[biases_key]
|
||||||
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
|
group_size = (original_weight.shape[-1] * 32) // (
|
||||||
|
scales.shape[-1] * quantization_bits
|
||||||
|
)
|
||||||
dequantized = mx.dequantize(
|
dequantized = mx.dequantize(
|
||||||
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
|
original_weight,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=quantization_bits,
|
||||||
)
|
)
|
||||||
modified = apply_lora_to_linear(dequantized, loras)
|
modified = apply_lora_to_linear(dequantized, loras)
|
||||||
# Re-quantize with same parameters
|
# Re-quantize with same parameters
|
||||||
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
|
new_w, new_scales, new_biases = mx.quantize(
|
||||||
|
modified, group_size=group_size, bits=quantization_bits
|
||||||
|
)
|
||||||
modified_weights[weight_key] = new_w
|
modified_weights[weight_key] = new_w
|
||||||
modified_weights[scales_key] = new_scales
|
modified_weights[scales_key] = new_scales
|
||||||
modified_weights[biases_key] = new_biases
|
modified_weights[biases_key] = new_biases
|
||||||
@@ -346,9 +363,15 @@ def apply_loras_to_model(
|
|||||||
parent = model
|
parent = model
|
||||||
try:
|
try:
|
||||||
for part in parts[:-1]:
|
for part in parts[:-1]:
|
||||||
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
parent = (
|
||||||
|
getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
||||||
|
)
|
||||||
leaf_name = parts[-1]
|
leaf_name = parts[-1]
|
||||||
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
|
target = (
|
||||||
|
getattr(parent, leaf_name)
|
||||||
|
if not leaf_name.isdigit()
|
||||||
|
else parent[int(leaf_name)]
|
||||||
|
)
|
||||||
except (AttributeError, IndexError, TypeError):
|
except (AttributeError, IndexError, TypeError):
|
||||||
skipped.append(lora_key)
|
skipped.append(lora_key)
|
||||||
if verbose:
|
if verbose:
|
||||||
@@ -358,8 +381,11 @@ def apply_loras_to_model(
|
|||||||
if isinstance(target, nn.QuantizedLinear):
|
if isinstance(target, nn.QuantizedLinear):
|
||||||
# Dequantize → merge LoRA → replace with bf16 Linear
|
# Dequantize → merge LoRA → replace with bf16 Linear
|
||||||
weight = mx.dequantize(
|
weight = mx.dequantize(
|
||||||
target.weight, target.scales, target.biases,
|
target.weight,
|
||||||
group_size=target.group_size, bits=target.bits,
|
target.scales,
|
||||||
|
target.biases,
|
||||||
|
group_size=target.group_size,
|
||||||
|
bits=target.bits,
|
||||||
)
|
)
|
||||||
merged = apply_lora_to_linear(weight, loras)
|
merged = apply_lora_to_linear(weight, loras)
|
||||||
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
|
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
|
||||||
@@ -379,7 +405,9 @@ def apply_loras_to_model(
|
|||||||
else:
|
else:
|
||||||
skipped.append(lora_key)
|
skipped.append(lora_key)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
|
print(
|
||||||
|
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if applied_count > 0:
|
if applied_count > 0:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,2 @@
|
|||||||
|
|
||||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||||
from mlx_video.models.wan import WanModel, WanModelConfig
|
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
|
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||||
from mlx_video.models.ltx_2.config import (
|
from mlx_video.models.ltx_2.config import (
|
||||||
LTXModelConfig,
|
LTXModelConfig,
|
||||||
TransformerConfig,
|
|
||||||
LTXModelType,
|
LTXModelType,
|
||||||
|
TransformerConfig,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
|
|||||||
|
|
||||||
class AdaLayerNormSingle(nn.Module):
|
class AdaLayerNormSingle(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
self.linear = nn.Linear(
|
||||||
|
embedding_dim, embedding_coefficient * embedding_dim, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -63,8 +64,12 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|||||||
self.size_emb_dim = size_emb_dim
|
self.size_emb_dim = size_emb_dim
|
||||||
self.use_additional_conditions = use_additional_conditions
|
self.use_additional_conditions = use_additional_conditions
|
||||||
|
|
||||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
self.time_proj = Timesteps(
|
||||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||||
|
)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
if use_additional_conditions and size_emb_dim > 0:
|
if use_additional_conditions and size_emb_dim > 0:
|
||||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||||
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|||||||
# Add additional conditions if enabled
|
# Add additional conditions if enabled
|
||||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||||
if resolution is not None and aspect_ratio is not None:
|
if resolution is not None and aspect_ratio is not None:
|
||||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
additional_embeds = self.additional_embedder(
|
||||||
|
resolution, aspect_ratio, hidden_dtype
|
||||||
|
)
|
||||||
timesteps_emb = timesteps_emb + additional_embeds
|
timesteps_emb = timesteps_emb + additional_embeds
|
||||||
|
|
||||||
return timesteps_emb
|
return timesteps_emb
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Audio VAE module for LTX-2 audio generation."""
|
"""Audio VAE module for LTX-2 audio generation."""
|
||||||
|
|
||||||
from .attention import AttentionType, AttnBlock, make_attn
|
|
||||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
|
||||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
|
||||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
|
||||||
from ..config import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
|
from .attention import AttentionType, AttnBlock, make_attn
|
||||||
|
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
|
||||||
|
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||||
|
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||||
from .downsample import Downsample, build_downsampling_path
|
from .downsample import Downsample, build_downsampling_path
|
||||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
|
|||||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
@@ -103,6 +105,8 @@ def make_attn(
|
|||||||
elif attn_type == AttentionType.NONE:
|
elif attn_type == AttentionType.NONE:
|
||||||
return Identity()
|
return Identity()
|
||||||
elif attn_type == AttentionType.LINEAR:
|
elif attn_type == AttentionType.LINEAR:
|
||||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
raise NotImplementedError(
|
||||||
|
f"Attention type {attn_type.value} is not supported yet."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||||
|
|||||||
@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
|
|||||||
using librosa for macOS/MLX compatibility.
|
using librosa for macOS/MLX compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
@@ -99,14 +98,16 @@ def waveform_to_mel(
|
|||||||
|
|
||||||
for ch in range(channels):
|
for ch in range(channels):
|
||||||
# Magnitude spectrogram (power=1.0)
|
# Magnitude spectrogram (power=1.0)
|
||||||
S = np.abs(librosa.stft(
|
S = np.abs(
|
||||||
waveform[ch],
|
librosa.stft(
|
||||||
n_fft=n_fft,
|
waveform[ch],
|
||||||
hop_length=hop_length,
|
n_fft=n_fft,
|
||||||
win_length=win_length,
|
hop_length=hop_length,
|
||||||
center=True,
|
win_length=win_length,
|
||||||
pad_mode="reflect",
|
center=True,
|
||||||
))
|
pad_mode="reflect",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Mel filterbank with slaney normalization
|
# Mel filterbank with slaney normalization
|
||||||
mel_basis = librosa.filters.mel(
|
mel_basis = librosa.filters.mel(
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Audio VAE encoder and decoder for LTX-2."""
|
"""Audio VAE encoder and decoder for LTX-2."""
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx_vlm.models.base import check_array_shape
|
from mlx_vlm.models.base import check_array_shape
|
||||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
|
||||||
|
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
||||||
from .attention import AttentionType, make_attn
|
from .attention import AttentionType, make_attn
|
||||||
from .causal_conv_2d import make_conv2d
|
from .causal_conv_2d import make_conv2d
|
||||||
from ..config import CausalityAxis
|
|
||||||
from .downsample import build_downsampling_path
|
from .downsample import build_downsampling_path
|
||||||
from .normalization import NormType, build_normalization_layer
|
from .normalization import NormType, build_normalization_layer
|
||||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
@@ -39,7 +39,9 @@ def build_mid_block(
|
|||||||
causality_axis=causality_axis,
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
mid["attn_1"] = (
|
mid["attn_1"] = (
|
||||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
||||||
|
if add_attention
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
mid["block_2"] = ResnetBlock(
|
mid["block_2"] = ResnetBlock(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
|
|||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
|
|
||||||
self.conv_in = make_conv2d(
|
self.conv_in = make_conv2d(
|
||||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
config.in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
causality_axis=self.causality_axis,
|
causality_axis=self.causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
|
|||||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||||
self.conv_out = make_conv2d(
|
self.conv_out = make_conv2d(
|
||||||
block_in, out_channels, kernel_size=3, stride=1,
|
block_in,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
causality_axis=self.causality_axis,
|
causality_axis=self.causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
value = (
|
||||||
|
value
|
||||||
|
if check_array_shape(value)
|
||||||
|
else mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
)
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
return sanitized
|
return sanitized
|
||||||
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||||
"""Load audio encoder from pretrained weights."""
|
"""Load audio encoder from pretrained weights."""
|
||||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
config = AudioEncoderModelConfig.from_dict(
|
||||||
|
json.load(open(model_path / "config.json"))
|
||||||
|
)
|
||||||
encoder = cls(config)
|
encoder = cls(config)
|
||||||
weights = mx.load(str(model_path / "model.safetensors"))
|
weights = mx.load(str(model_path / "model.safetensors"))
|
||||||
encoder.load_weights(list(weights.items()), strict=True)
|
encoder.load_weights(list(weights.items()), strict=True)
|
||||||
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
# Per-channel statistics for denormalizing latents
|
# Per-channel statistics for denormalizing latents
|
||||||
# Uses ch (base channel count) to match the patchified latent dimension
|
# Uses ch (base channel count) to match the patchified latent dimension
|
||||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||||
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
|
|||||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||||
|
|
||||||
self.conv_in = make_conv2d(
|
self.conv_in = make_conv2d(
|
||||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
config.z_channels,
|
||||||
|
base_block_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=self.causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mid = build_mid_block(
|
self.mid = build_mid_block(
|
||||||
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
|
|||||||
initial_block_channels=base_block_channels,
|
initial_block_channels=base_block_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
self.norm_out = build_normalization_layer(
|
||||||
|
final_block_channels, normtype=self.norm_type
|
||||||
|
)
|
||||||
self.conv_out = make_conv2d(
|
self.conv_out = make_conv2d(
|
||||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
final_block_channels,
|
||||||
|
config.out_ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=self.causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
|
|||||||
# PyTorch: (out_channels, in_channels, H, W)
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
# MLX: (out_channels, H, W, in_channels)
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
value = (
|
||||||
|
value
|
||||||
|
if check_array_shape(value)
|
||||||
|
else mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
)
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
|
||||||
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||||
"""Load audio VAE decoder from pretrained model."""
|
"""Load audio VAE decoder from pretrained model."""
|
||||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||||
|
|
||||||
|
config = AudioDecoderModelConfig.from_dict(
|
||||||
|
json.load(open(model_path / "config.json"))
|
||||||
|
)
|
||||||
decoder = cls(config)
|
decoder = cls(config)
|
||||||
weights = mx.load(str(model_path / "model.safetensors"))
|
weights = mx.load(str(model_path / "model.safetensors"))
|
||||||
# weights = decoder.sanitize(weights)
|
# weights = decoder.sanitize(weights)
|
||||||
decoder.load_weights(list(weights.items()), strict=True)
|
decoder.load_weights(list(weights.items()), strict=True)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, sample: mx.array) -> mx.array:
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Decode latent features back to audio spectrograms.
|
Decode latent features back to audio spectrograms.
|
||||||
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
|
|||||||
|
|
||||||
return self._adjust_output_shape(h, target_shape)
|
return self._adjust_output_shape(h, target_shape)
|
||||||
|
|
||||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
def _denormalize_latents(
|
||||||
|
self, sample: mx.array
|
||||||
|
) -> tuple[mx.array, AudioLatentShape]:
|
||||||
"""Denormalize latents using per-channel statistics."""
|
"""Denormalize latents using per-channel statistics."""
|
||||||
# sample shape: (B, H, W, C) in MLX format
|
# sample shape: (B, H, W, C) in MLX format
|
||||||
latent_shape = AudioLatentShape(
|
latent_shape = AudioLatentShape(
|
||||||
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
|
|||||||
batch=latent_shape.batch,
|
batch=latent_shape.batch,
|
||||||
channels=self.out_ch,
|
channels=self.out_ch,
|
||||||
frames=target_frames,
|
frames=target_frames,
|
||||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
mel_bins=(
|
||||||
|
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return sample, target_shape
|
return sample, target_shape
|
||||||
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
|
|||||||
|
|
||||||
# Step 1: Crop first to avoid exceeding target dimensions
|
# Step 1: Crop first to avoid exceeding target dimensions
|
||||||
decoded_output = decoded_output[
|
decoded_output = decoded_output[
|
||||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
:,
|
||||||
|
: min(current_time, target_time),
|
||||||
|
: min(current_freq, target_freq),
|
||||||
|
:target_channels,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||||
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
|
|||||||
return mx.tanh(h) if self.tanh_out else h
|
return mx.tanh(h) if self.tanh_out else h
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
def decode_audio(
|
||||||
|
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
|
|||||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||||
if self.causality_axis == CausalityAxis.NONE:
|
if self.causality_axis == CausalityAxis.NONE:
|
||||||
# Non-causal: symmetric padding
|
# Non-causal: symmetric padding
|
||||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
self.padding = (
|
||||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
pad_h // 2,
|
||||||
|
pad_h - pad_h // 2,
|
||||||
|
pad_w // 2,
|
||||||
|
pad_w - pad_w // 2,
|
||||||
|
)
|
||||||
|
elif self.causality_axis in (
|
||||||
|
CausalityAxis.WIDTH,
|
||||||
|
CausalityAxis.WIDTH_COMPATIBILITY,
|
||||||
|
):
|
||||||
# Causal on width: pad left (before width axis)
|
# Causal on width: pad left (before width axis)
|
||||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||||
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
|
|||||||
if any(p > 0 for p in self.padding):
|
if any(p > 0 for p in self.padding):
|
||||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
x = mx.pad(
|
||||||
|
x,
|
||||||
|
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
|
||||||
|
)
|
||||||
|
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
@@ -124,7 +135,14 @@ def make_conv2d(
|
|||||||
if causality_axis is not None:
|
if causality_axis is not None:
|
||||||
# For causal convolution, padding is handled internally by CausalConv2d
|
# For causal convolution, padding is handled internally by CausalConv2d
|
||||||
return CausalConv2d(
|
return CausalConv2d(
|
||||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
causality_axis,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# For non-causal convolution, use symmetric padding if not specified
|
# For non-causal convolution, use symmetric padding if not specified
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Set, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .attention import AttentionType, make_attn
|
|
||||||
from ..config import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
|
from .attention import AttentionType, make_attn
|
||||||
from .normalization import NormType
|
from .normalization import NormType
|
||||||
from .resnet import ResnetBlock
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
|
|||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# Do time downsampling here
|
# Do time downsampling here
|
||||||
# no asymmetric padding in MLX conv, must do it ourselves
|
# no asymmetric padding in MLX conv, must do it ourselves
|
||||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
@@ -116,10 +118,14 @@ def build_downsampling_path(
|
|||||||
)
|
)
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
stage["attn"][i_block] = make_attn(
|
||||||
|
block_in, attn_type=attn_type, norm_type=norm_type
|
||||||
|
)
|
||||||
|
|
||||||
if i_level != num_resolutions - 1:
|
if i_level != num_resolutions - 1:
|
||||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
stage["downsample"] = Downsample(
|
||||||
|
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||||
|
)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
|
|
||||||
down_modules[i_level] = stage
|
down_modules[i_level] = stage
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ def build_normalization_layer(
|
|||||||
A normalization layer
|
A normalization layer
|
||||||
"""
|
"""
|
||||||
if normtype == NormType.GROUP:
|
if normtype == NormType.GROUP:
|
||||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
return nn.GroupNorm(
|
||||||
|
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
if normtype == NormType.PIXEL:
|
if normtype == NormType.PIXEL:
|
||||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""ResNet blocks for audio VAE and vocoder."""
|
"""ResNet blocks for audio VAE and vocoder."""
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .causal_conv_2d import make_conv2d
|
|
||||||
from ..config import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
|
from .causal_conv_2d import make_conv2d
|
||||||
from .normalization import NormType, build_normalization_layer
|
from .normalization import NormType, build_normalization_layer
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||||
self.conv1 = make_conv2d(
|
self.conv1 = make_conv2d(
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
|
|||||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||||
self.dropout_rate = dropout
|
self.dropout_rate = dropout
|
||||||
self.conv2 = make_conv2d(
|
self.conv2 = make_conv2d(
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = make_conv2d(
|
self.conv_shortcut = make_conv2d(
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = make_conv2d(
|
self.nin_shortcut = make_conv2d(
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
|
|||||||
if temb is not None and self.temb_channels > 0:
|
if temb is not None and self.temb_channels > 0:
|
||||||
# temb: (B, temb_channels) -> (B, out_channels)
|
# temb: (B, temb_channels) -> (B, out_channels)
|
||||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
h = h + mx.expand_dims(
|
||||||
|
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
|
||||||
|
)
|
||||||
|
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = nn.silu(h)
|
h = nn.silu(h)
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from typing import Set, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from ..config import CausalityAxis
|
||||||
from .attention import AttentionType, make_attn
|
from .attention import AttentionType, make_attn
|
||||||
from .causal_conv_2d import make_conv2d
|
from .causal_conv_2d import make_conv2d
|
||||||
from ..config import CausalityAxis
|
|
||||||
from .normalization import NormType
|
from .normalization import NormType
|
||||||
from .resnet import ResnetBlock
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
|
|||||||
self.causality_axis = causality_axis
|
self.causality_axis = causality_axis
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = make_conv2d(
|
self.conv = make_conv2d(
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
@@ -124,10 +128,14 @@ def build_upsampling_path(
|
|||||||
)
|
)
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
stage["attn"][i_block] = make_attn(
|
||||||
|
block_in, attn_type=attn_type, norm_type=norm_type
|
||||||
|
)
|
||||||
|
|
||||||
if level != 0:
|
if level != 0:
|
||||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
stage["upsample"] = Upsample(
|
||||||
|
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||||
|
)
|
||||||
curr_res *= 2
|
curr_res *= 2
|
||||||
|
|
||||||
up_modules[level] = stage
|
up_modules[level] = stage
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ Supports:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -32,7 +32,9 @@ class Snake(nn.Module):
|
|||||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha_logscale = alpha_logscale
|
self.alpha_logscale = alpha_logscale
|
||||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
self.alpha = (
|
||||||
|
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# x: (N, L, C) in MLX format
|
# x: (N, L, C) in MLX format
|
||||||
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
|
|||||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha_logscale = alpha_logscale
|
self.alpha_logscale = alpha_logscale
|
||||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
self.alpha = (
|
||||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||||
|
)
|
||||||
|
self.beta = (
|
||||||
|
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
def kaiser_sinc_filter1d(
|
||||||
|
cutoff: float, half_width: float, kernel_size: int
|
||||||
|
) -> mx.array:
|
||||||
"""Compute a Kaiser-windowed sinc filter."""
|
"""Compute a Kaiser-windowed sinc filter."""
|
||||||
even = kernel_size % 2 == 0
|
even = kernel_size % 2 == 0
|
||||||
half_size = kernel_size // 2
|
half_size = kernel_size // 2
|
||||||
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
|||||||
|
|
||||||
# Kaiser window - compute using scipy-compatible formula
|
# Kaiser window - compute using scipy-compatible formula
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||||
|
|
||||||
if even:
|
if even:
|
||||||
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
|||||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
rolloff = 0.99
|
rolloff = 0.99
|
||||||
lowpass_filter_width = 6
|
lowpass_filter_width = 6
|
||||||
width = math.ceil(lowpass_filter_width / rolloff)
|
width = math.ceil(lowpass_filter_width / rolloff)
|
||||||
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
|
|||||||
self.kernel_size = filt.shape[2]
|
self.kernel_size = filt.shape[2]
|
||||||
self.filter = filt
|
self.filter = filt
|
||||||
else:
|
else:
|
||||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
self.kernel_size = (
|
||||||
|
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
)
|
||||||
self.pad = self.kernel_size // ratio - 1
|
self.pad = self.kernel_size // ratio - 1
|
||||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
self.pad_left = (
|
||||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
)
|
||||||
|
self.pad_right = (
|
||||||
|
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
)
|
||||||
self.filter = kaiser_sinc_filter1d(
|
self.filter = kaiser_sinc_filter1d(
|
||||||
cutoff=0.5 / ratio,
|
cutoff=0.5 / ratio,
|
||||||
half_width=0.6 / ratio,
|
half_width=0.6 / ratio,
|
||||||
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
|
|||||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||||
|
|
||||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
x = self.ratio * mx.conv_transpose1d(
|
||||||
|
x, filt, stride=self.stride
|
||||||
|
) # (N*C, L', 1)
|
||||||
|
|
||||||
# Trim padding
|
# Trim padding
|
||||||
x = x[:, self.pad_left:-self.pad_right, :]
|
x = x[:, self.pad_left : -self.pad_right, :]
|
||||||
|
|
||||||
x = x.reshape(n, c, -1) # (N, C, L')
|
x = x.reshape(n, c, -1) # (N, C, L')
|
||||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||||
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
|
|||||||
|
|
||||||
self.convs1 = {
|
self.convs1 = {
|
||||||
i: nn.Conv1d(
|
i: nn.Conv1d(
|
||||||
channels, channels, kernel_size, stride=1,
|
channels,
|
||||||
dilation=d, padding=get_padding(kernel_size, d),
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=get_padding(kernel_size, d),
|
||||||
)
|
)
|
||||||
for i, d in enumerate(dilation)
|
for i, d in enumerate(dilation)
|
||||||
}
|
}
|
||||||
|
|
||||||
self.convs2 = {
|
self.convs2 = {
|
||||||
i: nn.Conv1d(
|
i: nn.Conv1d(
|
||||||
channels, channels, kernel_size, stride=1,
|
channels,
|
||||||
dilation=1, padding=get_padding(kernel_size, 1),
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
)
|
)
|
||||||
for i in range(len(dilation))
|
for i in range(len(dilation))
|
||||||
}
|
}
|
||||||
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
|
|||||||
y = mx.concatenate([first, y], axis=1)
|
y = mx.concatenate([first, y], axis=1)
|
||||||
|
|
||||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
basis = mx.transpose(
|
||||||
|
self.forward_basis.astype(y.dtype), (0, 2, 1)
|
||||||
|
) # (514, K, 1)
|
||||||
|
|
||||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||||
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
|
|||||||
real = spec[..., :n_freqs]
|
real = spec[..., :n_freqs]
|
||||||
imag = spec[..., n_freqs:]
|
imag = spec[..., n_freqs:]
|
||||||
|
|
||||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
magnitude = mx.sqrt(real**2 + imag**2)
|
||||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
|
||||||
|
real.dtype
|
||||||
|
)
|
||||||
|
|
||||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||||
return magnitude, phase
|
return magnitude, phase
|
||||||
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
|
|||||||
class MelSTFT(nn.Module):
|
class MelSTFT(nn.Module):
|
||||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||||
|
|
||||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
def __init__(
|
||||||
|
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||||
n_freqs = filter_length // 2 + 1
|
n_freqs = filter_length // 2 + 1
|
||||||
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
|
|||||||
"""
|
"""
|
||||||
magnitude, phase = self.stft_fn(y)
|
magnitude, phase = self.stft_fn(y)
|
||||||
# magnitude: (B, T_frames, n_freqs)
|
# magnitude: (B, T_frames, n_freqs)
|
||||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
mel = (
|
||||||
|
magnitude @ self.mel_basis.astype(magnitude.dtype).T
|
||||||
|
) # (B, T_frames, n_mels)
|
||||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||||
return mx.transpose(log_mel, (0, 2, 1))
|
return mx.transpose(log_mel, (0, 2, 1))
|
||||||
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
|
|||||||
|
|
||||||
in_channels = 128 if config.stereo else 64
|
in_channels = 128 if config.stereo else 64
|
||||||
self.conv_pre = nn.Conv1d(
|
self.conv_pre = nn.Conv1d(
|
||||||
in_channels, config.upsample_initial_channel,
|
in_channels,
|
||||||
kernel_size=7, stride=1, padding=3,
|
config.upsample_initial_channel,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Upsampling layers
|
# Upsampling layers
|
||||||
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
|
|||||||
for i, (stride, kernel_size) in enumerate(
|
for i, (stride, kernel_size) in enumerate(
|
||||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||||
):
|
):
|
||||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
in_ch = config.upsample_initial_channel // (2**i)
|
||||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||||
self.ups[i] = nn.ConvTranspose1d(
|
self.ups[i] = nn.ConvTranspose1d(
|
||||||
in_ch, out_ch,
|
in_ch,
|
||||||
kernel_size=kernel_size, stride=stride,
|
out_ch,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
padding=(kernel_size - stride) // 2,
|
padding=(kernel_size - stride) // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
|
|||||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||||
):
|
):
|
||||||
self.resblocks[block_idx] = AMPBlock1(
|
self.resblocks[block_idx] = AMPBlock1(
|
||||||
ch, kernel_size, tuple(dilations),
|
ch,
|
||||||
|
kernel_size,
|
||||||
|
tuple(dilations),
|
||||||
activation=config.activation,
|
activation=config.activation,
|
||||||
)
|
)
|
||||||
block_idx += 1
|
block_idx += 1
|
||||||
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
|
|||||||
for kernel_size, dilations in zip(
|
for kernel_size, dilations in zip(
|
||||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||||
):
|
):
|
||||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
self.resblocks[block_idx] = resblock_class(
|
||||||
|
ch, kernel_size, tuple(dilations)
|
||||||
|
)
|
||||||
block_idx += 1
|
block_idx += 1
|
||||||
|
|
||||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
final_channels = config.upsample_initial_channel // (
|
||||||
|
2 ** len(config.upsample_rates)
|
||||||
|
)
|
||||||
|
|
||||||
# Post-activation
|
# Post-activation
|
||||||
if self.is_amp:
|
if self.is_amp:
|
||||||
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
|
|||||||
# Final conv
|
# Final conv
|
||||||
out_channels = 2 if config.stereo else 1
|
out_channels = 2 if config.stereo else 1
|
||||||
self.conv_post = nn.Conv1d(
|
self.conv_post = nn.Conv1d(
|
||||||
final_channels, out_channels,
|
final_channels,
|
||||||
kernel_size=7, stride=1, padding=3,
|
out_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
bias=config.use_bias_at_final,
|
bias=config.use_bias_at_final,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||||
_, _, length_low_rate = x.shape
|
_, _, length_low_rate = x.shape
|
||||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
output_length = (
|
||||||
|
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
# Pad to hop_length multiple
|
# Pad to hop_length multiple
|
||||||
remainder = length_low_rate % self.hop_length
|
remainder = length_low_rate % self.hop_length
|
||||||
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
|||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=False)
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
"""Conditioning modules for LTX-2 video generation."""
|
"""Conditioning modules for LTX-2 video generation."""
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
from mlx_video.models.ltx_2.conditioning.latent import (
|
||||||
|
VideoConditionByLatentIndex,
|
||||||
|
apply_conditioning,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
|
|||||||
frame_idx: Frame index to condition (0 = first frame)
|
frame_idx: Frame index to condition (0 = first frame)
|
||||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latent: mx.array
|
latent: mx.array
|
||||||
frame_idx: int = 0
|
frame_idx: int = 0
|
||||||
strength: float = 1.0
|
strength: float = 1.0
|
||||||
@@ -41,6 +42,7 @@ class LatentState:
|
|||||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||||
1.0 = full denoise, 0.0 = keep clean
|
1.0 = full denoise, 0.0 = keep clean
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latent: mx.array
|
latent: mx.array
|
||||||
clean_latent: mx.array
|
clean_latent: mx.array
|
||||||
denoise_mask: mx.array
|
denoise_mask: mx.array
|
||||||
@@ -130,15 +132,15 @@ def apply_conditioning(
|
|||||||
if frame_idx <= i < end_idx:
|
if frame_idx <= i < end_idx:
|
||||||
# Use conditioning latent
|
# Use conditioning latent
|
||||||
cond_idx = i - frame_idx
|
cond_idx = i - frame_idx
|
||||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||||
else:
|
else:
|
||||||
# Keep original
|
# Keep original
|
||||||
latent_list.append(state.latent[:, :, i:i+1])
|
latent_list.append(state.latent[:, :, i : i + 1])
|
||||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
clean_list.append(state.clean_latent[:, :, i : i + 1])
|
||||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
mask_list.append(state.denoise_mask[:, :, i : i + 1])
|
||||||
|
|
||||||
state.latent = mx.concatenate(latent_list, axis=2)
|
state.latent = mx.concatenate(latent_list, axis=2)
|
||||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
|
|||||||
SPLIT = "split"
|
SPLIT = "split"
|
||||||
TWO_D = "2d"
|
TWO_D = "2d"
|
||||||
|
|
||||||
|
|
||||||
class AttentionType(Enum):
|
class AttentionType(Enum):
|
||||||
DEFAULT = "default"
|
DEFAULT = "default"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelConfig:
|
class BaseModelConfig:
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ class BaseModelConfig:
|
|||||||
if v is not None:
|
if v is not None:
|
||||||
if isinstance(v, Enum):
|
if isinstance(v, Enum):
|
||||||
result[k] = v.value
|
result[k] = v.value
|
||||||
elif hasattr(v, 'to_dict'):
|
elif hasattr(v, "to_dict"):
|
||||||
result[k] = v.to_dict()
|
result[k] = v.to_dict()
|
||||||
else:
|
else:
|
||||||
result[k] = v
|
result[k] = v
|
||||||
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
|
|||||||
out_channels: int = 128
|
out_channels: int = 128
|
||||||
latent_channels: int = 128
|
latent_channels: int = 128
|
||||||
patch_size: int = 4
|
patch_size: int = 4
|
||||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
encoder_blocks: List[tuple] = field(
|
||||||
("res_x", {"num_layers": 4}),
|
default_factory=lambda: [
|
||||||
("compress_space_res", {"multiplier": 2}),
|
("res_x", {"num_layers": 4}),
|
||||||
("res_x", {"num_layers": 6}),
|
("compress_space_res", {"multiplier": 2}),
|
||||||
("compress_time_res", {"multiplier": 2}),
|
("res_x", {"num_layers": 6}),
|
||||||
("res_x", {"num_layers": 6}),
|
("compress_time_res", {"multiplier": 2}),
|
||||||
("compress_all_res", {"multiplier": 2}),
|
("res_x", {"num_layers": 6}),
|
||||||
("res_x", {"num_layers": 2}),
|
("compress_all_res", {"multiplier": 2}),
|
||||||
("compress_all_res", {"multiplier": 2}),
|
("res_x", {"num_layers": 2}),
|
||||||
("res_x", {"num_layers": 2}),
|
("compress_all_res", {"multiplier": 2}),
|
||||||
])
|
("res_x", {"num_layers": 2}),
|
||||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
]
|
||||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
)
|
||||||
("compress_all", {"residual": True, "multiplier": 2}),
|
decoder_blocks: List[tuple] = field(
|
||||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
default_factory=lambda: [
|
||||||
("compress_all", {"residual": True, "multiplier": 2}),
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
("compress_all", {"residual": True, "multiplier": 2}),
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
])
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
audio_in_channels: int = 128
|
audio_in_channels: int = 128
|
||||||
audio_out_channels: int = 128
|
audio_out_channels: int = 128
|
||||||
audio_cross_attention_dim: int = 2048
|
audio_cross_attention_dim: int = 2048
|
||||||
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
audio_caption_channels: int = (
|
||||||
|
3840 # Input dim for audio text embeddings (same as video)
|
||||||
|
)
|
||||||
|
|
||||||
# Positional embedding config
|
# Positional embedding config
|
||||||
positional_embedding_theta: float = 10000.0
|
positional_embedding_theta: float = 10000.0
|
||||||
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CausalityAxis(Enum):
|
class CausalityAxis(Enum):
|
||||||
"""Enum for specifying the causality axis in causal convolutions."""
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
@@ -237,8 +243,8 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Convert string enum values to proper enum types."""
|
"""Convert string enum values to proper enum types."""
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from .audio_vae.normalization import NormType
|
|
||||||
from .audio_vae.attention import AttentionType
|
from .audio_vae.attention import AttentionType
|
||||||
|
from .audio_vae.normalization import NormType
|
||||||
|
|
||||||
# Convert causality_axis string to enum
|
# Convert causality_axis string to enum
|
||||||
if isinstance(self.causality_axis, str):
|
if isinstance(self.causality_axis, str):
|
||||||
@@ -252,6 +258,7 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
|||||||
if isinstance(self.attn_type, str):
|
if isinstance(self.attn_type, str):
|
||||||
self.attn_type = AttentionType(self.attn_type)
|
self.attn_type = AttentionType(self.attn_type)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AudioEncoderModelConfig(BaseModelConfig):
|
class AudioEncoderModelConfig(BaseModelConfig):
|
||||||
ch: int = 128
|
ch: int = 128
|
||||||
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Convert string enum values to proper enum types."""
|
"""Convert string enum values to proper enum types."""
|
||||||
from .audio_vae.normalization import NormType
|
|
||||||
from .audio_vae.attention import AttentionType
|
from .audio_vae.attention import AttentionType
|
||||||
|
from .audio_vae.normalization import NormType
|
||||||
|
|
||||||
if isinstance(self.causality_axis, str):
|
if isinstance(self.causality_axis, str):
|
||||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||||
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
|||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
timestep_conditioning: bool = False
|
timestep_conditioning: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VideoEncoderModelConfig(BaseModelConfig):
|
class VideoEncoderModelConfig(BaseModelConfig):
|
||||||
convolution_dimensions: int = 3
|
convolution_dimensions: int = 3
|
||||||
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
|||||||
norm_layer: Enum = None
|
norm_layer: Enum = None
|
||||||
latent_log_var: Enum = None
|
latent_log_var: Enum = None
|
||||||
encoder_spatial_padding_mode: Enum = None
|
encoder_spatial_padding_mode: Enum = None
|
||||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
encoder_blocks: List[tuple] = field(
|
||||||
("compress_space_res", {"multiplier": 2}),
|
default_factory=lambda: [
|
||||||
("res_x", {"num_layers": 6}),
|
("res_x", {"num_layers": 4}),
|
||||||
("compress_time_res", {"multiplier": 2}),
|
("compress_space_res", {"multiplier": 2}),
|
||||||
("res_x", {"num_layers": 6}),
|
("res_x", {"num_layers": 6}),
|
||||||
("compress_all_res", {"multiplier": 2}),
|
("compress_time_res", {"multiplier": 2}),
|
||||||
("res_x", {"num_layers": 2}),
|
("res_x", {"num_layers": 6}),
|
||||||
("compress_all_res", {"multiplier": 2}),
|
("compress_all_res", {"multiplier": 2}),
|
||||||
("res_x", {"num_layers": 2})
|
("res_x", {"num_layers": 2}),
|
||||||
])
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
|
||||||
|
|
||||||
if self.norm_layer is None:
|
if self.norm_layer is None:
|
||||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||||
@@ -371,7 +382,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
|||||||
if isinstance(self.latent_log_var, str):
|
if isinstance(self.latent_log_var, str):
|
||||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
self.encoder_spatial_padding_mode = PaddingModeType(
|
||||||
|
self.encoder_spatial_padding_mode
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
result = super().to_dict()
|
result = super().to_dict()
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ from typing import Dict
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||||
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_key = key[len(TRANSFORMER_PREFIX):]
|
new_key = key[len(TRANSFORMER_PREFIX) :]
|
||||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
elif key.startswith(VAE_DECODER_PREFIX):
|
elif key.startswith(VAE_DECODER_PREFIX):
|
||||||
new_key = key[len(VAE_DECODER_PREFIX):]
|
new_key = key[len(VAE_DECODER_PREFIX) :]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
if value.dtype != mx.float32:
|
if value.dtype != mx.float32:
|
||||||
value = value.astype(mx.float32)
|
value = value.astype(mx.float32)
|
||||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||||
new_key = key[len(VAE_ENCODER_PREFIX):]
|
new_key = key[len(VAE_ENCODER_PREFIX) :]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
new_key = None
|
new_key = None
|
||||||
|
|
||||||
if key.startswith(AUDIO_DECODER_PREFIX):
|
if key.startswith(AUDIO_DECODER_PREFIX):
|
||||||
new_key = key[len(AUDIO_DECODER_PREFIX):]
|
new_key = key[len(AUDIO_DECODER_PREFIX) :]
|
||||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||||
if "mean-of-means" in key:
|
if "mean-of-means" in key:
|
||||||
new_key = "per_channel_statistics.mean_of_means"
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
new_key = None
|
new_key = None
|
||||||
|
|
||||||
if key.startswith(AUDIO_ENCODER_PREFIX):
|
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||||
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
|
||||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||||
if "mean-of-means" in key:
|
if "mean-of-means" in key:
|
||||||
new_key = "per_channel_statistics.mean_of_means"
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
if not key.startswith(VOCODER_PREFIX):
|
if not key.startswith(VOCODER_PREFIX):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_key = key[len(VOCODER_PREFIX):]
|
new_key = key[len(VOCODER_PREFIX) :]
|
||||||
|
|
||||||
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
||||||
if "weight" in new_key and value.ndim == 3:
|
if "weight" in new_key and value.ndim == 3:
|
||||||
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
|||||||
# aggregate_embed weights (text_embedding_projection.*)
|
# aggregate_embed weights (text_embedding_projection.*)
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(TEXT_PROJ_PREFIX):
|
if key.startswith(TEXT_PROJ_PREFIX):
|
||||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
new_key = key[len(TEXT_PROJ_PREFIX) :]
|
||||||
extracted[new_key] = value
|
extracted[new_key] = value
|
||||||
|
|
||||||
# video_embeddings_connector
|
# video_embeddings_connector
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
||||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
|
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
|
||||||
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
||||||
extracted[new_key] = value
|
extracted[new_key] = value
|
||||||
|
|
||||||
# audio_embeddings_connector
|
# audio_embeddings_connector
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
|
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
|
||||||
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
||||||
extracted[new_key] = value
|
extracted[new_key] = value
|
||||||
|
|
||||||
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
|
|||||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
MONOLITHIC_PATTERN = re.compile(
|
||||||
|
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
|
||||||
|
)
|
||||||
|
|
||||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||||
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
|
UPSCALER_PATTERN = re.compile(
|
||||||
|
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_source(source: str, variant: str) -> Path:
|
def resolve_source(source: str, variant: str) -> Path:
|
||||||
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
|||||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||||
"""Infer VAE decoder config from weights."""
|
"""Infer VAE decoder config from weights."""
|
||||||
# Check for timestep conditioning keys
|
# Check for timestep conditioning keys
|
||||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
has_timestep = any(
|
||||||
|
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
|
||||||
|
)
|
||||||
|
|
||||||
# Count channel multipliers from up_blocks
|
# Count channel multipliers from up_blocks
|
||||||
max_block = -1
|
max_block = -1
|
||||||
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
config = infer_transformer_config(transformer_weights)
|
config = infer_transformer_config(transformer_weights)
|
||||||
save_config(config, output_path / "transformer")
|
save_config(config, output_path / "transformer")
|
||||||
t_params = sum(v.size for v in transformer_weights.values())
|
t_params = sum(v.size for v in transformer_weights.values())
|
||||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
print(
|
||||||
|
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. VAE Decoder
|
# 2. VAE Decoder
|
||||||
print(" [2/7] VAE Decoder...")
|
print(" [2/7] VAE Decoder...")
|
||||||
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
upscaler_files = [
|
upscaler_files = [
|
||||||
f.name for f in source_dir.iterdir()
|
f.name
|
||||||
|
for f in source_dir.iterdir()
|
||||||
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||||
if all_converted < total_keys:
|
if all_converted < total_keys:
|
||||||
known_prefixes = (
|
known_prefixes = (
|
||||||
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
|
TRANSFORMER_PREFIX,
|
||||||
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
|
VAE_DECODER_PREFIX,
|
||||||
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
|
VAE_ENCODER_PREFIX,
|
||||||
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
|
VAE_STATS_PREFIX,
|
||||||
|
AUDIO_DECODER_PREFIX,
|
||||||
|
AUDIO_ENCODER_PREFIX,
|
||||||
|
AUDIO_STATS_PREFIX,
|
||||||
|
VOCODER_PREFIX,
|
||||||
|
TEXT_PROJ_PREFIX,
|
||||||
|
VIDEO_CONNECTOR_PREFIX,
|
||||||
|
AUDIO_CONNECTOR_PREFIX,
|
||||||
)
|
)
|
||||||
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
|
skipped = [
|
||||||
|
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
|
||||||
|
]
|
||||||
if skipped:
|
if skipped:
|
||||||
print(f" Skipped {len(skipped)} keys:")
|
print(f" Skipped {len(skipped)} keys:")
|
||||||
for k in sorted(skipped)[:20]:
|
for k in sorted(skipped)[:20]:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,14 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
|
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||||
from mlx_video.models.ltx_2.config import (
|
from mlx_video.models.ltx_2.config import (
|
||||||
LTXModelConfig,
|
LTXModelConfig,
|
||||||
LTXModelType,
|
|
||||||
LTXRopeType,
|
LTXRopeType,
|
||||||
TransformerConfig,
|
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
|
||||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||||
from mlx_video.models.ltx_2.transformer import (
|
from mlx_video.models.ltx_2.transformer import (
|
||||||
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
|
|||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
timestep_emb, embedded_timestep = self.adaln(
|
||||||
|
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape to (batch, tokens, dim)
|
# Reshape to (batch, tokens, dim)
|
||||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
timestep_emb = mx.reshape(
|
||||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||||
|
)
|
||||||
|
embedded_timestep = mx.reshape(
|
||||||
|
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
)
|
||||||
|
|
||||||
return timestep_emb, embedded_timestep
|
return timestep_emb, embedded_timestep
|
||||||
|
|
||||||
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
|
|||||||
hidden_dtype: mx.Dtype = None,
|
hidden_dtype: mx.Dtype = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
timestep_emb, embedded_timestep = adaln(
|
||||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
)
|
||||||
|
timestep_emb = mx.reshape(
|
||||||
|
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||||
|
)
|
||||||
|
embedded_timestep = mx.reshape(
|
||||||
|
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
)
|
||||||
return timestep_emb, embedded_timestep
|
return timestep_emb, embedded_timestep
|
||||||
|
|
||||||
def _prepare_context(
|
def _prepare_context(
|
||||||
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
|
|||||||
# Convert boolean/int mask to float mask
|
# Convert boolean/int mask to float mask
|
||||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||||
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
mask = mx.reshape(
|
||||||
|
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||||
|
)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def _prepare_positional_embeddings(
|
def _prepare_positional_embeddings(
|
||||||
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
|
|||||||
|
|
||||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
x = self.patchify_proj(modality.latent)
|
x = self.patchify_proj(modality.latent)
|
||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
timestep, embedded_timestep = self._prepare_timestep(
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
)
|
||||||
|
context, attention_mask = self._prepare_context(
|
||||||
|
modality.context, x, modality.context_mask
|
||||||
|
)
|
||||||
|
attention_mask = self._prepare_attention_mask(
|
||||||
|
attention_mask, modality.latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||||
if modality.positional_embeddings is not None:
|
if modality.positional_embeddings is not None:
|
||||||
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
|
|||||||
prompt_timestep = None
|
prompt_timestep = None
|
||||||
prompt_embedded_timestep = None
|
prompt_embedded_timestep = None
|
||||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||||
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
|
prompt_timestep, prompt_embedded_timestep = (
|
||||||
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
|
self._prepare_timestep_with_adaln(
|
||||||
|
self.prompt_adaln,
|
||||||
|
modality.sigma,
|
||||||
|
x.shape[0],
|
||||||
|
hidden_dtype=x.dtype,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return TransformerArgs(
|
return TransformerArgs(
|
||||||
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare cross-attention timestep embeddings
|
# Prepare cross-attention timestep embeddings
|
||||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
cross_scale_shift_timestep, cross_gate_timestep = (
|
||||||
timestep=modality.timesteps,
|
self._prepare_cross_attention_timestep(
|
||||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
timestep=modality.timesteps,
|
||||||
batch_size=transformer_args.x.shape[0],
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
hidden_dtype=transformer_args.x.dtype,
|
batch_size=transformer_args.x.shape[0],
|
||||||
|
hidden_dtype=transformer_args.x.dtype,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return replace(
|
return replace(
|
||||||
@@ -254,11 +280,19 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
|
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||||
|
|
||||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
||||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||||
|
)
|
||||||
|
scale_shift_timestep = mx.reshape(
|
||||||
|
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
|
||||||
|
)
|
||||||
|
|
||||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
gate_timestep, _ = self.cross_gate_adaln(
|
||||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
|
||||||
|
)
|
||||||
|
gate_timestep = mx.reshape(
|
||||||
|
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
|
||||||
|
)
|
||||||
|
|
||||||
return scale_shift_timestep, gate_timestep
|
return scale_shift_timestep, gate_timestep
|
||||||
|
|
||||||
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
|
|||||||
self._init_video(config)
|
self._init_video(config)
|
||||||
|
|
||||||
if config.model_type.is_audio_enabled():
|
if config.model_type.is_audio_enabled():
|
||||||
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
self.audio_positional_embedding_max_pos = (
|
||||||
|
config.audio_positional_embedding_max_pos
|
||||||
|
)
|
||||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||||
self.audio_inner_dim = config.audio_inner_dim
|
self.audio_inner_dim = config.audio_inner_dim
|
||||||
self._init_audio(config)
|
self._init_audio(config)
|
||||||
|
|
||||||
# Initialize cross-modal components
|
# Initialize cross-modal components
|
||||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
if (
|
||||||
|
config.model_type.is_video_enabled()
|
||||||
|
and config.model_type.is_audio_enabled()
|
||||||
|
):
|
||||||
cross_pe_max_pos = max(
|
cross_pe_max_pos = max(
|
||||||
config.positional_embedding_max_pos[0],
|
config.positional_embedding_max_pos[0],
|
||||||
config.audio_positional_embedding_max_pos[0],
|
config.audio_positional_embedding_max_pos[0],
|
||||||
)
|
)
|
||||||
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
self.av_ca_timestep_scale_multiplier = (
|
||||||
|
config.av_ca_timestep_scale_multiplier
|
||||||
|
)
|
||||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||||
self._init_audio_video(config)
|
self._init_audio_video(config)
|
||||||
|
|
||||||
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
|
|||||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||||
|
|
||||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim, embedding_coefficient=adaln_coefficient
|
||||||
|
)
|
||||||
|
|
||||||
if config.has_prompt_adaln:
|
if config.has_prompt_adaln:
|
||||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
|
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim, embedding_coefficient=2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=config.caption_channels,
|
in_features=config.caption_channels,
|
||||||
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
|
|||||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||||
|
|
||||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
self.audio_patchify_proj = nn.Linear(
|
||||||
|
config.audio_in_channels, self.audio_inner_dim, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
|
||||||
|
)
|
||||||
|
|
||||||
if config.has_prompt_adaln:
|
if config.has_prompt_adaln:
|
||||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
|
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim, embedding_coefficient=2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=config.audio_caption_channels,
|
in_features=config.audio_caption_channels,
|
||||||
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
|
|||||||
|
|
||||||
# Output components
|
# Output components
|
||||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||||
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
self.audio_norm_out = nn.LayerNorm(
|
||||||
|
self.audio_inner_dim, eps=config.norm_eps, affine=False
|
||||||
|
)
|
||||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||||
|
|
||||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||||
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
|
|||||||
embedding_coefficient=1,
|
embedding_coefficient=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
def _init_preprocessors(
|
||||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
|
||||||
|
) -> None:
|
||||||
|
if (
|
||||||
|
config.model_type.is_video_enabled()
|
||||||
|
and config.model_type.is_audio_enabled()
|
||||||
|
):
|
||||||
# Multi-modal preprocessors
|
# Multi-modal preprocessors
|
||||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||||
patchify_proj=self.patchify_proj,
|
patchify_proj=self.patchify_proj,
|
||||||
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
|
|||||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||||
for idx, block in self.transformer_blocks.items():
|
for idx, block in self.transformer_blocks.items():
|
||||||
video, audio = block(
|
video, audio = block(
|
||||||
video=video, audio=audio,
|
video=video,
|
||||||
|
audio=audio,
|
||||||
skip_video_self_attn=(idx in stg_v_set),
|
skip_video_self_attn=(idx in stg_v_set),
|
||||||
skip_audio_self_attn=(idx in stg_a_set),
|
skip_audio_self_attn=(idx in stg_a_set),
|
||||||
skip_cross_modal=skip_cross_modal,
|
skip_cross_modal=skip_cross_modal,
|
||||||
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
|
|||||||
raise ValueError("Audio is not enabled for this model")
|
raise ValueError("Audio is not enabled for this model")
|
||||||
|
|
||||||
# Preprocess arguments
|
# Preprocess arguments
|
||||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
video_args = (
|
||||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||||
|
)
|
||||||
|
audio_args = (
|
||||||
|
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
video_out, audio_out = self._process_transformer_blocks(
|
video_out, audio_out = self._process_transformer_blocks(
|
||||||
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
|
|||||||
|
|
||||||
if not key.startswith("model.diffusion_model."):
|
if not key.startswith("model.diffusion_model."):
|
||||||
continue
|
continue
|
||||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
if (
|
||||||
|
"audio_embeddings_connector" in key
|
||||||
|
or "video_embeddings_connector" in key
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove 'model.diffusion_model.' prefix
|
# Remove 'model.diffusion_model.' prefix
|
||||||
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
|
|||||||
for weight_file in model_path.glob("*.safetensors"):
|
for weight_file in model_path.glob("*.safetensors"):
|
||||||
weights.update(mx.load(str(weight_file)))
|
weights.update(mx.load(str(weight_file)))
|
||||||
|
|
||||||
|
|
||||||
sanitized = model.sanitize(weights)
|
sanitized = model.sanitize(weights)
|
||||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
sanitized = {
|
||||||
|
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
|
||||||
|
for k, v in sanitized.items()
|
||||||
|
}
|
||||||
|
|
||||||
model.load_weights(list(sanitized.items()), strict=strict)
|
model.load_weights(list(sanitized.items()), strict=strict)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
|
|||||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||||
|
|
||||||
vx, ax = self.velocity_model(
|
vx, ax = self.velocity_model(
|
||||||
video, audio,
|
video,
|
||||||
|
audio,
|
||||||
stg_video_blocks=stg_video_blocks,
|
stg_video_blocks=stg_video_blocks,
|
||||||
stg_audio_blocks=stg_audio_blocks,
|
stg_audio_blocks=stg_audio_blocks,
|
||||||
skip_cross_modal=skip_cross_modal,
|
skip_cross_modal=skip_cross_modal,
|
||||||
)
|
)
|
||||||
|
|
||||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
denoised_video = (
|
||||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||||
|
)
|
||||||
|
denoised_audio = (
|
||||||
|
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
return denoised_video, denoised_audio
|
return denoised_video, denoised_audio
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
def bilateral_filter(
|
||||||
|
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
|
||||||
|
) -> np.ndarray:
|
||||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback to simple Gaussian blur if cv2 not available
|
# Fallback to simple Gaussian blur if cv2 not available
|
||||||
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Simple box blur fallback
|
# Simple box blur fallback
|
||||||
from scipy.ndimage import uniform_filter
|
from scipy.ndimage import uniform_filter
|
||||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
|
||||||
|
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
def unsharp_mask(
|
||||||
|
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
"""Apply unsharp masking to enhance edges after blur.
|
"""Apply unsharp masking to enhance edges after blur.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||||
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||||
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||||
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
|
|||||||
if method == "bilateral":
|
if method == "bilateral":
|
||||||
d = max(3, int(5 * strength))
|
d = max(3, int(5 * strength))
|
||||||
sigma = 50 + 50 * strength
|
sigma = 50 + 50 * strength
|
||||||
processed = np.stack([
|
processed = np.stack(
|
||||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
[
|
||||||
for frame in video
|
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||||
])
|
for frame in video
|
||||||
|
]
|
||||||
|
)
|
||||||
elif method == "gaussian":
|
elif method == "gaussian":
|
||||||
kernel_size = max(3, int(3 + 4 * strength))
|
kernel_size = max(3, int(3 + 4 * strength))
|
||||||
if kernel_size % 2 == 0:
|
if kernel_size % 2 == 0:
|
||||||
kernel_size += 1
|
kernel_size += 1
|
||||||
processed = np.stack([
|
processed = np.stack(
|
||||||
gaussian_blur(frame, kernel_size=kernel_size)
|
[gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
|
||||||
for frame in video
|
)
|
||||||
])
|
|
||||||
elif method == "frequency":
|
elif method == "frequency":
|
||||||
processed = np.stack([
|
processed = np.stack(
|
||||||
remove_grid_frequency(frame, grid_size=8)
|
[remove_grid_frequency(frame, grid_size=8) for frame in video]
|
||||||
for frame in video
|
)
|
||||||
])
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown method: {method}")
|
raise ValueError(f"Unknown method: {method}")
|
||||||
|
|
||||||
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
|||||||
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
|
|||||||
"""
|
"""
|
||||||
# x: (..., dim) where dim is even
|
# x: (..., dim) where dim is even
|
||||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||||
return mx.reshape(rotated, x.shape)
|
return mx.reshape(rotated, x.shape)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_1d(
|
def apply_rotary_emb_1d(
|
||||||
q: mx.array,
|
q: mx.array,
|
||||||
k: mx.array,
|
k: mx.array,
|
||||||
@@ -228,9 +228,9 @@ def get_fractional_positions(
|
|||||||
Fractional positions in range [-1, 1] after scaling
|
Fractional positions in range [-1, 1] after scaling
|
||||||
"""
|
"""
|
||||||
n_pos_dims = indices_grid.shape[1]
|
n_pos_dims = indices_grid.shape[1]
|
||||||
assert n_pos_dims == len(max_pos), (
|
assert n_pos_dims == len(
|
||||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
max_pos
|
||||||
)
|
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||||
|
|
||||||
# Divide each dimension by its max position
|
# Divide each dimension by its max position
|
||||||
fractional_positions = []
|
fractional_positions = []
|
||||||
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
|
|||||||
if max_pos is None:
|
if max_pos is None:
|
||||||
max_pos = [20, 2048, 2048]
|
max_pos = [20, 2048, 2048]
|
||||||
|
|
||||||
|
|
||||||
if double_precision:
|
if double_precision:
|
||||||
return _precompute_freqs_cis_double_precision(
|
return _precompute_freqs_cis_double_precision(
|
||||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
indices_grid,
|
||||||
num_attention_heads, rope_type
|
dim,
|
||||||
|
theta,
|
||||||
|
max_pos,
|
||||||
|
use_middle_indices_grid,
|
||||||
|
num_attention_heads,
|
||||||
|
rope_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep positions in float32 for RoPE computation.
|
# Keep positions in float32 for RoPE computation.
|
||||||
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
# Compute frequencies: outer product
|
# Compute frequencies: outer product
|
||||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
|
||||||
|
freq_indices, (1, 1, 1, -1)
|
||||||
|
)
|
||||||
# freqs: (B, T, n_dims, num_indices)
|
# freqs: (B, T, n_dims, num_indices)
|
||||||
|
|
||||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||||
|
|||||||
@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def phi(j: int, neg_h: float) -> float:
|
def phi(j: int, neg_h: float) -> float:
|
||||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||||
|
|
||||||
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
|
|||||||
Returns:
|
Returns:
|
||||||
(a21, b1, b2): RK coefficients.
|
(a21, b1, b2): RK coefficients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_phi(j: int, neg_h: float) -> float:
|
def get_phi(j: int, neg_h: float) -> float:
|
||||||
cache_key = (j, neg_h)
|
cache_key = (j, neg_h)
|
||||||
if cache_key in phi_cache:
|
if cache_key in phi_cache:
|
||||||
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
|
|||||||
# SDE noise injection
|
# SDE noise injection
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def get_sde_coeff(
|
def get_sde_coeff(
|
||||||
sigma_next: float,
|
sigma_next: float,
|
||||||
) -> tuple[float, float, float]:
|
) -> tuple[float, float, float]:
|
||||||
@@ -139,7 +140,9 @@ def sde_noise_step(
|
|||||||
denoised_next = sample_f32 - sigma * eps_next
|
denoised_next = sample_f32 - sigma * eps_next
|
||||||
|
|
||||||
# Mix deterministic and stochastic components
|
# Mix deterministic and stochastic components
|
||||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
x_noised = (
|
||||||
|
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||||
|
)
|
||||||
|
|
||||||
return x_noised
|
return x_noised
|
||||||
|
|
||||||
@@ -148,6 +151,7 @@ def sde_noise_step(
|
|||||||
# Noise generation
|
# Noise generation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm, apply_quantization
|
|
||||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
|
||||||
|
|
||||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
|
||||||
from mlx_vlm.models.gemma3.config import TextConfig
|
from mlx_vlm.models.gemma3.config import TextConfig
|
||||||
|
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TextColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
)
|
||||||
|
|
||||||
|
from mlx_video.utils import apply_quantization, rms_norm
|
||||||
|
|
||||||
# Path to system prompts
|
# Path to system prompts
|
||||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||||
@@ -36,7 +36,6 @@ def _load_system_prompt(prompt_name: str) -> str:
|
|||||||
|
|
||||||
class LanguageModel(nn.Module):
|
class LanguageModel(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, config: TextConfig):
|
def __init__(self, config: TextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Create config matching LTX-2 text encoder requirements
|
# Create config matching LTX-2 text encoder requirements
|
||||||
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
|
|||||||
|
|
||||||
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||||
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
min_val = (
|
||||||
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
mx.full(combined.shape, min_val, dtype=dtype))
|
)
|
||||||
|
mask = mx.where(
|
||||||
|
combined,
|
||||||
|
mx.zeros(combined.shape, dtype=dtype),
|
||||||
|
mx.full(combined.shape, min_val, dtype=dtype),
|
||||||
|
)
|
||||||
return mask[:, None, :, :]
|
return mask[:, None, :, :]
|
||||||
else:
|
else:
|
||||||
# No padding mask, just causal
|
# No padding mask, just causal
|
||||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
min_val = (
|
||||||
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
)
|
||||||
|
mask = mx.where(
|
||||||
|
causal_mask,
|
||||||
|
mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||||
|
mx.full((seq_len, seq_len), min_val, dtype=dtype),
|
||||||
|
)
|
||||||
return mask[None, None, :, :] # (1, 1, seq, seq)
|
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
|
|||||||
batch_size, seq_len = inputs.shape
|
batch_size, seq_len = inputs.shape
|
||||||
|
|
||||||
# Get embeddings
|
# Get embeddings
|
||||||
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
h = (
|
||||||
|
input_embeddings
|
||||||
|
if input_embeddings is not None
|
||||||
|
else self.model.embed_tokens(inputs)
|
||||||
|
)
|
||||||
|
|
||||||
# Apply Gemma scaling
|
# Apply Gemma scaling
|
||||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||||
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.model.layers)
|
cache = [None] * len(self.model.layers)
|
||||||
|
|
||||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
full_causal_mask = self._create_causal_mask_with_padding(
|
||||||
|
seq_len, attention_mask, h.dtype
|
||||||
|
)
|
||||||
|
|
||||||
sliding_mask = full_causal_mask
|
sliding_mask = full_causal_mask
|
||||||
|
|
||||||
|
|
||||||
num_layers = len(self.model.layers)
|
num_layers = len(self.model.layers)
|
||||||
for i, layer in enumerate(self.model.layers):
|
for i, layer in enumerate(self.model.layers):
|
||||||
is_global = (
|
is_global = (
|
||||||
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
|
|||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||||
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
|
||||||
else:
|
else:
|
||||||
sanitized[key[len(prefix):]] = value
|
sanitized[key[len(prefix) :]] = value
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
|
|||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||||
|
|
||||||
caches = []
|
caches = []
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
if (
|
if (
|
||||||
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: str):
|
def from_pretrained(cls, model_path: str):
|
||||||
import json
|
import json
|
||||||
|
|
||||||
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||||
config_file = Path(model_path) / "config.json"
|
config_file = Path(model_path) / "config.json"
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
|
|||||||
with open(config_file, "r") as f:
|
with open(config_file, "r") as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
|
|
||||||
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
language_model = cls(
|
||||||
|
config=TextConfig.from_dict(config_dict["text_config"])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Config file not found at {model_path}")
|
raise ValueError(f"Config file not found at {model_path}")
|
||||||
|
|
||||||
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
|
|||||||
for i, wf in enumerate(weight_files):
|
for i, wf in enumerate(weight_files):
|
||||||
weights.update(mx.load(str(wf)))
|
weights.update(mx.load(str(wf)))
|
||||||
|
|
||||||
|
|
||||||
if hasattr(language_model, "sanitize"):
|
if hasattr(language_model, "sanitize"):
|
||||||
weights = language_model.sanitize(weights=weights)
|
weights = language_model.sanitize(weights=weights)
|
||||||
|
|
||||||
|
apply_quantization(
|
||||||
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
model=language_model, weights=weights, quantization=quantization
|
||||||
|
)
|
||||||
|
|
||||||
language_model.load_weights(list(weights.items()), strict=False)
|
language_model.load_weights(list(weights.items()), strict=False)
|
||||||
|
|
||||||
return language_model
|
return language_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectorAttention(nn.Module):
|
class ConnectorAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
# Reshape to (B, H, T, D) for SPLIT RoPE
|
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
q = mx.reshape(
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
q, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
k = mx.reshape(
|
||||||
|
k, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
v = mx.reshape(
|
||||||
|
v, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||||
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
|
|||||||
|
|
||||||
class ConnectorTransformerBlock(nn.Module):
|
class ConnectorTransformerBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
num_heads: int = 30,
|
||||||
|
head_dim: int = 128,
|
||||||
|
has_gate_logits: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
self.attn1 = ConnectorAttention(
|
||||||
|
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||||
|
)
|
||||||
self.ff = ConnectorFeedForward(dim)
|
self.ff = ConnectorFeedForward(dim)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
||||||
|
|
||||||
self.transformer_1d_blocks = {
|
self.transformer_1d_blocks = {
|
||||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
i: ConnectorTransformerBlock(
|
||||||
|
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||||
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
if num_learnable_registers > 0:
|
if num_learnable_registers > 0:
|
||||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||||
|
|
||||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
def _precompute_freqs_cis(
|
||||||
|
self, seq_len: int, dtype: mx.Dtype
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||||
|
|
||||||
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||||
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
# Binary mask: 1 for valid tokens, 0 for padded
|
# Binary mask: 1 for valid tokens, 0 for padded
|
||||||
# attention_mask is additive: 0 for valid, large negative for padded
|
# attention_mask is additive: 0 for valid, large negative for padded
|
||||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
|
||||||
|
mx.int32
|
||||||
|
) # (batch, seq)
|
||||||
|
|
||||||
# Tile registers to match sequence length, cast to hidden_states dtype
|
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||||
num_tiles = seq_len // self.num_learnable_registers
|
num_tiles = seq_len // self.num_learnable_registers
|
||||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
|
||||||
|
dtype
|
||||||
|
) # (seq_len, dim)
|
||||||
|
|
||||||
# Process each batch item (PyTorch uses advanced indexing)
|
# Process each batch item (PyTorch uses advanced indexing)
|
||||||
result_list = []
|
result_list = []
|
||||||
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
# Extract valid tokens (where mask is 1)
|
# Extract valid tokens (where mask is 1)
|
||||||
# Since we have left-padded input, valid tokens are at the end
|
# Since we have left-padded input, valid tokens are at the end
|
||||||
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
|
||||||
|
|
||||||
# Pad with zeros on the right to get back to seq_len
|
# Pad with zeros on the right to get back to seq_len
|
||||||
pad_length = seq_len - num_valid
|
pad_length = seq_len - num_valid
|
||||||
if pad_length > 0:
|
if pad_length > 0:
|
||||||
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
adjusted = mx.concatenate(
|
||||||
|
[valid_tokens, padding], axis=0
|
||||||
|
) # (seq_len, dim)
|
||||||
else:
|
else:
|
||||||
adjusted = valid_tokens
|
adjusted = valid_tokens
|
||||||
|
|
||||||
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||||
flipped_mask = mx.concatenate([
|
flipped_mask = mx.concatenate(
|
||||||
mx.ones((num_valid,), dtype=mx.int32),
|
[
|
||||||
mx.zeros((pad_length,), dtype=mx.int32)
|
mx.ones((num_valid,), dtype=mx.int32),
|
||||||
], axis=0) # (seq,)
|
mx.zeros((pad_length,), dtype=mx.int32),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
) # (seq,)
|
||||||
|
|
||||||
# Combine: valid tokens at front, registers at back
|
# Combine: valid tokens at front, registers at back
|
||||||
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
combined = (
|
||||||
|
flipped_mask_expanded * adjusted
|
||||||
|
+ (1 - flipped_mask_expanded) * registers
|
||||||
|
)
|
||||||
result_list.append(combined)
|
result_list.append(combined)
|
||||||
|
|
||||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||||
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
# Process through transformer blocks
|
# Process through transformer blocks
|
||||||
for i in range(len(self.transformer_1d_blocks)):
|
for i in range(len(self.transformer_1d_blocks)):
|
||||||
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
hidden_states = self.transformer_1d_blocks[i](
|
||||||
|
hidden_states, attention_mask, freqs_cis
|
||||||
|
)
|
||||||
|
|
||||||
# Final RMS norm
|
# Final RMS norm
|
||||||
hidden_states = rms_norm(hidden_states)
|
hidden_states = rms_norm(hidden_states)
|
||||||
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
return hidden_states, attention_mask
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def norm_and_concat_hidden_states(
|
def norm_and_concat_hidden_states(
|
||||||
hidden_states: List[mx.array],
|
hidden_states: List[mx.array],
|
||||||
attention_mask: mx.array,
|
attention_mask: mx.array,
|
||||||
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
|
|||||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||||
|
|
||||||
# Compute masked min/max per layer
|
# Compute masked min/max per layer
|
||||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
x_for_min = mx.where(
|
||||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
|
||||||
|
)
|
||||||
|
x_for_max = mx.where(
|
||||||
|
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
|
||||||
|
)
|
||||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||||
range_val = x_max - x_min
|
range_val = x_max - x_min
|
||||||
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
|
|||||||
dtype = encoded_text.dtype
|
dtype = encoded_text.dtype
|
||||||
|
|
||||||
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
||||||
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
|
variance = mx.mean(
|
||||||
|
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
|
||||||
|
) # (B, T, 1, L)
|
||||||
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
||||||
normed = normed.astype(dtype)
|
normed = normed.astype(dtype)
|
||||||
|
|
||||||
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
|
|||||||
class GemmaFeaturesExtractor(nn.Module):
|
class GemmaFeaturesExtractor(nn.Module):
|
||||||
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
||||||
|
|
||||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
|
def __init__(
|
||||||
|
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
||||||
|
|
||||||
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
|
|||||||
|
|
||||||
if mode == "video":
|
if mode == "video":
|
||||||
target_dim = self.video_aggregate_embed.weight.shape[0]
|
target_dim = self.video_aggregate_embed.weight.shape[0]
|
||||||
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
return self.video_aggregate_embed(
|
||||||
|
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
||||||
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
return self.audio_aggregate_embed(
|
||||||
|
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AudioEmbeddingsConnector(nn.Module):
|
class AudioEmbeddingsConnector(nn.Module):
|
||||||
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
video_output_dim = 4096
|
video_output_dim = 4096
|
||||||
audio_output_dim = 2048
|
audio_output_dim = 2048
|
||||||
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
||||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||||
video_output_dim=video_output_dim,
|
video_output_dim=video_output_dim,
|
||||||
audio_output_dim=audio_output_dim,
|
audio_output_dim=audio_output_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
@@ -728,33 +785,53 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
dim=video_output_dim,
|
||||||
num_layers=8, num_learnable_registers=128,
|
num_heads=32,
|
||||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
head_dim=128,
|
||||||
|
num_layers=8,
|
||||||
|
num_learnable_registers=128,
|
||||||
|
positional_embedding_max_pos=[4096],
|
||||||
|
has_gate_logits=True,
|
||||||
)
|
)
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
dim=audio_output_dim,
|
||||||
num_layers=8, num_learnable_registers=128,
|
num_heads=32,
|
||||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
head_dim=64,
|
||||||
|
num_layers=8,
|
||||||
|
num_learnable_registers=128,
|
||||||
|
positional_embedding_max_pos=[4096],
|
||||||
|
has_gate_logits=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||||
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
|
self.feature_extractor = GemmaFeaturesExtractor(
|
||||||
|
feature_input_dim, hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
dim=hidden_dim,
|
||||||
num_layers=2, num_learnable_registers=128,
|
num_heads=30,
|
||||||
|
head_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[1],
|
positional_embedding_max_pos=[1],
|
||||||
)
|
)
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
dim=hidden_dim,
|
||||||
num_layers=2, num_learnable_registers=128,
|
num_heads=30,
|
||||||
|
head_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_learnable_registers=128,
|
||||||
positional_embedding_max_pos=[1],
|
positional_embedding_max_pos=[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|
||||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
def load(
|
||||||
|
self,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
|
||||||
|
):
|
||||||
|
|
||||||
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||||
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||||
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
if transformer_weights:
|
if transformer_weights:
|
||||||
self._load_feature_extractors(transformer_weights, is_reformatted)
|
self._load_feature_extractors(transformer_weights, is_reformatted)
|
||||||
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
|
self._load_connector(
|
||||||
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
|
"video_embeddings_connector", transformer_weights, is_reformatted
|
||||||
|
)
|
||||||
|
self._load_connector(
|
||||||
|
"audio_embeddings_connector", transformer_weights, is_reformatted
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("WARNING: No transformer weights found for text projection connectors. "
|
print(
|
||||||
"Text conditioning will use uninitialized weights!")
|
"WARNING: No transformer weights found for text projection connectors. "
|
||||||
|
"Text conditioning will use uninitialized weights!"
|
||||||
|
)
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
tokenizer_path = model_path / "tokenizer"
|
tokenizer_path = model_path / "tokenizer"
|
||||||
if tokenizer_path.exists():
|
if tokenizer_path.exists():
|
||||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(
|
||||||
|
str(tokenizer_path), trust_remote_code=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(
|
||||||
|
text_encoder_path, trust_remote_code=True
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(
|
||||||
|
"google/gemma-3-12b-it", trust_remote_code=True
|
||||||
|
)
|
||||||
# Set left padding to match official LTX-2 text encoder
|
# Set left padding to match official LTX-2 text encoder
|
||||||
self.processor.padding_side = "left"
|
self.processor.padding_side = "left"
|
||||||
|
|
||||||
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
submodule.bias = weights[b_key]
|
submodule.bias = weights[b_key]
|
||||||
else:
|
else:
|
||||||
# LTX-2: single aggregate_embed
|
# LTX-2: single aggregate_embed
|
||||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
agg_key = (
|
||||||
|
"aggregate_embed.weight"
|
||||||
|
if is_reformatted
|
||||||
|
else "text_embedding_projection.aggregate_embed.weight"
|
||||||
|
)
|
||||||
if agg_key in weights:
|
if agg_key in weights:
|
||||||
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
||||||
|
|
||||||
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
prefix = f"{name}."
|
prefix = f"{name}."
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
connector_weights[key[len(prefix):]] = value
|
connector_weights[key[len(prefix) :]] = value
|
||||||
else:
|
else:
|
||||||
mono_prefix = f"model.diffusion_model.{name}."
|
mono_prefix = f"model.diffusion_model.{name}."
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
if key.startswith(mono_prefix):
|
if key.startswith(mono_prefix):
|
||||||
connector_weights[key[len(mono_prefix):]] = value
|
connector_weights[key[len(mono_prefix) :]] = value
|
||||||
|
|
||||||
if not connector_weights:
|
if not connector_weights:
|
||||||
return
|
return
|
||||||
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
_, all_hidden_states = self.language_model(
|
||||||
|
inputs=input_ids,
|
||||||
|
input_embeddings=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
if self.has_prompt_adaln:
|
if self.has_prompt_adaln:
|
||||||
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
||||||
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
|
video_features = self.feature_extractor_v2(
|
||||||
|
all_hidden_states, attention_mask, mode="video"
|
||||||
|
)
|
||||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = (
|
||||||
|
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
)
|
||||||
|
|
||||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
video_embeddings, _ = self.video_embeddings_connector(
|
||||||
|
video_features, additive_mask
|
||||||
|
)
|
||||||
|
|
||||||
if return_audio_embeddings:
|
if return_audio_embeddings:
|
||||||
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
|
audio_features = self.feature_extractor_v2(
|
||||||
|
all_hidden_states, attention_mask, mode="audio"
|
||||||
|
)
|
||||||
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
||||||
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
|
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||||
|
audio_features, audio_mask
|
||||||
|
)
|
||||||
return video_embeddings, audio_embeddings
|
return video_embeddings, audio_embeddings
|
||||||
else:
|
else:
|
||||||
return video_embeddings, attention_mask
|
return video_embeddings, attention_mask
|
||||||
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
video_features = self.feature_extractor(concat_hidden)
|
video_features = self.feature_extractor(concat_hidden)
|
||||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = (
|
||||||
|
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
)
|
||||||
|
|
||||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
video_embeddings, _ = self.video_embeddings_connector(
|
||||||
|
video_features, additive_mask
|
||||||
|
)
|
||||||
|
|
||||||
if return_audio_embeddings:
|
if return_audio_embeddings:
|
||||||
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
|
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||||
|
video_features, additive_mask
|
||||||
|
)
|
||||||
return video_embeddings, audio_embeddings
|
return video_embeddings, audio_embeddings
|
||||||
else:
|
else:
|
||||||
return video_embeddings, attention_mask
|
return video_embeddings, attention_mask
|
||||||
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
# Remove leading/trailing whitespace
|
# Remove leading/trailing whitespace
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
# Remove any leading punctuation
|
# Remove any leading punctuation
|
||||||
response = re.sub(r'^[^\w\s]+', '', response)
|
response = re.sub(r"^[^\w\s]+", "", response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _apply_chat_template(
|
def _apply_chat_template(
|
||||||
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# Handle multimodal content (image + text)
|
# Handle multimodal content (image + text)
|
||||||
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||||
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
formatted += (
|
||||||
|
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||||
|
)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||||
# Add generation prompt
|
# Add generation prompt
|
||||||
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
from mlx_lm import stream_generate
|
from mlx_lm import stream_generate
|
||||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
logging.warning(
|
||||||
|
"mlx-lm not available for prompt enhancement. Using original prompt."
|
||||||
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
|
|
||||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
|
sampler = make_sampler(
|
||||||
|
kwargs.get("temperature", 0.7),
|
||||||
|
kwargs.get("top_p", 1.0),
|
||||||
|
top_k=kwargs.get("top_k", -1),
|
||||||
|
)
|
||||||
logits_processors = make_logits_processors(
|
logits_processors = make_logits_processors(
|
||||||
kwargs.get("logit_bias", None),
|
kwargs.get("logit_bias", None),
|
||||||
kwargs.get("repetition_penalty", 1.3),
|
kwargs.get("repetition_penalty", 1.3),
|
||||||
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Decode only the new tokens
|
# Decode only the new tokens
|
||||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
enhanced_prompt = self.processor.decode(
|
||||||
|
generated_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||||
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||||
|
|
||||||
return enhanced_prompt
|
return enhanced_prompt
|
||||||
|
|
||||||
|
|
||||||
def enhance_i2v(
|
def enhance_i2v(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
|||||||
encoder = LTX2TextEncoder()
|
encoder = LTX2TextEncoder()
|
||||||
encoder.load(model_path=model_path)
|
encoder.load(model_path=model_path)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
|
||||||
from mlx_video.models.ltx_2.attention import Attention
|
from mlx_video.models.ltx_2.attention import Attention
|
||||||
|
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm
|
||||||
|
|
||||||
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||||
timestep_reshaped = mx.reshape(
|
timestep_reshaped = mx.reshape(
|
||||||
timestep,
|
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
|
||||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the relevant indices
|
# Extract the relevant indices
|
||||||
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Squeeze the sequence dimension if it's 1
|
# Squeeze the sequence dimension if it's 1
|
||||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
scale_shift_squeezed = tuple(
|
||||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
|
||||||
|
)
|
||||||
|
gate_squeezed = tuple(
|
||||||
|
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
|
||||||
|
)
|
||||||
|
|
||||||
return (*scale_shift_squeezed, *gate_squeezed)
|
return (*scale_shift_squeezed, *gate_squeezed)
|
||||||
|
|
||||||
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
# Check which modalities to run
|
# Check which modalities to run
|
||||||
run_vx = video is not None and video.enabled and vx.size > 0
|
run_vx = video is not None and video.enabled and vx.size > 0
|
||||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
run_a2v = (
|
||||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
run_vx
|
||||||
|
and (audio is not None and audio.enabled and ax.size > 0)
|
||||||
|
and not skip_cross_modal
|
||||||
|
)
|
||||||
|
run_v2a = (
|
||||||
|
run_ax
|
||||||
|
and (video is not None and video.enabled and vx.size > 0)
|
||||||
|
and not skip_cross_modal
|
||||||
|
)
|
||||||
|
|
||||||
# Process video self-attention and cross-attention with text
|
# Process video self-attention and cross-attention with text
|
||||||
if run_vx:
|
if run_vx:
|
||||||
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
vx = (
|
||||||
|
vx
|
||||||
|
+ self.attn1(
|
||||||
|
norm_vx,
|
||||||
|
pe=video.positional_embeddings,
|
||||||
|
skip_attention=skip_video_self_attn,
|
||||||
|
)
|
||||||
|
* vgate_msa
|
||||||
|
)
|
||||||
|
|
||||||
# Cross-attention with text context
|
# Cross-attention with text context
|
||||||
if self.has_prompt_adaln:
|
if self.has_prompt_adaln:
|
||||||
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||||
)
|
)
|
||||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
self.prompt_scale_shift_table,
|
||||||
|
vx.shape[0],
|
||||||
|
video.prompt_timesteps,
|
||||||
|
slice(0, 2),
|
||||||
)
|
)
|
||||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
encoder_hidden_states = (
|
||||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||||
|
)
|
||||||
|
vx = (
|
||||||
|
vx
|
||||||
|
+ self.attn2(
|
||||||
|
attn_input,
|
||||||
|
context=encoder_hidden_states,
|
||||||
|
mask=video.context_mask,
|
||||||
|
)
|
||||||
|
* vgate_q
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vx = vx + self.attn2(
|
vx = vx + self.attn2(
|
||||||
rms_norm(vx, eps=self.norm_eps),
|
rms_norm(vx, eps=self.norm_eps),
|
||||||
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
ax = (
|
||||||
|
ax
|
||||||
|
+ self.audio_attn1(
|
||||||
|
norm_ax,
|
||||||
|
pe=audio.positional_embeddings,
|
||||||
|
skip_attention=skip_audio_self_attn,
|
||||||
|
)
|
||||||
|
* agate_msa
|
||||||
|
)
|
||||||
|
|
||||||
# Cross-attention with text context
|
# Cross-attention with text context
|
||||||
if self.has_prompt_adaln:
|
if self.has_prompt_adaln:
|
||||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
self.audio_scale_shift_table,
|
||||||
|
ax.shape[0],
|
||||||
|
audio.timesteps,
|
||||||
|
slice(6, 9),
|
||||||
)
|
)
|
||||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
self.audio_prompt_scale_shift_table,
|
||||||
|
ax.shape[0],
|
||||||
|
audio.prompt_timesteps,
|
||||||
|
slice(0, 2),
|
||||||
|
)
|
||||||
|
attn_input_a = (
|
||||||
|
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||||
|
)
|
||||||
|
encoder_hidden_states_a = (
|
||||||
|
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||||
|
)
|
||||||
|
ax = (
|
||||||
|
ax
|
||||||
|
+ self.audio_attn2(
|
||||||
|
attn_input_a,
|
||||||
|
context=encoder_hidden_states_a,
|
||||||
|
mask=audio.context_mask,
|
||||||
|
)
|
||||||
|
* agate_q
|
||||||
)
|
)
|
||||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
|
||||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
|
||||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
|
||||||
else:
|
else:
|
||||||
ax = ax + self.audio_attn2(
|
ax = ax + self.audio_attn2(
|
||||||
rms_norm(ax, eps=self.norm_eps),
|
rms_norm(ax, eps=self.norm_eps),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
|
|||||||
self.groups = groups
|
self.groups = groups
|
||||||
|
|
||||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
scale = (
|
||||||
|
1.0
|
||||||
|
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||||
|
)
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
shape=(
|
||||||
|
out_channels,
|
||||||
|
kernel_size[0],
|
||||||
|
kernel_size[1],
|
||||||
|
kernel_size[2],
|
||||||
|
in_channels,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
|
|||||||
n, d, h, w, c = x.shape
|
n, d, h, w, c = x.shape
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
|
|
||||||
|
|
||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
|
|
||||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||||
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
|
|||||||
self.den = den
|
self.den = den
|
||||||
|
|
||||||
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
||||||
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
|
self.conv = nn.Conv2d(
|
||||||
|
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
self.pixel_shuffle = PixelShuffle2D(num, num)
|
self.pixel_shuffle = PixelShuffle2D(num, num)
|
||||||
self.blur_down = BlurDownsample(stride=den)
|
self.blur_down = BlurDownsample(stride=den)
|
||||||
|
|
||||||
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
|
|||||||
|
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.pixel_shuffle(x) # H*num, W*num
|
x = self.pixel_shuffle(x) # H*num, W*num
|
||||||
x = self.blur_down(x) # H*num/den, W*num/den
|
x = self.blur_down(x) # H*num/den, W*num/den
|
||||||
|
|
||||||
_, h_out, w_out, _ = x.shape
|
_, h_out, w_out, _ = x.shape
|
||||||
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
||||||
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
|
|||||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
|
||||||
frac = Fraction(scale).limit_denominator(10)
|
frac = Fraction(scale).limit_denominator(10)
|
||||||
return frac.numerator, frac.denominator
|
return frac.numerator, frac.denominator
|
||||||
|
|
||||||
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
|
|||||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||||
|
|
||||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
self.res_blocks = {
|
||||||
|
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||||
|
}
|
||||||
|
|
||||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||||
if rational_resampler:
|
if rational_resampler:
|
||||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
|
self.upsampler = SpatialRationalResampler(
|
||||||
|
mid_channels=mid_channels, scale=spatial_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
||||||
|
|
||||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
self.post_upsample_res_blocks = {
|
||||||
|
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||||
|
}
|
||||||
|
|
||||||
# Final projection
|
# Final projection
|
||||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def debug_stats(name, t):
|
def debug_stats(name, t):
|
||||||
if debug:
|
if debug:
|
||||||
mx.eval(t)
|
mx.eval(t)
|
||||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
print(
|
||||||
|
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||||
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
|||||||
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||||
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||||
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||||
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
conv_key = (
|
||||||
|
"upsampler.conv.weight"
|
||||||
|
if "upsampler.conv.weight" in raw_weights
|
||||||
|
else "upsampler.0.weight"
|
||||||
|
)
|
||||||
if conv_key in raw_weights:
|
if conv_key in raw_weights:
|
||||||
out_channels = raw_weights[conv_key].shape[0]
|
out_channels = raw_weights[conv_key].shape[0]
|
||||||
ratio = out_channels // mid_channels
|
ratio = out_channels // mid_channels
|
||||||
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
|||||||
rational_resampler = False
|
rational_resampler = False
|
||||||
spatial_scale = 2.0
|
spatial_scale = 2.0
|
||||||
|
|
||||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
print(
|
||||||
|
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
upsampler = LatentUpsampler(
|
upsampler = LatentUpsampler(
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ def convert_audio_encoder(
|
|||||||
return encoder_dir
|
return encoder_dir
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
vae_path = hf_hub_download(
|
vae_path = hf_hub_download(
|
||||||
source_repo,
|
source_repo,
|
||||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
|
||||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
|
||||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||||
|
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||||
TilingConfig,
|
|
||||||
SpatialTilingConfig,
|
SpatialTilingConfig,
|
||||||
TemporalTilingConfig,
|
TemporalTilingConfig,
|
||||||
|
TilingConfig,
|
||||||
)
|
)
|
||||||
|
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
|||||||
# Height padding (axis 2)
|
# Height padding (axis 2)
|
||||||
if pad_h > 0:
|
if pad_h > 0:
|
||||||
# Get reflection indices - exclude boundary
|
# Get reflection indices - exclude boundary
|
||||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
|
||||||
|
:, :, ::-1, :, :
|
||||||
|
] # Flip bottom portion
|
||||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||||
|
|
||||||
# Width padding (axis 3)
|
# Width padding (axis 3)
|
||||||
if pad_w > 0:
|
if pad_w > 0:
|
||||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
|
||||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
|
||||||
|
:, :, :, ::-1, :
|
||||||
|
] # Flip right portion
|
||||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -126,7 +130,9 @@ class CausalConv3d(nn.Module):
|
|||||||
if self.time_kernel_size > 1:
|
if self.time_kernel_size > 1:
|
||||||
if use_causal:
|
if use_causal:
|
||||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
first_frame_pad = mx.repeat(
|
||||||
|
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
|
||||||
|
)
|
||||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||||
else:
|
else:
|
||||||
# Non-causal: replicate first frame at start, last frame at end
|
# Non-causal: replicate first frame at start, last frame at end
|
||||||
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b, d, h, w, c = x.shape
|
b, d, h, w, c = x.shape
|
||||||
|
|
||||||
|
|
||||||
total_elements = d * h * w * c
|
total_elements = d * h * w * c
|
||||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||||
|
|
||||||
@@ -191,7 +196,6 @@ class CausalConv3d(nn.Module):
|
|||||||
|
|
||||||
overlap = kernel_t - 1
|
overlap = kernel_t - 1
|
||||||
|
|
||||||
|
|
||||||
expected_output_frames = d - overlap
|
expected_output_frames = d - overlap
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|||||||
@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Dict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
|
||||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
|
|||||||
def __init__(self, embedding_dim: int):
|
def __init__(self, embedding_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.timestep_embedder = TimestepEmbedding(
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
in_channels=256,
|
in_channels=256, time_embed_dim=embedding_dim
|
||||||
time_embed_dim=embedding_dim
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
def __call__(
|
||||||
|
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
|
||||||
|
) -> mx.array:
|
||||||
timesteps_proj = get_timestep_embedding(
|
timesteps_proj = get_timestep_embedding(
|
||||||
timestep,
|
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||||
embedding_dim=256,
|
|
||||||
flip_sin_to_cos=True,
|
|
||||||
downscale_freq_shift=0
|
|
||||||
)
|
)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||||
return timesteps_emb
|
return timesteps_emb
|
||||||
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
|
|||||||
|
|
||||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||||
|
|
||||||
class ConvWrapper(nn.Module):
|
class ConvWrapper(nn.Module):
|
||||||
def __init__(self_inner):
|
def __init__(self_inner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
|
|||||||
padding=1,
|
padding=1,
|
||||||
spatial_padding_mode=padding_mode,
|
spatial_padding_mode=padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self_inner, x, causal=False):
|
def __call__(self_inner, x, causal=False):
|
||||||
return self_inner.conv(x, causal=causal)
|
return self_inner.conv(x, causal=causal)
|
||||||
|
|
||||||
return ConvWrapper()
|
return ConvWrapper()
|
||||||
|
|
||||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
"""Apply pixel normalization."""
|
"""Apply pixel normalization."""
|
||||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
|
|||||||
if self.timestep_conditioning and timestep_embed is not None:
|
if self.timestep_conditioning and timestep_embed is not None:
|
||||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||||
# Combine table with timestep embedding
|
# Combine table with timestep embedding
|
||||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
ada_values = self.scale_shift_table[
|
||||||
|
None, :, :, None, None, None
|
||||||
|
] # (1, 4, C, 1, 1, 1)
|
||||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||||
channels = self.scale_shift_table.shape[1]
|
channels = self.scale_shift_table.shape[1]
|
||||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||||
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
|
|||||||
|
|
||||||
# Time embedder for this block group: embed_dim = 4 * channels
|
# Time embedder for this block group: embed_dim = 4 * channels
|
||||||
if timestep_conditioning:
|
if timestep_conditioning:
|
||||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
|
||||||
embedding_dim=channels * 4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use dict with int keys for MLX to track parameters properly
|
# Use dict with int keys for MLX to track parameters properly
|
||||||
self.res_blocks = {
|
self.res_blocks = {
|
||||||
i: ResnetBlock3DSimple(
|
i: ResnetBlock3DSimple(
|
||||||
channels,
|
channels,
|
||||||
spatial_padding_mode,
|
spatial_padding_mode,
|
||||||
timestep_conditioning=timestep_conditioning
|
timestep_conditioning=timestep_conditioning,
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
}
|
}
|
||||||
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
|
|||||||
if self.timestep_conditioning and timestep is not None:
|
if self.timestep_conditioning and timestep is not None:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
timestep_embed = self.time_embedder(
|
timestep_embed = self.time_embedder(
|
||||||
timestep.flatten(),
|
timestep.flatten(), hidden_dtype=x.dtype
|
||||||
hidden_dtype=x.dtype
|
|
||||||
)
|
)
|
||||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||||
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
padding=1,
|
padding=1,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self_inner, x, causal=False):
|
def __call__(self_inner, x, causal=False):
|
||||||
return self_inner.conv(x, causal=causal)
|
return self_inner.conv(x, causal=causal)
|
||||||
|
|
||||||
self.conv_in = ConvInWrapper()
|
self.conv_in = ConvInWrapper()
|
||||||
|
|
||||||
# Build up blocks from config
|
# Build up blocks from config
|
||||||
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
block_type = block_def[0]
|
block_type = block_def[0]
|
||||||
ch = block_def[1]
|
ch = block_def[1]
|
||||||
if block_type == "res":
|
if block_type == "res":
|
||||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
num_layers = (
|
||||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||||
|
)
|
||||||
|
self.up_blocks[idx] = ResBlockGroup(
|
||||||
|
ch, num_layers, spatial_padding_mode, timestep_conditioning
|
||||||
|
)
|
||||||
elif block_type == "d2s":
|
elif block_type == "d2s":
|
||||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||||
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
final_out_channels = out_channels * patch_size * patch_size
|
final_out_channels = out_channels * patch_size * patch_size
|
||||||
|
|
||||||
class ConvOutWrapper(nn.Module):
|
class ConvOutWrapper(nn.Module):
|
||||||
def __init__(self_inner):
|
def __init__(self_inner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
padding=1,
|
padding=1,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self_inner, x, causal=False):
|
def __call__(self_inner, x, causal=False):
|
||||||
return self_inner.conv(x, causal=causal)
|
return self_inner.conv(x, causal=causal)
|
||||||
|
|
||||||
self.conv_out = ConvOutWrapper()
|
self.conv_out = ConvOutWrapper()
|
||||||
|
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
if key.startswith("vae.decoder."):
|
if key.startswith("vae.decoder."):
|
||||||
new_key = key.replace("vae.decoder.", "")
|
new_key = key.replace("vae.decoder.", "")
|
||||||
|
|
||||||
|
|
||||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||||
if ".conv.weight" in key and value.ndim == 5:
|
if ".conv.weight" in key and value.ndim == 5:
|
||||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||||
|
|
||||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
if (
|
||||||
|
".conv.conv.weight" not in new_key
|
||||||
|
and ".conv.conv.bias" not in new_key
|
||||||
|
):
|
||||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||||
|
|
||||||
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
def from_pretrained(
|
||||||
|
cls, model_path: Path, strict: bool = True
|
||||||
|
) -> "LTX2VideoDecoder":
|
||||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(str(wf)))
|
weights.update(mx.load(str(wf)))
|
||||||
|
|
||||||
|
|
||||||
# Infer block structure from weights
|
# Infer block structure from weights
|
||||||
decoder_blocks = cls._infer_blocks(weights)
|
decoder_blocks = cls._infer_blocks(weights)
|
||||||
|
|
||||||
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
return final_blocks
|
return final_blocks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
"""Apply pixel normalization."""
|
"""Apply pixel normalization."""
|
||||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -552,20 +562,15 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
chunked_conv: bool = False,
|
chunked_conv: bool = False,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
|
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Add noise if timestep conditioning is enabled
|
# Add noise if timestep conditioning is enabled
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||||
|
|
||||||
|
|
||||||
sample = self.per_channel_statistics.un_normalize(sample)
|
sample = self.per_channel_statistics.un_normalize(sample)
|
||||||
|
|
||||||
|
|
||||||
if timestep is None and self.timestep_conditioning:
|
if timestep is None and self.timestep_conditioning:
|
||||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||||
|
|
||||||
@@ -575,7 +580,6 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
x = self.conv_in(sample, causal=causal)
|
x = self.conv_in(sample, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
for i, block in self.up_blocks.items():
|
for i, block in self.up_blocks.items():
|
||||||
if isinstance(block, ResBlockGroup):
|
if isinstance(block, ResBlockGroup):
|
||||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||||
@@ -584,18 +588,17 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = block(x, causal=causal)
|
x = block(x, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
x = self.pixel_norm(x)
|
x = self.pixel_norm(x)
|
||||||
|
|
||||||
|
|
||||||
if self.timestep_conditioning and scaled_timestep is not None:
|
if self.timestep_conditioning and scaled_timestep is not None:
|
||||||
embedded_timestep = self.last_time_embedder(
|
embedded_timestep = self.last_time_embedder(
|
||||||
scaled_timestep.flatten(),
|
scaled_timestep.flatten(), hidden_dtype=x.dtype
|
||||||
hidden_dtype=x.dtype
|
|
||||||
)
|
)
|
||||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||||
|
|
||||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
ada_values = self.last_scale_shift_table[
|
||||||
|
None, :, :, None, None, None
|
||||||
|
] # (1, 2, 128, 1, 1, 1)
|
||||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||||
ada_values = ada_values + ts_reshaped
|
ada_values = ada_values + ts_reshaped
|
||||||
|
|
||||||
@@ -604,16 +607,13 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
x = x * (1 + scale) + shift
|
x = x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
|
|
||||||
|
|
||||||
x = self.conv_out(x, causal=causal)
|
x = self.conv_out(x, causal=causal)
|
||||||
|
|
||||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def decode_tiled(
|
def decode_tiled(
|
||||||
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
use_chunked_conv = tiling_mode in (
|
||||||
|
"conservative",
|
||||||
|
"none",
|
||||||
|
"auto",
|
||||||
|
"default",
|
||||||
|
"spatial",
|
||||||
|
)
|
||||||
|
|
||||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||||
# No tiling needed, use regular decode
|
# No tiling needed, use regular decode
|
||||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
return self(
|
||||||
|
sample,
|
||||||
|
causal=causal,
|
||||||
|
timestep=timestep,
|
||||||
|
debug=debug,
|
||||||
|
chunked_conv=use_chunked_conv,
|
||||||
|
)
|
||||||
|
|
||||||
return decode_with_tiling(
|
return decode_with_tiling(
|
||||||
decoder_fn=self,
|
decoder_fn=self,
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
|
||||||
|
|
||||||
|
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||||
|
|
||||||
|
|
||||||
def encode_image(
|
def encode_image(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Operations for Video VAE."""
|
"""Operations for Video VAE."""
|
||||||
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
|||||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||||
|
|
||||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
x = mx.reshape(
|
||||||
|
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
|
||||||
|
)
|
||||||
|
|
||||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||||
|
|||||||
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
def __call__(
|
||||||
|
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
b, c, d, h, w = x.shape
|
b, c, d, h, w = x.shape
|
||||||
st, sh, sw = self.stride
|
st, sh, sw = self.stride
|
||||||
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
def _chunked_conv_depth_to_space(
|
||||||
|
self, x: mx.array, causal: bool = True
|
||||||
|
) -> mx.array:
|
||||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||||
|
|
||||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||||
|
|||||||
@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
|
|||||||
# Apply right ramp (fade out)
|
# Apply right ramp (fade out)
|
||||||
if ramp_right > 0:
|
if ramp_right > 0:
|
||||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
fade_out = [
|
||||||
|
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
|
||||||
|
]
|
||||||
for i in range(ramp_right):
|
for i in range(ramp_right):
|
||||||
mask[length - ramp_right + i] *= fade_out[i]
|
mask[length - ramp_right + i] *= fade_out[i]
|
||||||
|
|
||||||
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.tile_size_in_pixels < 64:
|
if self.tile_size_in_pixels < 64:
|
||||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
raise ValueError(
|
||||||
|
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
|
||||||
|
)
|
||||||
if self.tile_size_in_pixels % 32 != 0:
|
if self.tile_size_in_pixels % 32 != 0:
|
||||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
raise ValueError(
|
||||||
|
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
|
||||||
|
)
|
||||||
if self.tile_overlap_in_pixels % 32 != 0:
|
if self.tile_overlap_in_pixels % 32 != 0:
|
||||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
raise ValueError(
|
||||||
|
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
|
||||||
|
)
|
||||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||||
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.tile_size_in_frames < 16:
|
if self.tile_size_in_frames < 16:
|
||||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
raise ValueError(
|
||||||
|
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
|
||||||
|
)
|
||||||
if self.tile_size_in_frames % 8 != 0:
|
if self.tile_size_in_frames % 8 != 0:
|
||||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
raise ValueError(
|
||||||
|
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
|
||||||
|
)
|
||||||
if self.tile_overlap_in_frames % 8 != 0:
|
if self.tile_overlap_in_frames % 8 != 0:
|
||||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
raise ValueError(
|
||||||
|
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
|
||||||
|
)
|
||||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||||
@@ -113,15 +127,21 @@ class TilingConfig:
|
|||||||
def default(cls) -> "TilingConfig":
|
def default(cls) -> "TilingConfig":
|
||||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||||
return cls(
|
return cls(
|
||||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
spatial_config=SpatialTilingConfig(
|
||||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||||
|
),
|
||||||
|
temporal_config=TemporalTilingConfig(
|
||||||
|
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||||
"""Spatial tiling only (for short videos with large resolution)."""
|
"""Spatial tiling only (for short videos with large resolution)."""
|
||||||
return cls(
|
return cls(
|
||||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
spatial_config=SpatialTilingConfig(
|
||||||
|
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
|
||||||
|
),
|
||||||
temporal_config=None,
|
temporal_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,23 +150,33 @@ class TilingConfig:
|
|||||||
"""Temporal tiling only (for long videos with small resolution)."""
|
"""Temporal tiling only (for long videos with small resolution)."""
|
||||||
return cls(
|
return cls(
|
||||||
spatial_config=None,
|
spatial_config=None,
|
||||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
temporal_config=TemporalTilingConfig(
|
||||||
|
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def aggressive(cls) -> "TilingConfig":
|
def aggressive(cls) -> "TilingConfig":
|
||||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||||
return cls(
|
return cls(
|
||||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
spatial_config=SpatialTilingConfig(
|
||||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
tile_size_in_pixels=256, tile_overlap_in_pixels=64
|
||||||
|
),
|
||||||
|
temporal_config=TemporalTilingConfig(
|
||||||
|
tile_size_in_frames=32, tile_overlap_in_frames=8
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conservative(cls) -> "TilingConfig":
|
def conservative(cls) -> "TilingConfig":
|
||||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||||
return cls(
|
return cls(
|
||||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
spatial_config=SpatialTilingConfig(
|
||||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
tile_size_in_pixels=768, tile_overlap_in_pixels=64
|
||||||
|
),
|
||||||
|
temporal_config=TemporalTilingConfig(
|
||||||
|
tile_size_in_frames=96, tile_overlap_in_frames=24
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -186,10 +216,14 @@ class TilingConfig:
|
|||||||
temporal_config = None
|
temporal_config = None
|
||||||
|
|
||||||
if needs_spatial:
|
if needs_spatial:
|
||||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
spatial_config = SpatialTilingConfig(
|
||||||
|
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||||
|
)
|
||||||
|
|
||||||
if needs_temporal:
|
if needs_temporal:
|
||||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
temporal_config = TemporalTilingConfig(
|
||||||
|
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||||
|
)
|
||||||
|
|
||||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||||
|
|
||||||
@@ -197,16 +231,21 @@ class TilingConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DimensionIntervals:
|
class DimensionIntervals:
|
||||||
"""Intervals for splitting a single dimension."""
|
"""Intervals for splitting a single dimension."""
|
||||||
|
|
||||||
starts: List[int]
|
starts: List[int]
|
||||||
ends: List[int]
|
ends: List[int]
|
||||||
left_ramps: List[int]
|
left_ramps: List[int]
|
||||||
right_ramps: List[int]
|
right_ramps: List[int]
|
||||||
|
|
||||||
|
|
||||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
def split_in_spatial(
|
||||||
|
size: int, overlap: int, dimension_size: int
|
||||||
|
) -> DimensionIntervals:
|
||||||
"""Split a spatial dimension into intervals."""
|
"""Split a spatial dimension into intervals."""
|
||||||
if dimension_size <= size:
|
if dimension_size <= size:
|
||||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
return DimensionIntervals(
|
||||||
|
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||||
|
)
|
||||||
|
|
||||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||||
starts = [i * (size - overlap) for i in range(amount)]
|
starts = [i * (size - overlap) for i in range(amount)]
|
||||||
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
|
|||||||
left_ramps = [0] + [overlap] * (amount - 1)
|
left_ramps = [0] + [overlap] * (amount - 1)
|
||||||
right_ramps = [overlap] * (amount - 1) + [0]
|
right_ramps = [overlap] * (amount - 1) + [0]
|
||||||
|
|
||||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
return DimensionIntervals(
|
||||||
|
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
def split_in_temporal(
|
||||||
|
size: int, overlap: int, dimension_size: int
|
||||||
|
) -> DimensionIntervals:
|
||||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||||
if dimension_size <= size:
|
if dimension_size <= size:
|
||||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
return DimensionIntervals(
|
||||||
|
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||||
|
)
|
||||||
|
|
||||||
# Start with spatial split
|
# Start with spatial split
|
||||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||||
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
|
|||||||
starts[i] = starts[i] - 1
|
starts[i] = starts[i] - 1
|
||||||
left_ramps[i] = left_ramps[i] + 1
|
left_ramps[i] = left_ramps[i] + 1
|
||||||
|
|
||||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
return DimensionIntervals(
|
||||||
|
starts=starts,
|
||||||
|
ends=intervals.ends,
|
||||||
|
left_ramps=left_ramps,
|
||||||
|
right_ramps=intervals.right_ramps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
def map_temporal_slice(
|
||||||
|
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||||
|
) -> Tuple[slice, mx.array]:
|
||||||
"""Map temporal latent interval to output coordinates and mask."""
|
"""Map temporal latent interval to output coordinates and mask."""
|
||||||
start = begin * scale
|
start = begin * scale
|
||||||
stop = 1 + (end - 1) * scale
|
stop = 1 + (end - 1) * scale
|
||||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||||
right_ramp_scaled = right_ramp * scale
|
right_ramp_scaled = right_ramp * scale
|
||||||
|
|
||||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
mask = compute_trapezoidal_mask_1d(
|
||||||
|
stop - start, left_ramp_scaled, right_ramp_scaled, True
|
||||||
|
)
|
||||||
return slice(start, stop), mask
|
return slice(start, stop), mask
|
||||||
|
|
||||||
|
|
||||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
def map_spatial_slice(
|
||||||
|
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||||
|
) -> Tuple[slice, mx.array]:
|
||||||
"""Map spatial latent interval to output coordinates and mask."""
|
"""Map spatial latent interval to output coordinates and mask."""
|
||||||
start = begin * scale
|
start = begin * scale
|
||||||
stop = end * scale
|
stop = end * scale
|
||||||
left_ramp_scaled = left_ramp * scale
|
left_ramp_scaled = left_ramp * scale
|
||||||
right_ramp_scaled = right_ramp * scale
|
right_ramp_scaled = right_ramp * scale
|
||||||
|
|
||||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
mask = compute_trapezoidal_mask_1d(
|
||||||
|
stop - start, left_ramp_scaled, right_ramp_scaled, False
|
||||||
|
)
|
||||||
return slice(start, stop), mask
|
return slice(start, stop), mask
|
||||||
|
|
||||||
|
|
||||||
@@ -315,7 +373,9 @@ def decode_with_tiling(
|
|||||||
temporal_overlap = 0
|
temporal_overlap = 0
|
||||||
|
|
||||||
# Compute intervals for each dimension
|
# Compute intervals for each dimension
|
||||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
temporal_intervals = split_in_temporal(
|
||||||
|
temporal_tile_size, temporal_overlap, f_latent
|
||||||
|
)
|
||||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||||
|
|
||||||
@@ -338,7 +398,9 @@ def decode_with_tiling(
|
|||||||
t_right = temporal_intervals.right_ramps[t_idx]
|
t_right = temporal_intervals.right_ramps[t_idx]
|
||||||
|
|
||||||
# Map temporal coordinates
|
# Map temporal coordinates
|
||||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
out_t_slice, t_mask = map_temporal_slice(
|
||||||
|
t_start, t_end, t_left, t_right, temporal_scale
|
||||||
|
)
|
||||||
|
|
||||||
for h_idx in range(num_h_tiles):
|
for h_idx in range(num_h_tiles):
|
||||||
h_start = height_intervals.starts[h_idx]
|
h_start = height_intervals.starts[h_idx]
|
||||||
@@ -347,7 +409,9 @@ def decode_with_tiling(
|
|||||||
h_right = height_intervals.right_ramps[h_idx]
|
h_right = height_intervals.right_ramps[h_idx]
|
||||||
|
|
||||||
# Map height coordinates
|
# Map height coordinates
|
||||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
out_h_slice, h_mask = map_spatial_slice(
|
||||||
|
h_start, h_end, h_left, h_right, spatial_scale
|
||||||
|
)
|
||||||
|
|
||||||
for w_idx in range(num_w_tiles):
|
for w_idx in range(num_w_tiles):
|
||||||
w_start = width_intervals.starts[w_idx]
|
w_start = width_intervals.starts[w_idx]
|
||||||
@@ -356,13 +420,23 @@ def decode_with_tiling(
|
|||||||
w_right = width_intervals.right_ramps[w_idx]
|
w_right = width_intervals.right_ramps[w_idx]
|
||||||
|
|
||||||
# Map width coordinates
|
# Map width coordinates
|
||||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
out_w_slice, w_mask = map_spatial_slice(
|
||||||
|
w_start, w_end, w_left, w_right, spatial_scale
|
||||||
|
)
|
||||||
|
|
||||||
# Extract tile latents (small slice)
|
# Extract tile latents (small slice)
|
||||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
tile_latents = latents[
|
||||||
|
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||||
|
]
|
||||||
|
|
||||||
# Decode tile
|
# Decode tile
|
||||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
tile_output = decoder_fn(
|
||||||
|
tile_latents,
|
||||||
|
causal=causal,
|
||||||
|
timestep=timestep,
|
||||||
|
debug=False,
|
||||||
|
chunked_conv=chunked_conv,
|
||||||
|
)
|
||||||
mx.eval(tile_output)
|
mx.eval(tile_output)
|
||||||
|
|
||||||
# Clear tile_latents reference
|
# Clear tile_latents reference
|
||||||
@@ -385,13 +459,15 @@ def decode_with_tiling(
|
|||||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||||
|
|
||||||
blend_mask = (
|
blend_mask = (
|
||||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slice tile output to match
|
# Slice tile output to match
|
||||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
tile_output_slice = tile_output[
|
||||||
|
:, :, :actual_t, :actual_h, :actual_w
|
||||||
|
].astype(mx.float32)
|
||||||
|
|
||||||
# Clear full tile_output
|
# Clear full tile_output
|
||||||
del tile_output
|
del tile_output
|
||||||
@@ -409,11 +485,37 @@ def decode_with_tiling(
|
|||||||
weighted_tile = tile_output_slice * blend_mask
|
weighted_tile = tile_output_slice * blend_mask
|
||||||
|
|
||||||
# Update output using slice assignment
|
# Update output using slice assignment
|
||||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
output[
|
||||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
] = (
|
||||||
|
output[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
]
|
||||||
|
+ weighted_tile
|
||||||
)
|
)
|
||||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
weights[
|
||||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
] = (
|
||||||
|
weights[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
]
|
||||||
|
+ blend_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# Force evaluation to free memory
|
# Force evaluation to free memory
|
||||||
@@ -445,10 +547,12 @@ def decode_with_tiling(
|
|||||||
if next_tile_start_latent == 0:
|
if next_tile_start_latent == 0:
|
||||||
next_tile_start_out = 0
|
next_tile_start_out = 0
|
||||||
else:
|
else:
|
||||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
next_tile_start_out = (
|
||||||
|
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
)
|
||||||
|
|
||||||
# We need to track how many frames we've already emitted
|
# We need to track how many frames we've already emitted
|
||||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||||
decode_with_tiling._emitted_frames = 0
|
decode_with_tiling._emitted_frames = 0
|
||||||
emitted = decode_with_tiling._emitted_frames
|
emitted = decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
@@ -456,7 +560,10 @@ def decode_with_tiling(
|
|||||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
finalized_output = (
|
||||||
|
output[:, :, emitted:next_tile_start_out, :, :]
|
||||||
|
/ finalized_weights
|
||||||
|
)
|
||||||
finalized_output = finalized_output.astype(latents.dtype)
|
finalized_output = finalized_output.astype(latents.dtype)
|
||||||
mx.eval(finalized_output)
|
mx.eval(finalized_output)
|
||||||
|
|
||||||
@@ -473,7 +580,7 @@ def decode_with_tiling(
|
|||||||
|
|
||||||
# Emit remaining frames if callback provided
|
# Emit remaining frames if callback provided
|
||||||
if on_frames_ready is not None:
|
if on_frames_ready is not None:
|
||||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||||
if emitted < out_f:
|
if emitted < out_f:
|
||||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||||
mx.eval(remaining_output)
|
mx.eval(remaining_output)
|
||||||
@@ -481,7 +588,7 @@ def decode_with_tiling(
|
|||||||
del remaining_output
|
del remaining_output
|
||||||
|
|
||||||
# Reset emitted frames counter for next call
|
# Reset emitted frames counter for next call
|
||||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||||
del decode_with_tiling._emitted_frames
|
del decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
# Clean up weights
|
# Clean up weights
|
||||||
|
|||||||
@@ -8,12 +8,15 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
from mlx_video.models.ltx_2.video_vae.ops import (
|
||||||
|
PerChannelStatistics,
|
||||||
|
patchify,
|
||||||
|
unpatchify,
|
||||||
|
)
|
||||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||||
NormLayerType,
|
NormLayerType,
|
||||||
ResnetBlock3D,
|
ResnetBlock3D,
|
||||||
UNetMidBlock3D,
|
UNetMidBlock3D,
|
||||||
get_norm_layer,
|
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||||
DepthToSpaceUpsample,
|
DepthToSpaceUpsample,
|
||||||
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
|
|||||||
|
|
||||||
class LogVarianceType(Enum):
|
class LogVarianceType(Enum):
|
||||||
"""Log variance mode for VAE."""
|
"""Log variance mode for VAE."""
|
||||||
|
|
||||||
PER_CHANNEL = "per_channel"
|
PER_CHANNEL = "per_channel"
|
||||||
UNIFORM = "uniform"
|
UNIFORM = "uniform"
|
||||||
CONSTANT = "constant"
|
CONSTANT = "constant"
|
||||||
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
|
|||||||
config: VideoEncoderModelConfig with encoder parameters
|
config: VideoEncoderModelConfig with encoder parameters
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
|
||||||
|
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.norm_layer = config.norm_layer
|
self.norm_layer = config.norm_layer
|
||||||
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
|
|||||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||||
|
|
||||||
# Per-channel statistics for normalizing latents
|
# Per-channel statistics for normalizing latents
|
||||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
self.per_channel_statistics = PerChannelStatistics(
|
||||||
|
latent_channels=config.out_channels
|
||||||
|
)
|
||||||
|
|
||||||
# After patchify, channels increase by patch_size^2
|
# After patchify, channels increase by patch_size^2
|
||||||
in_channels = config.in_channels * config.patch_size ** 2
|
in_channels = config.in_channels * config.patch_size**2
|
||||||
feature_channels = config.out_channels
|
feature_channels = config.out_channels
|
||||||
|
|
||||||
# Initial convolution
|
# Initial convolution
|
||||||
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
|
|||||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||||
self.down_blocks = {}
|
self.down_blocks = {}
|
||||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
block_config = (
|
||||||
|
{"num_layers": block_params}
|
||||||
|
if isinstance(block_params, int)
|
||||||
|
else block_params
|
||||||
|
)
|
||||||
|
|
||||||
block, feature_channels = _make_encoder_block(
|
block, feature_channels = _make_encoder_block(
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
|
|||||||
conv_out_channels = config.out_channels
|
conv_out_channels = config.out_channels
|
||||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||||
conv_out_channels *= 2
|
conv_out_channels *= 2
|
||||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
elif config.latent_log_var in {
|
||||||
|
LogVarianceType.UNIFORM,
|
||||||
|
LogVarianceType.CONSTANT,
|
||||||
|
}:
|
||||||
conv_out_channels += 1
|
conv_out_channels += 1
|
||||||
|
|
||||||
self.conv_out = CausalConv3d(
|
self.conv_out = CausalConv3d(
|
||||||
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
|
|||||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||||
sample = sample[:, :-1, ...]
|
sample = sample[:, :-1, ...]
|
||||||
approx_ln_0 = -30
|
approx_ln_0 = -30
|
||||||
sample = mx.concatenate([
|
sample = mx.concatenate(
|
||||||
sample,
|
[
|
||||||
mx.full_like(sample, approx_ln_0),
|
sample,
|
||||||
], axis=1)
|
mx.full_like(sample, approx_ln_0),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Split into means and logvar, normalize means
|
# Split into means and logvar, normalize means
|
||||||
means = sample[:, :self.latent_channels, ...]
|
means = sample[:, : self.latent_channels, ...]
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
|
|||||||
Loaded VideoEncoder instance
|
Loaded VideoEncoder instance
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
|
|||||||
decoder_blocks = []
|
decoder_blocks = []
|
||||||
|
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
out_channels = out_channels * patch_size ** 2
|
out_channels = out_channels * patch_size**2
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.timestep_conditioning = timestep_conditioning
|
self.timestep_conditioning = timestep_conditioning
|
||||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||||
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
|
|||||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||||
self.up_blocks = {}
|
self.up_blocks = {}
|
||||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
block_config = (
|
||||||
|
{"num_layers": block_params}
|
||||||
|
if isinstance(block_params, int)
|
||||||
|
else block_params
|
||||||
|
)
|
||||||
|
|
||||||
block, feature_channels = _make_decoder_block(
|
block, feature_channels = _make_decoder_block(
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
|
|||||||
@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
|
|||||||
v = self.v(x_w).reshape(b, s, n, d)
|
v = self.v(x_w).reshape(b, s, n, d)
|
||||||
|
|
||||||
# RoPE in float32 for precision (official uses float64)
|
# RoPE in float32 for precision (official uses float64)
|
||||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
q = rope_apply(
|
||||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||||
|
)
|
||||||
|
k = rope_apply(
|
||||||
|
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||||
|
)
|
||||||
|
|
||||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||||
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
q, k, v, scale=self.scale, mask=mask
|
q, k, v, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = mx.fast.scaled_dot_product_attention(
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||||
q, k, v, scale=self.scale
|
|
||||||
)
|
|
||||||
|
|
||||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||||
return self.o(out)
|
return self.o(out)
|
||||||
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
|
|||||||
q, k, v, scale=self.scale, mask=mask
|
q, k, v, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = mx.fast.scaled_dot_product_attention(
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||||
q, k, v, scale=self.scale
|
|
||||||
)
|
|
||||||
|
|
||||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||||
return self.o(out)
|
return self.o(out)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
|||||||
return weights
|
return weights
|
||||||
|
|
||||||
|
|
||||||
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize_wan_transformer_weights(
|
||||||
|
weights: Dict[str, mx.array]
|
||||||
|
) -> Dict[str, mx.array]:
|
||||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||||
|
|
||||||
Wan2.2 keys follow the pattern:
|
Wan2.2 keys follow the pattern:
|
||||||
@@ -246,8 +247,8 @@ def _load_lora_configs(
|
|||||||
|
|
||||||
Shared between weight-merging and runtime-wrapping paths.
|
Shared between weight-merging and runtime-wrapping paths.
|
||||||
"""
|
"""
|
||||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
|
||||||
from mlx_video.generate_wan import Colors
|
from mlx_video.generate_wan import Colors
|
||||||
|
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||||
|
|
||||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||||
|
|
||||||
@@ -264,7 +265,9 @@ def _load_lora_configs(
|
|||||||
module_to_loras = load_multiple_loras(configs)
|
module_to_loras = load_multiple_loras(configs)
|
||||||
|
|
||||||
if not module_to_loras:
|
if not module_to_loras:
|
||||||
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
return module_to_loras
|
return module_to_loras
|
||||||
|
|
||||||
@@ -279,8 +282,8 @@ def load_and_apply_loras(
|
|||||||
|
|
||||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||||
"""
|
"""
|
||||||
from mlx_video.lora import apply_loras_to_weights
|
|
||||||
from mlx_video.generate_wan import Colors
|
from mlx_video.generate_wan import Colors
|
||||||
|
from mlx_video.lora import apply_loras_to_weights
|
||||||
|
|
||||||
if not lora_configs:
|
if not lora_configs:
|
||||||
return model_weights
|
return model_weights
|
||||||
@@ -289,12 +292,17 @@ def load_and_apply_loras(
|
|||||||
if not module_to_loras:
|
if not module_to_loras:
|
||||||
return model_weights
|
return model_weights
|
||||||
|
|
||||||
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Model has {len(model_weights)} weight keys")
|
print(f" Model has {len(model_weights)} weight keys")
|
||||||
|
|
||||||
modified_weights = apply_loras_to_weights(
|
modified_weights = apply_loras_to_weights(
|
||||||
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
model_weights,
|
||||||
|
module_to_loras,
|
||||||
|
verbose=verbose,
|
||||||
|
quantization_bits=quantization_bits,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||||
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
|
|||||||
src_model_type = src_config.get("model_type", "t2v")
|
src_model_type = src_config.get("model_type", "t2v")
|
||||||
src_text_len = src_config.get("text_len", 512)
|
src_text_len = src_config.get("text_len", 512)
|
||||||
|
|
||||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
print(
|
||||||
f"heads={src_num_heads}, type={src_model_type}")
|
f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||||
|
f"heads={src_num_heads}, type={src_model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
# Use preset for known TI2V 5B configuration
|
# Use preset for known TI2V 5B configuration
|
||||||
if src_model_type == "ti2v" and src_dim == 3072:
|
if src_model_type == "ti2v" and src_dim == 3072:
|
||||||
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
|
|||||||
weights = load_torch_weights(str(vae_path))
|
weights = load_torch_weights(str(vae_path))
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
weights = sanitize_wan22_vae_weights(
|
||||||
|
weights, include_encoder=include_encoder
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weights = sanitize_wan_vae_weights(weights)
|
weights = sanitize_wan_vae_weights(weights)
|
||||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||||
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
|
|||||||
|
|
||||||
# Quantize transformer weights if requested
|
# Quantize transformer weights if requested
|
||||||
if quantize:
|
if quantize:
|
||||||
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
print(
|
||||||
|
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
|
||||||
|
)
|
||||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||||
|
|
||||||
print(f"\nConversion complete! Output: {output_dir}")
|
print(f"\nConversion complete! Output: {output_dir}")
|
||||||
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
|
|||||||
return False
|
return False
|
||||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||||
quantize_patterns = (
|
quantize_patterns = (
|
||||||
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
".self_attn.q",
|
||||||
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
".self_attn.k",
|
||||||
".ffn.fc1", ".ffn.fc2",
|
".self_attn.v",
|
||||||
|
".self_attn.o",
|
||||||
|
".cross_attn.q",
|
||||||
|
".cross_attn.k",
|
||||||
|
".cross_attn.v",
|
||||||
|
".cross_attn.o",
|
||||||
|
".ffn.fc1",
|
||||||
|
".ffn.fc2",
|
||||||
)
|
)
|
||||||
return any(path.endswith(p) for p in quantize_patterns)
|
return any(path.endswith(p) for p in quantize_patterns)
|
||||||
|
|
||||||
@@ -684,14 +706,20 @@ def quantize_mlx_model(
|
|||||||
# Build model config
|
# Build model config
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
config_dict = {
|
||||||
|
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||||
|
}
|
||||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||||
if key in config_dict and isinstance(config_dict[key], list):
|
if key in config_dict and isinstance(config_dict[key], list):
|
||||||
config_dict[key] = tuple(config_dict[key])
|
config_dict[key] = tuple(config_dict[key])
|
||||||
config = WanModelConfig(**config_dict)
|
config = WanModelConfig(**config_dict)
|
||||||
|
|
||||||
# Copy non-transformer files to output dir (skip large model weights)
|
# Copy non-transformer files to output dir (skip large model weights)
|
||||||
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
transformer_files = {
|
||||||
|
"low_noise_model.safetensors",
|
||||||
|
"high_noise_model.safetensors",
|
||||||
|
"model.safetensors",
|
||||||
|
}
|
||||||
if dst.resolve() != src.resolve():
|
if dst.resolve() != src.resolve():
|
||||||
dst.mkdir(parents=True, exist_ok=True)
|
dst.mkdir(parents=True, exist_ok=True)
|
||||||
for f in src.iterdir():
|
for f in src.iterdir():
|
||||||
@@ -763,11 +791,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.quantize_only:
|
if args.quantize_only:
|
||||||
quantize_mlx_model(
|
quantize_mlx_model(
|
||||||
args.checkpoint_dir, args.output_dir,
|
args.checkpoint_dir,
|
||||||
bits=args.bits, group_size=args.group_size,
|
args.output_dir,
|
||||||
|
bits=args.bits,
|
||||||
|
group_size=args.group_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
convert_wan_checkpoint(
|
convert_wan_checkpoint(
|
||||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
args.checkpoint_dir,
|
||||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
args.output_dir,
|
||||||
|
args.dtype,
|
||||||
|
args.model_version,
|
||||||
|
quantize=args.quantize,
|
||||||
|
bits=args.bits,
|
||||||
|
group_size=args.group_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,18 +4,15 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||||
from mlx_video.models.wan.loading import (
|
from mlx_video.models.wan.loading import (
|
||||||
_clean_text,
|
|
||||||
encode_text,
|
encode_text,
|
||||||
load_t5_encoder,
|
load_t5_encoder,
|
||||||
load_vae_decoder,
|
load_vae_decoder,
|
||||||
@@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import (
|
|||||||
)
|
)
|
||||||
from mlx_video.models.wan.postprocess import save_video
|
from mlx_video.models.wan.postprocess import save_video
|
||||||
|
|
||||||
|
|
||||||
class Colors:
|
class Colors:
|
||||||
"""ANSI color codes for terminal output."""
|
"""ANSI color codes for terminal output."""
|
||||||
|
|
||||||
@@ -37,6 +35,7 @@ class Colors:
|
|||||||
DIM = "\033[2m"
|
DIM = "\033[2m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
# Backward-compat alias (tests and external code may use the old name)
|
# Backward-compat alias (tests and external code may use the old name)
|
||||||
_build_i2v_mask = build_i2v_mask
|
_build_i2v_mask = build_i2v_mask
|
||||||
|
|
||||||
@@ -143,10 +142,13 @@ def generate_video(
|
|||||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||||
if key in config_dict and isinstance(config_dict[key], list):
|
if key in config_dict and isinstance(config_dict[key], list):
|
||||||
config_dict[key] = tuple(config_dict[key])
|
config_dict[key] = tuple(config_dict[key])
|
||||||
config = WanModelConfig(**{
|
config = WanModelConfig(
|
||||||
k: v for k, v in config_dict.items()
|
**{
|
||||||
if k in WanModelConfig.__dataclass_fields__
|
k: v
|
||||||
})
|
for k, v in config_dict.items()
|
||||||
|
if k in WanModelConfig.__dataclass_fields__
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||||
if (model_dir / "low_noise_model.safetensors").exists():
|
if (model_dir / "low_noise_model.safetensors").exists():
|
||||||
@@ -182,7 +184,9 @@ def generate_video(
|
|||||||
if "patch_embedding_proj.weight" in k:
|
if "patch_embedding_proj.weight" in k:
|
||||||
actual_dim = v.shape[0]
|
actual_dim = v.shape[0]
|
||||||
if actual_dim != config.dim:
|
if actual_dim != config.dim:
|
||||||
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
|
||||||
|
)
|
||||||
if actual_dim <= 2048:
|
if actual_dim <= 2048:
|
||||||
config = WanModelConfig.wan21_t2v_1_3b()
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
else:
|
else:
|
||||||
@@ -192,13 +196,20 @@ def generate_video(
|
|||||||
|
|
||||||
# Auto-correct Wan2.2 VAE params from stale configs
|
# Auto-correct Wan2.2 VAE params from stale configs
|
||||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||||
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
|
print(
|
||||||
config = WanModelConfig(**{
|
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
|
||||||
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
|
)
|
||||||
"vae_z_dim": 48,
|
config = WanModelConfig(
|
||||||
"vae_stride": (4, 16, 16),
|
**{
|
||||||
"sample_fps": 24,
|
**{
|
||||||
})
|
f.name: getattr(config, f.name)
|
||||||
|
for f in config.__dataclass_fields__.values()
|
||||||
|
},
|
||||||
|
"vae_z_dim": 48,
|
||||||
|
"vae_stride": (4, 16, 16),
|
||||||
|
"sample_fps": 24,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Apply defaults from config if not overridden
|
# Apply defaults from config if not overridden
|
||||||
if steps is None:
|
if steps is None:
|
||||||
@@ -227,7 +238,9 @@ def generate_video(
|
|||||||
gen_frames = num_frames
|
gen_frames = num_frames
|
||||||
if trim_first_frames > 0:
|
if trim_first_frames > 0:
|
||||||
gen_frames = num_frames + trim_first_frames * 4
|
gen_frames = num_frames + trim_first_frames * 4
|
||||||
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
version_str = f"Wan{config.model_version}"
|
version_str = f"Wan{config.model_version}"
|
||||||
mode_str = "dual-model" if is_dual else "single-model"
|
mode_str = "dual-model" if is_dual else "single-model"
|
||||||
@@ -247,10 +260,16 @@ def generate_video(
|
|||||||
if is_i2v:
|
if is_i2v:
|
||||||
print(f" Image: {image}")
|
print(f" Image: {image}")
|
||||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
neg_display = (
|
||||||
|
neg_prompt_resolved[:60] + "..."
|
||||||
|
if len(neg_prompt_resolved) > 60
|
||||||
|
else neg_prompt_resolved
|
||||||
|
)
|
||||||
print(f" Neg prompt: {neg_display}")
|
print(f" Neg prompt: {neg_display}")
|
||||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
print(
|
||||||
|
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
|
||||||
|
)
|
||||||
if cfg_disabled:
|
if cfg_disabled:
|
||||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||||
print(f"{Colors.RESET}")
|
print(f"{Colors.RESET}")
|
||||||
@@ -275,12 +294,16 @@ def generate_video(
|
|||||||
height = align_h
|
height = align_h
|
||||||
if width == 0:
|
if width == 0:
|
||||||
width = align_w
|
width = align_w
|
||||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
# Enforce max_area constraint (model-specific resolution limit)
|
# Enforce max_area constraint (model-specific resolution limit)
|
||||||
if config.max_area > 0 and height * width > config.max_area:
|
if config.max_area > 0 and height * width > config.max_area:
|
||||||
old_h, old_w = height, width
|
old_h, old_w = height, width
|
||||||
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
width, height = _best_output_size(
|
||||||
|
width, height, align_w, align_h, config.max_area
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||||
@@ -309,6 +332,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
@@ -318,12 +342,15 @@ def generate_video(
|
|||||||
context_null = None
|
context_null = None
|
||||||
mx.eval(context)
|
mx.eval(context)
|
||||||
else:
|
else:
|
||||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
context_null = encode_text(
|
||||||
|
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
|
||||||
|
)
|
||||||
mx.eval(context, context_null)
|
mx.eval(context, context_null)
|
||||||
|
|
||||||
# Free T5 from memory
|
# Free T5 from memory
|
||||||
del t5_encoder
|
del t5_encoder
|
||||||
gc.collect(); mx.clear_cache()
|
gc.collect()
|
||||||
|
mx.clear_cache()
|
||||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
# I2V: encode image to latent space
|
# I2V: encode image to latent space
|
||||||
@@ -346,18 +373,25 @@ def generate_video(
|
|||||||
|
|
||||||
img = Image.open(image).convert("RGB")
|
img = Image.open(image).convert("RGB")
|
||||||
scale = max(width / img.width, height / img.height)
|
scale = max(width / img.width, height / img.height)
|
||||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
img = img.resize(
|
||||||
|
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||||
|
)
|
||||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||||
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
|
img_arr = mx.array(
|
||||||
|
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
|
||||||
|
) # [H, W, 3]
|
||||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||||
|
|
||||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||||
video = mx.concatenate([
|
video = mx.concatenate(
|
||||||
img_chw[:, None, :, :],
|
[
|
||||||
mx.zeros((3, num_frames - 1, height, width)),
|
img_chw[:, None, :, :],
|
||||||
], axis=1)
|
mx.zeros((3, num_frames - 1, height, width)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||||
vae_enc = load_vae_encoder(vae_path, config)
|
vae_enc = load_vae_encoder(vae_path, config)
|
||||||
@@ -367,12 +401,17 @@ def generate_video(
|
|||||||
|
|
||||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
msk = mx.concatenate(
|
||||||
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||||
|
)
|
||||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||||
msk = mx.concatenate([
|
msk = mx.concatenate(
|
||||||
mx.repeat(msk[:, :1], 4, axis=1),
|
[
|
||||||
msk[:, 1:],
|
mx.repeat(msk[:, :1], 4, axis=1),
|
||||||
], axis=1)
|
msk[:, 1:],
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||||
@@ -395,13 +434,16 @@ def generate_video(
|
|||||||
|
|
||||||
del vae_enc, img_tensor
|
del vae_enc, img_tensor
|
||||||
|
|
||||||
gc.collect(); mx.clear_cache()
|
gc.collect()
|
||||||
|
mx.clear_cache()
|
||||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
# Load transformer models
|
# Load transformer models
|
||||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||||
if quantization:
|
if quantization:
|
||||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
|
||||||
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
|
|
||||||
# Merge per-model LoRAs with shared LoRAs
|
# Merge per-model LoRAs with shared LoRAs
|
||||||
@@ -412,10 +454,16 @@ def generate_video(
|
|||||||
if is_dual:
|
if is_dual:
|
||||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||||
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
low_noise_model = load_wan_model(
|
||||||
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
low_noise_path, config, quantization, loras=_loras_low
|
||||||
|
)
|
||||||
|
high_noise_model = load_wan_model(
|
||||||
|
high_noise_path, config, quantization, loras=_loras_high
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
single_model = load_wan_model(
|
||||||
|
model_dir / "model.safetensors", config, quantization, loras=_loras_single
|
||||||
|
)
|
||||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||||
@@ -437,8 +485,12 @@ def generate_video(
|
|||||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||||
mx.eval(context_emb_low, context_emb_high)
|
mx.eval(context_emb_low, context_emb_high)
|
||||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
context_cfg_low = mx.concatenate(
|
||||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
|
||||||
|
)
|
||||||
|
context_cfg_high = mx.concatenate(
|
||||||
|
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
context_emb = single_model.embed_text([context, context_null])
|
context_emb = single_model.embed_text([context, context_null])
|
||||||
mx.eval(context_emb)
|
mx.eval(context_emb)
|
||||||
@@ -534,7 +586,7 @@ def generate_video(
|
|||||||
rcs = rope_cos_sin
|
rcs = rope_cos_sin
|
||||||
|
|
||||||
# Use compiled forward when available (faster after first trace)
|
# Use compiled forward when available (faster after first trace)
|
||||||
_call = getattr(model, '_compiled', model)
|
_call = getattr(model, "_compiled", model)
|
||||||
|
|
||||||
if cfg_disabled:
|
if cfg_disabled:
|
||||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||||
@@ -552,7 +604,9 @@ def generate_video(
|
|||||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||||
|
|
||||||
if is_dual:
|
if is_dual:
|
||||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
ctx = (
|
||||||
|
context_cond_high if timestep_val >= boundary else context_cond_low
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ctx = context_cond
|
ctx = context_cond
|
||||||
preds = _call(
|
preds = _call(
|
||||||
@@ -571,7 +625,11 @@ def generate_video(
|
|||||||
if is_dual:
|
if is_dual:
|
||||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||||
else:
|
else:
|
||||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
gs = (
|
||||||
|
guide_scale
|
||||||
|
if isinstance(guide_scale, (int, float))
|
||||||
|
else guide_scale[0]
|
||||||
|
)
|
||||||
|
|
||||||
if is_i2v_mask_blend:
|
if is_i2v_mask_blend:
|
||||||
t_tokens = i2v_mask_tokens * timestep_val
|
t_tokens = i2v_mask_tokens * timestep_val
|
||||||
@@ -586,8 +644,10 @@ def generate_video(
|
|||||||
|
|
||||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||||
|
|
||||||
ctx = context_cfg if not is_dual else (
|
ctx = (
|
||||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
context_cfg
|
||||||
|
if not is_dual
|
||||||
|
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
|
||||||
)
|
)
|
||||||
preds = _call(
|
preds = _call(
|
||||||
[latents, latents],
|
[latents, latents],
|
||||||
@@ -618,16 +678,24 @@ def generate_video(
|
|||||||
if debug_latents:
|
if debug_latents:
|
||||||
lat_np = np.array(latents) # [C, T, H, W]
|
lat_np = np.array(latents) # [C, T, H, W]
|
||||||
n_t = lat_np.shape[1]
|
n_t = lat_np.shape[1]
|
||||||
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
|
print(
|
||||||
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
|
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
|
||||||
|
)
|
||||||
for t_pos in range(min(n_t, 8)):
|
for t_pos in range(min(n_t, 8)):
|
||||||
frame = lat_np[:, t_pos, :, :]
|
frame = lat_np[:, t_pos, :, :]
|
||||||
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
print(
|
||||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
|
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||||
|
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
|
||||||
|
)
|
||||||
if n_t > 8:
|
if n_t > 8:
|
||||||
interior = lat_np[:, 4:, :, :]
|
interior = lat_np[:, 4:, :, :]
|
||||||
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
print(
|
||||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
|
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||||
|
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Free transformer models and text embeddings
|
# Free transformer models and text embeddings
|
||||||
@@ -646,7 +714,8 @@ def generate_video(
|
|||||||
del model, kv, context
|
del model, kv, context
|
||||||
if context_null is not None:
|
if context_null is not None:
|
||||||
del context_null
|
del context_null
|
||||||
gc.collect(); mx.clear_cache()
|
gc.collect()
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
# Load VAE and decode
|
# Load VAE and decode
|
||||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||||
@@ -677,13 +746,25 @@ def generate_video(
|
|||||||
elif tiling == "temporal":
|
elif tiling == "temporal":
|
||||||
tiling_config = TilingConfig.temporal_only()
|
tiling_config = TilingConfig.temporal_only()
|
||||||
else:
|
else:
|
||||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
|
||||||
|
)
|
||||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||||
|
|
||||||
if tiling_config is not None:
|
if tiling_config is not None:
|
||||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
spatial_info = (
|
||||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
|
||||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
if tiling_config.spatial_config
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
|
temporal_info = (
|
||||||
|
f"{tiling_config.temporal_config.tile_size_in_frames}f"
|
||||||
|
if tiling_config.temporal_config
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
@@ -718,7 +799,9 @@ def generate_video(
|
|||||||
if trim_first_frames > 0:
|
if trim_first_frames > 0:
|
||||||
trim_pixels = trim_first_frames * 4
|
trim_pixels = trim_first_frames * 4
|
||||||
video = video[trim_pixels:]
|
video = video[trim_pixels:]
|
||||||
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
|
print(
|
||||||
|
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
save_video(video, output_path, fps=config.sample_fps)
|
save_video(video, output_path, fps=config.sample_fps)
|
||||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||||
@@ -727,58 +810,124 @@ def generate_video(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
|
||||||
parser.add_argument("--image", type=str, default=None,
|
|
||||||
help="Path to input image for I2V (omit for T2V mode)")
|
|
||||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
|
||||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
|
||||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
|
||||||
help="Disable negative prompt (use empty string instead of config default)")
|
|
||||||
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
|
||||||
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
|
||||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
|
||||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
|
||||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
|
||||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
|
||||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
|
||||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scheduler", type=str, default="unipc",
|
"--model-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to converted MLX model directory",
|
||||||
|
)
|
||||||
|
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to input image for I2V (omit for T2V mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--negative-prompt",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Negative prompt for CFG (default: official Chinese prompt from config)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-negative-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable negative prompt (use empty string instead of config default)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--width", type=int, default=1280, help="Video width (default: 1280)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--height",
|
||||||
|
type=int,
|
||||||
|
default=704,
|
||||||
|
help="Video height (default: 704; 720p models use 704)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Number of diffusion steps (default: from config)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guide-scale",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Guidance scale: single float or low,high pair",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shift",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Noise schedule shift (default: from config)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-path", type=str, default="output.mp4", help="Output video path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scheduler",
|
||||||
|
type=str,
|
||||||
|
default="unipc",
|
||||||
choices=["euler", "dpm++", "unipc"],
|
choices=["euler", "dpm++", "unipc"],
|
||||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
"--lora",
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("PATH", "STRENGTH"),
|
||||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
"--lora-high",
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("PATH", "STRENGTH"),
|
||||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
"--lora-low",
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("PATH", "STRENGTH"),
|
||||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tiling",
|
"--tiling",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
choices=[
|
||||||
|
"auto",
|
||||||
|
"none",
|
||||||
|
"default",
|
||||||
|
"aggressive",
|
||||||
|
"conservative",
|
||||||
|
"spatial",
|
||||||
|
"temporal",
|
||||||
|
],
|
||||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-compile", action="store_true",
|
"--no-compile",
|
||||||
|
action="store_true",
|
||||||
help="Disable mx.compile on models (for debugging)",
|
help="Disable mx.compile on models (for debugging)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trim-first-frames", type=int, default=0, metavar="N",
|
"--trim-first-frames",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
metavar="N",
|
||||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||||
"Default: 0 (disabled)",
|
"Default: 0 (disabled)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug-latents", action="store_true",
|
"--debug-latents",
|
||||||
|
action="store_true",
|
||||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
|||||||
|
|
||||||
# Resize so that the image covers the target size (LANCZOS)
|
# Resize so that the image covers the target size (LANCZOS)
|
||||||
scale = max(width / img.width, height / img.height)
|
scale = max(width / img.width, height / img.height)
|
||||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
img = img.resize(
|
||||||
|
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||||
|
)
|
||||||
|
|
||||||
# Center crop
|
# Center crop
|
||||||
x1 = (img.width - width) // 2
|
x1 = (img.width - width) // 2
|
||||||
|
|||||||
@@ -6,7 +6,12 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
def load_wan_model(
|
||||||
|
model_path: Path,
|
||||||
|
config,
|
||||||
|
quantization: dict | None = None,
|
||||||
|
loras: list | None = None,
|
||||||
|
):
|
||||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None):
|
|||||||
|
|
||||||
if is_wan22:
|
if is_wan22:
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
|
||||||
vae = Wan22VAEDecoder(z_dim=48)
|
vae = Wan22VAEDecoder(z_dim=48)
|
||||||
else:
|
else:
|
||||||
from mlx_video.models.wan.vae import WanVAE
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
|
||||||
vae = WanVAE(z_dim=16)
|
vae = WanVAE(z_dim=16)
|
||||||
|
|
||||||
weights = mx.load(str(model_path))
|
weights = mx.load(str(model_path))
|
||||||
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
|
|
||||||
text = ftfy.fix_text(text)
|
text = ftfy.fix_text(text)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -37,7 +38,9 @@ class Head(nn.Module):
|
|||||||
proj_dim = math.prod(patch_size) * out_dim
|
proj_dim = math.prod(patch_size) * out_dim
|
||||||
self.norm = WanLayerNorm(dim, eps)
|
self.norm = WanLayerNorm(dim, eps)
|
||||||
self.head = nn.Linear(dim, proj_dim)
|
self.head = nn.Linear(dim, proj_dim)
|
||||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
|
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
|
||||||
|
mx.float32
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
|
|||||||
# Reference computes three rope_params with different dim normalizations
|
# Reference computes three rope_params with different dim normalizations
|
||||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||||
d = dim // config.num_heads
|
d = dim // config.num_heads
|
||||||
self.freqs = mx.concatenate([
|
self.freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Precompute sinusoidal inv_freq for time embedding.
|
# Precompute sinusoidal inv_freq for time embedding.
|
||||||
half = config.freq_dim // 2
|
half = config.freq_dim // 2
|
||||||
self._inv_freq = mx.array(
|
self._inv_freq = mx.array(
|
||||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
|
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
|
||||||
).astype(np.float32)
|
np.float32
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _patchify(self, x: mx.array) -> tuple:
|
def _patchify(self, x: mx.array) -> tuple:
|
||||||
"""Convert video tensor to patch embeddings.
|
"""Convert video tensor to patch embeddings.
|
||||||
|
|
||||||
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
|
|||||||
seq_lens_list.append(p.shape[1])
|
seq_lens_list.append(p.shape[1])
|
||||||
x = mx.concatenate(
|
x = mx.concatenate(
|
||||||
[
|
[
|
||||||
mx.concatenate(
|
(
|
||||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
mx.concatenate(
|
||||||
axis=1,
|
[
|
||||||
|
p,
|
||||||
|
mx.zeros(
|
||||||
|
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
|
||||||
|
),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
if p.shape[1] < seq_len
|
||||||
|
else p
|
||||||
)
|
)
|
||||||
if p.shape[1] < seq_len
|
|
||||||
else p
|
|
||||||
for p in patches
|
for p in patches
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
|
|||||||
t = t[None]
|
t = t[None]
|
||||||
|
|
||||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||||
sin_emb = mx.concatenate(
|
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||||
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
if t.ndim == 1:
|
if t.ndim == 1:
|
||||||
# Standard T2V: scalar timestep per batch element [B]
|
# Standard T2V: scalar timestep per batch element [B]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||||
"""Save video frames to MP4.
|
"""Save video frames to MP4.
|
||||||
|
|
||||||
@@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import imageio
|
import imageio
|
||||||
|
|
||||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
@@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
h, w = frames.shape[1], frames.shape[2]
|
h, w = frames.shape[1], frames.shape[2]
|
||||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||||
@@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
|||||||
except (ImportError, Exception):
|
except (ImportError, Exception):
|
||||||
# Last resort: save as individual PNGs
|
# Last resort: save as individual PNGs
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for i, frame in enumerate(frames):
|
for i, frame in enumerate(frames):
|
||||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||||
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
print(
|
||||||
|
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
|||||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||||
"""
|
"""
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
|
freqs = (
|
||||||
1.0
|
np.arange(max_seq_len, dtype=np.float64)[:, None]
|
||||||
/ np.power(
|
* (
|
||||||
theta,
|
1.0
|
||||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
/ np.power(
|
||||||
)
|
theta,
|
||||||
)[None, :]
|
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||||
|
)
|
||||||
|
)[None, :]
|
||||||
|
)
|
||||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||||
@@ -46,9 +48,9 @@ def rope_apply(
|
|||||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||||
f0, h0, w0 = grid_sizes[0]
|
f0, h0, w0 = grid_sizes[0]
|
||||||
seq_len = f0 * h0 * w0
|
seq_len = f0 * h0 * w0
|
||||||
all_same_grid = all(
|
all_same_grid = (
|
||||||
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
|
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
|
||||||
) if b > 1 else True
|
)
|
||||||
|
|
||||||
if all_same_grid:
|
if all_same_grid:
|
||||||
# Vectorized path: apply RoPE to all batch elements at once
|
# Vectorized path: apply RoPE to all batch elements at once
|
||||||
@@ -57,7 +59,9 @@ def rope_apply(
|
|||||||
x_imag = x_seq[..., 1]
|
x_imag = x_seq[..., 1]
|
||||||
out_real = x_real * cos_f - x_imag * sin_f
|
out_real = x_real * cos_f - x_imag * sin_f
|
||||||
out_imag = x_real * sin_f + x_imag * cos_f
|
out_imag = x_real * sin_f + x_imag * cos_f
|
||||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
|
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
|
||||||
|
b, seq_len, n, d
|
||||||
|
)
|
||||||
if seq_len < s:
|
if seq_len < s:
|
||||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||||
return x_rotated
|
return x_rotated
|
||||||
@@ -102,17 +106,11 @@ def rope_apply(
|
|||||||
|
|
||||||
# Build per-position frequencies by expanding along grid dims
|
# Build per-position frequencies by expanding along grid dims
|
||||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||||
ft = mx.broadcast_to(
|
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||||
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
|
|
||||||
)
|
|
||||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||||
fh = mx.broadcast_to(
|
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||||
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
|
|
||||||
)
|
|
||||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||||
fw = mx.broadcast_to(
|
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||||
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concatenate: [f*h*w, half_d, 2]
|
# Concatenate: [f*h*w, half_d, 2]
|
||||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ for the same quality as Euler.
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def _compute_sigmas(
|
def _compute_sigmas(
|
||||||
@@ -25,9 +24,7 @@ def _compute_sigmas(
|
|||||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||||
"""
|
"""
|
||||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
|
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
|
||||||
::-1
|
|
||||||
]
|
|
||||||
sigmas_unshifted = 1.0 - alphas
|
sigmas_unshifted = 1.0 - alphas
|
||||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||||
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
|
|||||||
sample: mx.array,
|
sample: mx.array,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||||
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
|
dt = (
|
||||||
|
self._sigmas_float[self._step_index + 1]
|
||||||
|
- self._sigmas_float[self._step_index]
|
||||||
|
)
|
||||||
x_next = sample + dt * model_output
|
x_next = sample + dt * model_output
|
||||||
self._step_index += 1
|
self._step_index += 1
|
||||||
return x_next
|
return x_next
|
||||||
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
|
|||||||
|
|
||||||
# Decide order: 1st for first step, last step (if lower_order_final
|
# Decide order: 1st for first step, last step (if lower_order_final
|
||||||
# and few steps), otherwise 2nd
|
# and few steps), otherwise 2nd
|
||||||
use_first_order = (
|
use_first_order = self._prev_x0 is None or (
|
||||||
self._prev_x0 is None
|
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
|
||||||
or (
|
|
||||||
self.lower_order_final
|
|
||||||
and i == self._num_steps - 1
|
|
||||||
and self._num_steps < 15
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_first_order or sigma_next == 0.0:
|
if use_first_order or sigma_next == 0.0:
|
||||||
|
|||||||
@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
|
|||||||
is_small = rel_pos < max_exact
|
is_small = rel_pos < max_exact
|
||||||
|
|
||||||
rel_pos_f = rel_pos.astype(mx.float32)
|
rel_pos_f = rel_pos.astype(mx.float32)
|
||||||
rel_pos_large = (
|
rel_pos_large = max_exact + (
|
||||||
max_exact
|
mx.log(rel_pos_f / max_exact)
|
||||||
+ (
|
/ math.log(self.max_dist / max_exact)
|
||||||
mx.log(rel_pos_f / max_exact)
|
* (num_buckets - max_exact)
|
||||||
/ math.log(self.max_dist / max_exact)
|
).astype(mx.int32)
|
||||||
* (num_buckets - max_exact)
|
|
||||||
).astype(mx.int32)
|
|
||||||
)
|
|
||||||
rel_pos_large = mx.minimum(
|
rel_pos_large = mx.minimum(
|
||||||
rel_pos_large,
|
rel_pos_large,
|
||||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||||
)
|
)
|
||||||
|
|
||||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
|
rel_buckets = rel_buckets + mx.where(
|
||||||
|
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||||
|
)
|
||||||
return rel_buckets
|
return rel_buckets
|
||||||
|
|
||||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||||
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
|
|||||||
v = v.transpose(0, 2, 1, 3)
|
v = v.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
# QK^T (no scaling) — compute in float32 for precision
|
# QK^T (no scaling) — compute in float32 for precision
|
||||||
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
|
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||||
|
|
||||||
# Add position bias
|
# Add position bias
|
||||||
if pos_bias is not None:
|
if pos_bias is not None:
|
||||||
|
|||||||
@@ -75,7 +75,11 @@ def decode_with_tiling(
|
|||||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||||
|
|
||||||
# Compute output shape
|
# Compute output shape
|
||||||
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
out_f = (
|
||||||
|
(1 + (f_latent - 1) * temporal_scale)
|
||||||
|
if causal_temporal
|
||||||
|
else (f_latent * temporal_scale)
|
||||||
|
)
|
||||||
out_h = h_latent * spatial_scale
|
out_h = h_latent * spatial_scale
|
||||||
out_w = w_latent * spatial_scale
|
out_w = w_latent * spatial_scale
|
||||||
|
|
||||||
@@ -98,9 +102,13 @@ def decode_with_tiling(
|
|||||||
|
|
||||||
# Compute intervals for each dimension
|
# Compute intervals for each dimension
|
||||||
if causal_temporal:
|
if causal_temporal:
|
||||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
temporal_intervals = split_in_temporal(
|
||||||
|
temporal_tile_size, temporal_overlap, f_latent
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
temporal_intervals = split_in_spatial(
|
||||||
|
temporal_tile_size, temporal_overlap, f_latent
|
||||||
|
)
|
||||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||||
|
|
||||||
@@ -124,9 +132,13 @@ def decode_with_tiling(
|
|||||||
|
|
||||||
# Map temporal coordinates
|
# Map temporal coordinates
|
||||||
if causal_temporal:
|
if causal_temporal:
|
||||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
out_t_slice, t_mask = map_temporal_slice(
|
||||||
|
t_start, t_end, t_left, t_right, temporal_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
out_t_slice, t_mask = map_spatial_slice(
|
||||||
|
t_start, t_end, t_left, t_right, temporal_scale
|
||||||
|
)
|
||||||
|
|
||||||
for h_idx in range(num_h_tiles):
|
for h_idx in range(num_h_tiles):
|
||||||
h_start = height_intervals.starts[h_idx]
|
h_start = height_intervals.starts[h_idx]
|
||||||
@@ -135,7 +147,9 @@ def decode_with_tiling(
|
|||||||
h_right = height_intervals.right_ramps[h_idx]
|
h_right = height_intervals.right_ramps[h_idx]
|
||||||
|
|
||||||
# Map height coordinates
|
# Map height coordinates
|
||||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
out_h_slice, h_mask = map_spatial_slice(
|
||||||
|
h_start, h_end, h_left, h_right, spatial_scale
|
||||||
|
)
|
||||||
|
|
||||||
for w_idx in range(num_w_tiles):
|
for w_idx in range(num_w_tiles):
|
||||||
w_start = width_intervals.starts[w_idx]
|
w_start = width_intervals.starts[w_idx]
|
||||||
@@ -144,13 +158,23 @@ def decode_with_tiling(
|
|||||||
w_right = width_intervals.right_ramps[w_idx]
|
w_right = width_intervals.right_ramps[w_idx]
|
||||||
|
|
||||||
# Map width coordinates
|
# Map width coordinates
|
||||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
out_w_slice, w_mask = map_spatial_slice(
|
||||||
|
w_start, w_end, w_left, w_right, spatial_scale
|
||||||
|
)
|
||||||
|
|
||||||
# Extract tile latents (small slice)
|
# Extract tile latents (small slice)
|
||||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
tile_latents = latents[
|
||||||
|
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||||
|
]
|
||||||
|
|
||||||
# Decode tile
|
# Decode tile
|
||||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
tile_output = decoder_fn(
|
||||||
|
tile_latents,
|
||||||
|
causal=causal,
|
||||||
|
timestep=timestep,
|
||||||
|
debug=False,
|
||||||
|
chunked_conv=chunked_conv,
|
||||||
|
)
|
||||||
mx.eval(tile_output)
|
mx.eval(tile_output)
|
||||||
|
|
||||||
# Clear tile_latents reference
|
# Clear tile_latents reference
|
||||||
@@ -173,13 +197,15 @@ def decode_with_tiling(
|
|||||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||||
|
|
||||||
blend_mask = (
|
blend_mask = (
|
||||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slice tile output to match
|
# Slice tile output to match
|
||||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
tile_output_slice = tile_output[
|
||||||
|
:, :, :actual_t, :actual_h, :actual_w
|
||||||
|
].astype(mx.float32)
|
||||||
|
|
||||||
# Clear full tile_output
|
# Clear full tile_output
|
||||||
del tile_output
|
del tile_output
|
||||||
@@ -196,11 +222,37 @@ def decode_with_tiling(
|
|||||||
weighted_tile = tile_output_slice * blend_mask
|
weighted_tile = tile_output_slice * blend_mask
|
||||||
|
|
||||||
# Update output using slice assignment
|
# Update output using slice assignment
|
||||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
output[
|
||||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
] = (
|
||||||
|
output[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
]
|
||||||
|
+ weighted_tile
|
||||||
)
|
)
|
||||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
weights[
|
||||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
] = (
|
||||||
|
weights[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
t_out_start:t_out_end,
|
||||||
|
h_out_start:h_out_end,
|
||||||
|
w_out_start:w_out_end,
|
||||||
|
]
|
||||||
|
+ blend_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# Force evaluation to free memory
|
# Force evaluation to free memory
|
||||||
@@ -232,12 +284,14 @@ def decode_with_tiling(
|
|||||||
if next_tile_start_latent == 0:
|
if next_tile_start_latent == 0:
|
||||||
next_tile_start_out = 0
|
next_tile_start_out = 0
|
||||||
elif causal_temporal:
|
elif causal_temporal:
|
||||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
next_tile_start_out = (
|
||||||
|
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||||
|
|
||||||
# We need to track how many frames we've already emitted
|
# We need to track how many frames we've already emitted
|
||||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||||
decode_with_tiling._emitted_frames = 0
|
decode_with_tiling._emitted_frames = 0
|
||||||
emitted = decode_with_tiling._emitted_frames
|
emitted = decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
@@ -245,7 +299,10 @@ def decode_with_tiling(
|
|||||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
finalized_output = (
|
||||||
|
output[:, :, emitted:next_tile_start_out, :, :]
|
||||||
|
/ finalized_weights
|
||||||
|
)
|
||||||
finalized_output = finalized_output.astype(latents.dtype)
|
finalized_output = finalized_output.astype(latents.dtype)
|
||||||
mx.eval(finalized_output)
|
mx.eval(finalized_output)
|
||||||
|
|
||||||
@@ -262,7 +319,7 @@ def decode_with_tiling(
|
|||||||
|
|
||||||
# Emit remaining frames if callback provided
|
# Emit remaining frames if callback provided
|
||||||
if on_frames_ready is not None:
|
if on_frames_ready is not None:
|
||||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||||
if emitted < out_f:
|
if emitted < out_f:
|
||||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||||
mx.eval(remaining_output)
|
mx.eval(remaining_output)
|
||||||
@@ -270,7 +327,7 @@ def decode_with_tiling(
|
|||||||
del remaining_output
|
del remaining_output
|
||||||
|
|
||||||
# Reset emitted frames counter for next call
|
# Reset emitted frames counter for next call
|
||||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||||
del decode_with_tiling._emitted_frames
|
del decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
# Clean up weights
|
# Clean up weights
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# Cross-attention (with optional norm on context)
|
# Cross-attention (with optional norm on context)
|
||||||
self.norm3 = (
|
self.norm3 = (
|
||||||
WanLayerNorm(dim, eps, elementwise_affine=True)
|
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
|
||||||
if cross_attn_norm
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||||
|
|
||||||
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
|
|||||||
self.ffn = WanFFN(dim, ffn_dim)
|
self.ffn = WanFFN(dim, ffn_dim)
|
||||||
|
|
||||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
|
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
|
||||||
|
mx.float32
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
y = self.self_attn(
|
||||||
|
x_mod,
|
||||||
|
seq_lens,
|
||||||
|
grid_sizes,
|
||||||
|
freqs,
|
||||||
|
rope_cos_sin=rope_cos_sin,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
x = x + y * e2
|
x = x + y * e2
|
||||||
|
|
||||||
# Cross-attention (no modulation, just norm)
|
# Cross-attention (no modulation, just norm)
|
||||||
|
|||||||
@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
# Per-channel normalization statistics for z_dim=16
|
# Per-channel normalization statistics for z_dim=16
|
||||||
VAE_MEAN = [
|
VAE_MEAN = [
|
||||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
-0.7571,
|
||||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
-0.7089,
|
||||||
|
-0.9113,
|
||||||
|
0.1075,
|
||||||
|
-0.1745,
|
||||||
|
0.9653,
|
||||||
|
-0.1517,
|
||||||
|
1.5508,
|
||||||
|
0.4134,
|
||||||
|
-0.0715,
|
||||||
|
0.5517,
|
||||||
|
-0.3632,
|
||||||
|
-0.1922,
|
||||||
|
-0.9497,
|
||||||
|
0.2503,
|
||||||
|
-0.2921,
|
||||||
]
|
]
|
||||||
VAE_STD = [
|
VAE_STD = [
|
||||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
2.8184,
|
||||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
1.4541,
|
||||||
|
2.3275,
|
||||||
|
2.6558,
|
||||||
|
1.2196,
|
||||||
|
1.7708,
|
||||||
|
2.6052,
|
||||||
|
2.0743,
|
||||||
|
3.2687,
|
||||||
|
2.1526,
|
||||||
|
2.8652,
|
||||||
|
1.5579,
|
||||||
|
1.6382,
|
||||||
|
1.1253,
|
||||||
|
2.8251,
|
||||||
|
1.9160,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
|
|||||||
self._pad_w = padding[2]
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
self.weight = mx.zeros(
|
||||||
|
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||||
|
)
|
||||||
self.bias = mx.zeros((out_channels,))
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||||
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
|
|||||||
x = mx.concatenate([pad_t, x], axis=2)
|
x = mx.concatenate([pad_t, x], axis=2)
|
||||||
|
|
||||||
if self._pad_h > 0 or self._pad_w > 0:
|
if self._pad_h > 0 or self._pad_w > 0:
|
||||||
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
|
x = mx.pad(
|
||||||
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
|
x,
|
||||||
|
[
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
(self._pad_h, self._pad_h),
|
||||||
|
(self._pad_w, self._pad_w),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||||
out = self._conv3d(x)
|
out = self._conv3d(x)
|
||||||
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
|
|||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
norm_dim = 1 if self.channel_first else -1
|
norm_dim = 1 if self.channel_first else -1
|
||||||
# L2 normalize along channel dim (matches F.normalize)
|
# L2 normalize along channel dim (matches F.normalize)
|
||||||
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
|
norm = mx.sqrt(
|
||||||
|
mx.clip(
|
||||||
|
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
|
||||||
|
)
|
||||||
|
)
|
||||||
return (x / norm) * self.scale * self.gamma
|
return (x / norm) * self.scale * self.gamma
|
||||||
|
|
||||||
|
|
||||||
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
|
|||||||
def __init__(self, in_dim: int, out_dim: int):
|
def __init__(self, in_dim: int, out_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.residual = [
|
self.residual = [
|
||||||
RMS_norm(in_dim, images=False), # [0]
|
RMS_norm(in_dim, images=False), # [0]
|
||||||
None, # [1] SiLU
|
None, # [1] SiLU
|
||||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||||
RMS_norm(out_dim, images=False), # [3]
|
RMS_norm(out_dim, images=False), # [3]
|
||||||
None, # [4] SiLU
|
None, # [4] SiLU
|
||||||
None, # [5] Dropout
|
None, # [5] Dropout
|
||||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||||
]
|
]
|
||||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||||
@@ -226,13 +266,16 @@ class Resample(nn.Module):
|
|||||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||||
if mode == "upsample3d":
|
if mode == "upsample3d":
|
||||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||||
if mode == "downsample3d":
|
if mode == "downsample3d":
|
||||||
self.time_conv = CausalConv3d(
|
self.time_conv = CausalConv3d(
|
||||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||||
"""x: [B, C, T, H, W]"""
|
"""x: [B, C, T, H, W]"""
|
||||||
@@ -272,8 +315,7 @@ class Resample(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Subsequent chunks: use cached frame as temporal context
|
# Subsequent chunks: use cached frame as temporal context
|
||||||
cache_x = x[:, :, -1:]
|
cache_x = x[:, :, -1:]
|
||||||
x = self.time_conv(
|
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
|
||||||
x, cache_x=feat_cache[idx][:, :, -1:])
|
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||||
self.head = [
|
self.head = [
|
||||||
RMS_norm(dims[-1], images=False), # [0]
|
RMS_norm(dims[-1], images=False), # [0]
|
||||||
None, # [1] SiLU
|
None, # [1] SiLU
|
||||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:]
|
cache_x = x[:, :, -CACHE_T:]
|
||||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||||
cache_x = mx.concatenate(
|
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
|
||||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:]
|
cache_x = x[:, :, -CACHE_T:]
|
||||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||||
cache_x = mx.concatenate(
|
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
|
||||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
|
|||||||
decoder_fn=tile_decode,
|
decoder_fn=tile_decode,
|
||||||
latents=z_denorm,
|
latents=z_denorm,
|
||||||
tiling_config=tiling_config,
|
tiling_config=tiling_config,
|
||||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
|
|||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
# Per-channel normalization for z_dim=48 latent space
|
# Per-channel normalization for z_dim=48 latent space
|
||||||
VAE22_MEAN = mx.array([
|
VAE22_MEAN = mx.array(
|
||||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
[
|
||||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
-0.2289,
|
||||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
-0.0052,
|
||||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
-0.1323,
|
||||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
-0.2339,
|
||||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
-0.2799,
|
||||||
])
|
0.0174,
|
||||||
|
0.1838,
|
||||||
|
0.1557,
|
||||||
|
-0.1382,
|
||||||
|
0.0542,
|
||||||
|
0.2813,
|
||||||
|
0.0891,
|
||||||
|
0.1570,
|
||||||
|
-0.0098,
|
||||||
|
0.0375,
|
||||||
|
-0.1825,
|
||||||
|
-0.2246,
|
||||||
|
-0.1207,
|
||||||
|
-0.0698,
|
||||||
|
0.5109,
|
||||||
|
0.2665,
|
||||||
|
-0.2108,
|
||||||
|
-0.2158,
|
||||||
|
0.2502,
|
||||||
|
-0.2055,
|
||||||
|
-0.0322,
|
||||||
|
0.1109,
|
||||||
|
0.1567,
|
||||||
|
-0.0729,
|
||||||
|
0.0899,
|
||||||
|
-0.2799,
|
||||||
|
-0.1230,
|
||||||
|
-0.0313,
|
||||||
|
-0.1649,
|
||||||
|
0.0117,
|
||||||
|
0.0723,
|
||||||
|
-0.2839,
|
||||||
|
-0.2083,
|
||||||
|
-0.0520,
|
||||||
|
0.3748,
|
||||||
|
0.0152,
|
||||||
|
0.1957,
|
||||||
|
0.1433,
|
||||||
|
-0.2944,
|
||||||
|
0.3573,
|
||||||
|
-0.0548,
|
||||||
|
-0.1681,
|
||||||
|
-0.0667,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
VAE22_STD = mx.array([
|
VAE22_STD = mx.array(
|
||||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
[
|
||||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
0.4765,
|
||||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
1.0364,
|
||||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
0.4514,
|
||||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
1.1677,
|
||||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
|
0.5313,
|
||||||
])
|
0.4990,
|
||||||
|
0.4818,
|
||||||
|
0.5013,
|
||||||
|
0.8158,
|
||||||
|
1.0344,
|
||||||
|
0.5894,
|
||||||
|
1.0901,
|
||||||
|
0.6885,
|
||||||
|
0.6165,
|
||||||
|
0.8454,
|
||||||
|
0.4978,
|
||||||
|
0.5759,
|
||||||
|
0.3523,
|
||||||
|
0.7135,
|
||||||
|
0.6804,
|
||||||
|
0.5833,
|
||||||
|
1.4146,
|
||||||
|
0.8986,
|
||||||
|
0.5659,
|
||||||
|
0.7069,
|
||||||
|
0.5338,
|
||||||
|
0.4889,
|
||||||
|
0.4917,
|
||||||
|
0.4069,
|
||||||
|
0.4999,
|
||||||
|
0.6866,
|
||||||
|
0.4093,
|
||||||
|
0.5709,
|
||||||
|
0.6065,
|
||||||
|
0.6415,
|
||||||
|
0.4944,
|
||||||
|
0.5726,
|
||||||
|
1.2042,
|
||||||
|
0.5458,
|
||||||
|
1.6887,
|
||||||
|
0.3971,
|
||||||
|
1.0600,
|
||||||
|
0.3943,
|
||||||
|
0.5537,
|
||||||
|
0.5444,
|
||||||
|
0.4089,
|
||||||
|
0.7468,
|
||||||
|
0.7744,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
|
|||||||
self._pad_w = padding[2]
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
# Weight: [O, D, H, W, I] for MLX
|
# Weight: [O, D, H, W, I] for MLX
|
||||||
self.weight = mx.zeros((
|
self.weight = mx.zeros(
|
||||||
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
|
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||||
))
|
)
|
||||||
self.bias = mx.zeros((out_channels,))
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
def __call__(self, x, cache_x=None):
|
def __call__(self, x, cache_x=None):
|
||||||
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
|
|||||||
|
|
||||||
# Spatial padding
|
# Spatial padding
|
||||||
if self._pad_h > 0 or self._pad_w > 0:
|
if self._pad_h > 0 or self._pad_w > 0:
|
||||||
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
|
x = mx.pad(
|
||||||
(self._pad_w, self._pad_w), (0, 0)])
|
x,
|
||||||
|
[
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
(self._pad_h, self._pad_h),
|
||||||
|
(self._pad_w, self._pad_w),
|
||||||
|
(0, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
T_padded = x.shape[1]
|
T_padded = x.shape[1]
|
||||||
H_padded, W_padded = x.shape[2], x.shape[3]
|
H_padded, W_padded = x.shape[2], x.shape[3]
|
||||||
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
|
|||||||
for d in range(kd):
|
for d in range(kd):
|
||||||
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
||||||
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
||||||
conv_out = mx.conv_general(frame, w2d,
|
conv_out = mx.conv_general(
|
||||||
stride=(self.stride[1], self.stride[2]))
|
frame, w2d, stride=(self.stride[1], self.stride[2])
|
||||||
|
)
|
||||||
accum = conv_out if accum is None else accum + conv_out
|
accum = conv_out if accum is None else accum + conv_out
|
||||||
outputs.append(accum + self.bias)
|
outputs.append(accum + self.bias)
|
||||||
|
|
||||||
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** 0.5
|
self.scale = dim**0.5
|
||||||
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
||||||
self.gamma = mx.ones((dim,))
|
self.gamma = mx.ones((dim,))
|
||||||
|
|
||||||
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
|
|||||||
# x: [..., C] (channels-last)
|
# x: [..., C] (channels-last)
|
||||||
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
||||||
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
||||||
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
return (
|
||||||
|
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
|
|||||||
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
||||||
# We store as named layers matching PyTorch's indices
|
# We store as named layers matching PyTorch's indices
|
||||||
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
||||||
self.shortcut = (
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||||
CausalConv3d(in_dim, out_dim, 1)
|
|
||||||
if in_dim != out_dim
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x, feat_cache=None, feat_idx=None):
|
def __call__(self, x, feat_cache=None, feat_idx=None):
|
||||||
h = self.shortcut(x) if self.shortcut is not None else x
|
h = self.shortcut(x) if self.shortcut is not None else x
|
||||||
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
|
|||||||
# Save last CACHE_T frames before conv (for next chunk's context)
|
# Save last CACHE_T frames before conv (for next chunk's context)
|
||||||
cache_x = x[:, -CACHE_T:]
|
cache_x = x[:, -CACHE_T:]
|
||||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
cache_x = mx.concatenate(
|
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
|
||||||
)
|
|
||||||
out = conv(x, cache_x=feat_cache[idx])
|
out = conv(x, cache_x=feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
||||||
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
|
qkv = (
|
||||||
|
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
|
||||||
|
) # [BT, H, W, 3C]
|
||||||
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
||||||
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
||||||
|
|
||||||
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
|
|||||||
k = k[:, None, :, :]
|
k = k[:, None, :, :]
|
||||||
v = v[:, None, :, :]
|
v = v[:, None, :, :]
|
||||||
|
|
||||||
scale = C ** -0.5
|
scale = C**-0.5
|
||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale
|
||||||
|
) # [BT, 1, HW, C]
|
||||||
out = out.squeeze(1).reshape(B * T, H, W, C)
|
out = out.squeeze(1).reshape(B * T, H, W, C)
|
||||||
|
|
||||||
# Project output
|
# Project output
|
||||||
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
|
|||||||
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
||||||
|
|
||||||
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
||||||
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
|
x = x.reshape(
|
||||||
|
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
|
||||||
|
)
|
||||||
|
|
||||||
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
||||||
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
||||||
|
|
||||||
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
||||||
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
|
x = x.reshape(
|
||||||
|
B,
|
||||||
|
T * self.factor_t,
|
||||||
|
H * self.factor_s,
|
||||||
|
W * self.factor_s,
|
||||||
|
self.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
if first_chunk:
|
if first_chunk:
|
||||||
x = x[:, self.factor_t - 1:, :, :, :]
|
x = x[:, self.factor_t - 1 :, :, :, :]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -348,7 +452,9 @@ class Resample(nn.Module):
|
|||||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||||
self.resample_bias = mx.zeros((dim,))
|
self.resample_bias = mx.zeros((dim,))
|
||||||
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||||
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported mode: {mode}")
|
raise ValueError(f"Unsupported mode: {mode}")
|
||||||
|
|
||||||
@@ -369,7 +475,9 @@ class Resample(nn.Module):
|
|||||||
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||||
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
return (
|
||||||
|
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
|
||||||
# x: [B, T, H, W, C]
|
# x: [B, T, H, W, C]
|
||||||
@@ -444,14 +552,17 @@ class Resample(nn.Module):
|
|||||||
class Up_ResidualBlock(nn.Module):
|
class Up_ResidualBlock(nn.Module):
|
||||||
"""Upsampling residual block with optional DupUp3D shortcut."""
|
"""Upsampling residual block with optional DupUp3D shortcut."""
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
|
def __init__(
|
||||||
|
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up_flag = up_flag
|
self.up_flag = up_flag
|
||||||
|
|
||||||
# DupUp3D shortcut (no learnable params)
|
# DupUp3D shortcut (no learnable params)
|
||||||
if up_flag:
|
if up_flag:
|
||||||
self.avg_shortcut = DupUp3D(
|
self.avg_shortcut = DupUp3D(
|
||||||
in_dim, out_dim,
|
in_dim,
|
||||||
|
out_dim,
|
||||||
factor_t=2 if temperal_upsample else 1,
|
factor_t=2 if temperal_upsample else 1,
|
||||||
factor_s=2 if up_flag else 1,
|
factor_s=2 if up_flag else 1,
|
||||||
)
|
)
|
||||||
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
class Down_ResidualBlock(nn.Module):
|
class Down_ResidualBlock(nn.Module):
|
||||||
"""Downsampling residual block with AvgDown3D shortcut."""
|
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
num_res_blocks,
|
||||||
|
temperal_downsample=False,
|
||||||
|
down_flag=False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.down_flag = down_flag
|
self.down_flag = down_flag
|
||||||
|
|
||||||
# AvgDown3D shortcut (no learnable params, always present)
|
# AvgDown3D shortcut (no learnable params, always present)
|
||||||
self.avg_shortcut = AvgDown3D(
|
self.avg_shortcut = AvgDown3D(
|
||||||
in_dim, out_dim,
|
in_dim,
|
||||||
|
out_dim,
|
||||||
factor_t=2 if temperal_downsample else 1,
|
factor_t=2 if temperal_downsample else 1,
|
||||||
factor_s=2 if down_flag else 1,
|
factor_s=2 if down_flag else 1,
|
||||||
)
|
)
|
||||||
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
|
|||||||
self.upsamples = []
|
self.upsamples = []
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||||
self.upsamples.append(Up_ResidualBlock(
|
self.upsamples.append(
|
||||||
in_dim=in_dim,
|
Up_ResidualBlock(
|
||||||
out_dim=out_dim,
|
in_dim=in_dim,
|
||||||
num_res_blocks=num_res_blocks + 1,
|
out_dim=out_dim,
|
||||||
temperal_upsample=t_up,
|
num_res_blocks=num_res_blocks + 1,
|
||||||
up_flag=(i != len(dim_mult) - 1),
|
temperal_upsample=t_up,
|
||||||
))
|
up_flag=(i != len(dim_mult) - 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
||||||
self.head = Head22(dims[-1])
|
self.head = Head22(dims[-1])
|
||||||
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
|
|||||||
for i in range(len(dim_mult)):
|
for i in range(len(dim_mult)):
|
||||||
in_d, out_d = dims[i], dims[i + 1]
|
in_d, out_d = dims[i], dims[i + 1]
|
||||||
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||||
self.downsamples.append(Down_ResidualBlock(
|
self.downsamples.append(
|
||||||
in_dim=in_d,
|
Down_ResidualBlock(
|
||||||
out_dim=out_d,
|
in_dim=in_d,
|
||||||
num_res_blocks=num_res_blocks,
|
out_dim=out_d,
|
||||||
temperal_downsample=t_down,
|
num_res_blocks=num_res_blocks,
|
||||||
down_flag=(i < len(dim_mult) - 1),
|
temperal_downsample=t_down,
|
||||||
))
|
down_flag=(i < len(dim_mult) - 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Middle blocks (same as decoder)
|
# Middle blocks (same as decoder)
|
||||||
out_dim = dims[-1]
|
out_dim = dims[-1]
|
||||||
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, -CACHE_T:]
|
cache_x = x[:, -CACHE_T:]
|
||||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
cache_x = mx.concatenate(
|
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
|
||||||
)
|
|
||||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@@ -700,9 +821,7 @@ class Head22(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, -CACHE_T:]
|
cache_x = x[:, -CACHE_T:]
|
||||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||||
cache_x = mx.concatenate(
|
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
|
||||||
[feat_cache[idx][:, -1:], cache_x], axis=1
|
|
||||||
)
|
|
||||||
x = self.layer_2(x, cache_x=feat_cache[idx])
|
x = self.layer_2(x, cache_x=feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
|
|||||||
if i == 0:
|
if i == 0:
|
||||||
chunk = x[:, :1]
|
chunk = x[:, :1]
|
||||||
else:
|
else:
|
||||||
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
|
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||||
if out is None:
|
if out is None:
|
||||||
out = chunk_out
|
out = chunk_out
|
||||||
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
|
|||||||
|
|
||||||
# conv1 (pointwise) + split into mu, log_var
|
# conv1 (pointwise) + split into mu, log_var
|
||||||
out = self.conv1(out)
|
out = self.conv1(out)
|
||||||
mu = out[:, :, :, :, :self.z_dim]
|
mu = out[:, :, :, :, : self.z_dim]
|
||||||
|
|
||||||
# Normalize
|
# Normalize
|
||||||
mu = normalize_latents(mu)
|
mu = normalize_latents(mu)
|
||||||
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
|
|||||||
decoder_fn=tile_decode,
|
decoder_fn=tile_decode,
|
||||||
latents=z_cf,
|
latents=z_cf,
|
||||||
tiling_config=tiling_config,
|
tiling_config=tiling_config,
|
||||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||||
causal_temporal=True,
|
causal_temporal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_repo: str):
|
def get_model_path(model_repo: str):
|
||||||
"""Get or download LTX-2 model path."""
|
"""Get or download LTX-2 model path."""
|
||||||
try:
|
try:
|
||||||
@@ -17,15 +18,19 @@ def get_model_path(model_repo: str):
|
|||||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Downloading LTX-2 model weights...")
|
print("Downloading LTX-2 model weights...")
|
||||||
return Path(snapshot_download(
|
return Path(
|
||||||
repo_id=model_repo,
|
snapshot_download(
|
||||||
local_files_only=False,
|
repo_id=model_repo,
|
||||||
resume_download=True,
|
local_files_only=False,
|
||||||
allow_patterns=["*.safetensors", "*.json"],
|
resume_download=True,
|
||||||
))
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
|
|
||||||
def get_class_predicate(p, m):
|
def get_class_predicate(p, m):
|
||||||
# Handle custom per layer quantizations
|
# Handle custom per layer quantizations
|
||||||
if p in quantization:
|
if p in quantization:
|
||||||
@@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
|||||||
class_predicate=get_class_predicate,
|
class_predicate=get_class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def to_denoised(
|
def to_denoised(
|
||||||
noisy: mx.array,
|
noisy: mx.array, velocity: mx.array, sigma: mx.array | float
|
||||||
velocity: mx.array,
|
|
||||||
sigma: mx.array | float
|
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Convert velocity prediction to denoised output.
|
"""Convert velocity prediction to denoised output.
|
||||||
|
|
||||||
@@ -284,7 +287,9 @@ def prepare_image_for_encoding(
|
|||||||
if image_np.max() <= 1.0:
|
if image_np.max() <= 1.0:
|
||||||
image_np = (image_np * 255).astype(np.uint8)
|
image_np = (image_np * 255).astype(np.uint8)
|
||||||
pil_image = Image.fromarray(image_np)
|
pil_image = Image.fromarray(image_np)
|
||||||
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
pil_image = pil_image.resize(
|
||||||
|
(target_width, target_height), Image.Resampling.LANCZOS
|
||||||
|
)
|
||||||
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||||
|
|
||||||
# Normalize to [-1, 1]
|
# Normalize to [-1, 1]
|
||||||
|
|||||||
@@ -170,19 +170,33 @@ def print_report(results, ref_path, test_path):
|
|||||||
|
|
||||||
print("AGGREGATE METRICS")
|
print("AGGREGATE METRICS")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}")
|
print(
|
||||||
print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}")
|
f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}"
|
||||||
print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}")
|
)
|
||||||
print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}")
|
print(
|
||||||
print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}")
|
f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
|
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" Reference: {results['ref_temporal_coherence']:.2f}")
|
print(f" Reference: {results['ref_temporal_coherence']:.2f}")
|
||||||
print(f" Test: {results['test_temporal_coherence']:.2f}")
|
print(f" Test: {results['test_temporal_coherence']:.2f}")
|
||||||
ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10)
|
ratio = results["test_temporal_coherence"] / (
|
||||||
print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}")
|
results["ref_temporal_coherence"] + 1e-10
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Identify worst frames
|
# Identify worst frames
|
||||||
@@ -190,7 +204,9 @@ def print_report(results, ref_path, test_path):
|
|||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
worst_idx = np.argsort(psnr)[:5]
|
worst_idx = np.argsort(psnr)[:5]
|
||||||
for i in worst_idx:
|
for i in worst_idx:
|
||||||
print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}")
|
print(
|
||||||
|
f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Quality assessment
|
# Quality assessment
|
||||||
@@ -210,7 +226,9 @@ def print_report(results, ref_path, test_path):
|
|||||||
grade = "Very different"
|
grade = "Very different"
|
||||||
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
|
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
|
||||||
if mean_psnr < 30:
|
if mean_psnr < 30:
|
||||||
print(" ⚠ Videos differ significantly — likely a bug or different generation seed")
|
print(
|
||||||
|
" ⚠ Videos differ significantly — likely a bug or different generation seed"
|
||||||
|
)
|
||||||
print("=" * 72)
|
print("=" * 72)
|
||||||
|
|
||||||
|
|
||||||
@@ -242,9 +260,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--diff-video", help="Save side-by-side diff visualization to this path"
|
"--diff-video", help="Save side-by-side diff visualization to this path"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--max-frames", type=int, help="Compare only first N frames")
|
||||||
"--max-frames", type=int, help="Compare only first N frames"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
|
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
|
||||||
)
|
)
|
||||||
@@ -254,26 +270,29 @@ def main():
|
|||||||
default=5.0,
|
default=5.0,
|
||||||
help="Diff heatmap amplification (default: 5.0)",
|
help="Diff heatmap amplification (default: 5.0)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--csv", help="Export per-frame metrics to CSV file")
|
||||||
"--csv", help="Export per-frame metrics to CSV file"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"Loading reference: {args.reference}")
|
print(f"Loading reference: {args.reference}")
|
||||||
ref_frames, ref_fps = load_video(args.reference, args.max_frames)
|
ref_frames, ref_fps = load_video(args.reference, args.max_frames)
|
||||||
print(f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}")
|
print(
|
||||||
|
f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Loading test: {args.test}")
|
print(f"Loading test: {args.test}")
|
||||||
test_frames, test_fps = load_video(args.test, args.max_frames)
|
test_frames, test_fps = load_video(args.test, args.max_frames)
|
||||||
print(f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}")
|
print(
|
||||||
|
f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
if ref_frames[0].shape != test_frames[0].shape:
|
if ref_frames[0].shape != test_frames[0].shape:
|
||||||
print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}")
|
print(
|
||||||
|
f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}"
|
||||||
|
)
|
||||||
print("Resizing test frames to match reference...")
|
print("Resizing test frames to match reference...")
|
||||||
h, w = ref_frames[0].shape[:2]
|
h, w = ref_frames[0].shape[:2]
|
||||||
test_frames = [
|
test_frames = [
|
||||||
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) for f in test_frames
|
||||||
for f in test_frames
|
|
||||||
]
|
]
|
||||||
|
|
||||||
print("Computing metrics...")
|
print("Computing metrics...")
|
||||||
@@ -282,23 +301,29 @@ def main():
|
|||||||
print_report(results, args.reference, args.test)
|
print_report(results, args.reference, args.test)
|
||||||
|
|
||||||
if args.diff_video:
|
if args.diff_video:
|
||||||
save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale)
|
save_diff_video(
|
||||||
|
ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale
|
||||||
|
)
|
||||||
|
|
||||||
if args.csv:
|
if args.csv:
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
with open(args.csv, "w", newline="") as f:
|
with open(args.csv, "w", newline="") as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"])
|
writer.writerow(
|
||||||
|
["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]
|
||||||
|
)
|
||||||
for i in range(results["num_frames"]):
|
for i in range(results["num_frames"]):
|
||||||
writer.writerow([
|
writer.writerow(
|
||||||
i,
|
[
|
||||||
f"{results['psnr'][i]:.4f}",
|
i,
|
||||||
f"{results['ssim'][i]:.6f}",
|
f"{results['psnr'][i]:.4f}",
|
||||||
f"{results['mean_diff'][i]:.4f}",
|
f"{results['ssim'][i]:.6f}",
|
||||||
f"{results['max_diff'][i]:.1f}",
|
f"{results['mean_diff'][i]:.4f}",
|
||||||
f"{results['color_dist'][i]:.6f}",
|
f"{results['max_diff'][i]:.1f}",
|
||||||
])
|
f"{results['color_dist'][i]:.6f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
print(f"Per-frame metrics saved to {args.csv}")
|
print(f"Per-frame metrics saved to {args.csv}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -158,10 +158,14 @@ def analyze_video(frames, chunk_size=None, compute_flow=False):
|
|||||||
boundary_metrics = []
|
boundary_metrics = []
|
||||||
for b in boundaries:
|
for b in boundaries:
|
||||||
if b < n and b > 0:
|
if b < n and b > 0:
|
||||||
pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
|
pre = (
|
||||||
|
metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
|
||||||
|
)
|
||||||
at = metrics["frame_diff"][b]
|
at = metrics["frame_diff"][b]
|
||||||
ratio = at / (pre + 1e-10)
|
ratio = at / (pre + 1e-10)
|
||||||
brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1]
|
brightness_jump = (
|
||||||
|
metrics["brightness"][b] - metrics["brightness"][b - 1]
|
||||||
|
)
|
||||||
contrast_jump = (
|
contrast_jump = (
|
||||||
(metrics["contrast"][b] - metrics["contrast"][b - 1])
|
(metrics["contrast"][b] - metrics["contrast"][b - 1])
|
||||||
/ (metrics["contrast"][b - 1] + 1e-10)
|
/ (metrics["contrast"][b - 1] + 1e-10)
|
||||||
@@ -198,7 +202,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
|
|||||||
print("VIDEO QUALITY REPORT")
|
print("VIDEO QUALITY REPORT")
|
||||||
print("=" * 72)
|
print("=" * 72)
|
||||||
print(f" File: {path}")
|
print(f" File: {path}")
|
||||||
print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}")
|
print(
|
||||||
|
f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}"
|
||||||
|
)
|
||||||
duration = total_frames / fps if fps > 0 else 0
|
duration = total_frames / fps if fps > 0 else 0
|
||||||
print(f" Duration: {duration:.1f}s")
|
print(f" Duration: {duration:.1f}s")
|
||||||
print()
|
print()
|
||||||
@@ -211,52 +217,76 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
|
|||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
if n_uniform:
|
if n_uniform:
|
||||||
frames_list = np.where(metrics["is_uniform"])[0][:10]
|
frames_list = np.where(metrics["is_uniform"])[0][:10]
|
||||||
print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}")
|
print(
|
||||||
|
f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}"
|
||||||
|
)
|
||||||
if n_noisy:
|
if n_noisy:
|
||||||
frames_list = np.where(metrics["is_noisy"])[0][:10]
|
frames_list = np.where(metrics["is_noisy"])[0][:10]
|
||||||
print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}")
|
print(
|
||||||
|
f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("SHARPNESS")
|
print("SHARPNESS")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}")
|
print(
|
||||||
print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}")
|
f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}"
|
||||||
|
)
|
||||||
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
|
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
|
||||||
print(" ⚠ High sharpness variation — possible blur artifacts")
|
print(" ⚠ High sharpness variation — possible blur artifacts")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("BRIGHTNESS & CONTRAST")
|
print("BRIGHTNESS & CONTRAST")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}")
|
print(
|
||||||
print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}")
|
f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}"
|
||||||
|
)
|
||||||
if np.std(br) > 3.0:
|
if np.std(br) > 3.0:
|
||||||
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
|
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("COLOR DISTRIBUTION (BGR)")
|
print("COLOR DISTRIBUTION (BGR)")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}")
|
print(
|
||||||
print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}")
|
f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}"
|
||||||
print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}")
|
)
|
||||||
|
print(
|
||||||
|
f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("TEMPORAL STABILITY")
|
print("TEMPORAL STABILITY")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
fd_nz = fd[1:] # skip first frame (always 0)
|
fd_nz = fd[1:] # skip first frame (always 0)
|
||||||
if len(fd_nz) > 0:
|
if len(fd_nz) > 0:
|
||||||
print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}")
|
print(
|
||||||
|
f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}"
|
||||||
|
)
|
||||||
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
|
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
|
||||||
print(" ⚠ High diff variance — jitter or discontinuities")
|
print(" ⚠ High diff variance — jitter or discontinuities")
|
||||||
if "flow_mean" in metrics:
|
if "flow_mean" in metrics:
|
||||||
fm = metrics["flow_mean"][1:]
|
fm = metrics["flow_mean"][1:]
|
||||||
print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}")
|
print(
|
||||||
|
f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Chunk boundaries
|
# Chunk boundaries
|
||||||
if "boundaries" in metrics and metrics["boundaries"]:
|
if "boundaries" in metrics and metrics["boundaries"]:
|
||||||
print("CHUNK BOUNDARIES")
|
print("CHUNK BOUNDARIES")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}")
|
print(
|
||||||
|
f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}"
|
||||||
|
)
|
||||||
for bm in metrics["boundaries"]:
|
for bm in metrics["boundaries"]:
|
||||||
print(
|
print(
|
||||||
f" {bm['frame']:6d}"
|
f" {bm['frame']:6d}"
|
||||||
@@ -267,7 +297,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
|
|||||||
)
|
)
|
||||||
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
|
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
|
||||||
if avg_ratio > 2.0:
|
if avg_ratio > 2.0:
|
||||||
print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions")
|
print(
|
||||||
|
f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Overall grade
|
# Overall grade
|
||||||
@@ -303,9 +335,7 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
help="Frames per chunk for boundary analysis (e.g., 32)",
|
help="Frames per chunk for boundary analysis (e.g., 32)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--start", type=int, default=0, help="Start frame (default: 0)")
|
||||||
"--start", type=int, default=0, help="Start frame (default: 0)"
|
|
||||||
)
|
|
||||||
parser.add_argument("--end", type=int, help="End frame (default: all)")
|
parser.add_argument("--end", type=int, help="End frame (default: all)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--flow",
|
"--flow",
|
||||||
@@ -329,8 +359,14 @@ def main():
|
|||||||
import csv
|
import csv
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
"sharpness_lap", "sharpness_grad", "brightness", "contrast",
|
"sharpness_lap",
|
||||||
"color_mean_b", "color_mean_g", "color_mean_r", "frame_diff",
|
"sharpness_grad",
|
||||||
|
"brightness",
|
||||||
|
"contrast",
|
||||||
|
"color_mean_b",
|
||||||
|
"color_mean_g",
|
||||||
|
"color_mean_r",
|
||||||
|
"frame_diff",
|
||||||
]
|
]
|
||||||
if args.flow:
|
if args.flow:
|
||||||
keys += ["flow_mean", "flow_max"]
|
keys += ["flow_mean", "flow_max"]
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
"""Tests for LTX-2 dev model generation pipeline."""
|
"""Tests for LTX-2 dev model generation pipeline."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mlx_video.generate_dev import (
|
from mlx_video.generate_dev import (
|
||||||
ltx2_scheduler,
|
|
||||||
create_position_grid,
|
|
||||||
create_audio_position_grid,
|
|
||||||
compute_audio_frames,
|
|
||||||
cfg_delta,
|
|
||||||
DEFAULT_NEGATIVE_PROMPT,
|
|
||||||
AUDIO_SAMPLE_RATE,
|
|
||||||
AUDIO_LATENTS_PER_SECOND,
|
AUDIO_LATENTS_PER_SECOND,
|
||||||
|
AUDIO_SAMPLE_RATE,
|
||||||
|
DEFAULT_NEGATIVE_PROMPT,
|
||||||
|
cfg_delta,
|
||||||
|
compute_audio_frames,
|
||||||
|
create_audio_position_grid,
|
||||||
|
create_position_grid,
|
||||||
|
ltx2_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
|
|||||||
"""Scheduler should return steps+1 sigma values."""
|
"""Scheduler should return steps+1 sigma values."""
|
||||||
steps = 20
|
steps = 20
|
||||||
sigmas = ltx2_scheduler(steps=steps)
|
sigmas = ltx2_scheduler(steps=steps)
|
||||||
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
|
assert sigmas.shape == (
|
||||||
|
steps + 1,
|
||||||
|
), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||||
|
|
||||||
def test_scheduler_starts_at_one(self):
|
def test_scheduler_starts_at_one(self):
|
||||||
"""Sigma schedule should start at 1.0."""
|
"""Sigma schedule should start at 1.0."""
|
||||||
sigmas = ltx2_scheduler(steps=20)
|
sigmas = ltx2_scheduler(steps=20)
|
||||||
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
|
assert (
|
||||||
|
abs(sigmas[0].item() - 1.0) < 1e-5
|
||||||
|
), f"Expected 1.0, got {sigmas[0].item()}"
|
||||||
|
|
||||||
def test_scheduler_ends_at_zero(self):
|
def test_scheduler_ends_at_zero(self):
|
||||||
"""Sigma schedule should end at 0.0."""
|
"""Sigma schedule should end at 0.0."""
|
||||||
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
|
|||||||
sigmas = ltx2_scheduler(steps=20)
|
sigmas = ltx2_scheduler(steps=20)
|
||||||
sigmas_list = sigmas.tolist()
|
sigmas_list = sigmas.tolist()
|
||||||
for i in range(len(sigmas_list) - 1):
|
for i in range(len(sigmas_list) - 1):
|
||||||
assert sigmas_list[i] >= sigmas_list[i + 1], \
|
assert (
|
||||||
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
sigmas_list[i] >= sigmas_list[i + 1]
|
||||||
|
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||||
|
|
||||||
def test_scheduler_dtype(self):
|
def test_scheduler_dtype(self):
|
||||||
"""Scheduler should return float32 array."""
|
"""Scheduler should return float32 array."""
|
||||||
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
|
|||||||
num_patches = num_frames * height * width
|
num_patches = num_frames * height * width
|
||||||
|
|
||||||
expected_shape = (batch_size, 3, num_patches, 2)
|
expected_shape = (batch_size, 3, num_patches, 2)
|
||||||
assert positions.shape == expected_shape, \
|
assert (
|
||||||
f"Expected {expected_shape}, got {positions.shape}"
|
positions.shape == expected_shape
|
||||||
|
), f"Expected {expected_shape}, got {positions.shape}"
|
||||||
|
|
||||||
def test_position_grid_dtype(self):
|
def test_position_grid_dtype(self):
|
||||||
"""Position grid should be float32 for RoPE precision."""
|
"""Position grid should be float32 for RoPE precision."""
|
||||||
positions = create_position_grid(1, 5, 16, 24)
|
positions = create_position_grid(1, 5, 16, 24)
|
||||||
assert positions.dtype == mx.float32, \
|
assert (
|
||||||
f"Expected float32 for RoPE precision, got {positions.dtype}"
|
positions.dtype == mx.float32
|
||||||
|
), f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||||
|
|
||||||
def test_position_grid_batch_size(self):
|
def test_position_grid_batch_size(self):
|
||||||
"""Position grid should respect batch size."""
|
"""Position grid should respect batch size."""
|
||||||
@@ -165,7 +172,9 @@ class TestCFGDelta:
|
|||||||
mx.eval(delta)
|
mx.eval(delta)
|
||||||
|
|
||||||
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
||||||
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
|
assert (
|
||||||
|
mx.max(mx.abs(delta)).item() < 1e-6
|
||||||
|
), "CFG delta with scale=1.0 should be zero"
|
||||||
|
|
||||||
def test_cfg_delta_formula(self):
|
def test_cfg_delta_formula(self):
|
||||||
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
||||||
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
|
|||||||
|
|
||||||
# Check for common negative quality terms
|
# Check for common negative quality terms
|
||||||
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
||||||
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
|
assert (
|
||||||
"Should contain quality-related terms"
|
"low quality" in prompt_lower or "low contrast" in prompt_lower
|
||||||
|
), "Should contain quality-related terms"
|
||||||
|
|
||||||
|
|
||||||
class TestInputValidation:
|
class TestInputValidation:
|
||||||
@@ -248,15 +258,16 @@ class TestInputValidation:
|
|||||||
(30, 33), # 30 -> nearest valid is 33
|
(30, 33), # 30 -> nearest valid is 33
|
||||||
(35, 33), # 35 -> nearest valid is 33
|
(35, 33), # 35 -> nearest valid is 33
|
||||||
(40, 41), # 40 -> nearest valid is 41
|
(40, 41), # 40 -> nearest valid is 41
|
||||||
(1, 1), # 1 is already valid
|
(1, 1), # 1 is already valid
|
||||||
(33, 33), # 33 is already valid
|
(33, 33), # 33 is already valid
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_frames, expected in test_cases:
|
for input_frames, expected in test_cases:
|
||||||
if input_frames % 8 != 1:
|
if input_frames % 8 != 1:
|
||||||
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
||||||
assert adjusted == expected, \
|
assert (
|
||||||
f"Expected {expected} for input {input_frames}, got {adjusted}"
|
adjusted == expected
|
||||||
|
), f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||||
|
|
||||||
|
|
||||||
class TestDenoiseWithCFGMocked:
|
class TestDenoiseWithCFGMocked:
|
||||||
@@ -277,14 +288,16 @@ class TestTilingDefault:
|
|||||||
def test_tiling_default_is_none(self):
|
def test_tiling_default_is_none(self):
|
||||||
"""Default tiling should be 'none' for performance."""
|
"""Default tiling should be 'none' for performance."""
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from mlx_video.generate_dev import generate_video_dev
|
from mlx_video.generate_dev import generate_video_dev
|
||||||
|
|
||||||
sig = inspect.signature(generate_video_dev)
|
sig = inspect.signature(generate_video_dev)
|
||||||
|
|
||||||
tiling_param = sig.parameters.get('tiling')
|
tiling_param = sig.parameters.get("tiling")
|
||||||
assert tiling_param is not None
|
assert tiling_param is not None
|
||||||
assert tiling_param.default == "none", \
|
assert (
|
||||||
f"Expected default tiling='none', got '{tiling_param.default}'"
|
tiling_param.default == "none"
|
||||||
|
), f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||||
|
|
||||||
|
|
||||||
class TestLatentDimensions:
|
class TestLatentDimensions:
|
||||||
@@ -296,8 +309,9 @@ class TestLatentDimensions:
|
|||||||
|
|
||||||
for height, expected_latent_h in test_cases:
|
for height, expected_latent_h in test_cases:
|
||||||
latent_h = height // 32
|
latent_h = height // 32
|
||||||
assert latent_h == expected_latent_h, \
|
assert (
|
||||||
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
latent_h == expected_latent_h
|
||||||
|
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||||
|
|
||||||
def test_latent_width_calculation(self):
|
def test_latent_width_calculation(self):
|
||||||
"""Latent width should be width // 32."""
|
"""Latent width should be width // 32."""
|
||||||
@@ -305,8 +319,9 @@ class TestLatentDimensions:
|
|||||||
|
|
||||||
for width, expected_latent_w in test_cases:
|
for width, expected_latent_w in test_cases:
|
||||||
latent_w = width // 32
|
latent_w = width // 32
|
||||||
assert latent_w == expected_latent_w, \
|
assert (
|
||||||
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
latent_w == expected_latent_w
|
||||||
|
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||||
|
|
||||||
def test_latent_frames_calculation(self):
|
def test_latent_frames_calculation(self):
|
||||||
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
||||||
@@ -314,8 +329,9 @@ class TestLatentDimensions:
|
|||||||
|
|
||||||
for num_frames, expected_latent_f in test_cases:
|
for num_frames, expected_latent_f in test_cases:
|
||||||
latent_f = 1 + (num_frames - 1) // 8
|
latent_f = 1 + (num_frames - 1) // 8
|
||||||
assert latent_f == expected_latent_f, \
|
assert (
|
||||||
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
latent_f == expected_latent_f
|
||||||
|
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||||
|
|
||||||
def test_num_tokens_calculation(self):
|
def test_num_tokens_calculation(self):
|
||||||
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
||||||
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
|
|||||||
positions = create_audio_position_grid(batch_size, audio_frames)
|
positions = create_audio_position_grid(batch_size, audio_frames)
|
||||||
expected_shape = (batch_size, 1, audio_frames, 2)
|
expected_shape = (batch_size, 1, audio_frames, 2)
|
||||||
|
|
||||||
assert positions.shape == expected_shape, \
|
assert (
|
||||||
f"Expected {expected_shape}, got {positions.shape}"
|
positions.shape == expected_shape
|
||||||
|
), f"Expected {expected_shape}, got {positions.shape}"
|
||||||
|
|
||||||
def test_audio_position_grid_dtype(self):
|
def test_audio_position_grid_dtype(self):
|
||||||
"""Audio position grid should be float32."""
|
"""Audio position grid should be float32."""
|
||||||
positions = create_audio_position_grid(1, 34)
|
positions = create_audio_position_grid(1, 34)
|
||||||
assert positions.dtype == mx.float32, \
|
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
|
||||||
f"Expected float32, got {positions.dtype}"
|
|
||||||
|
|
||||||
def test_audio_position_grid_batch_size(self):
|
def test_audio_position_grid_batch_size(self):
|
||||||
"""Audio position grid should respect batch size."""
|
"""Audio position grid should respect batch size."""
|
||||||
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
|
|||||||
"""Audio position grid should not contain NaN or Inf."""
|
"""Audio position grid should not contain NaN or Inf."""
|
||||||
positions = create_audio_position_grid(1, 34)
|
positions = create_audio_position_grid(1, 34)
|
||||||
|
|
||||||
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
|
assert not mx.any(
|
||||||
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
|
mx.isnan(positions)
|
||||||
|
).item(), "Audio position grid contains NaN"
|
||||||
|
assert not mx.any(
|
||||||
|
mx.isinf(positions)
|
||||||
|
).item(), "Audio position grid contains Inf"
|
||||||
|
|
||||||
|
|
||||||
class TestComputeAudioFrames:
|
class TestComputeAudioFrames:
|
||||||
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
|
|||||||
audio_33 = compute_audio_frames(33, 24.0)
|
audio_33 = compute_audio_frames(33, 24.0)
|
||||||
audio_65 = compute_audio_frames(65, 24.0)
|
audio_65 = compute_audio_frames(65, 24.0)
|
||||||
|
|
||||||
assert audio_65 > audio_33, \
|
assert (
|
||||||
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
audio_65 > audio_33
|
||||||
|
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||||
|
|
||||||
def test_audio_frames_formula(self):
|
def test_audio_frames_formula(self):
|
||||||
"""Audio frames should match expected formula."""
|
"""Audio frames should match expected formula."""
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import pytest
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.rope import (
|
|
||||||
precompute_freqs_cis,
|
|
||||||
)
|
|
||||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||||
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||||
|
|
||||||
|
|
||||||
def create_video_position_grid(
|
def create_video_position_grid(
|
||||||
@@ -20,7 +18,7 @@ def create_video_position_grid(
|
|||||||
h_coords = np.arange(0, height)
|
h_coords = np.arange(0, height)
|
||||||
w_coords = np.arange(0, width)
|
w_coords = np.arange(0, width)
|
||||||
|
|
||||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
|
||||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||||
patch_ends = patch_starts + 1
|
patch_ends = patch_starts + 1
|
||||||
|
|
||||||
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
|||||||
scaled = fractional * 2 - 1 # [-1, 1]
|
scaled = fractional * 2 - 1 # [-1, 1]
|
||||||
|
|
||||||
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
||||||
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
freqs = (
|
||||||
|
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||||
|
)
|
||||||
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
||||||
freqs = np.swapaxes(freqs, -1, -2)
|
freqs = np.swapaxes(freqs, -1, -2)
|
||||||
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
|
freqs = freqs.reshape(
|
||||||
|
freqs.shape[0], freqs.shape[1], -1
|
||||||
|
) # (B, T, num_indices * n_dims)
|
||||||
|
|
||||||
cos_ref = np.cos(freqs)
|
cos_ref = np.cos(freqs)
|
||||||
sin_ref = np.sin(freqs)
|
sin_ref = np.sin(freqs)
|
||||||
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
|||||||
pad_size = expected - cos_ref.shape[-1]
|
pad_size = expected - cos_ref.shape[-1]
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
||||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
cos_ref = np.concatenate(
|
||||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
|
||||||
|
)
|
||||||
|
sin_ref = np.concatenate(
|
||||||
|
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
B, T, _ = cos_ref.shape
|
B, T, _ = cos_ref.shape
|
||||||
dim_per_head = dim // num_heads
|
dim_per_head = dim // num_heads
|
||||||
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
|
|||||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||||
|
|
||||||
# Verify cos/sin are in valid range [-1, 1]
|
# Verify cos/sin are in valid range [-1, 1]
|
||||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
assert (
|
||||||
"cos_freq values out of [-1, 1] range"
|
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
|
||||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
), "cos_freq values out of [-1, 1] range"
|
||||||
"sin_freq values out of [-1, 1] range"
|
assert (
|
||||||
|
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
|
||||||
|
), "sin_freq values out of [-1, 1] range"
|
||||||
|
|
||||||
def test_bfloat16_positions_cause_precision_loss(self):
|
def test_bfloat16_positions_cause_precision_loss(self):
|
||||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||||
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
|
|||||||
# The threshold here is intentionally low to catch the issue
|
# The threshold here is intentionally low to catch the issue
|
||||||
precision_threshold = 1e-6
|
precision_threshold = 1e-6
|
||||||
|
|
||||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
has_precision_loss = (
|
||||||
|
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# Document the precision loss (this is expected behavior)
|
# Document the precision loss (this is expected behavior)
|
||||||
if has_precision_loss:
|
if has_precision_loss:
|
||||||
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
|
|||||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||||
|
|
||||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||||
assert has_precision_loss, \
|
assert (
|
||||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
has_precision_loss
|
||||||
|
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||||
|
|
||||||
def test_double_precision_converts_to_float32_internally(self):
|
def test_double_precision_converts_to_float32_internally(self):
|
||||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||||
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
|
|||||||
# Recommended: create positions in float32
|
# Recommended: create positions in float32
|
||||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
|
||||||
assert positions.dtype == mx.float32, \
|
assert (
|
||||||
"Position grids should be created in float32 for RoPE precision"
|
positions.dtype == mx.float32
|
||||||
|
), "Position grids should be created in float32 for RoPE precision"
|
||||||
|
|
||||||
# Verify the position values are reasonable
|
# Verify the position values are reasonable
|
||||||
# Temporal positions should be small (seconds)
|
# Temporal positions should be small (seconds)
|
||||||
temporal_positions = positions[:, 0, :, :]
|
temporal_positions = positions[:, 0, :, :]
|
||||||
assert mx.max(temporal_positions).item() < 100, \
|
assert (
|
||||||
"Temporal positions should be in seconds (small values)"
|
mx.max(temporal_positions).item() < 100
|
||||||
|
), "Temporal positions should be in seconds (small values)"
|
||||||
|
|
||||||
# Spatial positions should be larger (pixels)
|
# Spatial positions should be larger (pixels)
|
||||||
spatial_h = positions[:, 1, :, :]
|
spatial_h = positions[:, 1, :, :]
|
||||||
spatial_w = positions[:, 2, :, :]
|
spatial_w = positions[:, 2, :, :]
|
||||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
assert (
|
||||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
mx.max(spatial_h).item() > 0
|
||||||
|
), "Spatial height positions should be positive"
|
||||||
|
assert (
|
||||||
|
mx.max(spatial_w).item() > 0
|
||||||
|
), "Spatial width positions should be positive"
|
||||||
|
|
||||||
def test_float32_positions_match_numpy_float64_reference(self):
|
def test_float32_positions_match_numpy_float64_reference(self):
|
||||||
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
||||||
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NumPy float64 reference
|
# NumPy float64 reference
|
||||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
cos_ref, sin_ref = _numpy_reference_rope(
|
||||||
|
positions_np, dim, theta, max_pos, num_heads
|
||||||
|
)
|
||||||
|
|
||||||
cos_mlx_np = np.array(cos_mlx)
|
cos_mlx_np = np.array(cos_mlx)
|
||||||
sin_mlx_np = np.array(sin_mlx)
|
sin_mlx_np = np.array(sin_mlx)
|
||||||
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
|
|||||||
# Cosine similarity (flatten for single scalar)
|
# Cosine similarity (flatten for single scalar)
|
||||||
cos_flat = cos_mlx_np.flatten()
|
cos_flat = cos_mlx_np.flatten()
|
||||||
ref_flat = cos_ref.flatten()
|
ref_flat = cos_ref.flatten()
|
||||||
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
|
cosine_sim = np.dot(cos_flat, ref_flat) / (
|
||||||
|
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
|
||||||
|
)
|
||||||
|
|
||||||
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
||||||
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
||||||
assert max_cos_diff < 0.01, \
|
assert (
|
||||||
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
max_cos_diff < 0.01
|
||||||
assert max_sin_diff < 0.01, \
|
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||||
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
assert (
|
||||||
assert cosine_sim > 0.9999, \
|
max_sin_diff < 0.01
|
||||||
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||||
|
assert (
|
||||||
|
cosine_sim > 0.9999
|
||||||
|
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||||
|
|
||||||
def test_high_frequency_amplification_regression(self):
|
def test_high_frequency_amplification_regression(self):
|
||||||
"""Regression test for the specific failure mode: high-frequency index amplification.
|
"""Regression test for the specific failure mode: high-frequency index amplification.
|
||||||
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
|
|||||||
double_precision=False,
|
double_precision=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
cos_ref, sin_ref = _numpy_reference_rope(
|
||||||
|
positions_np, dim, theta, max_pos, num_heads
|
||||||
|
)
|
||||||
|
|
||||||
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
||||||
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
||||||
|
|
||||||
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
||||||
assert max_cos_diff < 0.01, \
|
assert (
|
||||||
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
max_cos_diff < 0.01
|
||||||
assert max_sin_diff < 0.01, \
|
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||||
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
assert (
|
||||||
|
max_sin_diff < 0.01
|
||||||
|
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||||
|
|
||||||
|
|
||||||
class TestRoPEInterleaved:
|
class TestRoPEInterleaved:
|
||||||
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
|
|||||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
dim=128,
|
||||||
use_middle_indices_grid=True, num_attention_heads=32,
|
theta=10000.0,
|
||||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||||
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
|
|||||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
dim=128,
|
||||||
use_middle_indices_grid=True, num_attention_heads=32,
|
theta=10000.0,
|
||||||
rope_type=LTXRopeType.SPLIT, double_precision=True,
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||||
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
|
|||||||
|
|
||||||
cos_freq, sin_freq = precompute_freqs_cis(
|
cos_freq, sin_freq = precompute_freqs_cis(
|
||||||
indices_grid=positions_f16,
|
indices_grid=positions_f16,
|
||||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
dim=128,
|
||||||
use_middle_indices_grid=True, num_attention_heads=32,
|
theta=10000.0,
|
||||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert cos_freq.dtype == mx.float32
|
assert cos_freq.dtype == mx.float32
|
||||||
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
|
|||||||
def test_ltx2_forces_double_precision_rope_false(self):
|
def test_ltx2_forces_double_precision_rope_false(self):
|
||||||
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
||||||
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
||||||
assert config.double_precision_rope is False, \
|
assert (
|
||||||
"LTX-2 should force double_precision_rope=False regardless of input"
|
config.double_precision_rope is False
|
||||||
|
), "LTX-2 should force double_precision_rope=False regardless of input"
|
||||||
|
|
||||||
def test_ltx23_preserves_double_precision_rope_true(self):
|
def test_ltx23_preserves_double_precision_rope_true(self):
|
||||||
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
||||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
||||||
assert config.double_precision_rope is True, \
|
assert (
|
||||||
"LTX-2.3 should preserve double_precision_rope=True"
|
config.double_precision_rope is True
|
||||||
|
), "LTX-2.3 should preserve double_precision_rope=True"
|
||||||
|
|
||||||
def test_ltx23_preserves_double_precision_rope_false(self):
|
def test_ltx23_preserves_double_precision_rope_false(self):
|
||||||
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
||||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
||||||
assert config.double_precision_rope is False, \
|
assert (
|
||||||
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
config.double_precision_rope is False
|
||||||
|
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||||
|
|
||||||
def test_ltx2_default_double_precision_rope(self):
|
def test_ltx2_default_double_precision_rope(self):
|
||||||
"""LTX-2 default (double_precision_rope not set) should be False."""
|
"""LTX-2 default (double_precision_rope not set) should be False."""
|
||||||
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
|
|||||||
|
|
||||||
def test_config_from_dict_ltx2(self):
|
def test_config_from_dict_ltx2(self):
|
||||||
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
||||||
config = LTXModelConfig.from_dict({
|
config = LTXModelConfig.from_dict(
|
||||||
"has_prompt_adaln": False,
|
{
|
||||||
"double_precision_rope": True,
|
"has_prompt_adaln": False,
|
||||||
"rope_type": "split",
|
"double_precision_rope": True,
|
||||||
})
|
"rope_type": "split",
|
||||||
|
}
|
||||||
|
)
|
||||||
assert config.double_precision_rope is False
|
assert config.double_precision_rope is False
|
||||||
|
|
||||||
def test_config_from_dict_ltx23(self):
|
def test_config_from_dict_ltx23(self):
|
||||||
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
||||||
config = LTXModelConfig.from_dict({
|
config = LTXModelConfig.from_dict(
|
||||||
"has_prompt_adaln": True,
|
{
|
||||||
"double_precision_rope": True,
|
"has_prompt_adaln": True,
|
||||||
"rope_type": "split",
|
"double_precision_rope": True,
|
||||||
})
|
"rope_type": "split",
|
||||||
|
}
|
||||||
|
)
|
||||||
assert config.double_precision_rope is True
|
assert config.double_precision_rope is True
|
||||||
|
|
||||||
|
|
||||||
@@ -496,10 +543,12 @@ class TestRoPESplit:
|
|||||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||||
dim_per_head = dim // num_heads
|
dim_per_head = dim // num_heads
|
||||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||||
assert cos_freq.shape == expected_shape, \
|
assert (
|
||||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
cos_freq.shape == expected_shape
|
||||||
assert sin_freq.shape == expected_shape, \
|
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
assert (
|
||||||
|
sin_freq.shape == expected_shape
|
||||||
|
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Tests for VAE streaming and chunked conv features."""
|
"""Tests for VAE streaming and chunked conv features."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||||
@@ -50,7 +50,7 @@ class TestChunkedConv:
|
|||||||
np.array(out_chunked),
|
np.array(out_chunked),
|
||||||
rtol=1e-5,
|
rtol=1e-5,
|
||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
err_msg="Chunked conv output differs from regular output"
|
err_msg="Chunked conv output differs from regular output",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_chunked_conv_small_input_passthrough(self):
|
def test_chunked_conv_small_input_passthrough(self):
|
||||||
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
|
|||||||
frames_received = []
|
frames_received = []
|
||||||
|
|
||||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||||
frames_received.append({
|
frames_received.append(
|
||||||
'shape': frames.shape,
|
{
|
||||||
'start_idx': start_idx,
|
"shape": frames.shape,
|
||||||
})
|
"start_idx": start_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Create a mock decoder that just returns scaled input
|
# Create a mock decoder that just returns scaled input
|
||||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
def mock_decoder(
|
||||||
|
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||||
|
):
|
||||||
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
||||||
b, c, f, h, w = x.shape
|
b, c, f, h, w = x.shape
|
||||||
out_f = 1 + (f - 1) * 8
|
out_f = 1 + (f - 1) * 8
|
||||||
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
|
|||||||
|
|
||||||
# All received frames should have correct channel count
|
# All received frames should have correct channel count
|
||||||
for received in frames_received:
|
for received in frames_received:
|
||||||
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
|
assert (
|
||||||
|
received["shape"][1] == 3
|
||||||
|
), f"Expected 3 channels, got {received['shape'][1]}"
|
||||||
|
|
||||||
def test_on_frames_ready_covers_all_frames(self):
|
def test_on_frames_ready_covers_all_frames(self):
|
||||||
"""Verify all frames are emitted via callbacks."""
|
"""Verify all frames are emitted via callbacks."""
|
||||||
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
|
|||||||
for i in range(num_frames):
|
for i in range(num_frames):
|
||||||
all_frame_indices.add(start_idx + i)
|
all_frame_indices.add(start_idx + i)
|
||||||
|
|
||||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
def mock_decoder(
|
||||||
|
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||||
|
):
|
||||||
b, c, f, h, w = x.shape
|
b, c, f, h, w = x.shape
|
||||||
out_f = 1 + (f - 1) * 8
|
out_f = 1 + (f - 1) * 8
|
||||||
out_h = h * 32
|
out_h = h * 32
|
||||||
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
|
|||||||
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
||||||
|
|
||||||
# All frames should have been emitted
|
# All frames should have been emitted
|
||||||
assert len(all_frame_indices) == expected_frames, \
|
assert (
|
||||||
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
len(all_frame_indices) == expected_frames
|
||||||
assert all_frame_indices == set(range(expected_frames)), \
|
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||||
"Not all frame indices were covered"
|
assert all_frame_indices == set(
|
||||||
|
range(expected_frames)
|
||||||
|
), "Not all frame indices were covered"
|
||||||
|
|
||||||
|
|
||||||
class TestAutoChunkedConv:
|
class TestAutoChunkedConv:
|
||||||
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
||||||
|
|
||||||
@pytest.mark.parametrize("tiling_mode,should_enable", [
|
@pytest.mark.parametrize(
|
||||||
("conservative", True),
|
"tiling_mode,should_enable",
|
||||||
("none", True),
|
[
|
||||||
("auto", True),
|
("conservative", True),
|
||||||
("default", True),
|
("none", True),
|
||||||
("spatial", True),
|
("auto", True),
|
||||||
("aggressive", False),
|
("default", True),
|
||||||
("temporal", False),
|
("spatial", True),
|
||||||
])
|
("aggressive", False),
|
||||||
|
("temporal", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
||||||
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
||||||
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||||
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
|
|||||||
|
|
||||||
use_chunked_conv = tiling_mode in expected_modes
|
use_chunked_conv = tiling_mode in expected_modes
|
||||||
|
|
||||||
assert use_chunked_conv == should_enable, \
|
assert (
|
||||||
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
use_chunked_conv == should_enable
|
||||||
|
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||||
|
|
||||||
|
|
||||||
class TestTrapezoidalMask:
|
class TestTrapezoidalMask:
|
||||||
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
|
|||||||
|
|
||||||
# Right ramp should be decreasing
|
# Right ramp should be decreasing
|
||||||
right_ramp = mask_np[-8:]
|
right_ramp = mask_np[-8:]
|
||||||
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
|
assert np.all(
|
||||||
|
np.diff(right_ramp) <= 0
|
||||||
|
), "Right ramp not monotonically decreasing"
|
||||||
|
|
||||||
def test_temporal_mask_starts_from_zero(self):
|
def test_temporal_mask_starts_from_zero(self):
|
||||||
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
||||||
|
|||||||
@@ -2,24 +2,25 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# RoPE Tests
|
# RoPE Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestRoPE:
|
class TestRoPE:
|
||||||
"""Tests for 3-way factorized RoPE."""
|
"""Tests for 3-way factorized RoPE."""
|
||||||
|
|
||||||
def test_rope_params_shape(self):
|
def test_rope_params_shape(self):
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
freqs = rope_params(1024, 64)
|
freqs = rope_params(1024, 64)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||||
|
|
||||||
def test_rope_params_different_dims(self):
|
def test_rope_params_different_dims(self):
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
for dim in [32, 64, 128]:
|
for dim in [32, 64, 128]:
|
||||||
freqs = rope_params(512, dim)
|
freqs = rope_params(512, dim)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
@@ -27,6 +28,7 @@ class TestRoPE:
|
|||||||
|
|
||||||
def test_rope_params_cos_sin_range(self):
|
def test_rope_params_cos_sin_range(self):
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
freqs = rope_params(256, 64)
|
freqs = rope_params(256, 64)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
cos_vals = np.array(freqs[:, :, 0])
|
cos_vals = np.array(freqs[:, :, 0])
|
||||||
@@ -37,13 +39,15 @@ class TestRoPE:
|
|||||||
def test_rope_params_position_zero(self):
|
def test_rope_params_position_zero(self):
|
||||||
"""At position 0, cos should be 1 and sin should be 0."""
|
"""At position 0, cos should be 1 and sin should be 0."""
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
freqs = rope_params(10, 64)
|
freqs = rope_params(10, 64)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||||
|
|
||||||
def test_rope_apply_output_shape(self):
|
def test_rope_apply_output_shape(self):
|
||||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||||
|
|
||||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||||
x = mx.random.normal((B, L, N, D))
|
x = mx.random.normal((B, L, N, D))
|
||||||
freqs = rope_params(1024, D)
|
freqs = rope_params(1024, D)
|
||||||
@@ -54,7 +58,8 @@ class TestRoPE:
|
|||||||
|
|
||||||
def test_rope_apply_preserves_norm(self):
|
def test_rope_apply_preserves_norm(self):
|
||||||
"""RoPE rotation should preserve vector norms."""
|
"""RoPE rotation should preserve vector norms."""
|
||||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||||
|
|
||||||
B, N, D = 1, 2, 16
|
B, N, D = 1, 2, 16
|
||||||
F, H, W = 2, 3, 4
|
F, H, W = 2, 3, 4
|
||||||
L = F * H * W
|
L = F * H * W
|
||||||
@@ -74,7 +79,8 @@ class TestRoPE:
|
|||||||
|
|
||||||
def test_rope_apply_with_padding(self):
|
def test_rope_apply_with_padding(self):
|
||||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||||
|
|
||||||
B, N, D = 1, 2, 16
|
B, N, D = 1, 2, 16
|
||||||
F, H, W = 2, 2, 2
|
F, H, W = 2, 2, 2
|
||||||
seq_len = F * H * W # 8
|
seq_len = F * H * W # 8
|
||||||
@@ -94,7 +100,8 @@ class TestRoPE:
|
|||||||
|
|
||||||
def test_rope_apply_batch(self):
|
def test_rope_apply_batch(self):
|
||||||
"""Test with batch_size > 1 and different grid sizes."""
|
"""Test with batch_size > 1 and different grid sizes."""
|
||||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||||
|
|
||||||
B, N, D = 2, 2, 16
|
B, N, D = 2, 2, 16
|
||||||
grids = [(2, 3, 4), (2, 3, 4)]
|
grids = [(2, 3, 4), (2, 3, 4)]
|
||||||
L = 2 * 3 * 4
|
L = 2 * 3 * 4
|
||||||
@@ -122,9 +129,11 @@ class TestRoPE:
|
|||||||
# Attention Tests
|
# Attention Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWanRMSNorm:
|
class TestWanRMSNorm:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.attention import WanRMSNorm
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
|
||||||
norm = WanRMSNorm(64)
|
norm = WanRMSNorm(64)
|
||||||
x = mx.random.normal((2, 10, 64))
|
x = mx.random.normal((2, 10, 64))
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -134,6 +143,7 @@ class TestWanRMSNorm:
|
|||||||
def test_zero_mean_variance(self):
|
def test_zero_mean_variance(self):
|
||||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||||
from mlx_video.models.wan.attention import WanRMSNorm
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
|
||||||
norm = WanRMSNorm(64)
|
norm = WanRMSNorm(64)
|
||||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -147,6 +157,7 @@ class TestWanRMSNorm:
|
|||||||
def test_dtype_preservation(self):
|
def test_dtype_preservation(self):
|
||||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||||
from mlx_video.models.wan.attention import WanRMSNorm
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
|
||||||
norm = WanRMSNorm(32)
|
norm = WanRMSNorm(32)
|
||||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -158,6 +169,7 @@ class TestWanRMSNorm:
|
|||||||
class TestWanLayerNorm:
|
class TestWanLayerNorm:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.attention import WanLayerNorm
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
|
||||||
norm = WanLayerNorm(64)
|
norm = WanLayerNorm(64)
|
||||||
x = mx.random.normal((2, 10, 64))
|
x = mx.random.normal((2, 10, 64))
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -166,6 +178,7 @@ class TestWanLayerNorm:
|
|||||||
|
|
||||||
def test_without_affine(self):
|
def test_without_affine(self):
|
||||||
from mlx_video.models.wan.attention import WanLayerNorm
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
|
||||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||||
x = mx.random.normal((1, 4, 64))
|
x = mx.random.normal((1, 4, 64))
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -178,6 +191,7 @@ class TestWanLayerNorm:
|
|||||||
|
|
||||||
def test_with_affine(self):
|
def test_with_affine(self):
|
||||||
from mlx_video.models.wan.attention import WanLayerNorm
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
|
||||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||||
assert hasattr(norm, "weight")
|
assert hasattr(norm, "weight")
|
||||||
assert hasattr(norm, "bias")
|
assert hasattr(norm, "bias")
|
||||||
@@ -196,6 +210,7 @@ class TestWanSelfAttention:
|
|||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
B, L = 1, 24
|
B, L = 1, 24
|
||||||
F, H, W = 2, 3, 4
|
F, H, W = 2, 3, 4
|
||||||
@@ -207,12 +222,14 @@ class TestWanSelfAttention:
|
|||||||
|
|
||||||
def test_with_qk_norm(self):
|
def test_with_qk_norm(self):
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||||
assert attn.norm_q is not None
|
assert attn.norm_q is not None
|
||||||
assert attn.norm_k is not None
|
assert attn.norm_k is not None
|
||||||
|
|
||||||
def test_without_qk_norm(self):
|
def test_without_qk_norm(self):
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||||
assert attn.norm_q is None
|
assert attn.norm_q is None
|
||||||
assert attn.norm_k is None
|
assert attn.norm_k is None
|
||||||
@@ -221,6 +238,7 @@ class TestWanSelfAttention:
|
|||||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||||
B, L = 1, 24
|
B, L = 1, 24
|
||||||
F, H, W = 2, 3, 4
|
F, H, W = 2, 3, 4
|
||||||
@@ -245,6 +263,7 @@ class TestWanCrossAttention:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.attention import WanCrossAttention
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
|
||||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
B, L_q, L_kv = 1, 24, 16
|
B, L_q, L_kv = 1, 24, 16
|
||||||
x = mx.random.normal((B, L_q, self.dim))
|
x = mx.random.normal((B, L_q, self.dim))
|
||||||
@@ -255,6 +274,7 @@ class TestWanCrossAttention:
|
|||||||
|
|
||||||
def test_with_context_mask(self):
|
def test_with_context_mask(self):
|
||||||
from mlx_video.models.wan.attention import WanCrossAttention
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
|
||||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
B, L_q, L_kv = 1, 12, 16
|
B, L_q, L_kv = 1, 12, 16
|
||||||
x = mx.random.normal((B, L_q, self.dim))
|
x = mx.random.normal((B, L_q, self.dim))
|
||||||
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
|
|||||||
# bfloat16 Autocast Tests
|
# bfloat16 Autocast Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestBFloat16Autocast:
|
class TestBFloat16Autocast:
|
||||||
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||||
for efficient matmul, matching official PyTorch autocast behavior."""
|
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||||
@@ -292,6 +313,7 @@ class TestBFloat16Autocast:
|
|||||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
attn.update(self._to_bf16(attn.parameters()))
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
@@ -305,6 +327,7 @@ class TestBFloat16Autocast:
|
|||||||
def test_cross_attn_casts_to_weight_dtype(self):
|
def test_cross_attn_casts_to_weight_dtype(self):
|
||||||
"""Cross-attention should cast input to weight dtype."""
|
"""Cross-attention should cast input to weight dtype."""
|
||||||
from mlx_video.models.wan.attention import WanCrossAttention
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
|
||||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
attn.update(self._to_bf16(attn.parameters()))
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
@@ -318,6 +341,7 @@ class TestBFloat16Autocast:
|
|||||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||||
"""prepare_kv should cast context to weight dtype."""
|
"""prepare_kv should cast context to weight dtype."""
|
||||||
from mlx_video.models.wan.attention import WanCrossAttention
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
|
||||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
attn.update(self._to_bf16(attn.parameters()))
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
@@ -330,6 +354,7 @@ class TestBFloat16Autocast:
|
|||||||
def test_ffn_casts_to_weight_dtype(self):
|
def test_ffn_casts_to_weight_dtype(self):
|
||||||
"""FFN should cast input to weight dtype for linear layers."""
|
"""FFN should cast input to weight dtype for linear layers."""
|
||||||
from mlx_video.models.wan.transformer import WanFFN
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
|
||||||
ffn = WanFFN(self.dim, 128)
|
ffn = WanFFN(self.dim, 128)
|
||||||
ffn.update(self._to_bf16(ffn.parameters()))
|
ffn.update(self._to_bf16(ffn.parameters()))
|
||||||
|
|
||||||
@@ -343,6 +368,7 @@ class TestBFloat16Autocast:
|
|||||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||||
from mlx_video.models.wan.attention import WanSelfAttention
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
attn.update(self._to_bf16(attn.parameters()))
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
|
|||||||
|
|
||||||
def test_block_float32_residual_with_bf16_weights(self):
|
def test_block_float32_residual_with_bf16_weights(self):
|
||||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||||
block.update(self._to_bf16(block.parameters()))
|
block.update(self._to_bf16(block.parameters()))
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
"""Tests for Wan model configuration."""
|
"""Tests for Wan model configuration."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Config Tests
|
# Config Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWanModelConfig:
|
class TestWanModelConfig:
|
||||||
"""Tests for WanModelConfig dataclass."""
|
"""Tests for WanModelConfig dataclass."""
|
||||||
|
|
||||||
def test_default_values(self):
|
def test_default_values(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
assert config.dim == 5120
|
assert config.dim == 5120
|
||||||
assert config.ffn_dim == 13824
|
assert config.ffn_dim == 13824
|
||||||
@@ -33,11 +33,13 @@ class TestWanModelConfig:
|
|||||||
|
|
||||||
def test_head_dim_property(self):
|
def test_head_dim_property(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
assert config.head_dim == 128 # 5120 // 40
|
assert config.head_dim == 128 # 5120 // 40
|
||||||
|
|
||||||
def test_to_dict_roundtrip(self):
|
def test_to_dict_roundtrip(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
d = config.to_dict()
|
d = config.to_dict()
|
||||||
assert isinstance(d, dict)
|
assert isinstance(d, dict)
|
||||||
@@ -47,6 +49,7 @@ class TestWanModelConfig:
|
|||||||
|
|
||||||
def test_t5_config_values(self):
|
def test_t5_config_values(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
assert config.t5_vocab_size == 256384
|
assert config.t5_vocab_size == 256384
|
||||||
assert config.t5_dim == 4096
|
assert config.t5_dim == 4096
|
||||||
@@ -61,11 +64,13 @@ class TestWanModelConfig:
|
|||||||
# Wan2.1 Config Tests
|
# Wan2.1 Config Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWan21Config:
|
class TestWan21Config:
|
||||||
"""Tests for Wan2.1 config presets."""
|
"""Tests for Wan2.1 config presets."""
|
||||||
|
|
||||||
def test_wan21_14b_factory(self):
|
def test_wan21_14b_factory(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_14b()
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
assert config.model_version == "2.1"
|
assert config.model_version == "2.1"
|
||||||
assert config.dual_model is False
|
assert config.dual_model is False
|
||||||
@@ -81,6 +86,7 @@ class TestWan21Config:
|
|||||||
|
|
||||||
def test_wan21_1_3b_factory(self):
|
def test_wan21_1_3b_factory(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_1_3b()
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
assert config.model_version == "2.1"
|
assert config.model_version == "2.1"
|
||||||
assert config.dual_model is False
|
assert config.dual_model is False
|
||||||
@@ -93,6 +99,7 @@ class TestWan21Config:
|
|||||||
|
|
||||||
def test_wan22_14b_factory(self):
|
def test_wan22_14b_factory(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan22_t2v_14b()
|
config = WanModelConfig.wan22_t2v_14b()
|
||||||
assert config.model_version == "2.2"
|
assert config.model_version == "2.2"
|
||||||
assert config.dual_model is True
|
assert config.dual_model is True
|
||||||
@@ -104,6 +111,7 @@ class TestWan21Config:
|
|||||||
|
|
||||||
def test_wan21_config_to_dict(self):
|
def test_wan21_config_to_dict(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_14b()
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
d = config.to_dict()
|
d = config.to_dict()
|
||||||
assert d["model_version"] == "2.1"
|
assert d["model_version"] == "2.1"
|
||||||
@@ -112,6 +120,7 @@ class TestWan21Config:
|
|||||||
|
|
||||||
def test_wan21_1_3b_config_to_dict(self):
|
def test_wan21_1_3b_config_to_dict(self):
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_1_3b()
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
d = config.to_dict()
|
d = config.to_dict()
|
||||||
assert d["dim"] == 1536
|
assert d["dim"] == 1536
|
||||||
@@ -120,6 +129,7 @@ class TestWan21Config:
|
|||||||
def test_default_config_is_wan22(self):
|
def test_default_config_is_wan22(self):
|
||||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
assert config.model_version == "2.2"
|
assert config.model_version == "2.2"
|
||||||
assert config.dual_model is True
|
assert config.dual_model is True
|
||||||
|
|||||||
@@ -3,17 +3,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Transformer Weight Conversion Tests
|
# Transformer Weight Conversion Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestSanitizeTransformerWeights:
|
class TestSanitizeTransformerWeights:
|
||||||
def test_patch_embedding_reshape(self):
|
def test_patch_embedding_reshape(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||||
@@ -25,6 +24,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_text_embedding_rename(self):
|
def test_text_embedding_rename(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
"text_embedding.0.bias": mx.zeros((64,)),
|
"text_embedding.0.bias": mx.zeros((64,)),
|
||||||
@@ -39,6 +39,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_time_embedding_rename(self):
|
def test_time_embedding_rename(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||||
@@ -49,6 +50,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_time_projection_rename(self):
|
def test_time_projection_rename(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||||
"time_projection.1.bias": mx.zeros((384,)),
|
"time_projection.1.bias": mx.zeros((384,)),
|
||||||
@@ -59,6 +61,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_ffn_rename(self):
|
def test_ffn_rename(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||||
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
||||||
@@ -73,6 +76,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_freqs_skipped(self):
|
def test_freqs_skipped(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"freqs": mx.zeros((1024, 64, 2)),
|
"freqs": mx.zeros((1024, 64, 2)),
|
||||||
"blocks.0.norm1.weight": mx.zeros((64,)),
|
"blocks.0.norm1.weight": mx.zeros((64,)),
|
||||||
@@ -83,6 +87,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_passthrough_keys(self):
|
def test_passthrough_keys(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||||
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
||||||
@@ -98,6 +103,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
|
|
||||||
def test_no_unconsumed_keys(self, caplog):
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||||
@@ -121,6 +127,7 @@ class TestSanitizeTransformerWeights:
|
|||||||
class TestSanitizeT5Weights:
|
class TestSanitizeT5Weights:
|
||||||
def test_gate_rename(self):
|
def test_gate_rename(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||||
@@ -133,6 +140,7 @@ class TestSanitizeT5Weights:
|
|||||||
|
|
||||||
def test_passthrough(self):
|
def test_passthrough(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"token_embedding.weight": mx.zeros((100, 64)),
|
"token_embedding.weight": mx.zeros((100, 64)),
|
||||||
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
||||||
@@ -144,6 +152,7 @@ class TestSanitizeT5Weights:
|
|||||||
|
|
||||||
def test_no_unconsumed_keys(self, caplog):
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"token_embedding.weight": mx.zeros((100, 64)),
|
"token_embedding.weight": mx.zeros((100, 64)),
|
||||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||||
@@ -159,6 +168,7 @@ class TestSanitizeT5Weights:
|
|||||||
class TestSanitizeVAEWeights:
|
class TestSanitizeVAEWeights:
|
||||||
def test_conv3d_transpose(self):
|
def test_conv3d_transpose(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||||
}
|
}
|
||||||
@@ -167,6 +177,7 @@ class TestSanitizeVAEWeights:
|
|||||||
|
|
||||||
def test_conv2d_transpose(self):
|
def test_conv2d_transpose(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||||
}
|
}
|
||||||
@@ -175,6 +186,7 @@ class TestSanitizeVAEWeights:
|
|||||||
|
|
||||||
def test_non_conv_passthrough(self):
|
def test_non_conv_passthrough(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||||
"decoder.bias": mx.zeros((16,)),
|
"decoder.bias": mx.zeros((16,)),
|
||||||
@@ -185,6 +197,7 @@ class TestSanitizeVAEWeights:
|
|||||||
|
|
||||||
def test_mixed_weights(self):
|
def test_mixed_weights(self):
|
||||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||||
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
||||||
@@ -199,6 +212,7 @@ class TestSanitizeVAEWeights:
|
|||||||
|
|
||||||
def test_no_unconsumed_keys(self, caplog):
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||||
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
|
|||||||
# Wan2.1 Conversion Tests
|
# Wan2.1 Conversion Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWan21Convert:
|
class TestWan21Convert:
|
||||||
"""Tests for Wan2.1 conversion support."""
|
"""Tests for Wan2.1 conversion support."""
|
||||||
|
|
||||||
@@ -222,7 +237,7 @@ class TestWan21Convert:
|
|||||||
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
||||||
(tmp_path / "dummy.safetensors").touch()
|
(tmp_path / "dummy.safetensors").touch()
|
||||||
# The auto-detect logic: no low_noise_model dir → 2.1
|
# The auto-detect logic: no low_noise_model dir → 2.1
|
||||||
from pathlib import Path
|
|
||||||
low = tmp_path / "low_noise_model"
|
low = tmp_path / "low_noise_model"
|
||||||
assert not low.exists()
|
assert not low.exists()
|
||||||
# Simulates auto detection
|
# Simulates auto detection
|
||||||
@@ -233,7 +248,7 @@ class TestWan21Convert:
|
|||||||
"""Auto-detect dual-model directory as Wan2.2."""
|
"""Auto-detect dual-model directory as Wan2.2."""
|
||||||
(tmp_path / "low_noise_model").mkdir()
|
(tmp_path / "low_noise_model").mkdir()
|
||||||
(tmp_path / "high_noise_model").mkdir()
|
(tmp_path / "high_noise_model").mkdir()
|
||||||
from pathlib import Path
|
|
||||||
low = tmp_path / "low_noise_model"
|
low = tmp_path / "low_noise_model"
|
||||||
assert low.exists()
|
assert low.exists()
|
||||||
version = "2.2" if low.exists() else "2.1"
|
version = "2.2" if low.exists() else "2.1"
|
||||||
@@ -242,6 +257,7 @@ class TestWan21Convert:
|
|||||||
def test_wan21_config_saved_correctly(self):
|
def test_wan21_config_saved_correctly(self):
|
||||||
"""Verify config dict has correct fields for Wan2.1."""
|
"""Verify config dict has correct fields for Wan2.1."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_14b()
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
d = config.to_dict()
|
d = config.to_dict()
|
||||||
assert d["model_version"] == "2.1"
|
assert d["model_version"] == "2.1"
|
||||||
@@ -254,6 +270,7 @@ class TestWan21Convert:
|
|||||||
# Encoder Weight Sanitization Tests
|
# Encoder Weight Sanitization Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestSanitizeEncoderWeights:
|
class TestSanitizeEncoderWeights:
|
||||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,13 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
from wan_test_helpers import _make_tiny_config
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: end-to-end tiny model forward pass
|
# Integration: end-to-end tiny model forward pass
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestEndToEnd:
|
class TestEndToEnd:
|
||||||
"""End-to-end test with tiny model (no real weights needed)."""
|
"""End-to-end test with tiny model (no real weights needed)."""
|
||||||
|
|
||||||
@@ -78,6 +76,7 @@ class TestEndToEnd:
|
|||||||
# I2V Mask Tests
|
# I2V Mask Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestI2VMask:
|
class TestI2VMask:
|
||||||
"""Tests for _build_i2v_mask."""
|
"""Tests for _build_i2v_mask."""
|
||||||
|
|
||||||
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
|
|||||||
def test_mask_with_ti2v_dimensions(self):
|
def test_mask_with_ti2v_dimensions(self):
|
||||||
"""Mask should work with TI2V-5B typical dimensions."""
|
"""Mask should work with TI2V-5B typical dimensions."""
|
||||||
from mlx_video.generate_wan import _build_i2v_mask
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
|
||||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||||
z_shape = (48, 21, 44, 80)
|
z_shape = (48, 21, 44, 80)
|
||||||
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
|
|||||||
def test_mask_per_token_timestep(self):
|
def test_mask_per_token_timestep(self):
|
||||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||||
from mlx_video.generate_wan import _build_i2v_mask
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
|
||||||
z_shape = (4, 3, 4, 4)
|
z_shape = (4, 3, 4, 4)
|
||||||
patch_size = (1, 2, 2)
|
patch_size = (1, 2, 2)
|
||||||
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||||
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
|
|||||||
|
|
||||||
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||||
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||||
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
|
np.testing.assert_allclose(
|
||||||
|
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Dimension Alignment Tests
|
# Dimension Alignment Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestDimensionAlignment:
|
class TestDimensionAlignment:
|
||||||
"""Tests for automatic dimension alignment in generate_wan."""
|
"""Tests for automatic dimension alignment in generate_wan."""
|
||||||
|
|
||||||
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
|
|||||||
def test_patchify_valid_after_alignment(self):
|
def test_patchify_valid_after_alignment(self):
|
||||||
"""After alignment, patchify should succeed without reshape errors."""
|
"""After alignment, patchify should succeed without reshape errors."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
|
|
||||||
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
|
|||||||
patches, grid_size = model._patchify(vid)
|
patches, grid_size = model._patchify(vid)
|
||||||
mx.eval(patches)
|
mx.eval(patches)
|
||||||
assert patches.ndim == 3 # [1, L, dim]
|
assert patches.ndim == 3 # [1, L, dim]
|
||||||
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
|
assert grid_size == (
|
||||||
|
t_latent,
|
||||||
|
h_latent // patch_size[1],
|
||||||
|
w_latent // patch_size[2],
|
||||||
|
)
|
||||||
|
|
||||||
def test_alignment_with_ti2v_config(self):
|
def test_alignment_with_ti2v_config(self):
|
||||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan22_ti2v_5b()
|
config = WanModelConfig.wan22_ti2v_5b()
|
||||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||||
align_w = config.patch_size[2] * config.vae_stride[2]
|
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
"""Tests for Wan2.2 I2V-14B support."""
|
"""Tests for Wan2.2 I2V-14B support."""
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from wan_test_helpers import _make_tiny_config
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +142,10 @@ class TestModelYParameter:
|
|||||||
latents = mx.random.normal((C_noise, F, H, W))
|
latents = mx.random.normal((C_noise, F, H, W))
|
||||||
y = mx.random.normal((C_y, F, H, W))
|
y = mx.random.normal((C_y, F, H, W))
|
||||||
t = mx.array([500.0, 500.0])
|
t = mx.array([500.0, 500.0])
|
||||||
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
|
ctx = [
|
||||||
|
mx.random.normal((6, config.text_dim)),
|
||||||
|
mx.random.normal((6, config.text_dim)),
|
||||||
|
]
|
||||||
|
|
||||||
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
||||||
mx.eval(out[0], out[1])
|
mx.eval(out[0], out[1])
|
||||||
@@ -160,7 +160,9 @@ class TestVAEEncoder:
|
|||||||
def test_encoder3d_instantiation(self):
|
def test_encoder3d_instantiation(self):
|
||||||
from mlx_video.models.wan.vae import Encoder3d
|
from mlx_video.models.wan.vae import Encoder3d
|
||||||
|
|
||||||
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
enc = Encoder3d(
|
||||||
|
dim=32, z_dim=8
|
||||||
|
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||||
assert enc.conv1 is not None
|
assert enc.conv1 is not None
|
||||||
assert len(enc.downsamples) > 0
|
assert len(enc.downsamples) > 0
|
||||||
assert len(enc.middle) == 3
|
assert len(enc.middle) == 3
|
||||||
@@ -199,10 +201,10 @@ class TestVAEEncoder:
|
|||||||
from mlx_video.models.wan.vae import WanVAE
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
|
||||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||||
assert not hasattr(vae_no_enc, 'encoder')
|
assert not hasattr(vae_no_enc, "encoder")
|
||||||
|
|
||||||
vae_enc = WanVAE(z_dim=4, encoder=True)
|
vae_enc = WanVAE(z_dim=4, encoder=True)
|
||||||
assert hasattr(vae_enc, 'encoder')
|
assert hasattr(vae_enc, "encoder")
|
||||||
|
|
||||||
|
|
||||||
class TestResampleDownsample:
|
class TestResampleDownsample:
|
||||||
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
|
|||||||
|
|
||||||
# Build mask following reference logic
|
# Build mask following reference logic
|
||||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
msk = mx.concatenate(
|
||||||
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||||
|
)
|
||||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||||
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
|
|||||||
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
||||||
|
|
||||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
msk = mx.concatenate(
|
||||||
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||||
|
)
|
||||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||||
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
||||||
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
|
|||||||
config = _make_tiny_i2v_config()
|
config = _make_tiny_i2v_config()
|
||||||
config.vae_z_dim = 16
|
config.vae_z_dim = 16
|
||||||
config.out_dim = 16 # must match VAE z_dim for decode
|
config.out_dim = 16 # must match VAE z_dim for decode
|
||||||
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
config.in_dim = (
|
||||||
|
16 + 4 + 16
|
||||||
|
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
|
|
||||||
# --- Tiny VAE (with encoder) ---
|
# --- Tiny VAE (with encoder) ---
|
||||||
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
|
|||||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||||
|
|
||||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||||
video = mx.concatenate([
|
video = mx.concatenate(
|
||||||
img,
|
[
|
||||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
img,
|
||||||
], axis=2)
|
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||||
|
],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
|
||||||
# --- VAE encode ---
|
# --- VAE encode ---
|
||||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||||
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
|
|||||||
|
|
||||||
# --- Build I2V mask (4 channels) ---
|
# --- Build I2V mask (4 channels) ---
|
||||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
msk = mx.concatenate(
|
||||||
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||||
|
)
|
||||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||||
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
|
|||||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||||
@@ -461,9 +476,9 @@ class TestDualModelSwitching:
|
|||||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||||
if high_used_steps and low_used_steps:
|
if high_used_steps and low_used_steps:
|
||||||
assert max(high_used_steps) < min(low_used_steps) or \
|
assert max(high_used_steps) < min(low_used_steps) or min(
|
||||||
min(high_used_steps) < max(low_used_steps), \
|
high_used_steps
|
||||||
"Model switching should happen during the loop"
|
) < max(low_used_steps), "Model switching should happen during the loop"
|
||||||
|
|
||||||
assert latents.shape == (C_noise, F, H, W)
|
assert latents.shape == (C_noise, F, H, W)
|
||||||
assert not mx.any(mx.isnan(latents)).item()
|
assert not mx.any(mx.isnan(latents)).item()
|
||||||
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
|
|||||||
y=[y_i2v, y_i2v],
|
y=[y_i2v, y_i2v],
|
||||||
)
|
)
|
||||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
# Verify both guide scales were used
|
# Verify both guide scales were used
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +39,9 @@ class TestLoRATypes:
|
|||||||
|
|
||||||
lora_a = mx.ones((2, 4))
|
lora_a = mx.ones((2, 4))
|
||||||
lora_b = mx.ones((8, 2))
|
lora_b = mx.ones((8, 2))
|
||||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
w = LoRAWeights(
|
||||||
|
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||||
|
)
|
||||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||||
delta = applied.compute_delta()
|
delta = applied.compute_delta()
|
||||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||||
@@ -51,7 +52,9 @@ class TestLoRATypes:
|
|||||||
class TestLoRALoader:
|
class TestLoRALoader:
|
||||||
"""Test LoRA weight loading from safetensors."""
|
"""Test LoRA weight loading from safetensors."""
|
||||||
|
|
||||||
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
def _make_lora_file(
|
||||||
|
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
|
||||||
|
):
|
||||||
"""Helper to create a mock LoRA safetensors file."""
|
"""Helper to create a mock LoRA safetensors file."""
|
||||||
weights = {}
|
weights = {}
|
||||||
for name in module_names:
|
for name in module_names:
|
||||||
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
|
|||||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||||
keys = set()
|
keys = set()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
for layer in [
|
||||||
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
"self_attn.q",
|
||||||
|
"self_attn.k",
|
||||||
|
"self_attn.v",
|
||||||
|
"self_attn.o",
|
||||||
|
"cross_attn.q",
|
||||||
|
"cross_attn.k",
|
||||||
|
"cross_attn.v",
|
||||||
|
"cross_attn.o",
|
||||||
|
]:
|
||||||
keys.add(f"blocks.{i}.{layer}.weight")
|
keys.add(f"blocks.{i}.{layer}.weight")
|
||||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||||
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
|
|||||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
keys = self._wan_model_keys()
|
keys = self._wan_model_keys()
|
||||||
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
assert (
|
||||||
|
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
|
||||||
|
== "blocks.0.self_attn.q"
|
||||||
|
)
|
||||||
|
|
||||||
def test_strip_diffusion_model_prefix(self):
|
def test_strip_diffusion_model_prefix(self):
|
||||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
|
|||||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
keys = self._wan_model_keys()
|
keys = self._wan_model_keys()
|
||||||
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
result = _normalize_wan_lora_key(
|
||||||
|
"model.diffusion_model.blocks.0.self_attn.k", keys
|
||||||
|
)
|
||||||
assert result == "blocks.0.self_attn.k"
|
assert result == "blocks.0.self_attn.k"
|
||||||
|
|
||||||
def test_ffn_key_mapping(self):
|
def test_ffn_key_mapping(self):
|
||||||
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
|
|||||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
keys = self._wan_model_keys()
|
keys = self._wan_model_keys()
|
||||||
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
assert (
|
||||||
|
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||||
|
)
|
||||||
|
|
||||||
def test_combined_prefix_and_ffn(self):
|
def test_combined_prefix_and_ffn(self):
|
||||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
@@ -219,7 +237,9 @@ class TestApplyLoRA:
|
|||||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
w = LoRAWeights(
|
||||||
|
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||||
|
)
|
||||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||||
|
|
||||||
@@ -230,7 +250,9 @@ class TestApplyLoRA:
|
|||||||
original = mx.ones((8, 4), dtype=mx.float16)
|
original = mx.ones((8, 4), dtype=mx.float16)
|
||||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
w = LoRAWeights(
|
||||||
|
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||||
|
)
|
||||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||||
|
|
||||||
@@ -241,7 +263,9 @@ class TestApplyLoRA:
|
|||||||
original = mx.ones((8, 4))
|
original = mx.ones((8, 4))
|
||||||
lora_a = mx.ones((2, 4)) * 0.1
|
lora_a = mx.ones((2, 4)) * 0.1
|
||||||
lora_b = mx.ones((8, 2)) * 0.1
|
lora_b = mx.ones((8, 2)) * 0.1
|
||||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
w = LoRAWeights(
|
||||||
|
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||||
|
)
|
||||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||||
expected = original + 0.02 * mx.ones((8, 4))
|
expected = original + 0.02 * mx.ones((8, 4))
|
||||||
@@ -255,12 +279,16 @@ class TestApplyLoRA:
|
|||||||
w1 = LoRAWeights(
|
w1 = LoRAWeights(
|
||||||
lora_A=mx.ones((2, 4)),
|
lora_A=mx.ones((2, 4)),
|
||||||
lora_B=mx.ones((8, 2)),
|
lora_B=mx.ones((8, 2)),
|
||||||
rank=2, alpha=2.0, module_name="a",
|
rank=2,
|
||||||
|
alpha=2.0,
|
||||||
|
module_name="a",
|
||||||
)
|
)
|
||||||
w2 = LoRAWeights(
|
w2 = LoRAWeights(
|
||||||
lora_A=mx.ones((2, 4)) * 2,
|
lora_A=mx.ones((2, 4)) * 2,
|
||||||
lora_B=mx.ones((8, 2)) * 2,
|
lora_B=mx.ones((8, 2)) * 2,
|
||||||
rank=2, alpha=4.0, module_name="b",
|
rank=2,
|
||||||
|
alpha=4.0,
|
||||||
|
module_name="b",
|
||||||
)
|
)
|
||||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||||
@@ -282,7 +310,9 @@ class TestApplyLoRA:
|
|||||||
w = LoRAWeights(
|
w = LoRAWeights(
|
||||||
lora_A=mx.ones((4, 64)) * 0.01,
|
lora_A=mx.ones((4, 64)) * 0.01,
|
||||||
lora_B=mx.ones((128, 4)) * 0.01,
|
lora_B=mx.ones((128, 4)) * 0.01,
|
||||||
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
rank=4,
|
||||||
|
alpha=4.0,
|
||||||
|
module_name="blocks.0.self_attn.q",
|
||||||
)
|
)
|
||||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||||
@@ -319,9 +349,7 @@ class TestEndToEnd:
|
|||||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||||
}
|
}
|
||||||
|
|
||||||
result = load_and_apply_loras(
|
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
|
||||||
model_weights, [(str(lora_path), 1.0)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# q weight should be modified, k unchanged
|
# q weight should be modified, k unchanged
|
||||||
assert not mx.array_equal(
|
assert not mx.array_equal(
|
||||||
|
|||||||
@@ -3,18 +3,17 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
from wan_test_helpers import _make_tiny_config
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Sinusoidal Embedding Tests
|
# Sinusoidal Embedding Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestSinusoidalEmbedding:
|
class TestSinusoidalEmbedding:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.arange(10).astype(mx.float32)
|
pos = mx.arange(10).astype(mx.float32)
|
||||||
emb = sinusoidal_embedding_1d(256, pos)
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
mx.eval(emb)
|
mx.eval(emb)
|
||||||
@@ -23,6 +22,7 @@ class TestSinusoidalEmbedding:
|
|||||||
def test_position_zero(self):
|
def test_position_zero(self):
|
||||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([0.0])
|
pos = mx.array([0.0])
|
||||||
emb = sinusoidal_embedding_1d(64, pos)
|
emb = sinusoidal_embedding_1d(64, pos)
|
||||||
mx.eval(emb)
|
mx.eval(emb)
|
||||||
@@ -34,6 +34,7 @@ class TestSinusoidalEmbedding:
|
|||||||
|
|
||||||
def test_different_positions_differ(self):
|
def test_different_positions_differ(self):
|
||||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([0.0, 100.0, 999.0])
|
pos = mx.array([0.0, 100.0, 999.0])
|
||||||
emb = sinusoidal_embedding_1d(128, pos)
|
emb = sinusoidal_embedding_1d(128, pos)
|
||||||
mx.eval(emb)
|
mx.eval(emb)
|
||||||
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
|
|||||||
# Head Tests
|
# Head Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestHead:
|
class TestHead:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.model import Head
|
from mlx_video.models.wan.model import Head
|
||||||
|
|
||||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
B, L = 1, 24
|
B, L = 1, 24
|
||||||
x = mx.random.normal((B, L, 64))
|
x = mx.random.normal((B, L, 64))
|
||||||
@@ -60,6 +63,7 @@ class TestHead:
|
|||||||
|
|
||||||
def test_modulation_shape(self):
|
def test_modulation_shape(self):
|
||||||
from mlx_video.models.wan.model import Head
|
from mlx_video.models.wan.model import Head
|
||||||
|
|
||||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
assert head.modulation.shape == (1, 2, 64)
|
assert head.modulation.shape == (1, 2, 64)
|
||||||
|
|
||||||
@@ -68,12 +72,14 @@ class TestHead:
|
|||||||
# WanModel (Tiny) Tests
|
# WanModel (Tiny) Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWanModel:
|
class TestWanModel:
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
|
|
||||||
def test_instantiation(self):
|
def test_instantiation(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||||
@@ -81,6 +87,7 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_patchify_shape(self):
|
def test_patchify_shape(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
# Input: [C=4, F=1, H=4, W=4]
|
# Input: [C=4, F=1, H=4, W=4]
|
||||||
@@ -93,6 +100,7 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_patchify_various_sizes(self):
|
def test_patchify_various_sizes(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||||
@@ -108,6 +116,7 @@ class TestWanModel:
|
|||||||
def test_unpatchify_inverse(self):
|
def test_unpatchify_inverse(self):
|
||||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
C, F, H, W = config.in_dim, 2, 4, 6
|
C, F, H, W = config.in_dim, 2, 4, 6
|
||||||
@@ -123,6 +132,7 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_forward_pass(self):
|
def test_forward_pass(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
C, F, H, W = config.in_dim, 1, 4, 4
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
@@ -140,6 +150,7 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_forward_batch(self):
|
def test_forward_batch(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
C, F, H, W = config.in_dim, 1, 4, 4
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
@@ -148,7 +159,10 @@ class TestWanModel:
|
|||||||
|
|
||||||
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||||
t = mx.array([500.0, 200.0])
|
t = mx.array([500.0, 200.0])
|
||||||
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
context = [
|
||||||
|
mx.random.normal((6, config.text_dim)),
|
||||||
|
mx.random.normal((4, config.text_dim)),
|
||||||
|
]
|
||||||
|
|
||||||
out = model(x_list, t, context, seq_len)
|
out = model(x_list, t, context, seq_len)
|
||||||
mx.eval(out[0], out[1])
|
mx.eval(out[0], out[1])
|
||||||
@@ -158,12 +172,17 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_output_is_float32(self):
|
def test_output_is_float32(self):
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
C, F, H, W = config.in_dim, 1, 4, 4
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||||
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
out = model(
|
||||||
[mx.random.normal((4, config.text_dim))], seq_len)
|
[mx.random.normal((C, F, H, W))],
|
||||||
|
mx.array([100.0]),
|
||||||
|
[mx.random.normal((4, config.text_dim))],
|
||||||
|
seq_len,
|
||||||
|
)
|
||||||
mx.eval(out[0])
|
mx.eval(out[0])
|
||||||
assert out[0].dtype == mx.float32
|
assert out[0].dtype == mx.float32
|
||||||
|
|
||||||
@@ -172,6 +191,7 @@ class TestWanModel:
|
|||||||
# Wan2.1 Model Tests
|
# Wan2.1 Model Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWan21Model:
|
class TestWan21Model:
|
||||||
"""Test tiny Wan2.1-style model (single model mode)."""
|
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||||
|
|
||||||
@@ -181,6 +201,7 @@ class TestWan21Model:
|
|||||||
def _make_tiny_wan21_config(self):
|
def _make_tiny_wan21_config(self):
|
||||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_14b()
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
# Override to tiny values
|
# Override to tiny values
|
||||||
config.dim = 64
|
config.dim = 64
|
||||||
@@ -197,6 +218,7 @@ class TestWan21Model:
|
|||||||
def _make_tiny_wan21_1_3b_config(self):
|
def _make_tiny_wan21_1_3b_config(self):
|
||||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig.wan21_t2v_1_3b()
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||||
config.dim = 48
|
config.dim = 48
|
||||||
@@ -271,7 +293,9 @@ class TestWan21Model:
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
t = sched.timesteps[i]
|
t = sched.timesteps[i]
|
||||||
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||||
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
pred_uncond = model(
|
||||||
|
[latents], mx.array([t.item()]), [context_null], seq_len
|
||||||
|
)[0]
|
||||||
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
@@ -304,6 +328,7 @@ class TestWan21Model:
|
|||||||
# Per-Token Timestep Tests
|
# Per-Token Timestep Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestPerTokenTimestep:
|
class TestPerTokenTimestep:
|
||||||
"""Tests for per-token sinusoidal embedding."""
|
"""Tests for per-token sinusoidal embedding."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
"""Tests for Wan model quantization pipeline."""
|
"""Tests for Wan model quantization pipeline."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
from wan_test_helpers import _make_tiny_config
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Quantize Predicate Tests
|
# Quantize Predicate Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizePredicate:
|
class TestQuantizePredicate:
|
||||||
def test_matches_self_attention_layers(self):
|
def test_matches_self_attention_layers(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_linear = nn.Linear(64, 64)
|
mock_linear = nn.Linear(64, 64)
|
||||||
for suffix in ["q", "k", "v", "o"]:
|
for suffix in ["q", "k", "v", "o"]:
|
||||||
path = f"blocks.0.self_attn.{suffix}"
|
path = f"blocks.0.self_attn.{suffix}"
|
||||||
@@ -24,6 +24,7 @@ class TestQuantizePredicate:
|
|||||||
|
|
||||||
def test_matches_cross_attention_layers(self):
|
def test_matches_cross_attention_layers(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_linear = nn.Linear(64, 64)
|
mock_linear = nn.Linear(64, 64)
|
||||||
for suffix in ["q", "k", "v", "o"]:
|
for suffix in ["q", "k", "v", "o"]:
|
||||||
path = f"blocks.0.cross_attn.{suffix}"
|
path = f"blocks.0.cross_attn.{suffix}"
|
||||||
@@ -31,23 +32,31 @@ class TestQuantizePredicate:
|
|||||||
|
|
||||||
def test_matches_ffn_layers(self):
|
def test_matches_ffn_layers(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_linear = nn.Linear(64, 64)
|
mock_linear = nn.Linear(64, 64)
|
||||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||||
|
|
||||||
def test_rejects_embeddings(self):
|
def test_rejects_embeddings(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_linear = nn.Linear(64, 64)
|
mock_linear = nn.Linear(64, 64)
|
||||||
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
|
for path in [
|
||||||
|
"patch_embedding_proj",
|
||||||
|
"text_embedding_fc1",
|
||||||
|
"time_embedding.fc1",
|
||||||
|
]:
|
||||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||||
|
|
||||||
def test_rejects_norms(self):
|
def test_rejects_norms(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_norm = nn.RMSNorm(64)
|
mock_norm = nn.RMSNorm(64)
|
||||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||||
|
|
||||||
def test_rejects_non_quantizable_modules(self):
|
def test_rejects_non_quantizable_modules(self):
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_norm = nn.RMSNorm(64)
|
mock_norm = nn.RMSNorm(64)
|
||||||
# Even if path matches, module must have to_quantized
|
# Even if path matches, module must have to_quantized
|
||||||
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
||||||
@@ -55,13 +64,19 @@ class TestQuantizePredicate:
|
|||||||
def test_all_10_patterns_covered(self):
|
def test_all_10_patterns_covered(self):
|
||||||
"""Verify exactly 10 layer patterns are targeted."""
|
"""Verify exactly 10 layer patterns are targeted."""
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
mock_linear = nn.Linear(64, 64)
|
mock_linear = nn.Linear(64, 64)
|
||||||
patterns = [
|
patterns = [
|
||||||
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
|
"blocks.0.self_attn.q",
|
||||||
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
|
"blocks.0.self_attn.k",
|
||||||
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
|
"blocks.0.self_attn.v",
|
||||||
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
|
"blocks.0.self_attn.o",
|
||||||
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
|
"blocks.0.cross_attn.q",
|
||||||
|
"blocks.0.cross_attn.k",
|
||||||
|
"blocks.0.cross_attn.v",
|
||||||
|
"blocks.0.cross_attn.o",
|
||||||
|
"blocks.0.ffn.fc1",
|
||||||
|
"blocks.0.ffn.fc2",
|
||||||
]
|
]
|
||||||
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
||||||
assert len(matched) == 10
|
assert len(matched) == 10
|
||||||
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
|
|||||||
# Quantize Round-Trip Tests
|
# Quantize Round-Trip Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizeRoundTrip:
|
class TestQuantizeRoundTrip:
|
||||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||||
"""Helper: create model, quantize, save to tmp_path."""
|
"""Helper: create model, quantize, save to tmp_path."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@@ -101,8 +117,10 @@ class TestQuantizeRoundTrip:
|
|||||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||||
|
|
||||||
from mlx_video.models.wan.loading import load_wan_model
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
|
||||||
loaded = load_wan_model(
|
loaded = load_wan_model(
|
||||||
model_path, config,
|
model_path,
|
||||||
|
config,
|
||||||
quantization={"bits": 4, "group_size": 64},
|
quantization={"bits": 4, "group_size": 64},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,8 +137,10 @@ class TestQuantizeRoundTrip:
|
|||||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||||
|
|
||||||
from mlx_video.models.wan.loading import load_wan_model
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
|
||||||
loaded = load_wan_model(
|
loaded = load_wan_model(
|
||||||
model_path, config,
|
model_path,
|
||||||
|
config,
|
||||||
quantization={"bits": 8, "group_size": 64},
|
quantization={"bits": 8, "group_size": 64},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -132,8 +152,10 @@ class TestQuantizeRoundTrip:
|
|||||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||||
|
|
||||||
from mlx_video.models.wan.loading import load_wan_model
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
|
||||||
loaded = load_wan_model(
|
loaded = load_wan_model(
|
||||||
model_path, config,
|
model_path,
|
||||||
|
config,
|
||||||
quantization={"bits": 4, "group_size": 64},
|
quantization={"bits": 4, "group_size": 64},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,6 +173,7 @@ class TestQuantizeRoundTrip:
|
|||||||
mx.save_safetensors(str(model_path), weights_dict)
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
|
||||||
from mlx_video.models.wan.loading import load_wan_model
|
from mlx_video.models.wan.loading import load_wan_model
|
||||||
|
|
||||||
loaded = load_wan_model(model_path, config, quantization=None)
|
loaded = load_wan_model(model_path, config, quantization=None)
|
||||||
|
|
||||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
||||||
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
|
|||||||
# Quantized Inference Tests
|
# Quantized Inference Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedInference:
|
class TestQuantizedInference:
|
||||||
def _make_quantized_model(self, config, bits=4):
|
def _make_quantized_model(self, config, bits=4):
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@@ -214,8 +238,8 @@ class TestQuantizedInference:
|
|||||||
|
|
||||||
def test_quantized_output_differs_from_unquantized(self):
|
def test_quantized_output_differs_from_unquantized(self):
|
||||||
"""Sanity check: quantization should change the weights."""
|
"""Sanity check: quantization should change the weights."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
@@ -243,11 +267,12 @@ class TestQuantizedInference:
|
|||||||
# Config Metadata Tests
|
# Config Metadata Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizationConfig:
|
class TestQuantizationConfig:
|
||||||
def test_config_metadata_written(self, tmp_path):
|
def test_config_metadata_written(self, tmp_path):
|
||||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_saved_model
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
|
|||||||
assert cfg["quantization"]["group_size"] == 64
|
assert cfg["quantization"]["group_size"] == 64
|
||||||
|
|
||||||
def test_config_metadata_8bit(self, tmp_path):
|
def test_config_metadata_8bit(self, tmp_path):
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_saved_model
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
|
|||||||
|
|
||||||
def test_dual_model_quantization(self, tmp_path):
|
def test_dual_model_quantization(self, tmp_path):
|
||||||
"""Verify dual-model quantization writes both model files."""
|
"""Verify dual-model quantization writes both model files."""
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
from mlx_video.convert_wan import _quantize_saved_model
|
from mlx_video.convert_wan import _quantize_saved_model
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
|
|
||||||
|
|||||||
@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
|
|||||||
|
|
||||||
d = 128 # head_dim for all Wan models
|
d = 128 # head_dim for all Wan models
|
||||||
# Reference: three separate calls
|
# Reference: three separate calls
|
||||||
correct = mx.concatenate([
|
correct = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
# Wrong: single call
|
# Wrong: single call
|
||||||
wrong = rope_params(1024, d)
|
wrong = rope_params(1024, d)
|
||||||
mx.eval(correct, wrong)
|
mx.eval(correct, wrong)
|
||||||
|
|
||||||
assert correct.shape == wrong.shape
|
assert correct.shape == wrong.shape
|
||||||
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
||||||
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
|
assert (
|
||||||
|
diff > 0.1
|
||||||
|
), f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||||
|
|
||||||
def test_each_axis_starts_at_frequency_one(self):
|
def test_each_axis_starts_at_frequency_one(self):
|
||||||
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
||||||
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
|
|||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
f = np.array(freqs)
|
f = np.array(freqs)
|
||||||
|
|
||||||
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
|
|||||||
|
|
||||||
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
||||||
# Temporal axis first freq
|
# Temporal axis first freq
|
||||||
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
|
np.testing.assert_allclose(
|
||||||
err_msg="temporal[0] cos at pos 1")
|
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
|
||||||
|
)
|
||||||
# Height axis first freq (starts at index d_t)
|
# Height axis first freq (starts at index d_t)
|
||||||
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
|
np.testing.assert_allclose(
|
||||||
err_msg="height[0] cos at pos 1")
|
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
|
||||||
|
)
|
||||||
# Width axis first freq (starts at index d_t + d_h)
|
# Width axis first freq (starts at index d_t + d_h)
|
||||||
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
|
np.testing.assert_allclose(
|
||||||
err_msg="width[0] cos at pos 1")
|
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
|
||||||
|
)
|
||||||
|
|
||||||
def test_height_width_frequencies_identical(self):
|
def test_height_width_frequencies_identical(self):
|
||||||
"""Height and width axes should have identical frequency tables.
|
"""Height and width axes should have identical frequency tables.
|
||||||
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
|
|||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
d_h_dim = 2 * (d // 6) # 42
|
d_h_dim = 2 * (d // 6) # 42
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, d_h_dim),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, d_h_dim),
|
rope_params(1024, d_h_dim),
|
||||||
], axis=1)
|
rope_params(1024, d_h_dim),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
f = np.array(freqs)
|
f = np.array(freqs)
|
||||||
|
|
||||||
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
|
|||||||
d_t = half_d - 2 * (half_d // 3)
|
d_t = half_d - 2 * (half_d // 3)
|
||||||
d_h = half_d // 3
|
d_h = half_d // 3
|
||||||
|
|
||||||
height_freqs = f[:, d_t:d_t + d_h]
|
height_freqs = f[:, d_t : d_t + d_h]
|
||||||
width_freqs = f[:, d_t + d_h:]
|
width_freqs = f[:, d_t + d_h :]
|
||||||
np.testing.assert_array_equal(height_freqs, width_freqs)
|
np.testing.assert_array_equal(height_freqs, width_freqs)
|
||||||
|
|
||||||
def test_frequency_range_per_axis(self):
|
def test_frequency_range_per_axis(self):
|
||||||
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
|
|||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
f = np.array(freqs)
|
f = np.array(freqs)
|
||||||
|
|
||||||
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
|
|||||||
pos1_h = f[1, d_t, 0] # height first freq
|
pos1_h = f[1, d_t, 0] # height first freq
|
||||||
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
||||||
|
|
||||||
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
assert (
|
||||||
|
pos1_t > 0.5
|
||||||
|
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||||
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
||||||
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
||||||
|
|
||||||
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
|
|||||||
|
|
||||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||||
d = head_dim # 16
|
d = head_dim # 16
|
||||||
freqs_manual = mx.concatenate([
|
freqs_manual = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(freqs_model, freqs_manual)
|
mx.eval(freqs_model, freqs_manual)
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
np.array(freqs_model), np.array(freqs_manual),
|
np.array(freqs_model),
|
||||||
err_msg="WanModel.freqs should use three-call construction"
|
np.array(freqs_manual),
|
||||||
|
err_msg="WanModel.freqs should use three-call construction",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_model_freqs_14b_dimensions(self):
|
def test_model_freqs_14b_dimensions(self):
|
||||||
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
|
|||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
[
|
||||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(freqs)
|
mx.eval(freqs)
|
||||||
|
|
||||||
assert freqs.shape == (1024, 64, 2)
|
assert freqs.shape == (1024, 64, 2)
|
||||||
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def has_torch(self):
|
def has_torch(self):
|
||||||
try:
|
try:
|
||||||
import torch
|
pass
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.skip("PyTorch not installed")
|
pytest.skip("PyTorch not installed")
|
||||||
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
|
|||||||
def test_freqs_match_pytorch_reference(self, has_torch):
|
def test_freqs_match_pytorch_reference(self, has_torch):
|
||||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
|
|||||||
def pt_rope_params(max_seq_len, dim, theta=10000):
|
def pt_rope_params(max_seq_len, dim, theta=10000):
|
||||||
freqs = torch.outer(
|
freqs = torch.outer(
|
||||||
torch.arange(max_seq_len),
|
torch.arange(max_seq_len),
|
||||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
1.0
|
||||||
|
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
||||||
|
)
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
ref = torch.cat([
|
ref = torch.cat(
|
||||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
pt_rope_params(1024, 2 * (d // 6)),
|
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||||
pt_rope_params(1024, 2 * (d // 6)),
|
pt_rope_params(1024, 2 * (d // 6)),
|
||||||
], dim=1)
|
pt_rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
# MLX
|
# MLX
|
||||||
ours = mx.concatenate([
|
ours = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
mx.eval(ours)
|
mx.eval(ours)
|
||||||
|
|
||||||
our_cos = np.array(ours[:, :, 0])
|
our_cos = np.array(ours[:, :, 0])
|
||||||
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
|
|||||||
ref_cos = ref.real.float().numpy()
|
ref_cos = ref.real.float().numpy()
|
||||||
ref_sin = ref.imag.float().numpy()
|
ref_sin = ref.imag.float().numpy()
|
||||||
|
|
||||||
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
|
np.testing.assert_allclose(
|
||||||
err_msg="cos mismatch vs PyTorch reference")
|
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
|
||||||
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
|
)
|
||||||
err_msg="sin mismatch vs PyTorch reference")
|
np.testing.assert_allclose(
|
||||||
|
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestRoPEApplyWithCorrectFreqs:
|
class TestRoPEApplyWithCorrectFreqs:
|
||||||
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
|
|||||||
This is the key property that was broken by the single-call bug:
|
This is the key property that was broken by the single-call bug:
|
||||||
height/width frequencies were too low to distinguish nearby positions.
|
height/width frequencies were too low to distinguish nearby positions.
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
B, N = 1, 4
|
B, N = 1, 4
|
||||||
F, H, W = 1, 4, 4
|
F, H, W = 1, 4, 4
|
||||||
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
|
|||||||
|
|
||||||
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
||||||
# and width was ~0.002. With correct freqs, both are ~1.3.
|
# and width was ~0.002. With correct freqs, both are ~1.3.
|
||||||
assert height_diff > 0.5, (
|
assert (
|
||||||
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
height_diff > 0.5
|
||||||
)
|
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||||
assert width_diff > 0.5, (
|
assert (
|
||||||
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
width_diff > 0.5
|
||||||
)
|
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||||
# Height and width should have identical frequency tables → same diffs
|
# Height and width should have identical frequency tables → same diffs
|
||||||
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
|
np.testing.assert_allclose(
|
||||||
err_msg="Height and width should use identical frequency tables")
|
height_diff,
|
||||||
|
width_diff,
|
||||||
|
rtol=1e-5,
|
||||||
|
err_msg="Height and width should use identical frequency tables",
|
||||||
|
)
|
||||||
|
|
||||||
def test_precomputed_matches_online(self):
|
def test_precomputed_matches_online(self):
|
||||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||||
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
d = 128
|
d = 128
|
||||||
freqs = mx.concatenate([
|
freqs = mx.concatenate(
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
rope_params(1024, 2 * (d // 6)),
|
rope_params(1024, 2 * (d // 6)),
|
||||||
], axis=1)
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
B, N = 2, 4
|
B, N = 2, 4
|
||||||
F, H, W = 2, 3, 4
|
F, H, W = 2, 3, 4
|
||||||
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
|
|||||||
mx.eval(out_online, out_precomp)
|
mx.eval(out_online, out_precomp)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(out_online), np.array(out_precomp), atol=1e-5,
|
np.array(out_online),
|
||||||
err_msg="Precomputed and online RoPE should match"
|
np.array(out_precomp),
|
||||||
|
atol=1e-5,
|
||||||
|
err_msg="Precomputed and online RoPE should match",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,14 +6,15 @@ import mlx.core as mx
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Euler Scheduler Tests
|
# Euler Scheduler Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestFlowMatchEulerScheduler:
|
class TestFlowMatchEulerScheduler:
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
assert sched.num_train_timesteps == 1000
|
assert sched.num_train_timesteps == 1000
|
||||||
assert sched.timesteps is None
|
assert sched.timesteps is None
|
||||||
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_set_timesteps(self):
|
def test_set_timesteps(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(40, shift=12.0)
|
sched.set_timesteps(40, shift=12.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_timesteps_decreasing(self):
|
def test_timesteps_decreasing(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(40, shift=12.0)
|
sched.set_timesteps(40, shift=12.0)
|
||||||
mx.eval(sched.timesteps)
|
mx.eval(sched.timesteps)
|
||||||
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_sigmas_decreasing(self):
|
def test_sigmas_decreasing(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(20, shift=1.0)
|
sched.set_timesteps(20, shift=1.0)
|
||||||
mx.eval(sched.sigmas)
|
mx.eval(sched.sigmas)
|
||||||
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_terminal_sigma_is_zero(self):
|
def test_terminal_sigma_is_zero(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(20, shift=5.0)
|
sched.set_timesteps(20, shift=5.0)
|
||||||
mx.eval(sched.sigmas)
|
mx.eval(sched.sigmas)
|
||||||
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
def test_shift_effect(self):
|
def test_shift_effect(self):
|
||||||
"""Larger shift should push sigmas toward higher values."""
|
"""Larger shift should push sigmas toward higher values."""
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched1 = FlowMatchEulerScheduler()
|
sched1 = FlowMatchEulerScheduler()
|
||||||
sched2 = FlowMatchEulerScheduler()
|
sched2 = FlowMatchEulerScheduler()
|
||||||
sched1.set_timesteps(20, shift=1.0)
|
sched1.set_timesteps(20, shift=1.0)
|
||||||
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_step_euler(self):
|
def test_step_euler(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(10, shift=1.0)
|
sched.set_timesteps(10, shift=1.0)
|
||||||
mx.eval(sched.sigmas)
|
mx.eval(sched.sigmas)
|
||||||
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
|
|||||||
# Euler: x_next = x + (sigma_next - sigma) * v
|
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||||
expected = 1.0 + (sigma_next - sigma) * 0.5
|
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(result).flatten()[0], expected, rtol=1e-4,
|
np.array(result).flatten()[0],
|
||||||
|
expected,
|
||||||
|
rtol=1e-4,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_step_index_increments(self):
|
def test_step_index_increments(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
assert sched._step_index == 0
|
assert sched._step_index == 0
|
||||||
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
|
|
||||||
def test_reset(self):
|
def test_reset(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 1, 1))
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||||
def test_various_step_counts(self, steps):
|
def test_various_step_counts(self, steps):
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(steps, shift=12.0)
|
sched.set_timesteps(steps, shift=12.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
|
|||||||
def test_full_denoise_loop(self):
|
def test_full_denoise_loop(self):
|
||||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
sched = FlowMatchEulerScheduler()
|
sched = FlowMatchEulerScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 2, 1, 2, 2))
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
@@ -141,22 +154,26 @@ class TestComputeSigmas:
|
|||||||
|
|
||||||
def test_length(self):
|
def test_length(self):
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
sigmas = _compute_sigmas(20, shift=5.0)
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
assert len(sigmas) == 21 # num_steps + terminal
|
assert len(sigmas) == 21 # num_steps + terminal
|
||||||
|
|
||||||
def test_terminal_zero(self):
|
def test_terminal_zero(self):
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
sigmas = _compute_sigmas(10, shift=1.0)
|
sigmas = _compute_sigmas(10, shift=1.0)
|
||||||
assert sigmas[-1] == 0.0
|
assert sigmas[-1] == 0.0
|
||||||
|
|
||||||
def test_starts_near_one(self):
|
def test_starts_near_one(self):
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
sigmas = _compute_sigmas(20, shift=5.0)
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||||
|
|
||||||
def test_decreasing(self):
|
def test_decreasing(self):
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
sigmas = _compute_sigmas(20, shift=5.0)
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
assert np.all(np.diff(sigmas) <= 0)
|
assert np.all(np.diff(sigmas) <= 0)
|
||||||
|
|
||||||
@@ -169,6 +186,7 @@ class TestComputeSigmas:
|
|||||||
shift is applied only once (single-shift).
|
shift is applied only once (single-shift).
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
steps, shift, N = 50, 5.0, 1000
|
steps, shift, N = 50, 5.0, 1000
|
||||||
sigmas = _compute_sigmas(steps, shift, N)
|
sigmas = _compute_sigmas(steps, shift, N)
|
||||||
# Official single-shift: unshifted bounds, then shift once
|
# Official single-shift: unshifted bounds, then shift once
|
||||||
@@ -183,6 +201,7 @@ class TestComputeSigmas:
|
|||||||
|
|
||||||
def test_shift_one_is_near_linear(self):
|
def test_shift_one_is_near_linear(self):
|
||||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
sigmas = _compute_sigmas(10, shift=1.0)
|
sigmas = _compute_sigmas(10, shift=1.0)
|
||||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||||
# so schedule is nearly linear from ~0.999 to 0
|
# so schedule is nearly linear from ~0.999 to 0
|
||||||
@@ -196,6 +215,7 @@ class TestComputeSigmas:
|
|||||||
FlowMatchEulerScheduler,
|
FlowMatchEulerScheduler,
|
||||||
FlowUniPCScheduler,
|
FlowUniPCScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheds = [
|
scheds = [
|
||||||
FlowMatchEulerScheduler(1000),
|
FlowMatchEulerScheduler(1000),
|
||||||
FlowDPMPP2MScheduler(1000),
|
FlowDPMPP2MScheduler(1000),
|
||||||
@@ -214,6 +234,7 @@ class TestComputeSigmas:
|
|||||||
FlowMatchEulerScheduler,
|
FlowMatchEulerScheduler,
|
||||||
FlowUniPCScheduler,
|
FlowUniPCScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheds = [
|
scheds = [
|
||||||
FlowMatchEulerScheduler(1000),
|
FlowMatchEulerScheduler(1000),
|
||||||
FlowDPMPP2MScheduler(1000),
|
FlowDPMPP2MScheduler(1000),
|
||||||
@@ -235,12 +256,14 @@ class TestComputeSigmas:
|
|||||||
class TestFlowDPMPP2MScheduler:
|
class TestFlowDPMPP2MScheduler:
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
assert sched.num_train_timesteps == 1000
|
assert sched.num_train_timesteps == 1000
|
||||||
assert sched.lower_order_final is True
|
assert sched.lower_order_final is True
|
||||||
|
|
||||||
def test_set_timesteps(self):
|
def test_set_timesteps(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(20, shift=5.0)
|
sched.set_timesteps(20, shift=5.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
|
|
||||||
def test_step_index_increments(self):
|
def test_step_index_increments(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 4, 1, 2, 2))
|
sample = mx.ones((1, 4, 1, 2, 2))
|
||||||
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
|
|
||||||
def test_reset(self):
|
def test_reset(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 1, 1))
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
def test_full_loop_finite(self):
|
def test_full_loop_finite(self):
|
||||||
"""Full loop with constant velocity should produce finite output."""
|
"""Full loop with constant velocity should produce finite output."""
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(10, shift=1.0)
|
sched.set_timesteps(10, shift=1.0)
|
||||||
sample = mx.ones((1, 2, 1, 2, 2))
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
def test_first_step_is_first_order(self):
|
def test_first_step_is_first_order(self):
|
||||||
"""First step should use 1st-order (no prev_x0 available)."""
|
"""First step should use 1st-order (no prev_x0 available)."""
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(10, shift=5.0)
|
sched.set_timesteps(10, shift=5.0)
|
||||||
sample = mx.random.normal((1, 4, 2, 4, 4))
|
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||||
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
def test_second_step_uses_correction(self):
|
def test_second_step_uses_correction(self):
|
||||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(10, shift=5.0)
|
sched.set_timesteps(10, shift=5.0)
|
||||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||||
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
x0_after_second = sched._prev_x0
|
x0_after_second = sched._prev_x0
|
||||||
assert x0_after_second is not None
|
assert x0_after_second is not None
|
||||||
# The stored x0 should differ from the first step's
|
# The stored x0 should differ from the first step's
|
||||||
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
|
assert not np.allclose(
|
||||||
|
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
def test_denoise_to_target(self):
|
def test_denoise_to_target(self):
|
||||||
"""Perfect oracle should denoise to target with any solver."""
|
"""Perfect oracle should denoise to target with any solver."""
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(20, shift=5.0)
|
sched.set_timesteps(20, shift=5.0)
|
||||||
target = mx.zeros((1, 2, 1, 4, 4))
|
target = mx.zeros((1, 2, 1, 4, 4))
|
||||||
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||||
def test_various_step_counts(self, steps):
|
def test_various_step_counts(self, steps):
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(steps, shift=5.0)
|
sched.set_timesteps(steps, shift=5.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
def test_terminal_sigma_produces_x0(self):
|
def test_terminal_sigma_produces_x0(self):
|
||||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
sched = FlowDPMPP2MScheduler()
|
sched = FlowDPMPP2MScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||||
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
|
|||||||
class TestFlowUniPCScheduler:
|
class TestFlowUniPCScheduler:
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
assert sched.num_train_timesteps == 1000
|
assert sched.num_train_timesteps == 1000
|
||||||
assert sched.solver_order == 2
|
assert sched.solver_order == 2
|
||||||
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
|
|||||||
|
|
||||||
def test_set_timesteps(self):
|
def test_set_timesteps(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(30, shift=12.0)
|
sched.set_timesteps(30, shift=12.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
|
|||||||
|
|
||||||
def test_step_index_increments(self):
|
def test_step_index_increments(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 1, 1))
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
|
|||||||
|
|
||||||
def test_reset(self):
|
def test_reset(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 1, 1))
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
|
|||||||
|
|
||||||
def test_full_loop_finite(self):
|
def test_full_loop_finite(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(10, shift=1.0)
|
sched.set_timesteps(10, shift=1.0)
|
||||||
sample = mx.ones((1, 2, 1, 2, 2))
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
|
|||||||
def test_corrector_not_applied_first_step(self):
|
def test_corrector_not_applied_first_step(self):
|
||||||
"""First step should skip the corrector (no history)."""
|
"""First step should skip the corrector (no history)."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler(use_corrector=True)
|
sched = FlowUniPCScheduler(use_corrector=True)
|
||||||
sched.set_timesteps(10, shift=5.0)
|
sched.set_timesteps(10, shift=5.0)
|
||||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||||
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
|
|||||||
def test_corrector_applied_after_first_step(self):
|
def test_corrector_applied_after_first_step(self):
|
||||||
"""Steps after the first should use the corrector when enabled."""
|
"""Steps after the first should use the corrector when enabled."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler(use_corrector=True)
|
sched = FlowUniPCScheduler(use_corrector=True)
|
||||||
sched.set_timesteps(10, shift=5.0)
|
sched.set_timesteps(10, shift=5.0)
|
||||||
sample = mx.random.normal((1, 2, 1, 4, 4))
|
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||||
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
|
|||||||
|
|
||||||
def test_denoise_to_target(self):
|
def test_denoise_to_target(self):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(20, shift=5.0)
|
sched.set_timesteps(20, shift=5.0)
|
||||||
target = mx.zeros((1, 2, 1, 4, 4))
|
target = mx.zeros((1, 2, 1, 4, 4))
|
||||||
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
|
|||||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||||
def test_various_step_counts(self, steps):
|
def test_various_step_counts(self, steps):
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
sched.set_timesteps(steps, shift=5.0)
|
sched.set_timesteps(steps, shift=5.0)
|
||||||
mx.eval(sched.timesteps, sched.sigmas)
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
|
|||||||
def test_disable_corrector(self):
|
def test_disable_corrector(self):
|
||||||
"""Disabling corrector on step 0 should still work without error."""
|
"""Disabling corrector on step 0 should still work without error."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||||
sched.set_timesteps(5, shift=1.0)
|
sched.set_timesteps(5, shift=1.0)
|
||||||
sample = mx.ones((1, 1, 1, 2, 2))
|
sample = mx.ones((1, 1, 1, 2, 2))
|
||||||
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
|
|||||||
def test_solver_order_3(self):
|
def test_solver_order_3(self):
|
||||||
"""Order 3 should work without error."""
|
"""Order 3 should work without error."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||||
sched.set_timesteps(10, shift=5.0)
|
sched.set_timesteps(10, shift=5.0)
|
||||||
sample = mx.random.normal((1, 2, 1, 2, 2))
|
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||||
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
|
|||||||
def test_corrector_rhos_c_not_hardcoded(self):
|
def test_corrector_rhos_c_not_hardcoded(self):
|
||||||
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||||
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
|
|||||||
rhos_c = np.linalg.solve(R, b)
|
rhos_c = np.linalg.solve(R, b)
|
||||||
|
|
||||||
# History weight should be small (~0.07-0.09), not 0.5
|
# History weight should be small (~0.07-0.09), not 0.5
|
||||||
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
assert (
|
||||||
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
rhos_c[0] < 0.15
|
||||||
|
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||||
|
assert (
|
||||||
|
rhos_c[0] > 0.0
|
||||||
|
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||||
# D1_t weight should be ~0.42-0.45, not 0.5
|
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||||
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
assert (
|
||||||
|
0.3 < rhos_c[1] < 0.5
|
||||||
|
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Scheduler Coherence Tests
|
# Scheduler Coherence Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulerCoherence:
|
class TestSchedulerCoherence:
|
||||||
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||||
|
|
||||||
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
|
|||||||
results[name] = np.array(r)
|
results[name] = np.array(r)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
results["dpm++"], results["euler"], atol=1e-5,
|
results["dpm++"],
|
||||||
|
results["euler"],
|
||||||
|
atol=1e-5,
|
||||||
err_msg="DPM++ step 0 should match Euler",
|
err_msg="DPM++ step 0 should match Euler",
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
results["unipc"], results["euler"], atol=1e-5,
|
results["unipc"],
|
||||||
|
results["euler"],
|
||||||
|
atol=1e-5,
|
||||||
err_msg="UniPC step 0 should match Euler",
|
err_msg="UniPC step 0 should match Euler",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
|
|||||||
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||||
mx.eval(euler_r, dpm_r, unipc_r)
|
mx.eval(euler_r, dpm_r, unipc_r)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(dpm_r), np.array(euler_r), atol=1e-5,
|
np.array(dpm_r),
|
||||||
|
np.array(euler_r),
|
||||||
|
atol=1e-5,
|
||||||
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(unipc_r), np.array(euler_r), atol=1e-5,
|
np.array(unipc_r),
|
||||||
|
np.array(euler_r),
|
||||||
|
atol=1e-5,
|
||||||
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
|
|||||||
latents = sched.step(v, sched.timesteps[i], latents)
|
latents = sched.step(v, sched.timesteps[i], latents)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(latents), 0.0, atol=1e-3,
|
np.array(latents),
|
||||||
|
0.0,
|
||||||
|
atol=1e-3,
|
||||||
err_msg=f"{name} did not converge to target with oracle",
|
err_msg=f"{name} did not converge to target with oracle",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
|
|||||||
# Higher-order solvers should not be significantly worse than Euler
|
# Higher-order solvers should not be significantly worse than Euler
|
||||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
|
assert (
|
||||||
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
errors["dpm++"] <= errors["euler"] * 1.5 + eps
|
||||||
)
|
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||||
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
|
assert (
|
||||||
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
errors["unipc"] <= errors["euler"] * 1.5 + eps
|
||||||
)
|
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||||
|
|
||||||
def test_multistep_trajectory_similar_magnitude(self):
|
def test_multistep_trajectory_similar_magnitude(self):
|
||||||
"""Over a full denoising loop with constant velocity, all solvers
|
"""Over a full denoising loop with constant velocity, all solvers
|
||||||
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
|
|||||||
# All solvers should produce results within the same order of magnitude
|
# All solvers should produce results within the same order of magnitude
|
||||||
vals = list(final_means.values())
|
vals = list(final_means.values())
|
||||||
ratio = max(vals) / max(min(vals), 1e-10)
|
ratio = max(vals) / max(min(vals), 1e-10)
|
||||||
assert ratio < 10.0, (
|
assert (
|
||||||
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
ratio < 10.0
|
||||||
)
|
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||||
|
|
||||||
def test_intermediate_values_finite(self):
|
def test_intermediate_values_finite(self):
|
||||||
"""Every intermediate latent value must be finite for all solvers."""
|
"""Every intermediate latent value must be finite for all solvers."""
|
||||||
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
|
|||||||
vel = mx.random.normal(shape)
|
vel = mx.random.normal(shape)
|
||||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
assert np.isfinite(np.array(latents)).all(), (
|
assert np.isfinite(
|
||||||
f"{name} produced non-finite values at step {i}"
|
np.array(latents)
|
||||||
)
|
).all(), f"{name} produced non-finite values at step {i}"
|
||||||
|
|
||||||
def test_lambda_boundary_values(self):
|
def test_lambda_boundary_values(self):
|
||||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||||
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||||
assert cls._lambda(1.0) == -math.inf, (
|
assert (
|
||||||
f"{cls.__name__}._lambda(1.0) should be -inf"
|
cls._lambda(1.0) == -math.inf
|
||||||
)
|
), f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||||
assert cls._lambda(0.0) == math.inf, (
|
assert (
|
||||||
f"{cls.__name__}._lambda(0.0) should be +inf"
|
cls._lambda(0.0) == math.inf
|
||||||
)
|
), f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||||
# Interior values should be finite
|
# Interior values should be finite
|
||||||
lam = cls._lambda(0.5)
|
lam = cls._lambda(0.5)
|
||||||
assert math.isfinite(lam) and lam == 0.0, (
|
assert (
|
||||||
f"{cls.__name__}._lambda(0.5) should be 0.0"
|
math.isfinite(lam) and lam == 0.0
|
||||||
)
|
), f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||||
|
|
||||||
def test_lambda_monotonically_decreasing(self):
|
def test_lambda_monotonically_decreasing(self):
|
||||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||||
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
|
|||||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||||
mx.eval(result)
|
mx.eval(result)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(result), np.array(expected), atol=5e-4,
|
np.array(result),
|
||||||
|
np.array(expected),
|
||||||
|
atol=5e-4,
|
||||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
|
|||||||
results[name] = np.array(r)
|
results[name] = np.array(r)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
results["dpm++"], results["euler"], atol=1e-5,
|
results["dpm++"],
|
||||||
|
results["euler"],
|
||||||
|
atol=1e-5,
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
results["unipc"], results["euler"], atol=1e-5,
|
results["unipc"],
|
||||||
|
results["euler"],
|
||||||
|
atol=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dpmpp_unipc_agree_on_step1(self):
|
def test_dpmpp_unipc_agree_on_step1(self):
|
||||||
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
|
|||||||
shape = (1, 2, 1, 2, 2)
|
shape = (1, 2, 1, 2, 2)
|
||||||
noise = mx.random.normal(shape)
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import (
|
||||||
|
FlowDPMPP2MScheduler,
|
||||||
|
FlowUniPCScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||||
sched = cls()
|
sched = cls()
|
||||||
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
|
|||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
result2 = np.array(latents)
|
result2 = np.array(latents)
|
||||||
|
|
||||||
np.testing.assert_allclose(result1, result2, atol=1e-5,
|
np.testing.assert_allclose(
|
||||||
err_msg=f"{cls.__name__} not reproducible after reset()")
|
result1,
|
||||||
|
result2,
|
||||||
|
atol=1e-5,
|
||||||
|
err_msg=f"{cls.__name__} not reproducible after reset()",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# UniPC Corrector Default Tests
|
# UniPC Corrector Default Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestUniPCCorrectorDefault:
|
class TestUniPCCorrectorDefault:
|
||||||
"""Tests that the UniPC corrector is enabled by default,
|
"""Tests that the UniPC corrector is enabled by default,
|
||||||
matching official FlowUniPCMultistepScheduler behavior."""
|
matching official FlowUniPCMultistepScheduler behavior."""
|
||||||
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
|
|||||||
def test_corrector_enabled_by_default(self):
|
def test_corrector_enabled_by_default(self):
|
||||||
"""Default construction should have corrector enabled."""
|
"""Default construction should have corrector enabled."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
sched = FlowUniPCScheduler()
|
sched = FlowUniPCScheduler()
|
||||||
assert sched._use_corrector is True
|
assert sched._use_corrector is True
|
||||||
|
|
||||||
def test_corrector_affects_output(self):
|
def test_corrector_affects_output(self):
|
||||||
"""Corrector should produce different results than no corrector after step 1."""
|
"""Corrector should produce different results than no corrector after step 1."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
shape = (1, 4, 1, 4, 4)
|
shape = (1, 4, 1, 4, 4)
|
||||||
noise = mx.random.normal(shape)
|
noise = mx.random.normal(shape)
|
||||||
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
|
|||||||
def test_corrector_does_not_affect_first_step(self):
|
def test_corrector_does_not_affect_first_step(self):
|
||||||
"""Step 0 should be identical regardless of corrector setting."""
|
"""Step 0 should be identical regardless of corrector setting."""
|
||||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
shape = (1, 4, 1, 4, 4)
|
shape = (1, 4, 1, 4, 4)
|
||||||
noise = mx.random.normal(shape)
|
noise = mx.random.normal(shape)
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# T5 Encoder Tests
|
# T5 Encoder Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestT5LayerNorm:
|
class TestT5LayerNorm:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||||
|
|
||||||
norm = T5LayerNorm(64)
|
norm = T5LayerNorm(64)
|
||||||
x = mx.random.normal((2, 10, 64))
|
x = mx.random.normal((2, 10, 64))
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -22,6 +22,7 @@ class TestT5LayerNorm:
|
|||||||
def test_rms_normalization(self):
|
def test_rms_normalization(self):
|
||||||
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||||
|
|
||||||
norm = T5LayerNorm(128)
|
norm = T5LayerNorm(128)
|
||||||
x = mx.random.normal((1, 5, 128)) * 5.0
|
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -35,6 +36,7 @@ class TestT5LayerNorm:
|
|||||||
class TestT5RelativeEmbedding:
|
class TestT5RelativeEmbedding:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
|
||||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||||
out = rel_emb(10, 10)
|
out = rel_emb(10, 10)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -42,6 +44,7 @@ class TestT5RelativeEmbedding:
|
|||||||
|
|
||||||
def test_asymmetric_lengths(self):
|
def test_asymmetric_lengths(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
|
||||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||||
out = rel_emb(8, 12)
|
out = rel_emb(8, 12)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -50,6 +53,7 @@ class TestT5RelativeEmbedding:
|
|||||||
def test_symmetry(self):
|
def test_symmetry(self):
|
||||||
"""Position bias should have structure (not all zeros/random)."""
|
"""Position bias should have structure (not all zeros/random)."""
|
||||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
|
||||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||||
out = rel_emb(6, 6)
|
out = rel_emb(6, 6)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -64,6 +68,7 @@ class TestT5RelativeEmbedding:
|
|||||||
class TestT5Attention:
|
class TestT5Attention:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Attention
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
|
||||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
x = mx.random.normal((1, 10, 64))
|
x = mx.random.normal((1, 10, 64))
|
||||||
out = attn(x)
|
out = attn(x)
|
||||||
@@ -73,12 +78,14 @@ class TestT5Attention:
|
|||||||
def test_no_scaling(self):
|
def test_no_scaling(self):
|
||||||
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||||
from mlx_video.models.wan.text_encoder import T5Attention
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
|
||||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
# No scale attribute (unlike standard attention)
|
# No scale attribute (unlike standard attention)
|
||||||
assert not hasattr(attn, "scale")
|
assert not hasattr(attn, "scale")
|
||||||
|
|
||||||
def test_with_position_bias(self):
|
def test_with_position_bias(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
||||||
|
|
||||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
rel_emb = T5RelativeEmbedding(32, 4)
|
rel_emb = T5RelativeEmbedding(32, 4)
|
||||||
x = mx.random.normal((1, 10, 64))
|
x = mx.random.normal((1, 10, 64))
|
||||||
@@ -89,6 +96,7 @@ class TestT5Attention:
|
|||||||
|
|
||||||
def test_with_mask(self):
|
def test_with_mask(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Attention
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
|
||||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
x = mx.random.normal((1, 10, 64))
|
x = mx.random.normal((1, 10, 64))
|
||||||
mask = mx.ones((1, 10))
|
mask = mx.ones((1, 10))
|
||||||
@@ -101,6 +109,7 @@ class TestT5Attention:
|
|||||||
class TestT5FeedForward:
|
class TestT5FeedForward:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||||
|
|
||||||
ffn = T5FeedForward(64, 256)
|
ffn = T5FeedForward(64, 256)
|
||||||
x = mx.random.normal((1, 10, 64))
|
x = mx.random.normal((1, 10, 64))
|
||||||
out = ffn(x)
|
out = ffn(x)
|
||||||
@@ -110,6 +119,7 @@ class TestT5FeedForward:
|
|||||||
def test_gated_structure(self):
|
def test_gated_structure(self):
|
||||||
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||||
|
|
||||||
ffn = T5FeedForward(32, 64)
|
ffn = T5FeedForward(32, 64)
|
||||||
assert hasattr(ffn, "gate_proj")
|
assert hasattr(ffn, "gate_proj")
|
||||||
assert hasattr(ffn, "fc1")
|
assert hasattr(ffn, "fc1")
|
||||||
@@ -122,9 +132,16 @@ class TestT5Encoder:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
encoder = T5Encoder(
|
encoder = T5Encoder(
|
||||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
vocab_size=100,
|
||||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
dim=64,
|
||||||
|
dim_attn=64,
|
||||||
|
dim_ffn=128,
|
||||||
|
num_heads=4,
|
||||||
|
num_layers=2,
|
||||||
|
num_buckets=32,
|
||||||
|
shared_pos=False,
|
||||||
)
|
)
|
||||||
ids = mx.array([[1, 5, 10, 0, 0]])
|
ids = mx.array([[1, 5, 10, 0, 0]])
|
||||||
mask = mx.array([[1, 1, 1, 0, 0]])
|
mask = mx.array([[1, 1, 1, 0, 0]])
|
||||||
@@ -134,9 +151,16 @@ class TestT5Encoder:
|
|||||||
|
|
||||||
def test_shared_pos(self):
|
def test_shared_pos(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
encoder = T5Encoder(
|
encoder = T5Encoder(
|
||||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
vocab_size=100,
|
||||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
|
dim=64,
|
||||||
|
dim_attn=64,
|
||||||
|
dim_ffn=128,
|
||||||
|
num_heads=4,
|
||||||
|
num_layers=2,
|
||||||
|
num_buckets=32,
|
||||||
|
shared_pos=True,
|
||||||
)
|
)
|
||||||
assert encoder.pos_embedding is not None
|
assert encoder.pos_embedding is not None
|
||||||
for block in encoder.blocks:
|
for block in encoder.blocks:
|
||||||
@@ -144,9 +168,16 @@ class TestT5Encoder:
|
|||||||
|
|
||||||
def test_per_layer_pos(self):
|
def test_per_layer_pos(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
encoder = T5Encoder(
|
encoder = T5Encoder(
|
||||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
vocab_size=100,
|
||||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
dim=64,
|
||||||
|
dim_attn=64,
|
||||||
|
dim_ffn=128,
|
||||||
|
num_heads=4,
|
||||||
|
num_layers=2,
|
||||||
|
num_buckets=32,
|
||||||
|
shared_pos=False,
|
||||||
)
|
)
|
||||||
assert encoder.pos_embedding is None
|
assert encoder.pos_embedding is None
|
||||||
for block in encoder.blocks:
|
for block in encoder.blocks:
|
||||||
@@ -154,18 +185,32 @@ class TestT5Encoder:
|
|||||||
|
|
||||||
def test_param_count(self):
|
def test_param_count(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
encoder = T5Encoder(
|
encoder = T5Encoder(
|
||||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
vocab_size=100,
|
||||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
dim=64,
|
||||||
|
dim_attn=64,
|
||||||
|
dim_ffn=128,
|
||||||
|
num_heads=4,
|
||||||
|
num_layers=2,
|
||||||
|
num_buckets=32,
|
||||||
|
shared_pos=False,
|
||||||
)
|
)
|
||||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
||||||
assert num_params > 0
|
assert num_params > 0
|
||||||
|
|
||||||
def test_without_mask(self):
|
def test_without_mask(self):
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
encoder = T5Encoder(
|
encoder = T5Encoder(
|
||||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
vocab_size=100,
|
||||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
dim=64,
|
||||||
|
dim_attn=64,
|
||||||
|
dim_ffn=128,
|
||||||
|
num_heads=4,
|
||||||
|
num_layers=2,
|
||||||
|
num_buckets=32,
|
||||||
|
shared_pos=False,
|
||||||
)
|
)
|
||||||
ids = mx.array([[1, 5, 10]])
|
ids = mx.array([[1, 5, 10]])
|
||||||
out = encoder(ids)
|
out = encoder(ids)
|
||||||
|
|||||||
@@ -2,13 +2,11 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.tiling import (
|
from mlx_video.models.ltx.video_vae.tiling import (
|
||||||
TilingConfig,
|
TilingConfig,
|
||||||
decode_with_tiling,
|
decode_with_tiling,
|
||||||
split_in_spatial,
|
split_in_spatial,
|
||||||
split_in_temporal,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
|
|||||||
|
|
||||||
# Causal: 1 + (4-1)*4 = 13
|
# Causal: 1 + (4-1)*4 = 13
|
||||||
out_causal = decode_with_tiling(
|
out_causal = decode_with_tiling(
|
||||||
dummy_decoder_causal, latents, config,
|
dummy_decoder_causal,
|
||||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
|
latents,
|
||||||
|
config,
|
||||||
|
spatial_scale=scale,
|
||||||
|
temporal_scale=scale,
|
||||||
|
causal_temporal=True,
|
||||||
)
|
)
|
||||||
mx.eval(out_causal)
|
mx.eval(out_causal)
|
||||||
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||||
|
|
||||||
# Non-causal: 4*4 = 16
|
# Non-causal: 4*4 = 16
|
||||||
out_noncausal = decode_with_tiling(
|
out_noncausal = decode_with_tiling(
|
||||||
dummy_decoder_noncausal, latents, config,
|
dummy_decoder_noncausal,
|
||||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
|
latents,
|
||||||
|
config,
|
||||||
|
spatial_scale=scale,
|
||||||
|
temporal_scale=scale,
|
||||||
|
causal_temporal=False,
|
||||||
)
|
)
|
||||||
mx.eval(out_noncausal)
|
mx.eval(out_noncausal)
|
||||||
assert out_noncausal.shape[2] == t * scale # 16
|
assert out_noncausal.shape[2] == t * scale # 16
|
||||||
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
|
|||||||
mx.eval(out_tiled)
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
# Both should produce the same shape
|
# Both should produce the same shape
|
||||||
assert out_regular.shape == out_tiled.shape, (
|
assert (
|
||||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
out_regular.shape == out_tiled.shape
|
||||||
)
|
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||||
|
|
||||||
def test_decode_tiled_falls_through_when_small(self):
|
def test_decode_tiled_falls_through_when_small(self):
|
||||||
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||||
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
|
|||||||
mx.eval(out_tiled)
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(out_regular), np.array(out_tiled),
|
np.array(out_regular),
|
||||||
rtol=1e-4, atol=1e-4,
|
np.array(out_tiled),
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-4,
|
||||||
err_msg="Tiled decode should match regular decode for small inputs",
|
err_msg="Tiled decode should match regular decode for small inputs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
|
|||||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||||
mx.eval(out_tiled)
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
assert out_regular.shape == out_tiled.shape, (
|
assert (
|
||||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
out_regular.shape == out_tiled.shape
|
||||||
)
|
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||||
|
|
||||||
def test_decode_tiled_falls_through_when_small(self):
|
def test_decode_tiled_falls_through_when_small(self):
|
||||||
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||||
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
|
|||||||
mx.eval(out_tiled)
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
np.array(out_regular), np.array(out_tiled),
|
np.array(out_regular),
|
||||||
rtol=1e-4, atol=1e-4,
|
np.array(out_tiled),
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-4,
|
||||||
err_msg="Tiled decode should match regular decode for small inputs",
|
err_msg="Tiled decode should match regular decode for small inputs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
|
|||||||
from mlx_video.models.wan.vae import Decoder3d
|
from mlx_video.models.wan.vae import Decoder3d
|
||||||
|
|
||||||
# Small decoder for fast test
|
# Small decoder for fast test
|
||||||
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
|
dec = Decoder3d(
|
||||||
temporal_upsample=[True, True, False])
|
dim=16,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 1, 1, 1],
|
||||||
|
num_res_blocks=1,
|
||||||
|
temporal_upsample=[True, True, False],
|
||||||
|
)
|
||||||
mx.eval(dec.parameters())
|
mx.eval(dec.parameters())
|
||||||
|
|
||||||
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||||
|
|||||||
@@ -2,16 +2,16 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Transformer Block Tests
|
# Transformer Block Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestWanFFN:
|
class TestWanFFN:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.transformer import WanFFN
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
|
||||||
ffn = WanFFN(64, 256)
|
ffn = WanFFN(64, 256)
|
||||||
x = mx.random.normal((2, 10, 64))
|
x = mx.random.normal((2, 10, 64))
|
||||||
out = ffn(x)
|
out = ffn(x)
|
||||||
@@ -21,6 +21,7 @@ class TestWanFFN:
|
|||||||
def test_gelu_activation(self):
|
def test_gelu_activation(self):
|
||||||
"""FFN should use GELU activation (non-linearity)."""
|
"""FFN should use GELU activation (non-linearity)."""
|
||||||
from mlx_video.models.wan.transformer import WanFFN
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
|
||||||
ffn = WanFFN(32, 128)
|
ffn = WanFFN(32, 128)
|
||||||
x = mx.ones((1, 1, 32)) * 2.0
|
x = mx.ones((1, 1, 32)) * 2.0
|
||||||
out1 = ffn(x)
|
out1 = ffn(x)
|
||||||
@@ -39,10 +40,13 @@ class TestWanAttentionBlock:
|
|||||||
self.num_heads = 4
|
self.num_heads = 4
|
||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(
|
block = WanAttentionBlock(
|
||||||
self.dim, self.ffn_dim, self.num_heads,
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
)
|
)
|
||||||
B, L = 1, 24
|
B, L = 1, 24
|
||||||
@@ -53,37 +57,49 @@ class TestWanAttentionBlock:
|
|||||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
|
||||||
out = block(
|
out = block(
|
||||||
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
|
x,
|
||||||
freqs=freqs, context=context,
|
e,
|
||||||
|
seq_lens=[L],
|
||||||
|
grid_sizes=[(F, H, W)],
|
||||||
|
freqs=freqs,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
assert out.shape == (B, L, self.dim)
|
assert out.shape == (B, L, self.dim)
|
||||||
|
|
||||||
def test_modulation_shape(self):
|
def test_modulation_shape(self):
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||||
assert block.modulation.shape == (1, 6, self.dim)
|
assert block.modulation.shape == (1, 6, self.dim)
|
||||||
|
|
||||||
def test_with_cross_attn_norm(self):
|
def test_with_cross_attn_norm(self):
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(
|
block = WanAttentionBlock(
|
||||||
self.dim, self.ffn_dim, self.num_heads,
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
)
|
)
|
||||||
assert block.norm3 is not None
|
assert block.norm3 is not None
|
||||||
|
|
||||||
def test_without_cross_attn_norm(self):
|
def test_without_cross_attn_norm(self):
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(
|
block = WanAttentionBlock(
|
||||||
self.dim, self.ffn_dim, self.num_heads,
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
)
|
)
|
||||||
assert block.norm3 is None
|
assert block.norm3 is None
|
||||||
|
|
||||||
def test_residual_connection(self):
|
def test_residual_connection(self):
|
||||||
"""Output should differ from zero even with small random init."""
|
"""Output should differ from zero even with small random init."""
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||||
B, L = 1, 8
|
B, L = 1, 8
|
||||||
F, H, W = 2, 2, 2
|
F, H, W = 2, 2, 2
|
||||||
@@ -102,6 +118,7 @@ class TestWanAttentionBlock:
|
|||||||
# Float32 Modulation Precision Tests
|
# Float32 Modulation Precision Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestFloat32Modulation:
|
class TestFloat32Modulation:
|
||||||
"""Tests that modulation/gate operations are computed in float32,
|
"""Tests that modulation/gate operations are computed in float32,
|
||||||
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
||||||
@@ -113,13 +130,15 @@ class TestFloat32Modulation:
|
|||||||
def test_block_modulation_in_float32(self):
|
def test_block_modulation_in_float32(self):
|
||||||
"""Modulation param starts random but should be usable as float32."""
|
"""Modulation param starts random but should be usable as float32."""
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||||
assert block.modulation.dtype == mx.float32
|
assert block.modulation.dtype == mx.float32
|
||||||
|
|
||||||
def test_block_output_float32_with_bf16_modulation_input(self):
|
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||||
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
||||||
from mlx_video.models.wan.rope import rope_params
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
|
||||||
block = WanAttentionBlock(self.dim, 128, 4)
|
block = WanAttentionBlock(self.dim, 128, 4)
|
||||||
B, L = 1, 8
|
B, L = 1, 8
|
||||||
x = mx.random.normal((B, L, self.dim))
|
x = mx.random.normal((B, L, self.dim))
|
||||||
@@ -135,6 +154,7 @@ class TestFloat32Modulation:
|
|||||||
def test_head_modulation_float32(self):
|
def test_head_modulation_float32(self):
|
||||||
"""Head modulation should be float32 even with bf16 e input."""
|
"""Head modulation should be float32 even with bf16 e input."""
|
||||||
from mlx_video.models.wan.model import Head
|
from mlx_video.models.wan.model import Head
|
||||||
|
|
||||||
head = Head(self.dim, 4, (1, 2, 2))
|
head = Head(self.dim, 4, (1, 2, 2))
|
||||||
x = mx.random.normal((1, 8, self.dim))
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
||||||
@@ -145,6 +165,7 @@ class TestFloat32Modulation:
|
|||||||
def test_model_time_embedding_float32(self):
|
def test_model_time_embedding_float32(self):
|
||||||
"""sinusoidal_embedding_1d output must be float32."""
|
"""sinusoidal_embedding_1d output must be float32."""
|
||||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
t = mx.array([500.0])
|
t = mx.array([500.0])
|
||||||
emb = sinusoidal_embedding_1d(256, t)
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
mx.eval(emb)
|
mx.eval(emb)
|
||||||
@@ -153,6 +174,7 @@ class TestFloat32Modulation:
|
|||||||
def test_model_per_token_time_embedding_float32(self):
|
def test_model_per_token_time_embedding_float32(self):
|
||||||
"""Per-token time embeddings (I2V) should also be float32."""
|
"""Per-token time embeddings (I2V) should also be float32."""
|
||||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||||
emb = sinusoidal_embedding_1d(256, t)
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
mx.eval(emb)
|
mx.eval(emb)
|
||||||
|
|||||||
@@ -4,16 +4,16 @@ import math
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# VAE 2.1 Tests
|
# VAE 2.1 Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestCausalConv3d:
|
class TestCausalConv3d:
|
||||||
def test_output_shape_stride1(self):
|
def test_output_shape_stride1(self):
|
||||||
from mlx_video.models.wan.vae import CausalConv3d
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||||
@@ -29,6 +29,7 @@ class TestCausalConv3d:
|
|||||||
|
|
||||||
def test_output_shape_kernel1(self):
|
def test_output_shape_kernel1(self):
|
||||||
from mlx_video.models.wan.vae import CausalConv3d
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||||
x = mx.random.normal((1, 4, 2, 4, 4))
|
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||||
@@ -39,6 +40,7 @@ class TestCausalConv3d:
|
|||||||
def test_causal_padding(self):
|
def test_causal_padding(self):
|
||||||
"""Causal conv should only use past/current frames, not future."""
|
"""Causal conv should only use past/current frames, not future."""
|
||||||
from mlx_video.models.wan.vae import CausalConv3d
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||||
conv.bias = mx.zeros((2,))
|
conv.bias = mx.zeros((2,))
|
||||||
@@ -55,6 +57,7 @@ class TestCausalConv3d:
|
|||||||
class TestResidualBlock:
|
class TestResidualBlock:
|
||||||
def test_same_dim(self):
|
def test_same_dim(self):
|
||||||
from mlx_video.models.wan.vae import ResidualBlock
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 8)
|
block = ResidualBlock(8, 8)
|
||||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -63,6 +66,7 @@ class TestResidualBlock:
|
|||||||
|
|
||||||
def test_different_dim(self):
|
def test_different_dim(self):
|
||||||
from mlx_video.models.wan.vae import ResidualBlock
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 16)
|
block = ResidualBlock(8, 16)
|
||||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -71,11 +75,13 @@ class TestResidualBlock:
|
|||||||
|
|
||||||
def test_shortcut_exists_when_dims_differ(self):
|
def test_shortcut_exists_when_dims_differ(self):
|
||||||
from mlx_video.models.wan.vae import ResidualBlock
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 16)
|
block = ResidualBlock(8, 16)
|
||||||
assert block.shortcut is not None
|
assert block.shortcut is not None
|
||||||
|
|
||||||
def test_no_shortcut_when_dims_same(self):
|
def test_no_shortcut_when_dims_same(self):
|
||||||
from mlx_video.models.wan.vae import ResidualBlock
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 8)
|
block = ResidualBlock(8, 8)
|
||||||
assert block.shortcut is None
|
assert block.shortcut is None
|
||||||
|
|
||||||
@@ -83,6 +89,7 @@ class TestResidualBlock:
|
|||||||
class TestAttentionBlock:
|
class TestAttentionBlock:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.vae import AttentionBlock
|
from mlx_video.models.wan.vae import AttentionBlock
|
||||||
|
|
||||||
block = AttentionBlock(8)
|
block = AttentionBlock(8)
|
||||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -91,6 +98,7 @@ class TestAttentionBlock:
|
|||||||
|
|
||||||
def test_residual_connection(self):
|
def test_residual_connection(self):
|
||||||
from mlx_video.models.wan.vae import AttentionBlock
|
from mlx_video.models.wan.vae import AttentionBlock
|
||||||
|
|
||||||
block = AttentionBlock(8)
|
block = AttentionBlock(8)
|
||||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -102,13 +110,15 @@ class TestAttentionBlock:
|
|||||||
class TestWanVAE:
|
class TestWanVAE:
|
||||||
def test_instantiation(self):
|
def test_instantiation(self):
|
||||||
from mlx_video.models.wan.vae import WanVAE
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
|
||||||
vae = WanVAE(z_dim=16)
|
vae = WanVAE(z_dim=16)
|
||||||
assert vae.z_dim == 16
|
assert vae.z_dim == 16
|
||||||
assert vae.mean.shape == (16,)
|
assert vae.mean.shape == (16,)
|
||||||
assert vae.std.shape == (16,)
|
assert vae.std.shape == (16,)
|
||||||
|
|
||||||
def test_normalization_stats(self):
|
def test_normalization_stats(self):
|
||||||
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
|
||||||
|
|
||||||
assert len(VAE_MEAN) == 16
|
assert len(VAE_MEAN) == 16
|
||||||
assert len(VAE_STD) == 16
|
assert len(VAE_STD) == 16
|
||||||
assert all(s > 0 for s in VAE_STD)
|
assert all(s > 0 for s in VAE_STD)
|
||||||
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
|
|||||||
|
|
||||||
def test_output_shape_k3(self):
|
def test_output_shape_k3(self):
|
||||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||||
out = conv(x)
|
out = conv(x)
|
||||||
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
|
|||||||
|
|
||||||
def test_output_shape_k1(self):
|
def test_output_shape_k1(self):
|
||||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = conv(x)
|
out = conv(x)
|
||||||
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
|
|||||||
def test_temporal_causal(self):
|
def test_temporal_causal(self):
|
||||||
"""Output at t=0 should not depend on t>0."""
|
"""Output at t=0 should not depend on t>0."""
|
||||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||||
conv.bias = mx.zeros(conv.bias.shape)
|
conv.bias = mx.zeros(conv.bias.shape)
|
||||||
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
|
|||||||
t0_ref = np.array(out_zero[0, 0])
|
t0_ref = np.array(out_zero[0, 0])
|
||||||
|
|
||||||
# Modify t=2..3; output at t=0 should be unchanged
|
# Modify t=2..3; output at t=0 should be unchanged
|
||||||
x_mod = mx.concatenate([
|
x_mod = mx.concatenate(
|
||||||
x[:, :2],
|
[
|
||||||
mx.ones((1, 2, 4, 4, 2)),
|
x[:, :2],
|
||||||
], axis=1)
|
mx.ones((1, 2, 4, 4, 2)),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
out_mod = conv(x_mod)
|
out_mod = conv(x_mod)
|
||||||
mx.eval(out_mod)
|
mx.eval(out_mod)
|
||||||
t0_mod = np.array(out_mod[0, 0])
|
t0_mod = np.array(out_mod[0, 0])
|
||||||
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
|
|||||||
def test_channels_last_format(self):
|
def test_channels_last_format(self):
|
||||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
|
||||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||||
out = conv(x)
|
out = conv(x)
|
||||||
@@ -175,6 +192,7 @@ class TestRMSNorm:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import RMS_norm
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
|
||||||
norm = RMS_norm(16)
|
norm = RMS_norm(16)
|
||||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||||
out = norm(x)
|
out = norm(x)
|
||||||
@@ -184,6 +202,7 @@ class TestRMSNorm:
|
|||||||
def test_l2_normalization(self):
|
def test_l2_normalization(self):
|
||||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||||
from mlx_video.models.wan.vae22 import RMS_norm
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
|
||||||
dim = 32
|
dim = 32
|
||||||
norm = RMS_norm(dim)
|
norm = RMS_norm(dim)
|
||||||
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||||
@@ -197,6 +216,7 @@ class TestRMSNorm:
|
|||||||
def test_scale_invariant(self):
|
def test_scale_invariant(self):
|
||||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||||
from mlx_video.models.wan.vae22 import RMS_norm
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
|
||||||
norm = RMS_norm(8)
|
norm = RMS_norm(8)
|
||||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||||
out1 = norm(x)
|
out1 = norm(x)
|
||||||
@@ -207,6 +227,7 @@ class TestRMSNorm:
|
|||||||
def test_gamma_effect(self):
|
def test_gamma_effect(self):
|
||||||
"""Non-unit gamma should scale output."""
|
"""Non-unit gamma should scale output."""
|
||||||
from mlx_video.models.wan.vae22 import RMS_norm
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
|
||||||
norm = RMS_norm(4)
|
norm = RMS_norm(4)
|
||||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||||
x = mx.ones((1, 1, 1, 1, 4))
|
x = mx.ones((1, 1, 1, 1, 4))
|
||||||
@@ -221,6 +242,7 @@ class TestDupUp3D:
|
|||||||
|
|
||||||
def test_spatial_only(self):
|
def test_spatial_only(self):
|
||||||
from mlx_video.models.wan.vae22 import DupUp3D
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
|
||||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
out = up(x)
|
out = up(x)
|
||||||
@@ -229,6 +251,7 @@ class TestDupUp3D:
|
|||||||
|
|
||||||
def test_temporal_and_spatial(self):
|
def test_temporal_and_spatial(self):
|
||||||
from mlx_video.models.wan.vae22 import DupUp3D
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
|
||||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||||
out = up(x)
|
out = up(x)
|
||||||
@@ -237,6 +260,7 @@ class TestDupUp3D:
|
|||||||
|
|
||||||
def test_first_chunk_trims(self):
|
def test_first_chunk_trims(self):
|
||||||
from mlx_video.models.wan.vae22 import DupUp3D
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
|
||||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
out_normal = up(x, first_chunk=False)
|
out_normal = up(x, first_chunk=False)
|
||||||
@@ -248,6 +272,7 @@ class TestDupUp3D:
|
|||||||
|
|
||||||
def test_no_temporal_first_chunk_noop(self):
|
def test_no_temporal_first_chunk_noop(self):
|
||||||
from mlx_video.models.wan.vae22 import DupUp3D
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
|
||||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
out_normal = up(x, first_chunk=False)
|
out_normal = up(x, first_chunk=False)
|
||||||
@@ -262,6 +287,7 @@ class TestVAE22Resample:
|
|||||||
|
|
||||||
def test_upsample2d_shape(self):
|
def test_upsample2d_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import Resample
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
|
||||||
r = Resample(8, "upsample2d")
|
r = Resample(8, "upsample2d")
|
||||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
@@ -271,6 +297,7 @@ class TestVAE22Resample:
|
|||||||
|
|
||||||
def test_upsample3d_shape(self):
|
def test_upsample3d_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import Resample
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
|
||||||
r = Resample(8, "upsample3d")
|
r = Resample(8, "upsample3d")
|
||||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
@@ -280,6 +307,7 @@ class TestVAE22Resample:
|
|||||||
|
|
||||||
def test_upsample3d_first_chunk(self):
|
def test_upsample3d_first_chunk(self):
|
||||||
from mlx_video.models.wan.vae22 import Resample
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
|
||||||
r = Resample(8, "upsample3d")
|
r = Resample(8, "upsample3d")
|
||||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
@@ -291,6 +319,7 @@ class TestVAE22Resample:
|
|||||||
def test_upsample3d_first_chunk_single_frame(self):
|
def test_upsample3d_first_chunk_single_frame(self):
|
||||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||||
from mlx_video.models.wan.vae22 import Resample
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
|
||||||
r = Resample(8, "upsample3d")
|
r = Resample(8, "upsample3d")
|
||||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
x = mx.random.normal((1, 1, 4, 4, 8))
|
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||||
@@ -308,6 +337,7 @@ class TestVAE22Resample:
|
|||||||
the first input frame (not on time_conv parameters).
|
the first input frame (not on time_conv parameters).
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.wan.vae22 import Resample
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
|
||||||
C = 8
|
C = 8
|
||||||
r = Resample(C, "upsample3d")
|
r = Resample(C, "upsample3d")
|
||||||
# Set time_conv weights to large values so its effect is detectable
|
# Set time_conv weights to large values so its effect is detectable
|
||||||
@@ -334,8 +364,9 @@ class TestVAE22Resample:
|
|||||||
# Compare first output frame to reference
|
# Compare first output frame to reference
|
||||||
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||||
mx.eval(first_out)
|
mx.eval(first_out)
|
||||||
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
assert mx.allclose(
|
||||||
"First frame should bypass time_conv and match spatial-only upsample"
|
first_out, ref, atol=1e-5
|
||||||
|
).item(), "First frame should bypass time_conv and match spatial-only upsample"
|
||||||
|
|
||||||
|
|
||||||
class TestVAE22ResidualBlock:
|
class TestVAE22ResidualBlock:
|
||||||
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
|
|||||||
|
|
||||||
def test_same_dim(self):
|
def test_same_dim(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 8)
|
block = ResidualBlock(8, 8)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
|
|||||||
|
|
||||||
def test_different_dim(self):
|
def test_different_dim(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 16)
|
block = ResidualBlock(8, 16)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
|
|||||||
|
|
||||||
def test_shortcut_when_dims_differ(self):
|
def test_shortcut_when_dims_differ(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 16)
|
block = ResidualBlock(8, 16)
|
||||||
assert block.shortcut is not None
|
assert block.shortcut is not None
|
||||||
|
|
||||||
def test_no_shortcut_same_dim(self):
|
def test_no_shortcut_same_dim(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
|
||||||
block = ResidualBlock(8, 8)
|
block = ResidualBlock(8, 8)
|
||||||
assert block.shortcut is None
|
assert block.shortcut is None
|
||||||
|
|
||||||
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
|
|||||||
def test_layer_names_no_underscore_prefix(self):
|
def test_layer_names_no_underscore_prefix(self):
|
||||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
|
||||||
block = ResidualBlockLayers(8, 8)
|
block = ResidualBlockLayers(8, 8)
|
||||||
params = dict(block.parameters())
|
params = dict(block.parameters())
|
||||||
# All param keys should use layer_N, not _layer_N
|
# All param keys should use layer_N, not _layer_N
|
||||||
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
|
|||||||
|
|
||||||
def test_has_expected_layers(self):
|
def test_has_expected_layers(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
|
||||||
block = ResidualBlockLayers(8, 16)
|
block = ResidualBlockLayers(8, 16)
|
||||||
assert hasattr(block, "layer_0") # first RMS_norm
|
assert hasattr(block, "layer_0") # first RMS_norm
|
||||||
assert hasattr(block, "layer_2") # first CausalConv3d
|
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||||
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
|
|||||||
|
|
||||||
def test_forward_shape(self):
|
def test_forward_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
|
||||||
block = ResidualBlockLayers(8, 16)
|
block = ResidualBlockLayers(8, 16)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||||
|
|
||||||
block = AttentionBlock(16)
|
block = AttentionBlock(16)
|
||||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||||
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||||
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
|
|||||||
|
|
||||||
def test_residual_connection(self):
|
def test_residual_connection(self):
|
||||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||||
|
|
||||||
block = AttentionBlock(8)
|
block = AttentionBlock(8)
|
||||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||||
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||||
@@ -427,6 +467,7 @@ class TestHead22:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import Head22
|
from mlx_video.models.wan.vae22 import Head22
|
||||||
|
|
||||||
head = Head22(16, out_channels=12)
|
head = Head22(16, out_channels=12)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||||
out = head(x)
|
out = head(x)
|
||||||
@@ -436,6 +477,7 @@ class TestHead22:
|
|||||||
def test_layer_names_no_underscore(self):
|
def test_layer_names_no_underscore(self):
|
||||||
"""Head layers must not use underscore prefix."""
|
"""Head layers must not use underscore prefix."""
|
||||||
from mlx_video.models.wan.vae22 import Head22
|
from mlx_video.models.wan.vae22 import Head22
|
||||||
|
|
||||||
head = Head22(8)
|
head = Head22(8)
|
||||||
assert hasattr(head, "layer_0") # RMS_norm
|
assert hasattr(head, "layer_0") # RMS_norm
|
||||||
assert hasattr(head, "layer_2") # CausalConv3d
|
assert hasattr(head, "layer_2") # CausalConv3d
|
||||||
@@ -449,6 +491,7 @@ class TestUnpatchify:
|
|||||||
|
|
||||||
def test_basic_shape(self):
|
def test_basic_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import _unpatchify
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
|
||||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||||
out = _unpatchify(x, patch_size=2)
|
out = _unpatchify(x, patch_size=2)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -456,6 +499,7 @@ class TestUnpatchify:
|
|||||||
|
|
||||||
def test_patch_size_1_noop(self):
|
def test_patch_size_1_noop(self):
|
||||||
from mlx_video.models.wan.vae22 import _unpatchify
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
|
||||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||||
out = _unpatchify(x, patch_size=1)
|
out = _unpatchify(x, patch_size=1)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -464,6 +508,7 @@ class TestUnpatchify:
|
|||||||
def test_preserves_content(self):
|
def test_preserves_content(self):
|
||||||
"""Unpatchify should be a lossless rearrangement."""
|
"""Unpatchify should be a lossless rearrangement."""
|
||||||
from mlx_video.models.wan.vae22 import _unpatchify
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
|
||||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||||
out = _unpatchify(x, patch_size=2)
|
out = _unpatchify(x, patch_size=2)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
|
|||||||
|
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||||
out = denormalize_latents(z)
|
out = denormalize_latents(z)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
|
|||||||
|
|
||||||
def test_custom_mean_std(self):
|
def test_custom_mean_std(self):
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
z = mx.ones((1, 1, 1, 1, 4))
|
z = mx.ones((1, 1, 1, 1, 4))
|
||||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||||
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||||
out = denormalize_latents(z, mean=mean, std=std)
|
out = denormalize_latents(z, mean=mean, std=std)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||||
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
np.testing.assert_allclose(
|
||||||
|
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
|
||||||
|
)
|
||||||
|
|
||||||
def test_uses_default_constants(self):
|
def test_uses_default_constants(self):
|
||||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
from mlx_video.models.wan.vae22 import (
|
||||||
|
VAE22_MEAN,
|
||||||
|
denormalize_latents,
|
||||||
|
)
|
||||||
|
|
||||||
# Should not raise with default constants
|
# Should not raise with default constants
|
||||||
z = mx.zeros((1, 1, 1, 1, 48))
|
z = mx.zeros((1, 1, 1, 1, 48))
|
||||||
out = denormalize_latents(z)
|
out = denormalize_latents(z)
|
||||||
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
|
|||||||
|
|
||||||
def test_dimensions(self):
|
def test_dimensions(self):
|
||||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||||
|
|
||||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||||
assert VAE22_MEAN.shape == (48,)
|
assert VAE22_MEAN.shape == (48,)
|
||||||
assert VAE22_STD.shape == (48,)
|
assert VAE22_STD.shape == (48,)
|
||||||
|
|
||||||
def test_std_positive(self):
|
def test_std_positive(self):
|
||||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||||
|
|
||||||
mx.eval(VAE22_STD)
|
mx.eval(VAE22_STD)
|
||||||
assert (np.array(VAE22_STD) > 0).all()
|
assert (np.array(VAE22_STD) > 0).all()
|
||||||
|
|
||||||
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
|
|||||||
def test_output_shape_small(self):
|
def test_output_shape_small(self):
|
||||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
|
||||||
# Use very small dims to keep test fast
|
# Use very small dims to keep test fast
|
||||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||||
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||||
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
|
|||||||
|
|
||||||
def test_output_clipped(self):
|
def test_output_clipped(self):
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
|
||||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||||
out = dec(z)
|
out = dec(z)
|
||||||
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_skip_encoder(self):
|
def test_skip_encoder(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"encoder.layer.weight": mx.zeros((4,)),
|
"encoder.layer.weight": mx.zeros((4,)),
|
||||||
"conv1.weight": mx.zeros((4,)),
|
"conv1.weight": mx.zeros((4,)),
|
||||||
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_sequential_index_remapping(self):
|
def test_sequential_index_remapping(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||||
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||||
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_resample_conv_remapping(self):
|
def test_resample_conv_remapping(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||||
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||||
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_attention_remapping(self):
|
def test_attention_remapping(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
weights = {
|
weights = {
|
||||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||||
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||||
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_conv3d_transpose(self):
|
def test_conv3d_transpose(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||||
w = mx.zeros((16, 8, 3, 3, 3))
|
w = mx.zeros((16, 8, 3, 3, 3))
|
||||||
weights = {"decoder.conv1.weight": w}
|
weights = {"decoder.conv1.weight": w}
|
||||||
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_conv2d_transpose(self):
|
def test_conv2d_transpose(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||||
w = mx.zeros((8, 8, 3, 3))
|
w = mx.zeros((8, 8, 3, 3))
|
||||||
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||||
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
|
|||||||
|
|
||||||
def test_gamma_squeeze(self):
|
def test_gamma_squeeze(self):
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||||
w = mx.ones((16, 1, 1, 1))
|
w = mx.ones((16, 1, 1, 1))
|
||||||
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||||
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
|
|||||||
|
|
||||||
def test_no_upsample(self):
|
def test_no_upsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
|
||||||
|
block = Up_ResidualBlock(
|
||||||
|
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
|
|||||||
|
|
||||||
def test_spatial_upsample(self):
|
def test_spatial_upsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
|
||||||
|
block = Up_ResidualBlock(
|
||||||
|
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
|
|||||||
|
|
||||||
def test_spatial_temporal_upsample(self):
|
def test_spatial_temporal_upsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
|
||||||
|
block = Up_ResidualBlock(
|
||||||
|
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
|
|||||||
def test_no_downsample(self):
|
def test_no_downsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
block = Down_ResidualBlock(
|
||||||
|
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
|
|||||||
def test_spatial_downsample(self):
|
def test_spatial_downsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
block = Down_ResidualBlock(
|
||||||
|
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
|
|||||||
def test_spatial_temporal_downsample(self):
|
def test_spatial_temporal_downsample(self):
|
||||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
block = Down_ResidualBlock(
|
||||||
|
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
|
||||||
|
)
|
||||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||||
out = block(x)
|
out = block(x)
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
|
|||||||
def test_encoder_temporal_downsample_pattern(self):
|
def test_encoder_temporal_downsample_pattern(self):
|
||||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||||
from mlx_video.models.wan.vae22 import Encoder3d
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
|
|
||||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||||
mx.eval(enc.parameters())
|
mx.eval(enc.parameters())
|
||||||
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
|
|||||||
|
|
||||||
def test_wrapper_uses_correct_pattern(self):
|
def test_wrapper_uses_correct_pattern(self):
|
||||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
|
||||||
|
|
||||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
down_blocks = enc.encoder.downsamples
|
down_blocks = enc.encoder.downsamples
|
||||||
found_modes = []
|
found_modes = []
|
||||||
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
|
|||||||
def test_single_frame_encoder(self):
|
def test_single_frame_encoder(self):
|
||||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||||
|
|
||||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||||
mx.eval(enc.parameters())
|
mx.eval(enc.parameters())
|
||||||
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
|
|||||||
def test_wrong_order_gives_different_result(self):
|
def test_wrong_order_gives_different_result(self):
|
||||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||||
from mlx_video.models.wan.vae22 import Encoder3d
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
|
||||||
|
enc_correct = Encoder3d(
|
||||||
|
dim=16, z_dim=8, temperal_downsample=(False, True, True)
|
||||||
|
)
|
||||||
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||||
|
|
||||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||||
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
|
|||||||
z_dim = 4
|
z_dim = 4
|
||||||
dim = 8
|
dim = 8
|
||||||
# No temporal up/downsampling to keep the test simple
|
# No temporal up/downsampling to keep the test simple
|
||||||
enc = Encoder3d(
|
enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
|
||||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
|
||||||
)
|
|
||||||
dec = Decoder3d(
|
|
||||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
|
||||||
)
|
|
||||||
mx.eval(enc.parameters(), dec.parameters())
|
mx.eval(enc.parameters(), dec.parameters())
|
||||||
|
|
||||||
# [B=1, C=3, T=1, H=8, W=8]
|
# [B=1, C=3, T=1, H=8, W=8]
|
||||||
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
|
|||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
|
|
||||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||||
assert out.shape[0] == 1 # batch
|
assert out.shape[0] == 1 # batch
|
||||||
assert out.shape[2] == 32 # H recovered
|
assert out.shape[2] == 32 # H recovered
|
||||||
assert out.shape[3] == 32 # W recovered
|
assert out.shape[3] == 32 # W recovered
|
||||||
assert out.shape[-1] == 3 # RGB
|
assert out.shape[-1] == 3 # RGB
|
||||||
|
|
||||||
out_np = np.array(out)
|
out_np = np.array(out)
|
||||||
assert np.all(np.isfinite(out_np))
|
assert np.all(np.isfinite(out_np))
|
||||||
assert out_np.min() >= -1.0 - 1e-6
|
assert out_np.min() >= -1.0 - 1e-6
|
||||||
assert out_np.max() <= 1.0 + 1e-6
|
assert out_np.max() <= 1.0 + 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
def _make_tiny_config():
|
def _make_tiny_config():
|
||||||
"""Create a tiny WanModelConfig for testing."""
|
"""Create a tiny WanModelConfig for testing."""
|
||||||
from mlx_video.models.wan.config import WanModelConfig
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
# Override to tiny values
|
# Override to tiny values
|
||||||
config.dim = 64
|
config.dim = 64
|
||||||
|
|||||||
27
uv.lock
generated
27
uv.lock
generated
@@ -622,6 +622,18 @@ http = [
|
|||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ftfy"
|
||||||
|
version = "6.3.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "wcwidth" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h11"
|
name = "h11"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
@@ -996,7 +1008,10 @@ wheels = [
|
|||||||
name = "mlx-video"
|
name = "mlx-video"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "ftfy" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "imageio" },
|
||||||
|
{ name = "imageio-ffmpeg" },
|
||||||
{ name = "librosa" },
|
{ name = "librosa" },
|
||||||
{ name = "mlx" },
|
{ name = "mlx" },
|
||||||
{ name = "mlx-vlm" },
|
{ name = "mlx-vlm" },
|
||||||
@@ -1016,7 +1031,10 @@ dev = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "ftfy" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "imageio", specifier = ">=2.37.2" },
|
||||||
|
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
|
||||||
{ name = "librosa", specifier = ">=0.10.0" },
|
{ name = "librosa", specifier = ">=0.10.0" },
|
||||||
{ name = "mlx", specifier = ">=0.22.0" },
|
{ name = "mlx", specifier = ">=0.22.0" },
|
||||||
{ name = "mlx-vlm" },
|
{ name = "mlx-vlm" },
|
||||||
@@ -2509,6 +2527,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" },
|
{ url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wcwidth"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xxhash"
|
name = "xxhash"
|
||||||
version = "3.6.0"
|
version = "3.6.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user