- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
act_fn: str = "gelu_tanh",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user