format
This commit is contained in:
@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -103,6 +105,8 @@ def make_attn(
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
raise NotImplementedError(
|
||||
f"Attention type {attn_type.value} is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
Reference in New Issue
Block a user