forked from tancik/fourier-feature-networks
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfourier_feature_transform.py
45 lines (31 loc) · 1.56 KB
/
fourier_feature_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from math import pi
import torch
class GaussianFourierFeatureTransform(torch.nn.Module):
"""
An implementation of Gaussian Fourier feature mapping.
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
Given an input of size [batches, num_input_channels, width, height],
returns a tensor of size [batches, mapping_size*2, width, height].
"""
def __init__(self, num_input_channels, mapping_size=256, scale=10):
super().__init__()
self._num_input_channels = num_input_channels
self._mapping_size = mapping_size
self._B = torch.randn((num_input_channels, mapping_size)) * scale
def forward(self, x):
assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
batches, channels, width, height = x.shape
assert channels == self._num_input_channels,\
"Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
# Make shape compatible for matmul with _B.
# From [B, C, W, H] to [(B*W*H), C].
x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
x = x @ self._B.to(x.device)
# From [(B*W*H), C] to [B, W, H, C]
x = x.view(batches, width, height, self._mapping_size)
# From [B, W, H, C] to [B, C, W, H]
x = x.permute(0, 3, 1, 2)
x = 2 * pi * x
return torch.cat([torch.sin(x), torch.cos(x)], dim=1)