Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -250,6 +250,18 @@ class LTX2VideoDecoder(nn.Module):
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
# stride is (D, H, W) tuple
DEFAULT_BLOCKS = [
("res", 1024, 5),
("d2s", 1024, 2, (2, 2, 2)),
("res", 512, 5),
("d2s", 512, 2, (2, 2, 2)),
("res", 256, 5),
("d2s", 256, 2, (2, 2, 2)),
("res", 128, 5),
]
def __init__(
self,
in_channels: int = 128,
@@ -258,6 +270,7 @@ class LTX2VideoDecoder(nn.Module):
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
decoder_blocks: list = None,
):
super().__init__()
@@ -272,13 +285,17 @@ class LTX2VideoDecoder(nn.Module):
# Per-channel statistics for denormalization (loaded from weights)
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024
blocks = decoder_blocks or self.DEFAULT_BLOCKS
first_ch = blocks[0][1]
last_ch = blocks[-1][1]
# Initial conv: in_channels -> first block channels
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=1024,
out_channels=first_ch,
kernel_size=3,
stride=1,
padding=1,
@@ -288,45 +305,32 @@ class LTX2VideoDecoder(nn.Module):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
# Use dict with int keys for MLX to track parameters properly
self.up_blocks = {
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
1: DepthToSpaceUpsample(
dims=3,
in_channels=1024,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
3: DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
5: DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
}
# Build up blocks from config
self.up_blocks = {}
for idx, block_def in enumerate(blocks):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3,
in_channels=ch,
stride=stride,
residual=True,
out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode,
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=128,
in_channels=last_ch,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
@@ -342,9 +346,9 @@ class LTX2VideoDecoder(nn.Module):
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=128 * 2 # 256, matches (2, 128) table
embedding_dim=last_ch * 2
)
self.last_scale_shift_table = mx.zeros((2, 128))
self.last_scale_shift_table = mx.zeros((2, last_ch))
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
@@ -418,11 +422,96 @@ class LTX2VideoDecoder(nn.Module):
weights.update(mx.load(str(wf)))
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
@staticmethod
def _infer_blocks(weights: dict) -> list:
"""Infer decoder block structure from weight keys."""
block_indices = set()
for k in weights:
if "up_blocks." in k:
idx_str = k.split("up_blocks.")[1].split(".")[0]
if idx_str.isdigit():
block_indices.add(int(idx_str))
if not block_indices:
return None
# First pass: collect block info
raw_blocks = []
for idx in sorted(block_indices):
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
res_indices = set()
for k in weights:
prefix = f"up_blocks.{idx}.res_blocks."
if prefix in k:
res_idx = k.split(prefix)[1].split(".")[0]
if res_idx.isdigit():
res_indices.add(int(res_idx))
if has_conv and not res_indices:
# D2S block - get conv shape
for k, v in weights.items():
if f"up_blocks.{idx}.conv." in k and "weight" in k:
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
conv_out_ch = v.shape[0]
raw_blocks.append(("d2s", in_ch, conv_out_ch))
break
elif res_indices:
num_res = max(res_indices) + 1
for k, v in weights.items():
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
ch = v.shape[0]
raw_blocks.append(("res", ch, num_res))
break
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
elif block[0] == "d2s":
in_ch, conv_out_ch = block[1], block[2]
# Find next res block's channels
next_ch = None
for j in range(i + 1, len(raw_blocks)):
if raw_blocks[j][0] == "res":
next_ch = raw_blocks[j][1]
break
if next_ch is None:
next_ch = in_ch // 2 # fallback
# out_ch = in_ch // reduction
reduction = in_ch // next_ch if next_ch > 0 else 2
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
# Determine stride from multiplier
if multiplier == 8:
stride = (2, 2, 2)
elif multiplier == 4:
stride = (1, 2, 2)
elif multiplier == 2:
stride = (2, 1, 1)
else:
stride = (2, 2, 2)
blocks.append(("d2s", in_ch, reduction, stride))
return blocks if blocks else None
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: