Refactor LTX2TextEncoder to utilize Rich for progress tracking during token generation. Replace tqdm with Rich's Progress for enhanced console output and user experience. Clean up imports and streamline the generation process.

This commit is contained in:
Prince Canuma
2026-01-19 02:13:10 +01:00
parent ac67ee8b1e
commit 4cd58f8b26

View File

@@ -11,6 +11,8 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@@ -854,7 +856,6 @@ class LTX2TextEncoder(nn.Module):
Returns: Returns:
Enhanced prompt string Enhanced prompt string
""" """
from tqdm import tqdm
try: try:
from mlx_lm import stream_generate from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.sample_utils import make_logits_processors, make_sampler
@@ -878,7 +879,6 @@ class LTX2TextEncoder(nn.Module):
# Use mlx-lm generate with temperature sampling # Use mlx-lm generate with temperature sampling
mx.random.seed(seed) mx.random.seed(seed)
# Tokenize # Tokenize
inputs = self.processor( inputs = self.processor(
formatted, formatted,
@@ -896,36 +896,48 @@ class LTX2TextEncoder(nn.Module):
generated_token_count = 0 generated_token_count = 0
generated_tokens = [] generated_tokens = []
for i, response in enumerate( console = Console()
tqdm(
stream_generate( generator = stream_generate(
self.language_model, self.language_model,
tokenizer=self.processor, tokenizer=self.processor,
prompt=input_ids.squeeze(0), prompt=input_ids.squeeze(0),
max_tokens=max_tokens, max_tokens=max_tokens,
sampler=sampler, sampler=sampler,
logits_processors=logits_processors, logits_processors=logits_processors,
), )
total=max_tokens,
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
disable=not verbose, disable=not verbose,
) )
):
with progress:
task = progress.add_task("[cyan]Generating[/]", total=max_tokens)
for i, response in enumerate(generator):
next_token = mx.array([response.token]) next_token = mx.array([response.token])
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
generated_tokens.append(next_token.squeeze()) generated_tokens.append(next_token.squeeze())
generated_token_count += 1 generated_token_count += 1
progress.update(task, advance=1)
if i % 50 == 0: if i % 50 == 0:
mx.clear_cache() mx.clear_cache()
# Check for EOS # Check for EOS
if response.token == 1 or response.token == 107: # EOS tokens if response.token == 1 or response.token == 107: # EOS tokens
progress.update(task, completed=max_tokens)
break break
mx.clear_cache() mx.clear_cache()
# Decode only the new tokens # Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True) enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self._clean_response(enhanced_prompt) enhanced_prompt = self._clean_response(enhanced_prompt)