diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 4df6fbe..7383464 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -545,6 +545,7 @@ def denoise_dev( transformer: LTXModel, sigmas: mx.array, cfg_scale: float = 4.0, + cfg_rescale: float = 0.0, verbose: bool = True, state: Optional[LatentState] = None, use_apg: bool = False, @@ -554,6 +555,9 @@ def denoise_dev( """Run denoising loop for dev pipeline with CFG or APG guidance. Args: + cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction + variance relative to conditional prediction to reduce over-saturation. + PyTorch default is 0.7. Set to 0.0 to disable. use_apg: Use Adaptive Projected Guidance instead of standard CFG. APG decomposes guidance into parallel/orthogonal components for more stable I2V generation. @@ -667,6 +671,14 @@ def denoise_dev( else: # Standard CFG x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) + + # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) + # factor = rescale * (cond_std / pred_std) + (1 - rescale) + # pred = pred * factor + if cfg_rescale > 0.0: + v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8) + v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) + x0_guided_f32 = x0_guided_f32 * v_factor else: x0_guided_f32 = x0_pos_f32 @@ -1381,6 +1393,7 @@ def generate_video( latents, video_positions, video_embeddings_pos, video_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, state=video_state, use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold ) @@ -1477,6 +1490,7 @@ def generate_video( latents, positions, video_embeddings_pos, video_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, state=state1, use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold )