Refactor LTX2TextEncoder to utilize Rich for progress tracking during token generation. Replace tqdm with Rich's Progress for enhanced console output and user experience. Clean up imports and streamline the generation process.
This commit is contained in:
@@ -11,6 +11,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm, apply_quantization
|
from mlx_video.utils import rms_norm, apply_quantization
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
@@ -854,7 +856,6 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Enhanced prompt string
|
Enhanced prompt string
|
||||||
"""
|
"""
|
||||||
from tqdm import tqdm
|
|
||||||
try:
|
try:
|
||||||
from mlx_lm import stream_generate
|
from mlx_lm import stream_generate
|
||||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
@@ -878,7 +879,6 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
# Use mlx-lm generate with temperature sampling
|
# Use mlx-lm generate with temperature sampling
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
# Tokenize
|
# Tokenize
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
formatted,
|
formatted,
|
||||||
@@ -893,39 +893,51 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
kwargs.get("repetition_penalty", 1.3),
|
kwargs.get("repetition_penalty", 1.3),
|
||||||
kwargs.get("repetition_context_size", 20),
|
kwargs.get("repetition_context_size", 20),
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_token_count = 0
|
generated_token_count = 0
|
||||||
generated_tokens = []
|
generated_tokens = []
|
||||||
for i, response in enumerate(
|
console = Console()
|
||||||
tqdm(
|
|
||||||
stream_generate(
|
|
||||||
self.language_model,
|
|
||||||
tokenizer=self.processor,
|
|
||||||
prompt=input_ids.squeeze(0),
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
sampler=sampler,
|
|
||||||
logits_processors=logits_processors,
|
|
||||||
),
|
|
||||||
total=max_tokens,
|
|
||||||
disable=not verbose,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
next_token = mx.array([response.token])
|
|
||||||
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
|
||||||
generated_tokens.append(next_token.squeeze())
|
|
||||||
generated_token_count += 1
|
|
||||||
|
|
||||||
if i % 50 == 0:
|
generator = stream_generate(
|
||||||
mx.clear_cache()
|
self.language_model,
|
||||||
|
tokenizer=self.processor,
|
||||||
|
prompt=input_ids.squeeze(0),
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=sampler,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
)
|
||||||
|
|
||||||
# Check for EOS
|
progress = Progress(
|
||||||
if response.token == 1 or response.token == 107: # EOS tokens
|
SpinnerColumn(),
|
||||||
break
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=console,
|
||||||
|
disable=not verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
with progress:
|
||||||
|
task = progress.add_task("[cyan]Generating[/]", total=max_tokens)
|
||||||
|
|
||||||
|
for i, response in enumerate(generator):
|
||||||
|
next_token = mx.array([response.token])
|
||||||
|
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
||||||
|
generated_tokens.append(next_token.squeeze())
|
||||||
|
generated_token_count += 1
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
|
if i % 50 == 0:
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Check for EOS
|
||||||
|
if response.token == 1 or response.token == 107: # EOS tokens
|
||||||
|
progress.update(task, completed=max_tokens)
|
||||||
|
break
|
||||||
|
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Decode only the new tokens
|
# Decode only the new tokens
|
||||||
|
|
||||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||||
|
|||||||
Reference in New Issue
Block a user