Add support for DEV_TWO_STAGE pipeline and implement LoRA merging functionality in generate.py. Enhance video generation capabilities by allowing LoRA weights to be loaded and merged into the model, improving flexibility in model configurations. Update pipeline handling to accommodate the new two-stage generation process.

This commit is contained in:
Prince Canuma
2026-03-13 01:22:45 +01:00
parent e0aafd72fc
commit 7435facc52

View File

@@ -37,6 +37,7 @@ class PipelineType(Enum):
"""Pipeline type selector.""" """Pipeline type selector."""
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
DEV = "dev" # Single-stage, dynamic sigmas, CFG DEV = "dev" # Single-stage, dynamic sigmas, CFG
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
# Distilled model sigma schedules # Distilled model sigma schedules
@@ -61,6 +62,111 @@ AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_L
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
def load_and_merge_lora(
model: LTXModel,
lora_path: str,
strength: float = 1.0,
) -> None:
"""Load LoRA weights and merge them into the transformer model in-place.
Supports two formats:
- Raw PyTorch: keys like diffusion_model.{module}.lora_A.weight (needs sanitization)
- Pre-converted MLX: keys like {module}.lora_A.weight (already sanitized)
Merge formula: weight += (lora_B * strength) @ lora_A
Args:
model: The LTXModel transformer to merge into
lora_path: Path to the LoRA safetensors file or directory containing one
strength: LoRA strength/coefficient (default 1.0)
"""
# Resolve path: if directory, find the safetensors file inside
lora_file = Path(lora_path)
if lora_file.is_dir():
candidates = sorted(lora_file.glob("*.safetensors"))
if not candidates:
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
lora_file = candidates[0]
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
# Load LoRA weights
lora_weights = mx.load(str(lora_file))
# Detect format: raw PyTorch has 'diffusion_model.' prefix
has_prefix = any(k.startswith("diffusion_model.") for k in lora_weights)
# Group into A/B pairs by module name
lora_pairs = {}
for key in lora_weights:
module_key = key
if has_prefix:
if not key.startswith("diffusion_model."):
continue
module_key = key.replace("diffusion_model.", "")
if module_key.endswith(".lora_A.weight"):
base_key = module_key.replace(".lora_A.weight", "")
lora_pairs.setdefault(base_key, {})["A"] = lora_weights[key]
elif module_key.endswith(".lora_B.weight"):
base_key = module_key.replace(".lora_B.weight", "")
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
# Apply key sanitization only for raw PyTorch format
if has_prefix:
sanitized_pairs = {}
for key, pair in lora_pairs.items():
new_key = key
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized_pairs[new_key] = pair
else:
sanitized_pairs = lora_pairs
# Get current model weights as a flat dict
def flatten_params(params, prefix=""):
flat = {}
for k, v in params.items():
full_key = f"{prefix}.{k}" if prefix else k
if isinstance(v, dict):
flat.update(flatten_params(v, full_key))
else:
flat[full_key] = v
return flat
flat_weights = flatten_params(dict(model.parameters()))
# Merge LoRA deltas
merged_count = 0
updates = []
for module_key, pair in sanitized_pairs.items():
if "A" not in pair or "B" not in pair:
continue
weight_key = f"{module_key}.weight"
if weight_key not in flat_weights:
continue
lora_a = pair["A"].astype(mx.float32) # (rank, in_features)
lora_b = pair["B"].astype(mx.float32) # (out_features, rank)
# delta = (lora_B * strength) @ lora_A
delta = (lora_b * strength) @ lora_a
base_weight = flat_weights[weight_key].astype(mx.float32)
merged_weight = base_weight + delta
updates.append((weight_key, merged_weight.astype(mx.bfloat16)))
merged_count += 1
model.load_weights(updates, strict=False)
mx.eval(model.parameters())
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
"""Compute CFG delta for classifier-free guidance. """Compute CFG delta for classifier-free guidance.
@@ -888,12 +994,15 @@ def generate_video(
use_apg: bool = False, use_apg: bool = False,
apg_eta: float = 1.0, apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0, apg_norm_threshold: float = 0.0,
lora_path: Optional[str] = None,
lora_strength: float = 1.0,
): ):
"""Generate video using LTX-2 models. """Generate video using LTX-2 models.
Supports two pipelines: Supports three pipelines:
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG - DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
- DEV: Single-stage generation with dynamic sigmas and CFG - DEV: Single-stage generation with dynamic sigmas and CFG
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
Args: Args:
model_repo: Model repository ID model_repo: Model repository ID
@@ -928,7 +1037,8 @@ def generate_video(
start_time = time.time() start_time = time.time()
# Validate dimensions # Validate dimensions
divisor = 64 if pipeline == PipelineType.DISTILLED else 32 is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE)
divisor = 64 if is_two_stage else 32
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
@@ -942,12 +1052,17 @@ def generate_video(
if audio: if audio:
mode_str += "+Audio" mode_str += "+Audio"
pipeline_name = "DEV" if pipeline == PipelineType.DEV else "DISTILLED" pipeline_names = {
PipelineType.DISTILLED: "DISTILLED",
PipelineType.DEV: "DEV",
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
}
pipeline_name = pipeline_names[pipeline]
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height}{num_frames} frames[/]" header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height}{num_frames} frames[/]"
console.print(Panel(header, expand=False)) console.print(Panel(header, expand=False))
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if pipeline == PipelineType.DEV: if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]")
if is_i2v: if is_i2v:
@@ -962,9 +1077,8 @@ def generate_video(
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Model weight file
# Calculate latent dimensions # Calculate latent dimensions
if pipeline == PipelineType.DISTILLED: if is_two_stage:
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32 stage2_h, stage2_w = height // 32, width // 32
else: else:
@@ -996,8 +1110,8 @@ def generate_video(
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Encode prompts # Encode prompts
if pipeline == PipelineType.DEV: if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
# Dev pipeline needs positive and negative embeddings # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
if audio: if audio:
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
@@ -1009,6 +1123,9 @@ def generate_video(
audio_embeddings_pos = audio_embeddings_neg = None audio_embeddings_pos = audio_embeddings_neg = None
model_dtype = video_embeddings_pos.dtype model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg) mx.eval(video_embeddings_pos, video_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE:
text_embeddings = video_embeddings_pos
else: else:
# Distilled pipeline - single embedding # Distilled pipeline - single embedding
if audio: if audio:
@@ -1172,7 +1289,7 @@ def generate_video(
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
) )
else: elif pipeline == PipelineType.DEV:
# ====================================================================== # ======================================================================
# DEV PIPELINE: Single-stage with CFG # DEV PIPELINE: Single-stage with CFG
# ====================================================================== # ======================================================================
@@ -1193,7 +1310,6 @@ def generate_video(
console.print("[green]✓[/] VAE encoder loaded and image encoded") console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Generate sigma schedule with token-count-dependent shifting # Generate sigma schedule with token-count-dependent shifting
num_tokens = latent_frames * latent_h * latent_w
sigmas = ltx2_scheduler(steps=num_inference_steps) sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas) mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]") console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
@@ -1261,6 +1377,181 @@ def generate_video(
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
elif pipeline == PipelineType.DEV_TWO_STAGE:
# ======================================================================
# DEV TWO-STAGE PIPELINE:
# Stage 1: Dev denoising at half resolution with CFG
# Upsample: 2x spatial via LatentUpsampler
# Stage 2: Distilled denoising at full resolution with LoRA, no CFG
# ======================================================================
# Load VAE encoder for I2V
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
del vae_encoder
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Stage 1: Dev denoising at half resolution with CFG
sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas)
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = None
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
state1 = None
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
state1 = LatentState(
latent=mx.zeros(stage1_shape, dtype=model_dtype),
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state1 = apply_conditioning(state1, [conditioning])
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
noise_scale = sigmas[0]
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
latents = state1.latent
mx.eval(latents)
else:
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents)
# Run stage 1 with dev-style CFG denoising
if audio:
latents, audio_latents = denoise_dev_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
)
else:
latents = denoise_dev(
latents, positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
)
# Upsample latents 2x
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
mx.eval(latents)
del upsampler
mx.clear_cache()
console.print("[green]✓[/] Latents upsampled")
# Merge LoRA weights for stage 2 (distilled refinement)
if lora_path is None:
# Auto-detect LoRA file in model directory
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
if lora_files:
lora_path = str(lora_files[0])
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
else:
console.print("[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]")
if lora_path is not None:
with console.status("[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=lora_strength)
# Stage 2: Distilled refinement at full resolution (no CFG)
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
state2 = None
if is_i2v and stage2_image_latent is not None:
state2 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state2 = apply_conditioning(state2, [conditioning])
noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
latents = state2.latent
mx.eval(latents)
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else:
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents)
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
# Stage 2 uses distilled denoising (no CFG)
latents, audio_latents = denoise_distilled(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None,
)
del transformer del transformer
mx.clear_cache() mx.clear_cache()
@@ -1445,6 +1736,9 @@ Examples:
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0 python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40 python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40
# Dev two-stage pipeline (dev + LoRA refinement)
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev-two-stage --cfg-scale 3.0
# Image-to-Video (works with both pipelines) # Image-to-Video (works with both pipelines)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev
@@ -1456,8 +1750,8 @@ Examples:
) )
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev"], parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"],
help="Pipeline type: distilled (two-stage, fast) or dev (single-stage, CFG)") help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)")
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
help="Negative prompt for CFG (dev pipeline only)") help="Negative prompt for CFG (dev pipeline only)")
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
@@ -1488,9 +1782,16 @@ Examples:
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
args = parser.parse_args() args = parser.parse_args()
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED pipeline_map = {
"distilled": PipelineType.DISTILLED,
"dev": PipelineType.DEV,
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
}
pipeline = pipeline_map[args.pipeline]
generate_video( generate_video(
model_repo=args.model_repo, model_repo=args.model_repo,
@@ -1522,6 +1823,8 @@ Examples:
use_apg=args.apg, use_apg=args.apg,
apg_eta=args.apg_eta, apg_eta=args.apg_eta,
apg_norm_threshold=args.apg_norm_threshold, apg_norm_threshold=args.apg_norm_threshold,
lora_path=args.lora_path,
lora_strength=args.lora_strength,
) )