fix loading

This commit is contained in:
Prince Canuma
2026-01-24 01:37:38 +01:00
parent ef76ec0921
commit cb2d19c84d
2 changed files with 22 additions and 24 deletions

View File

@@ -594,11 +594,7 @@ def denoise_dev_av(
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
# Dynamic CFG: compute per-step effective scale
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
apply_cfg_this_step = step_cfg_scale > 1.0
if apply_cfg_this_step:
if use_cfg:
# Negative conditioning pass
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
@@ -620,8 +616,8 @@ def denoise_dev_av(
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
# delta = (scale - 1) * (x0_pos - x0_neg)
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled
if cfg_rescale > 0.0:
@@ -659,8 +655,8 @@ def denoise_dev_av(
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
else:
video_latents = video_denoised
audio_latents = audio_denoised
video_latents = video_denoised_f32
audio_latents = audio_denoised_f32
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1125,7 +1121,7 @@ def generate_video(
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
del transformer
mx.clear_cache()

View File

@@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module):
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
@@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module):
import json
model_path = Path(model_path)
config_dict = {}
if model_path.is_dir():
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))