fix loading
This commit is contained in:
@@ -594,11 +594,7 @@ def denoise_dev_av(
|
|||||||
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
||||||
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
||||||
|
|
||||||
# Dynamic CFG: compute per-step effective scale
|
if use_cfg:
|
||||||
step_cfg_scale = get_dynamic_cfg_scale(sigma, cfg_scale) if use_cfg else 1.0
|
|
||||||
apply_cfg_this_step = step_cfg_scale > 1.0
|
|
||||||
|
|
||||||
if apply_cfg_this_step:
|
|
||||||
# Negative conditioning pass
|
# Negative conditioning pass
|
||||||
video_modality_neg = Modality(
|
video_modality_neg = Modality(
|
||||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||||
@@ -620,8 +616,8 @@ def denoise_dev_av(
|
|||||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
||||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
# delta = (scale - 1) * (x0_pos - x0_neg)
|
||||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
||||||
video_x0_guided_f32 = video_x0_pos_f32 + (step_cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (step_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||||
|
|
||||||
# Apply CFG rescale if enabled
|
# Apply CFG rescale if enabled
|
||||||
if cfg_rescale > 0.0:
|
if cfg_rescale > 0.0:
|
||||||
@@ -659,8 +655,8 @@ def denoise_dev_av(
|
|||||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||||
else:
|
else:
|
||||||
video_latents = video_denoised
|
video_latents = video_denoised_f32
|
||||||
audio_latents = audio_denoised
|
audio_latents = audio_denoised_f32
|
||||||
|
|
||||||
mx.eval(video_latents, audio_latents)
|
mx.eval(video_latents, audio_latents)
|
||||||
progress.advance(task)
|
progress.advance(task)
|
||||||
@@ -1125,7 +1121,7 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / weight_file))
|
vae_decoder = VideoDecoder.from_pretrained("/Users/prince_canuma/Documents/mlx-video/LTX-2-dev/vae/decoder")
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|||||||
@@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
# Build decoder weights dict with key remapping
|
# Build decoder weights dict with key remapping
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
if "per_channel_statistics.mean" in weights:
|
||||||
|
return weights
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
|
|
||||||
@@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
config_dict = {}
|
||||||
|
|
||||||
|
# Load config from directory
|
||||||
|
config_path = model_path / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path) as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
if model_path.is_dir():
|
# Load weights from directory
|
||||||
# Load config from directory
|
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||||
config_path = model_path / "config.json"
|
if not weight_files:
|
||||||
if config_path.exists():
|
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||||
with open(config_path) as f:
|
weights = {}
|
||||||
config_dict = json.load(f)
|
for wf in weight_files:
|
||||||
|
weights.update(mx.load(str(wf)))
|
||||||
# Load weights from directory
|
|
||||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
|
||||||
if not weight_files:
|
|
||||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
|
||||||
weights = {}
|
|
||||||
for wf in weight_files:
|
|
||||||
weights.update(mx.load(str(wf)))
|
|
||||||
|
|
||||||
|
|
||||||
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||||
|
|||||||
Reference in New Issue
Block a user