format
This commit is contained in:
@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
@@ -41,6 +42,7 @@ class LatentState:
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
@@ -130,15 +132,15 @@ def apply_conditioning(
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
latent_list.append(state.latent[:, :, i : i + 1])
|
||||
clean_list.append(state.clean_latent[:, :, i : i + 1])
|
||||
mask_list.append(state.denoise_mask[:, :, i : i + 1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
|
||||
Reference in New Issue
Block a user