Skip to content

why not write it using nn.Linear? #6

Open
@crazydogen

Description

@crazydogen

Here is a simple version.

import torch
import torch.nn as nn

class Time2vec(nn.Module):
    def __init__(self, c_in, c_out, activation="cos"):
        super().__init__()
        self.wnbn = nn.Linear(c_in, c_out - 1, bias=True)
        self.w0b0 = nn.Linear(c_in, 1, bias=True)
        self.act = torch.cos if activation == "cos" else torch.sin

    def forward(self, x):
        part0 = self.act(self.w0b0(x))
        # print(part0.shape)
        part1 = self.act(self.wnbn(x))
        # print(part1.shape)
        return torch.cat([part0, part1], -1)


if __name__ == "__main__":
    test_x = torch.randn((1, 3, 3000))  # [N, C, L] -> batch, channel, length
    m = Time2vec(3, 10)
    out = m(test_x.permute(0,2,1))
    print(out.shape)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions