improve dev color and quality
This commit is contained in:
@@ -545,6 +545,7 @@ def denoise_dev(
|
|||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigmas: mx.array,
|
sigmas: mx.array,
|
||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
|
cfg_rescale: float = 0.0,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
state: Optional[LatentState] = None,
|
state: Optional[LatentState] = None,
|
||||||
use_apg: bool = False,
|
use_apg: bool = False,
|
||||||
@@ -554,6 +555,9 @@ def denoise_dev(
|
|||||||
"""Run denoising loop for dev pipeline with CFG or APG guidance.
|
"""Run denoising loop for dev pipeline with CFG or APG guidance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
|
||||||
|
variance relative to conditional prediction to reduce over-saturation.
|
||||||
|
PyTorch default is 0.7. Set to 0.0 to disable.
|
||||||
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
|
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
|
||||||
APG decomposes guidance into parallel/orthogonal components
|
APG decomposes guidance into parallel/orthogonal components
|
||||||
for more stable I2V generation.
|
for more stable I2V generation.
|
||||||
@@ -667,6 +671,14 @@ def denoise_dev(
|
|||||||
else:
|
else:
|
||||||
# Standard CFG
|
# Standard CFG
|
||||||
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||||
|
|
||||||
|
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
|
||||||
|
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
|
||||||
|
# pred = pred * factor
|
||||||
|
if cfg_rescale > 0.0:
|
||||||
|
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
|
||||||
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
||||||
|
x0_guided_f32 = x0_guided_f32 * v_factor
|
||||||
else:
|
else:
|
||||||
x0_guided_f32 = x0_pos_f32
|
x0_guided_f32 = x0_pos_f32
|
||||||
|
|
||||||
@@ -1381,6 +1393,7 @@ def generate_video(
|
|||||||
latents, video_positions,
|
latents, video_positions,
|
||||||
video_embeddings_pos, video_embeddings_neg,
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
|
cfg_rescale=cfg_rescale,
|
||||||
verbose=verbose, state=video_state,
|
verbose=verbose, state=video_state,
|
||||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||||
)
|
)
|
||||||
@@ -1477,6 +1490,7 @@ def generate_video(
|
|||||||
latents, positions,
|
latents, positions,
|
||||||
video_embeddings_pos, video_embeddings_neg,
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
|
cfg_rescale=cfg_rescale,
|
||||||
verbose=verbose, state=state1,
|
verbose=verbose, state=state1,
|
||||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user