This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
@@ -41,6 +42,7 @@ class LatentState:
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
@@ -130,15 +132,15 @@ def apply_conditioning(
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
latent_list.append(state.latent[:, :, i : i + 1])
clean_list.append(state.clean_latent[:, :, i : i + 1])
mask_list.append(state.denoise_mask[:, :, i : i + 1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)