diff --git a/mlx_video/models/ltx_2/text_encoder.py b/mlx_video/models/ltx_2/text_encoder.py index c5d7aff..4f14c8a 100644 --- a/mlx_video/models/ltx_2/text_encoder.py +++ b/mlx_video/models/ltx_2/text_encoder.py @@ -1079,7 +1079,7 @@ class LTX2TextEncoder(nn.Module): for i, response in enumerate(generator): next_token = mx.array([response.token]) input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) - generated_tokens.append(next_token.squeeze()) + generated_tokens.append(response.token) generated_token_count += 1 progress.update(task, advance=1)