This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
self.linear = nn.Linear(
embedding_dim, embedding_coefficient * embedding_dim, bias=True
)
def __call__(
self,
@@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
self.time_proj = Timesteps(
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.timestep_embedder = TimestepEmbedding(
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
additional_embeds = self.additional_embedder(
resolution, aspect_ratio, hidden_dtype
)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb