From 4cd58f8b267ea943062c8f6651370d9bd7fa0d3f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 19 Jan 2026 02:13:10 +0100 Subject: [PATCH] Refactor LTX2TextEncoder to utilize Rich for progress tracking during token generation. Replace tqdm with Rich's Progress for enhanced console output and user experience. Clean up imports and streamline the generation process. --- mlx_video/models/ltx/text_encoder.py | 66 ++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index d6461d5..a38bb6d 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -11,6 +11,8 @@ from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn from mlx_video.utils import rms_norm, apply_quantization from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb @@ -854,7 +856,6 @@ class LTX2TextEncoder(nn.Module): Returns: Enhanced prompt string """ - from tqdm import tqdm try: from mlx_lm import stream_generate from mlx_lm.sample_utils import make_logits_processors, make_sampler @@ -878,7 +879,6 @@ class LTX2TextEncoder(nn.Module): # Use mlx-lm generate with temperature sampling mx.random.seed(seed) - # Tokenize inputs = self.processor( formatted, @@ -893,39 +893,51 @@ class LTX2TextEncoder(nn.Module): kwargs.get("repetition_penalty", 1.3), kwargs.get("repetition_context_size", 20), ) - + generated_token_count = 0 generated_tokens = [] - for i, response in enumerate( - tqdm( - stream_generate( - self.language_model, - tokenizer=self.processor, - prompt=input_ids.squeeze(0), - max_tokens=max_tokens, - sampler=sampler, - logits_processors=logits_processors, - ), - total=max_tokens, - disable=not verbose, - ) - ): - next_token = mx.array([response.token]) - input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) - generated_tokens.append(next_token.squeeze()) - generated_token_count += 1 + console = Console() - if i % 50 == 0: - mx.clear_cache() + generator = stream_generate( + self.language_model, + tokenizer=self.processor, + prompt=input_ids.squeeze(0), + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + ) - # Check for EOS - if response.token == 1 or response.token == 107: # EOS tokens - break + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) + + with progress: + task = progress.add_task("[cyan]Generating[/]", total=max_tokens) + + for i, response in enumerate(generator): + next_token = mx.array([response.token]) + input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) + generated_tokens.append(next_token.squeeze()) + generated_token_count += 1 + progress.update(task, advance=1) + + if i % 50 == 0: + mx.clear_cache() + + # Check for EOS + if response.token == 1 or response.token == 107: # EOS tokens + progress.update(task, completed=max_tokens) + break mx.clear_cache() # Decode only the new tokens - enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True) enhanced_prompt = self._clean_response(enhanced_prompt)