Add LTXModel with a from_pretrained class method for loading model weights from a specified path. Update weight sanitization to handle positional embeddings and dtype consistency. Refactor timestep and context preparation methods to accept hidden_dtype, improving flexibility in model processing.
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
from mlx_video.models.ltx.config import (
|
from mlx_video.models.ltx.config import (
|
||||||
LTXModelConfig,
|
LTXModelConfig,
|
||||||
LTXModelType,
|
LTXModelType,
|
||||||
@@ -52,10 +52,11 @@ class TransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
hidden_dtype: mx.Dtype = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||||
|
|
||||||
# Reshape to (batch, tokens, dim)
|
# Reshape to (batch, tokens, dim)
|
||||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||||
@@ -70,6 +71,9 @@ class TransformerArgsPreprocessor:
|
|||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# Context is already processed through embeddings connector in text encoder
|
||||||
|
# Here we just apply the caption projection
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
@@ -114,9 +118,14 @@ class TransformerArgsPreprocessor:
|
|||||||
|
|
||||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
x = self.patchify_proj(modality.latent)
|
x = self.patchify_proj(modality.latent)
|
||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||||
|
|
||||||
|
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||||
|
if modality.positional_embeddings is not None:
|
||||||
|
pe = modality.positional_embeddings
|
||||||
|
else:
|
||||||
pe = self._prepare_positional_embeddings(
|
pe = self._prepare_positional_embeddings(
|
||||||
positions=modality.positions,
|
positions=modality.positions,
|
||||||
inner_dim=self.inner_dim,
|
inner_dim=self.inner_dim,
|
||||||
@@ -198,6 +207,7 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep=modality.timesteps,
|
timestep=modality.timesteps,
|
||||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
batch_size=transformer_args.x.shape[0],
|
batch_size=transformer_args.x.shape[0],
|
||||||
|
hidden_dtype=transformer_args.x.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
return replace(
|
return replace(
|
||||||
@@ -212,15 +222,16 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
timestep_scale_multiplier: int,
|
timestep_scale_multiplier: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
hidden_dtype: mx.Dtype = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
timestep = timestep * timestep_scale_multiplier
|
timestep = timestep * timestep_scale_multiplier
|
||||||
|
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||||
|
|
||||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||||
|
|
||||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||||
|
|
||||||
return scale_shift_timestep, gate_timestep
|
return scale_shift_timestep, gate_timestep
|
||||||
@@ -282,6 +293,8 @@ class LTXModel(nn.Module):
|
|||||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||||
|
|
||||||
|
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=config.audio_caption_channels,
|
in_features=config.audio_caption_channels,
|
||||||
hidden_size=self.audio_inner_dim,
|
hidden_size=self.audio_inner_dim,
|
||||||
@@ -384,8 +397,9 @@ class LTXModel(nn.Module):
|
|||||||
video_config = config.get_video_config()
|
video_config = config.get_video_config()
|
||||||
audio_config = config.get_audio_config()
|
audio_config = config.get_audio_config()
|
||||||
|
|
||||||
self.transformer_blocks = [
|
|
||||||
BasicAVTransformerBlock(
|
self.transformer_blocks = {
|
||||||
|
idx: BasicAVTransformerBlock(
|
||||||
idx=idx,
|
idx=idx,
|
||||||
video=video_config,
|
video=video_config,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
@@ -393,7 +407,7 @@ class LTXModel(nn.Module):
|
|||||||
norm_eps=config.norm_eps,
|
norm_eps=config.norm_eps,
|
||||||
)
|
)
|
||||||
for idx in range(config.num_layers)
|
for idx in range(config.num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -401,7 +415,7 @@ class LTXModel(nn.Module):
|
|||||||
audio: Optional[TransformerArgs],
|
audio: Optional[TransformerArgs],
|
||||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
"""Process through all transformer blocks."""
|
"""Process through all transformer blocks."""
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks.values():
|
||||||
video, audio = block(video=video, audio=audio)
|
video, audio = block(video=video, audio=audio)
|
||||||
return video, audio
|
return video, audio
|
||||||
|
|
||||||
@@ -483,19 +497,58 @@ class LTXModel(nn.Module):
|
|||||||
|
|
||||||
def sanitize(self, weights: dict) -> dict:
|
def sanitize(self, weights: dict) -> dict:
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
|
if "model.diffusion_model." not in weights:
|
||||||
|
return weights
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
|
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||||
|
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove 'model.diffusion_model.' prefix
|
||||||
|
new_key = new_key.replace("model.diffusion_model.", "")
|
||||||
|
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
|
|
||||||
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||||
|
|
||||||
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||||
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||||
|
|
||||||
# Handle common remappings
|
|
||||||
# transformer_blocks.X -> transformer_blocks[X]
|
|
||||||
if "transformer_blocks." in new_key:
|
|
||||||
# Keep as-is for now, MLX handles this
|
|
||||||
pass
|
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel":
|
||||||
|
import json
|
||||||
|
|
||||||
|
config_dict = {}
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
config = LTXModelConfig(**config_dict)
|
||||||
|
model = cls(config)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
for weight_file in model_path.glob("*.safetensors"):
|
||||||
|
weights.update(mx.load(str(weight_file)))
|
||||||
|
|
||||||
|
|
||||||
|
sanitized = model.sanitize(weights)
|
||||||
|
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||||
|
|
||||||
|
model.load_weights(list(sanitized.items()), strict=strict)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class X0Model(nn.Module):
|
class X0Model(nn.Module):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user