diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index 81b815f..08f0840 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -1452,9 +1452,9 @@ def generate_video( use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, - stg_scale: float = 0.0, + stg_scale: float = 1.0, stg_blocks: Optional[list] = None, - modality_scale: float = 1.0, + modality_scale: float = 3.0, lora_path: Optional[str] = None, lora_strength: float = 1.0, lora_strength_stage_1: Optional[float] = None, @@ -2106,11 +2106,12 @@ def generate_video( # Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG # ====================================================================== - # HQ defaults + # HQ defaults: STG disabled, lower rescale, fewer steps (PyTorch LTX_2_3_HQ_PARAMS) hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25 hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5 hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45 hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15 + hq_stg_scale = stg_scale if stg_scale != 1.0 else 0.0 # Override default 1.0 → 0.0 # Load VAE encoder for I2V stage1_image_latent = None @@ -2201,7 +2202,7 @@ def generate_video( audio_cfg_scale=audio_cfg_scale, cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0, verbose=verbose, video_state=state1, - stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_scale=hq_stg_scale, stg_video_blocks=stg_blocks, stg_audio_blocks=stg_blocks, modality_scale=modality_scale, noise_seed=seed, audio_frozen=is_a2v, @@ -2531,9 +2532,9 @@ Examples: parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") - parser.add_argument("--stg-scale", type=float, default=0.0, help="STG (Spatiotemporal Guidance) scale (default 0.0 = disabled, PyTorch default: 1.0)") + parser.add_argument("--stg-scale", type=float, default=1.0, help="STG (Spatiotemporal Guidance) scale (default 1.0, 0.0 = disabled)") parser.add_argument("--stg-blocks", type=int, nargs="+", default=None, help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)") - parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)") + parser.add_argument("--modality-scale", type=float, default=3.0, help="Cross-modal guidance scale (default 3.0, 1.0 = disabled)") parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")