forked from snagcliffs/PDE-FIND
-
Notifications
You must be signed in to change notification settings - Fork 0
/
siren_pytorch.py
154 lines (115 loc) · 4.64 KB
/
siren_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, repeat = 1):
return val if isinstance(val, tuple) else ((val,) * repeat)
# sin activation
class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
# siren layer
class Siren(nn.Module):
def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c = c, w0 = w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight, bias, c, w0):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if exists(bias):
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
return out
# siren network
class SirenNet(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
self.layers.append(Siren(
dim_in = layer_dim_in,
dim_out = dim_hidden,
w0 = layer_w0,
use_bias = use_bias,
is_first = is_first
))
final_activation = nn.Identity() if not exists(final_activation) else final_activation
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)
def forward(self, x, mods = None):
mods = cast_tuple(mods, self.num_layers)
for layer, mod in zip(self.layers, mods):
x = layer(x)
if exists(mod):
x *= rearrange(mod, 'd -> () d')
return self.last_layer(x)
# modulatory feed forward
class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else (dim_hidden + dim_in)
self.layers.append(nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.ReLU()
))
def forward(self, z):
x = z
hiddens = []
for layer in self.layers:
x = layer(x)
hiddens.append(x)
x = torch.cat((x, z))
return tuple(hiddens)
# wrapper
class SirenWrapper(nn.Module):
def __init__(self, net, image_width, image_height, latent_dim = None):
super().__init__()
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'
self.net = net
self.image_width = image_width
self.image_height = image_height
self.modulator = None
if exists(latent_dim):
self.modulator = Modulator(
dim_in = latent_dim,
dim_hidden = net.dim_hidden,
num_layers = net.num_layers
)
tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1)
mgrid = rearrange(mgrid, 'h w c -> (h w) c')
self.register_buffer('grid', mgrid)
def forward(self, img = None, *, latent = None):
modulate = exists(self.modulator)
assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'
mods = self.modulator(latent) if modulate else None
coords = self.grid.clone().detach().requires_grad_()
out = self.net(coords, mods)
out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)
if exists(img):
return F.mse_loss(img, out)
return out