Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -250,6 +250,18 @@ class LTX2VideoDecoder(nn.Module):
|
||||
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||
"""
|
||||
|
||||
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
|
||||
# stride is (D, H, W) tuple
|
||||
DEFAULT_BLOCKS = [
|
||||
("res", 1024, 5),
|
||||
("d2s", 1024, 2, (2, 2, 2)),
|
||||
("res", 512, 5),
|
||||
("d2s", 512, 2, (2, 2, 2)),
|
||||
("res", 256, 5),
|
||||
("d2s", 256, 2, (2, 2, 2)),
|
||||
("res", 128, 5),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
@@ -258,6 +270,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
num_layers_per_block: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = True,
|
||||
decoder_blocks: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -272,13 +285,17 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
blocks = decoder_blocks or self.DEFAULT_BLOCKS
|
||||
first_ch = blocks[0][1]
|
||||
last_ch = blocks[-1][1]
|
||||
|
||||
# Initial conv: in_channels -> first block channels
|
||||
class ConvInWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=1024,
|
||||
out_channels=first_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
@@ -288,45 +305,32 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.up_blocks = {
|
||||
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
1: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=1024,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
3: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=512,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
5: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
}
|
||||
# Build up blocks from config
|
||||
self.up_blocks = {}
|
||||
for idx, block_def in enumerate(blocks):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
self.up_blocks[idx] = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=ch,
|
||||
stride=stride,
|
||||
residual=True,
|
||||
out_channels_reduction_factor=reduction,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=128,
|
||||
in_channels=last_ch,
|
||||
out_channels=final_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -342,9 +346,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=128 * 2 # 256, matches (2, 128) table
|
||||
embedding_dim=last_ch * 2
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
self.last_scale_shift_table = mx.zeros((2, last_ch))
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
@@ -418,11 +422,96 @@ class LTX2VideoDecoder(nn.Module):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
model = cls(
|
||||
timestep_conditioning=config_dict.get("timestep_conditioning", False),
|
||||
decoder_blocks=decoder_blocks,
|
||||
)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _infer_blocks(weights: dict) -> list:
|
||||
"""Infer decoder block structure from weight keys."""
|
||||
block_indices = set()
|
||||
for k in weights:
|
||||
if "up_blocks." in k:
|
||||
idx_str = k.split("up_blocks.")[1].split(".")[0]
|
||||
if idx_str.isdigit():
|
||||
block_indices.add(int(idx_str))
|
||||
|
||||
if not block_indices:
|
||||
return None
|
||||
|
||||
# First pass: collect block info
|
||||
raw_blocks = []
|
||||
for idx in sorted(block_indices):
|
||||
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
|
||||
res_indices = set()
|
||||
for k in weights:
|
||||
prefix = f"up_blocks.{idx}.res_blocks."
|
||||
if prefix in k:
|
||||
res_idx = k.split(prefix)[1].split(".")[0]
|
||||
if res_idx.isdigit():
|
||||
res_indices.add(int(res_idx))
|
||||
|
||||
if has_conv and not res_indices:
|
||||
# D2S block - get conv shape
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.conv." in k and "weight" in k:
|
||||
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
|
||||
conv_out_ch = v.shape[0]
|
||||
raw_blocks.append(("d2s", in_ch, conv_out_ch))
|
||||
break
|
||||
elif res_indices:
|
||||
num_res = max(res_indices) + 1
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
|
||||
ch = v.shape[0]
|
||||
raw_blocks.append(("res", ch, num_res))
|
||||
break
|
||||
|
||||
# Second pass: determine d2s strides using the channel progression
|
||||
# For each d2s block, the next res block tells us the expected output channels
|
||||
blocks = []
|
||||
for i, block in enumerate(raw_blocks):
|
||||
if block[0] == "res":
|
||||
blocks.append(block)
|
||||
elif block[0] == "d2s":
|
||||
in_ch, conv_out_ch = block[1], block[2]
|
||||
# Find next res block's channels
|
||||
next_ch = None
|
||||
for j in range(i + 1, len(raw_blocks)):
|
||||
if raw_blocks[j][0] == "res":
|
||||
next_ch = raw_blocks[j][1]
|
||||
break
|
||||
|
||||
if next_ch is None:
|
||||
next_ch = in_ch // 2 # fallback
|
||||
|
||||
# out_ch = in_ch // reduction
|
||||
reduction = in_ch // next_ch if next_ch > 0 else 2
|
||||
|
||||
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
|
||||
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
|
||||
|
||||
# Determine stride from multiplier
|
||||
if multiplier == 8:
|
||||
stride = (2, 2, 2)
|
||||
elif multiplier == 4:
|
||||
stride = (1, 2, 2)
|
||||
elif multiplier == 2:
|
||||
stride = (2, 1, 1)
|
||||
else:
|
||||
stride = (2, 2, 2)
|
||||
|
||||
blocks.append(("d2s", in_ch, reduction, stride))
|
||||
|
||||
return blocks if blocks else None
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user