diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index d6461d5..a38bb6d 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -11,6 +11,8 @@ from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn from mlx_video.utils import rms_norm, apply_quantization from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb @@ -854,7 +856,6 @@ class LTX2TextEncoder(nn.Module): Returns: Enhanced prompt string """ - from tqdm import tqdm try: from mlx_lm import stream_generate from mlx_lm.sample_utils import make_logits_processors, make_sampler @@ -878,7 +879,6 @@ class LTX2TextEncoder(nn.Module): # Use mlx-lm generate with temperature sampling mx.random.seed(seed) - # Tokenize inputs = self.processor( formatted, @@ -893,39 +893,51 @@ class LTX2TextEncoder(nn.Module): kwargs.get("repetition_penalty", 1.3), kwargs.get("repetition_context_size", 20), ) - + generated_token_count = 0 generated_tokens = [] - for i, response in enumerate( - tqdm( - stream_generate( - self.language_model, - tokenizer=self.processor, - prompt=input_ids.squeeze(0), - max_tokens=max_tokens, - sampler=sampler, - logits_processors=logits_processors, - ), - total=max_tokens, - disable=not verbose, - ) - ): - next_token = mx.array([response.token]) - input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) - generated_tokens.append(next_token.squeeze()) - generated_token_count += 1 + console = Console() - if i % 50 == 0: - mx.clear_cache() + generator = stream_generate( + self.language_model, + tokenizer=self.processor, + prompt=input_ids.squeeze(0), + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + ) - # Check for EOS - if response.token == 1 or response.token == 107: # EOS tokens - break + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) + + with progress: + task = progress.add_task("[cyan]Generating[/]", total=max_tokens) + + for i, response in enumerate(generator): + next_token = mx.array([response.token]) + input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) + generated_tokens.append(next_token.squeeze()) + generated_token_count += 1 + progress.update(task, advance=1) + + if i % 50 == 0: + mx.clear_cache() + + # Check for EOS + if response.token == 1 or response.token == 107: # EOS tokens + progress.update(task, completed=max_tokens) + break mx.clear_cache() # Decode only the new tokens - enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True) enhanced_prompt = self._clean_response(enhanced_prompt)