Open
Description
Here is a simple version.
import torch
import torch.nn as nn
class Time2vec(nn.Module):
def __init__(self, c_in, c_out, activation="cos"):
super().__init__()
self.wnbn = nn.Linear(c_in, c_out - 1, bias=True)
self.w0b0 = nn.Linear(c_in, 1, bias=True)
self.act = torch.cos if activation == "cos" else torch.sin
def forward(self, x):
part0 = self.act(self.w0b0(x))
# print(part0.shape)
part1 = self.act(self.wnbn(x))
# print(part1.shape)
return torch.cat([part0, part1], -1)
if __name__ == "__main__":
test_x = torch.randn((1, 3, 3000)) # [N, C, L] -> batch, channel, length
m = Time2vec(3, 10)
out = m(test_x.permute(0,2,1))
print(out.shape)
Metadata
Metadata
Assignees
Labels
No labels