Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fourier expansion #114

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions configs/fern_mf.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
expname = fern_test_mf
basedir = ./logs
datadir = ./data/nerf_llff_data/fern
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

i_embed = 2
21 changes: 21 additions & 0 deletions configs/lego_mf.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
expname = blender_paper_lego_mf
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender

no_batching = True

use_viewdirs = True
white_bkgd = True
lrate_decay = 500

N_samples = 64
N_importance = 128
N_rand = 1024

precrop_iters = 500
precrop_frac = 0.5

half_res = True

i_embed = 2
15 changes: 12 additions & 3 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data

from torch.utils.tensorboard import SummaryWriter
from torchstat import stat


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
Expand Down Expand Up @@ -467,7 +470,7 @@ def config_parser():
parser.add_argument("--use_viewdirs", action='store_true',
help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
help='set 0 for default positional encoding, -1 for none, 2 for multivariable fourier')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
Expand Down Expand Up @@ -650,6 +653,10 @@ def train():
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)

# Check model size
# stat(render_kwargs_train['network_fn'], (65536, 2003))
# stat(render_kwargs_train['network_fine'], (65536, 2003))

# Short circuit if only rendering out from trained model
if args.render_only:
print('RENDER ONLY')
Expand Down Expand Up @@ -705,7 +712,7 @@ def train():
print('VAL views are', i_val)

# Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
writer = SummaryWriter(os.path.join(basedir, 'summaries', expname, 'base'))

start = start + 1
for i in trange(start, N_iters):
Expand Down Expand Up @@ -827,6 +834,8 @@ def train():

if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
writer.add_scalar("Train/Loss", loss.item(), i)
writer.add_scalar("Train/PSNR", psnr.item(), i)
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
Expand Down Expand Up @@ -875,4 +884,4 @@ def train():
if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')

train()
train()
65 changes: 64 additions & 1 deletion run_nerf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools


DEBUG = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
Expand All @@ -13,17 +17,38 @@

# Positional encoding (section 5.1)
class Embedder:
def spherical_features(sqrt_dim=10, rand=True):
if rand: # Random
# np.random.seed(0) # generate consistant random feature
u, v = np.random.rand(2, sqrt_dim ** 2)
else: # Stratified
segs = np.linspace(0, 1, sqrt_dim)
u, v = np.array(list(itertools.product(segs, segs))).transpose()
# Spherical sampling
i = 2 * np.pi * u
j = np.arccos(1 - 2 * v)
x = np.sin(j) * np.cos(i)
y = np.sin(j) * np.sin(i)
z = np.cos(j)
return torch.from_numpy(np.stack((x, y, z)).transpose()).to(device)

def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

def create_embedding_fn(self):
embed_fns = []
# Multivariable Fourier Basis
embed_mffns = []

d = self.kwargs['input_dims']
out_dim = 0
out_mf_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
embed_mffns.append(lambda x : x)
out_dim += d
out_mf_dim += d

max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
Expand All @@ -37,13 +62,45 @@ def create_embedding_fn(self):
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d


# # Multivariable Fourier Basis
# for freq_x in freq_bands:
# for freq_y in freq_bands:
# for freq_z in freq_bands:
# for p_fn in self.kwargs['periodic_fns']:
# embed_mffns.append(lambda x, p_fn=p_fn, freq_x=freq_x, freq_y=freq_y,
# freq_z=freq_z : p_fn(x[:, 0:1] * freq_x + x[:, 1:2] * freq_y + x[:, 2:3] * freq_z))
# out_mf_dim += 1

# Spherical Fourier Basis
np.random.seed(0) # generate consistant random feature
for freq in freq_bands:
for unit_vec in Embedder.spherical_features():
for p_fn in self.kwargs['periodic_fns']:
embed_mffns.append(lambda x, p_fn=p_fn, freq=freq, vec=unit_vec :
p_fn(freq * (vec[0] * x[:, 0:1] + vec[1] * x[:, 1:2] + vec[2] * x[:, 2:3])))
out_mf_dim += 1


self.embed_fns = embed_fns
self.embed_mffns = embed_mffns
self.out_dim = out_dim
self.out_mf_dim = out_mf_dim

def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def embed_mf(self, inputs):
if DEBUG:
fns = [fn(inputs) for fn in self.embed_fns]
mffns = [fn(inputs) for fn in self.embed_mffns]
print("fns", torch.cat(fns, -1).shape)
print("out_dim", self.out_dim)
print("mffns", torch.cat(mffns, -1).shape)
print("out_mf_dim", self.out_mf_dim)
exit(-1)
return torch.cat([fn(inputs) for fn in self.embed_mffns], -1)


def get_embedder(multires, i=0):
if i == -1:
Expand All @@ -59,6 +116,12 @@ def get_embedder(multires, i=0):
}

embedder_obj = Embedder(**embed_kwargs)

# Multivariable Fourier Basis
if i == 2:
embed = lambda x, eo=embedder_obj : eo.embed_mf(x)
return embed, embedder_obj.out_mf_dim

embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim

Expand Down