fix loading
This commit is contained in:
@@ -594,11 +594,7 @@ def denoise_dev_av(
|
||||
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
||||
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
||||
|
||||
# Dynamic CFG: compute per-step effective scale
|
||||
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
|
||||
apply_cfg_this_step = step_cfg_scale > 1.0
|
||||
|
||||
if apply_cfg_this_step:
|
||||
if use_cfg:
|
||||
# Negative conditioning pass
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
@@ -620,8 +616,8 @@ def denoise_dev_av(
|
||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
|
||||
# Apply CFG rescale if enabled
|
||||
if cfg_rescale > 0.0:
|
||||
@@ -659,8 +655,8 @@ def denoise_dev_av(
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
video_latents = video_denoised_f32
|
||||
audio_latents = audio_denoised_f32
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
@@ -1125,7 +1121,7 @@ def generate_video(
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
||||
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
@@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
@@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module):
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config_dict = {}
|
||||
|
||||
if model_path.is_dir():
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||
|
||||
Reference in New Issue
Block a user