format
This commit is contained in:
@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||
self.linear = nn.Linear(
|
||||
embedding_dim, embedding_coefficient * embedding_dim, bias=True
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
use_additional_conditions: bool = False,
|
||||
timestep_proj_dim: int = 256,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.size_emb_dim = size_emb_dim
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||
self.time_proj = Timesteps(
|
||||
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
|
||||
)
|
||||
|
||||
if use_additional_conditions and size_emb_dim > 0:
|
||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
# Add additional conditions if enabled
|
||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||
if resolution is not None and aspect_ratio is not None:
|
||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||
additional_embeds = self.additional_embedder(
|
||||
resolution, aspect_ratio, hidden_dtype
|
||||
)
|
||||
timesteps_emb = timesteps_emb + additional_embeds
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
Reference in New Issue
Block a user