add audio
This commit is contained in:
@@ -34,10 +34,11 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
|
||||
|
||||
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
|
||||
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
|
||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user