Add Dev Two-Stage HQ pipeline mode
This commit is contained in:
39
README.md
39
README.md
@@ -24,7 +24,7 @@ Supported models:
|
||||
## Features
|
||||
|
||||
- Text-to-video (T2V) and Image-to-video (I2V) generation
|
||||
- Three pipeline modes: Distilled, Dev, and Dev Two-Stage
|
||||
- Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ
|
||||
- Synchronized audio-video generation (experimental)
|
||||
- LoRA support (including HuggingFace repos)
|
||||
- Prompt enhancement via Gemma
|
||||
@@ -35,13 +35,14 @@ Supported models:
|
||||
|
||||
### Pipelines
|
||||
|
||||
mlx-video supports three pipeline types via the `--pipeline` flag:
|
||||
mlx-video supports four pipeline types via the `--pipeline` flag:
|
||||
|
||||
| Pipeline | Description | CFG | Stages | Speed |
|
||||
|----------|-------------|-----|--------|-------|
|
||||
| `distilled` (default) | Fixed sigma schedule, no CFG | No | 2 (8+3 steps) | Fastest |
|
||||
| `dev` | Dynamic sigmas, constant CFG | Yes | 1 (30 steps) | Medium |
|
||||
| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slowest, highest quality |
|
||||
| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slow |
|
||||
| `dev-two-stage-hq` | res_2s sampler + LoRA both stages | Yes (stage 1) | 2 (15+3 steps) | Slow, highest quality |
|
||||
|
||||
### Text-to-Video
|
||||
|
||||
@@ -52,13 +53,24 @@ uv run mlx_video.generate --prompt "Two dogs wearing sunglasses, cinematic, suns
|
||||
# Dev - single-stage with CFG
|
||||
uv run mlx_video.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
|
||||
|
||||
# Dev two-stage - dev + LoRA refinement (highest quality)
|
||||
# Dev two-stage - dev + LoRA refinement
|
||||
uv run mlx_video.generate --pipeline dev-two-stage \
|
||||
--prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" \
|
||||
-n 145 --width 1024 --height 768 \
|
||||
--model-repo prince-canuma/LTX-2-dev \
|
||||
--cfg-scale 3.0 --lora-strength 0.8 \
|
||||
--enhance-prompt
|
||||
|
||||
# Dev two-stage HQ - res_2s sampler, LoRA both stages (highest quality)
|
||||
uv run mlx_video.generate --pipeline dev-two-stage-hq \
|
||||
--prompt "A cinematic scene of ocean waves at golden hour" \
|
||||
--model-repo prince-canuma/LTX-2-dev
|
||||
|
||||
# HQ with custom LoRA strengths
|
||||
uv run mlx_video.generate --pipeline dev-two-stage-hq \
|
||||
--prompt "A sunset over mountains" \
|
||||
--model-repo prince-canuma/LTX-2-dev \
|
||||
--lora-strength-stage-1 0.3 --lora-strength-stage-2 0.6
|
||||
```
|
||||
|
||||
<img src="https://github.com/Blaizzy/mlx-video/raw/main/examples/poodles.gif" width="512" alt="Poodles demo">
|
||||
@@ -124,7 +136,7 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--prompt`, `-p` | (required) | Text description of the video |
|
||||
| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, or `dev-two-stage` |
|
||||
| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, `dev-two-stage`, or `dev-two-stage-hq` |
|
||||
| `--height`, `-H` | 512 | Output height (divisible by 64 for two-stage, 32 for dev) |
|
||||
| `--width`, `-W` | 512 | Output width (divisible by 64 for two-stage, 32 for dev) |
|
||||
| `--num-frames`, `-n` | 33 | Number of frames (must be 1 + 8*k) |
|
||||
@@ -161,6 +173,15 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
|
||||
| `--lora-path` | auto-detect | Path to LoRA file, directory, or HuggingFace repo |
|
||||
| `--lora-strength` | 1.0 | LoRA merge strength |
|
||||
|
||||
**Dev-Two-Stage HQ options:**
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--lora-strength-stage-1` | 0.25 | LoRA strength for stage 1 |
|
||||
| `--lora-strength-stage-2` | 0.5 | LoRA strength for stage 2 |
|
||||
|
||||
HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses the res_2s second-order sampler (2 model evals per step) for better quality at the same compute budget.
|
||||
|
||||
## How It Works
|
||||
|
||||
### Distilled Pipeline (default)
|
||||
@@ -179,6 +200,14 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
|
||||
3. **Stage 2**: Distilled refinement at full resolution with LoRA weights (3 steps, no CFG)
|
||||
4. **Decode**: VAE decoder converts latents to RGB video
|
||||
|
||||
### Dev Two-Stage HQ Pipeline
|
||||
1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step)
|
||||
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
|
||||
3. **Stage 2**: res_2s refinement at full resolution with LoRA@0.5 (3 steps, no CFG)
|
||||
4. **Decode**: VAE decoder converts latents to RGB video
|
||||
|
||||
The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations).
|
||||
|
||||
## Requirements
|
||||
|
||||
- macOS with Apple Silicon
|
||||
|
||||
@@ -38,6 +38,7 @@ class PipelineType(Enum):
|
||||
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
|
||||
DEV = "dev" # Single-stage, dynamic sigmas, CFG
|
||||
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
|
||||
DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages
|
||||
|
||||
|
||||
# Distilled model sigma schedules
|
||||
@@ -1012,6 +1013,329 @@ def denoise_dev_av(
|
||||
return video_latents, audio_latents
|
||||
|
||||
|
||||
def denoise_res2s_av(
|
||||
video_latents: mx.array,
|
||||
audio_latents: mx.array,
|
||||
video_positions: mx.array,
|
||||
audio_positions: mx.array,
|
||||
video_embeddings_pos: mx.array,
|
||||
video_embeddings_neg: mx.array,
|
||||
audio_embeddings_pos: mx.array,
|
||||
audio_embeddings_neg: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: mx.array,
|
||||
cfg_scale: float = 3.0,
|
||||
audio_cfg_scale: float = 7.0,
|
||||
cfg_rescale: float = 0.45,
|
||||
audio_cfg_rescale: Optional[float] = None,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
stg_scale: float = 0.0,
|
||||
stg_video_blocks: Optional[list] = None,
|
||||
stg_audio_blocks: Optional[list] = None,
|
||||
modality_scale: float = 1.0,
|
||||
noise_seed: int = 42,
|
||||
bongmath: bool = True,
|
||||
bongmath_max_iter: int = 100,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
|
||||
|
||||
Two model evaluations per step (current point + midpoint), with SDE noise
|
||||
injection and optional bong iteration for anchor refinement.
|
||||
|
||||
Args:
|
||||
audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale.
|
||||
noise_seed: Seed for SDE noise generators.
|
||||
bongmath: Enable iterative anchor refinement.
|
||||
bongmath_max_iter: Max bong iterations per step.
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
from mlx_video.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise
|
||||
|
||||
if audio_cfg_rescale is None:
|
||||
audio_cfg_rescale = cfg_rescale
|
||||
|
||||
dtype = video_latents.dtype
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
|
||||
video_latents = video_latents.astype(mx.float32)
|
||||
audio_latents = audio_latents.astype(mx.float32)
|
||||
|
||||
sigmas_list = sigmas.tolist()
|
||||
use_cfg = cfg_scale != 1.0
|
||||
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
|
||||
use_modality = modality_scale != 1.0
|
||||
n_full_steps = len(sigmas_list) - 1
|
||||
|
||||
# Pad sigmas if last is 0 (avoid division by zero in RK steps)
|
||||
if sigmas_list[-1] == 0:
|
||||
sigmas_list = sigmas_list[:-1] + [0.0011, 0.0]
|
||||
|
||||
# Compute step sizes in log-space for the main loop steps only.
|
||||
# After padding, sigmas_list may have an extra [0.0011, 0.0] tail;
|
||||
# we only need hs for the n_full_steps pairs the loop actually uses.
|
||||
hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)]
|
||||
|
||||
# Precompute RoPE
|
||||
precomputed_video_rope = precompute_freqs_cis(
|
||||
video_positions,
|
||||
dim=transformer.inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
precomputed_audio_rope = precompute_freqs_cis(
|
||||
audio_positions,
|
||||
dim=transformer.audio_inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.audio_num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
||||
|
||||
phi_cache = {}
|
||||
c2 = 0.5
|
||||
|
||||
# Noise key management: step noise and substep noise use different keys
|
||||
step_noise_key = mx.random.key(noise_seed)
|
||||
substep_noise_key = mx.random.key(noise_seed + 10000)
|
||||
|
||||
def _eval_guided_denoise(v_latents, a_latents, sigma):
|
||||
"""Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format."""
|
||||
b, c, f, h, w = v_latents.shape
|
||||
num_video_tokens = f * h * w
|
||||
video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
||||
|
||||
ab, ac, at, af = a_latents.shape
|
||||
audio_flat = mx.transpose(a_latents, (0, 2, 1, 3))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||
|
||||
# Timesteps
|
||||
if video_state is not None:
|
||||
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
|
||||
# Pass 1: Positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_pos = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||
mx.eval(video_vel_pos, audio_vel_pos)
|
||||
|
||||
# Convert velocity to x0
|
||||
video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1))
|
||||
audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af))
|
||||
video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
|
||||
audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
|
||||
|
||||
video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32)
|
||||
audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_pos
|
||||
audio_x0_guided = audio_x0_pos
|
||||
|
||||
# Pass 2: CFG
|
||||
if use_cfg:
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_neg = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||
mx.eval(video_vel_neg, audio_vel_neg)
|
||||
|
||||
video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32)
|
||||
audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg)
|
||||
audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg)
|
||||
|
||||
# Pass 3: STG
|
||||
if use_stg:
|
||||
video_vel_ptb, audio_vel_ptb = transformer(
|
||||
video=video_modality_pos, audio=audio_modality_pos,
|
||||
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
|
||||
)
|
||||
mx.eval(video_vel_ptb, audio_vel_ptb)
|
||||
|
||||
video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32)
|
||||
audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb)
|
||||
audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb)
|
||||
|
||||
# Pass 4: Modality isolation
|
||||
if use_modality:
|
||||
video_vel_iso, audio_vel_iso = transformer(
|
||||
video=video_modality_pos, audio=audio_modality_pos,
|
||||
skip_cross_modal=True,
|
||||
)
|
||||
mx.eval(video_vel_iso, audio_vel_iso)
|
||||
|
||||
video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32)
|
||||
audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso)
|
||||
audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso)
|
||||
|
||||
# Rescale (separate factors for video and audio)
|
||||
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
||||
v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8)
|
||||
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
||||
video_x0_guided = video_x0_guided * v_factor
|
||||
if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
||||
a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8)
|
||||
a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale)
|
||||
audio_x0_guided = audio_x0_guided * a_factor
|
||||
|
||||
# Reshape to spatial
|
||||
video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w))
|
||||
audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af))
|
||||
audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3))
|
||||
|
||||
# Post-process with mask
|
||||
if video_state is not None:
|
||||
clean_f32 = video_state.clean_latent.astype(mx.float32)
|
||||
mask_f32 = video_state.denoise_mask.astype(mx.float32)
|
||||
video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32)
|
||||
|
||||
mx.eval(video_denoised, audio_denoised)
|
||||
return video_denoised, audio_denoised
|
||||
|
||||
# Main res_2s loop
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
disable=not verbose,
|
||||
) as progress:
|
||||
passes = ["res2s"]
|
||||
if use_cfg: passes.append("CFG")
|
||||
if use_stg: passes.append("STG")
|
||||
if use_modality: passes.append("Mod")
|
||||
label = "+".join(passes)
|
||||
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps)
|
||||
|
||||
for step_idx in range(n_full_steps):
|
||||
sigma = sigmas_list[step_idx]
|
||||
sigma_next = sigmas_list[step_idx + 1]
|
||||
h = hs[step_idx]
|
||||
|
||||
# Initialize anchor
|
||||
x_anchor_video = video_latents
|
||||
x_anchor_audio = audio_latents
|
||||
|
||||
# ============================================================
|
||||
# Stage 1: Evaluate denoiser at current sigma
|
||||
# ============================================================
|
||||
denoised_video_1, denoised_audio_1 = _eval_guided_denoise(
|
||||
video_latents, audio_latents, sigma
|
||||
)
|
||||
|
||||
# RK coefficients
|
||||
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
|
||||
|
||||
# Substep sigma (geometric midpoint for c2=0.5)
|
||||
sub_sigma = math.sqrt(sigma * sigma_next)
|
||||
|
||||
# Compute midpoint
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
|
||||
x_mid_video = x_anchor_video + h * a21 * eps_1_video
|
||||
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
||||
|
||||
# SDE noise injection at substep
|
||||
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
|
||||
substep_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
|
||||
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
||||
mx.eval(x_mid_video, x_mid_audio)
|
||||
|
||||
# ============================================================
|
||||
# Bong iteration: refine anchor (pure arithmetic, no model calls)
|
||||
# ============================================================
|
||||
if bongmath and h < 0.5 and sigma > 0.03:
|
||||
for _ in range(bongmath_max_iter):
|
||||
x_anchor_video = x_mid_video - h * a21 * eps_1_video
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
||||
|
||||
# ============================================================
|
||||
# Stage 2: Evaluate denoiser at midpoint sigma
|
||||
# ============================================================
|
||||
denoised_video_2, denoised_audio_2 = _eval_guided_denoise(
|
||||
x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Final combination with RK coefficients
|
||||
# ============================================================
|
||||
eps_2_video = denoised_video_2 - x_anchor_video
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
|
||||
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
|
||||
# SDE noise injection at step level
|
||||
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
|
||||
step_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
|
||||
video_latents = x_next_video.astype(mx.float32)
|
||||
audio_latents = x_next_audio.astype(mx.float32)
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
|
||||
# Final clean step if original schedule ended at 0
|
||||
if sigmas.tolist()[-1] == 0:
|
||||
denoised_video, denoised_audio = _eval_guided_denoise(
|
||||
video_latents, audio_latents, sigmas_list[n_full_steps]
|
||||
)
|
||||
video_latents = denoised_video
|
||||
audio_latents = denoised_audio
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
return video_latents, audio_latents
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audio Loading and Processing
|
||||
# =============================================================================
|
||||
@@ -1117,13 +1441,16 @@ def generate_video(
|
||||
modality_scale: float = 1.0,
|
||||
lora_path: Optional[str] = None,
|
||||
lora_strength: float = 1.0,
|
||||
lora_strength_stage_1: Optional[float] = None,
|
||||
lora_strength_stage_2: Optional[float] = None,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
Supports three pipelines:
|
||||
Supports four pipelines:
|
||||
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
|
||||
- DEV: Single-stage generation with dynamic sigmas and CFG
|
||||
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
|
||||
- DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale
|
||||
|
||||
Args:
|
||||
model_repo: Model repository ID
|
||||
@@ -1158,7 +1485,7 @@ def generate_video(
|
||||
start_time = time.time()
|
||||
|
||||
# Validate dimensions
|
||||
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE)
|
||||
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ)
|
||||
divisor = 64 if is_two_stage else 32
|
||||
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
|
||||
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
|
||||
@@ -1177,13 +1504,14 @@ def generate_video(
|
||||
PipelineType.DISTILLED: "DISTILLED",
|
||||
PipelineType.DEV: "DEV",
|
||||
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
|
||||
PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ",
|
||||
}
|
||||
pipeline_name = pipeline_names[pipeline]
|
||||
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]"
|
||||
console.print(Panel(header, expand=False))
|
||||
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
||||
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
|
||||
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
|
||||
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
|
||||
@@ -1237,14 +1565,14 @@ def generate_video(
|
||||
|
||||
# Encode prompts - always get audio embeddings since the model was trained
|
||||
# with joint audio-video processing (PyTorch unconditionally generates audio)
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
|
||||
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
|
||||
if pipeline == PipelineType.DEV_TWO_STAGE:
|
||||
if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
text_embeddings = video_embeddings_pos
|
||||
else:
|
||||
# Distilled pipeline - single embedding
|
||||
@@ -1655,6 +1983,190 @@ def generate_video(
|
||||
audio_embeddings=audio_embeddings_pos,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
|
||||
# ======================================================================
|
||||
# DEV TWO-STAGE HQ PIPELINE:
|
||||
# Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25
|
||||
# Upsample: 2x spatial via LatentUpsampler
|
||||
# Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG
|
||||
# ======================================================================
|
||||
|
||||
# HQ defaults
|
||||
hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25
|
||||
hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5
|
||||
hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45
|
||||
hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15
|
||||
|
||||
# Load VAE encoder for I2V
|
||||
stage1_image_latent = None
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Auto-detect and merge LoRA for stage 1 (strength 0.25)
|
||||
if lora_path is None:
|
||||
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
|
||||
if lora_files:
|
||||
lora_path = str(lora_files[0])
|
||||
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
|
||||
else:
|
||||
console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]")
|
||||
|
||||
if lora_path is not None:
|
||||
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
||||
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
||||
|
||||
# Stage 1: res_2s denoising at half resolution with CFG
|
||||
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
||||
num_tokens = latent_frames * stage1_h * stage1_w
|
||||
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
||||
|
||||
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
state1 = None
|
||||
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||
if is_i2v and stage1_image_latent is not None:
|
||||
state1 = LatentState(
|
||||
latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
||||
state1 = apply_conditioning(state1, [conditioning])
|
||||
|
||||
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
|
||||
noise_scale = sigmas[0]
|
||||
scaled_mask = state1.denoise_mask * noise_scale
|
||||
state1 = LatentState(
|
||||
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state1.clean_latent,
|
||||
denoise_mask=state1.denoise_mask,
|
||||
)
|
||||
latents = state1.latent
|
||||
mx.eval(latents)
|
||||
else:
|
||||
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
# Stage 1: res_2s with CFG (STG disabled for HQ by default)
|
||||
latents, audio_latents = denoise_res2s_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0,
|
||||
verbose=verbose, video_state=state1,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
noise_seed=seed,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Upsample latents 2x
|
||||
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
||||
mx.eval(latents)
|
||||
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Latents upsampled")
|
||||
|
||||
# Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total)
|
||||
if lora_path is not None:
|
||||
additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1
|
||||
if additional_strength > 0:
|
||||
with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"):
|
||||
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
||||
|
||||
# Stage 2: res_2s refinement at full resolution (no CFG)
|
||||
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
state2 = LatentState(
|
||||
latent=latents,
|
||||
clean_latent=mx.zeros_like(latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
||||
state2 = apply_conditioning(state2, [conditioning])
|
||||
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
scaled_mask = state2.denoise_mask * noise_scale
|
||||
state2 = LatentState(
|
||||
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state2.clean_latent,
|
||||
denoise_mask=state2.denoise_mask,
|
||||
)
|
||||
latents = state2.latent
|
||||
mx.eval(latents)
|
||||
else:
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
latents = noise * noise_scale + latents * one_minus_scale
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement
|
||||
if audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Stage 2: res_2s with no CFG (positive embeddings only)
|
||||
stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32)
|
||||
latents, audio_latents = denoise_res2s_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2)
|
||||
audio_embeddings_pos, audio_embeddings_pos,
|
||||
transformer, stage2_sigmas, cfg_scale=1.0, # no CFG
|
||||
audio_cfg_scale=1.0,
|
||||
cfg_rescale=0.0, verbose=verbose, video_state=state2,
|
||||
noise_seed=seed + 1,
|
||||
)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
@@ -1857,8 +2369,8 @@ Examples:
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
|
||||
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"],
|
||||
help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)")
|
||||
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"],
|
||||
help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
|
||||
help="Negative prompt for CFG (dev pipeline only)")
|
||||
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
|
||||
@@ -1895,12 +2407,15 @@ Examples:
|
||||
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
|
||||
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
|
||||
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
||||
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
||||
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipeline_map = {
|
||||
"distilled": PipelineType.DISTILLED,
|
||||
"dev": PipelineType.DEV,
|
||||
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
|
||||
"dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ,
|
||||
}
|
||||
pipeline = pipeline_map[args.pipeline]
|
||||
|
||||
@@ -1940,6 +2455,8 @@ Examples:
|
||||
modality_scale=args.modality_scale,
|
||||
lora_path=args.lora_path,
|
||||
lora_strength=args.lora_strength,
|
||||
lora_strength_stage_1=args.lora_strength_stage_1,
|
||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||
)
|
||||
|
||||
|
||||
|
||||
181
mlx_video/samplers.py
Normal file
181
mlx_video/samplers.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Second-order res_2s sampler for diffusion models.
|
||||
|
||||
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
|
||||
noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
phi_1(z) = (e^z - 1) / z
|
||||
phi_2(z) = (e^z - 1 - z) / z^2
|
||||
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
|
||||
"""
|
||||
if abs(neg_h) < 1e-10:
|
||||
return 1.0 / math.factorial(j)
|
||||
|
||||
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
|
||||
return (math.exp(neg_h) - remainder) / (neg_h**j)
|
||||
|
||||
|
||||
def get_res2s_coefficients(
|
||||
h: float,
|
||||
phi_cache: dict,
|
||||
c2: float = 0.5,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute res_2s Runge-Kutta coefficients for a given step size.
|
||||
|
||||
Args:
|
||||
h: Step size in log-space = log(sigma / sigma_next)
|
||||
phi_cache: Dictionary to cache phi function results.
|
||||
c2: Substep position (default 0.5 = midpoint)
|
||||
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
return phi_cache[cache_key]
|
||||
result = phi(j, neg_h)
|
||||
phi_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
neg_h_c2 = -h * c2
|
||||
phi_1_c2 = get_phi(1, neg_h_c2)
|
||||
a21 = c2 * phi_1_c2
|
||||
|
||||
neg_h_full = -h
|
||||
phi_2_full = get_phi(2, neg_h_full)
|
||||
b2 = phi_2_full / c2
|
||||
|
||||
phi_1_full = get_phi(1, neg_h_full)
|
||||
b1 = phi_1_full - b2
|
||||
|
||||
return a21, b1, b2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute SDE coefficients for variance-preserving noise injection.
|
||||
|
||||
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
|
||||
|
||||
Returns:
|
||||
(alpha_ratio, sigma_down, sigma_up)
|
||||
"""
|
||||
sigma_up = sigma_next * 0.5
|
||||
# Clamp sigma_up to avoid sqrt(negative)
|
||||
sigma_up = min(sigma_up, sigma_next * 0.9999)
|
||||
|
||||
sigma_signal = 1.0 - sigma_next # sigma_max=1
|
||||
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
|
||||
alpha_ratio = sigma_signal + sigma_residual
|
||||
|
||||
if alpha_ratio == 0:
|
||||
sigma_down = sigma_next
|
||||
else:
|
||||
sigma_down = sigma_residual / alpha_ratio
|
||||
|
||||
# Handle NaN edge cases
|
||||
if math.isnan(sigma_up):
|
||||
sigma_up = 0.0
|
||||
if math.isnan(sigma_down):
|
||||
sigma_down = sigma_next
|
||||
if math.isnan(alpha_ratio):
|
||||
alpha_ratio = 1.0
|
||||
|
||||
return alpha_ratio, sigma_down, sigma_up
|
||||
|
||||
|
||||
def sde_noise_step(
|
||||
sample: mx.array,
|
||||
denoised_sample: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
noise: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply SDE noise injection step.
|
||||
|
||||
Advances sample from sigma to sigma_next with stochastic noise injection.
|
||||
|
||||
Args:
|
||||
sample: Current sample (anchor point)
|
||||
denoised_sample: Denoised prediction at this step
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
noise: Pre-generated noise tensor (channel-wise normalized)
|
||||
|
||||
Returns:
|
||||
Noised sample at sigma_next
|
||||
"""
|
||||
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
|
||||
|
||||
if sigma_up == 0 or sigma_next == 0:
|
||||
return denoised_sample
|
||||
|
||||
# Float32 arithmetic
|
||||
sample_f32 = sample.astype(mx.float32)
|
||||
denoised_f32 = denoised_sample.astype(mx.float32)
|
||||
noise_f32 = noise.astype(mx.float32)
|
||||
|
||||
# Extract epsilon prediction
|
||||
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
|
||||
return x_noised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
Operates on the last 2 dimensions (spatial H, W or time, freq).
|
||||
"""
|
||||
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
|
||||
x = x - mean
|
||||
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
|
||||
x = x / std
|
||||
return x
|
||||
|
||||
|
||||
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
|
||||
"""Generate channel-wise normalized Gaussian noise.
|
||||
|
||||
PyTorch uses float64; we use float32 (MLX doesn't support float64).
|
||||
The channel-wise normalization is the key quality-affecting step.
|
||||
|
||||
Args:
|
||||
shape: Shape of the noise tensor
|
||||
key: MLX random key for deterministic generation
|
||||
|
||||
Returns:
|
||||
Channel-wise normalized noise in float32
|
||||
"""
|
||||
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
|
||||
# Global normalization
|
||||
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
|
||||
# Channel-wise normalization
|
||||
noise = channelwise_normalize(noise)
|
||||
return noise
|
||||
Reference in New Issue
Block a user