Add custom text encoder with quantization

Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2026-01-13 22:56:51 +01:00
parent 01d895bc77
commit fc6ef20c1b
3 changed files with 87 additions and 85 deletions

View File

@@ -150,6 +150,7 @@ def denoise(
def generate_video(
model_repo: str,
text_encoder_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
@@ -189,6 +190,7 @@ def generate_video(
# Get model path
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
@@ -200,8 +202,8 @@ def generate_video(
# Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
text_embeddings, _ = text_encoder(prompt)
@@ -317,7 +319,7 @@ def generate_video(
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np
@@ -387,6 +389,12 @@ Examples:
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
parser.add_argument(
"--text-encoder-repo",
type=str,
default=None,
help="Text encoder repository to use (default: None)"
)
parser.add_argument(
"--verbose",
action="store_true",
@@ -396,6 +404,7 @@ Examples:
generate_video(
model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,