fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -1192,9 +1192,11 @@ def generate_video(
if is_i2v:
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
audio_frames = None
# Always compute audio frames - PyTorch distilled pipeline unconditionally
# generates audio alongside video (model was trained with joint audio-video).
# The --audio flag only controls whether audio is decoded and saved to output.
audio_frames = compute_audio_frames(num_frames, fps)
if audio:
audio_frames = compute_audio_frames(num_frames, fps)
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
# Get model path
@@ -1233,32 +1235,21 @@ def generate_video(
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Encode prompts
# Encode prompts - always get audio embeddings since the model was trained
# with joint audio-video processing (PyTorch unconditionally generates audio)
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
if audio:
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
else:
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
audio_embeddings_pos = audio_embeddings_neg = None
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg)
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE:
text_embeddings = video_embeddings_pos
else:
# Distilled pipeline - single embedding
if audio:
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
mx.eval(text_embeddings, audio_embeddings)
else:
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
audio_embeddings = None
mx.eval(text_embeddings)
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
mx.eval(text_embeddings, audio_embeddings)
model_dtype = text_embeddings.dtype
del text_encoder
@@ -1317,12 +1308,10 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = None
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning
state1 = None
@@ -1406,7 +1395,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None:
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1417,7 +1406,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings if audio else None,
audio_embeddings=audio_embeddings,
)
elif pipeline == PipelineType.DEV:
@@ -1451,12 +1440,10 @@ def generate_video(
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
mx.eval(video_positions)
audio_positions = None
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Initialize latents with optional I2V conditioning
video_state = None
@@ -1484,31 +1471,19 @@ def generate_video(
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents)
# Denoise with CFG/APG/STG/modality
if audio:
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
video_positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
else:
# Use original denoise_dev with computed sigmas
latents = denoise_dev(
latents, video_positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
# Always use A/V denoising - PyTorch always processes audio+video jointly
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
video_positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -1553,12 +1528,10 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = None
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
state1 = None
@@ -1586,33 +1559,21 @@ def generate_video(
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents)
# Stage 1: Joint AV denoising at half resolution (matches PyTorch)
if audio:
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
else:
latents = denoise_dev(
latents, positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
# Stage 1: Always use joint AV denoising (matches PyTorch)
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
)
if audio and audio_latents is not None:
mx.eval(audio_latents)
mx.eval(audio_latents)
# Upsample latents 2x
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
@@ -1680,7 +1641,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio and audio_latents is not None:
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1691,7 +1652,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None,
audio_embeddings=audio_embeddings_pos,
)
del transformer