Skip to content

Commit 7a13886

Browse files
inference code
adding test code along with some examples.
1 parent b3e6cf8 commit 7a13886

39 files changed

+3406
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
.DS_Store

README.md

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,57 @@
77
> [Nima Khademi Kalantari](http://nkhademi.com/)
88
> SIGGRAPH Asia 2023 (TOG)
99
10-
[Project](https://people.engr.tamu.edu/nimak/Papers/SIGAsia2023_Reshader) | [Paper](https://arxiv.org/abs/2309.10689) | [Video](https://youtu.be/XW-tl48D3Ok)
10+
[![Paper](https://img.shields.io/badge/cs.CV-Paper-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.10689)
11+
[![Project Page](https://img.shields.io/badge/ReShader-Website-blue?logo=googlechrome&logoColor=blue)](https://people.engr.tamu.edu/nimak/Papers/SIGAsia2023_Reshader)
12+
[![Video](https://img.shields.io/badge/YouTube-Video-c4302b?logo=youtube&logoColor=red)](https://youtu.be/XW-tl48D3Ok)
13+
14+
---------------------------------------------------
15+
<p align="center" >
16+
<a href="">
17+
<img src="assets/teaser.gif?raw=true" alt="demo" width="85%">
18+
</a>
19+
</p>
20+
21+
## Prerequisites
22+
You can setup the anaconda environment using:
23+
```
24+
conda env create -f environment.yml
25+
conda activate reshader
26+
```
27+
28+
Download pretrained models.
29+
The following script from [3D Moments](https://github.com/google-research/3d-moments) will download their pretrained models and [RGBD-inpainting networks](https://github.com/vt-vl-lab/3d-photo-inpainting).
30+
```
31+
./download.sh
32+
```
33+
34+
35+
## Demos
36+
We provided some examples in the `examples/` folder. You can render novel views with view-dependent highlights using:
37+
38+
```
39+
python renderer.py --input_dir examples/camera/ --config configs/render.txt
40+
```
41+
42+
## Training
43+
Training code and dataset to be added.
44+
45+
## Citation
46+
```
47+
@article{paliwal2023reshader,
48+
author = {Paliwal, Avinash and Nguyen, Brandon G. and Tsarov, Andrii and Kalantari, Nima Khademi},
49+
title = {ReShader: View-Dependent Highlights for Single Image View-Synthesis},
50+
year = {2023},
51+
issue_date = {December 2023},
52+
volume = {42},
53+
number = {6},
54+
journal = {ACM Trans. Graph.},
55+
month = {dec},
56+
articleno = {216},
57+
numpages = {9},
58+
}
59+
```
60+
61+
62+
## Acknowledgement
63+
The novel view synthesis part of the code is borrowed from [3D Moments](https://github.com/google-research/3d-moments).

assets/teaser.gif

10.5 MB
Loading

config.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import configargparse
16+
17+
18+
def config_parser():
19+
parser = configargparse.ArgumentParser()
20+
parser.add_argument('--config', is_config_file=True, help='config file path')
21+
# general
22+
parser.add_argument('--rootdir', type=str, default='./',
23+
help='the path to the project root directory.')
24+
parser.add_argument("--expname", type=str, default='exp', help='experiment name')
25+
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
26+
help='number of data loading workers (default: 8)')
27+
parser.add_argument('--distributed', action='store_true', help='if use distributed training')
28+
parser.add_argument("--local_rank", type=int, default=0, help='rank for distributed training')
29+
parser.add_argument("--eval_mode", action='store_true', help='if in eval mode')
30+
31+
########## dataset options ##########
32+
# train and eval dataset
33+
parser.add_argument("--train_dataset", type=str, default='vimeo',
34+
help='the training dataset')
35+
parser.add_argument("--dataset_weights", nargs='+', type=float, default=[],
36+
help='the weights for training datasets, used when multiple datasets are used.')
37+
parser.add_argument('--eval_dataset', type=str, default='vimeo', help='the dataset to evaluate')
38+
parser.add_argument("--batch_size", type=int, default=1, help='batch size, currently only support 1')
39+
40+
########## network architecture ##########
41+
parser.add_argument("--feature_dim", type=int, default=32, help='the dimension of the extracted features')
42+
43+
########## training options ##########
44+
parser.add_argument("--use_inpainting_mask_for_feature", action='store_true')
45+
parser.add_argument("--inpainting", action='store_true', help='if do inpainting')
46+
parser.add_argument("--train_raft", action='store_true', help='if train raft')
47+
parser.add_argument('--boundary_crop_ratio', type=float, default=0, help='crop the image before computing loss')
48+
parser.add_argument("--vary_pts_radius", action='store_true', help='if vary point radius as augmentation')
49+
parser.add_argument("--adaptive_pts_radius", action='store_true', help='if use adaptive point radius')
50+
parser.add_argument("--use_mask_for_decoding", action='store_true', help='if use mask for decoding')
51+
52+
########## rendering/evaluation ##########
53+
parser.add_argument("--use_depth_for_feature", action='store_true',
54+
help='if use depth map when extracting features')
55+
parser.add_argument("--use_depth_for_decoding", action='store_true',
56+
help='if use depth map when decoding')
57+
parser.add_argument("--point_radius", type=float, default=1.5,
58+
help='point radius for rasterization')
59+
parser.add_argument("--input_dir", type=str, default='', help='input folder that contains a pair of images')
60+
parser.add_argument("--visualize_rgbda_layers", action='store_true',
61+
help="if visualize rgbda layers, save in out dir")
62+
63+
########### iterations & learning rate options & loss ##########
64+
parser.add_argument("--n_iters", type=int, default=250000, help='num of iterations')
65+
parser.add_argument("--lr", type=float, default=3e-4, help='learning rate for feature extractor')
66+
parser.add_argument("--lr_raft", type=float, default=5e-6, help='learning rate for raft')
67+
parser.add_argument("--lrate_decay_factor", type=float, default=0.5,
68+
help='decay learning rate by a factor every specified number of steps')
69+
parser.add_argument("--lrate_decay_steps", type=int, default=50000,
70+
help='decay learning rate by a factor every specified number of steps')
71+
parser.add_argument('--loss_mode', type=str, default='lpips',
72+
help='the loss function to use')
73+
74+
########## checkpoints ##########
75+
parser.add_argument("--ckpt_path", type=str, default="",
76+
help='specific weights npy file to reload for coarse network')
77+
parser.add_argument("--no_reload", action='store_true',
78+
help='do not reload weights from saved ckpt')
79+
parser.add_argument("--no_load_opt", action='store_true',
80+
help='do not load optimizer when reloading')
81+
parser.add_argument("--no_load_scheduler", action='store_true',
82+
help='do not load scheduler when reloading')
83+
84+
########## logging/saving options ##########
85+
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
86+
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
87+
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
88+
89+
############ demo parameters ##############
90+
parser.add_argument("--spec", action='store_true', help='use specular frames')
91+
parser.add_argument("--tung", action='store_true', help='using tungsten depths')
92+
parser.add_argument("--normalize_depth", action='store_true', help='normalize depth when depth map is euclidean distance')
93+
parser.add_argument("--fov", type=float, default=45.0, help='fov of camera')
94+
parser.add_argument("--spd", type=str, default="246", help='spec directory suffix')
95+
parser.add_argument("--dscl", type=int, default=1, help='depth scaling')
96+
97+
#training schedule
98+
parser.add_argument('--num_iterations', type=int, default=300000, help='total epochs to train')
99+
parser.add_argument('-train_batch_size', type=int, default=10)
100+
parser.add_argument('-val_batch_size', type=int, default=4)
101+
parser.add_argument('-checkpoint', type=int, default=10, help='save checkpoint for every <checkpoint> epochs. Be aware that! It will replace the previous checkpoint.')
102+
parser.add_argument('-tb_toc',type=int, default=100, help="print output to terminal for every tb_toc iterations")
103+
104+
#lr schedule
105+
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4, help='learning rate of the network')
106+
107+
#loss
108+
parser.add_argument('-style_coeff', type=float, default=1, help='hyperparameter for style loss')
109+
parser.add_argument('-prcp_coeff', type=float, default=0.01, help='hyperparameter for perceptual loss')
110+
parser.add_argument('-mse_coeff', type=float, default=1.0, help='hyperparameter for MSE loss')
111+
parser.add_argument('-l1_coeff', type=float, default=0.1, help='hyperparameter for L1 loss')
112+
#training and eval data
113+
parser.add_argument('-dataset', type=str, default="/data2/avinash/datasets/specular_fixed/specular/", help='directory to the dataset')
114+
115+
#training utility
116+
parser.add_argument('--model_dir', type=str, default="unet_prcp_gmm_mse", help='model (scene) directory which store in runs/<model_dir>/')
117+
parser.add_argument('-clean', action='store_true', help='delete old weight without start training process')
118+
parser.add_argument('--clip', type=float, default=1.0)
119+
120+
#model
121+
parser.add_argument('-multi', type=bool, default=True, help='append multi level direction vector')
122+
parser.add_argument('-use_mlp', type=bool, default=False, help='use mlp for feature vector from direction vector')
123+
parser.add_argument('--start_iter',type=int, default=0, help="starting iteration")
124+
parser.add_argument('-basis_out',type=int, default=8, help="num of basis functions")
125+
parser.add_argument('-pos_enc_freq',type=int, default=5, help="num of freqs in positional encoding")
126+
parser.add_argument('--losses', type=str, nargs='+', help='losses to use', default=['mse', 'prcp', 'gmm'])
127+
parser.add_argument('--ckpt', type=str, default=None, help='checkpopint to continue from')
128+
parser.add_argument('--example_index', type=str, default=None, help='example index for testing')
129+
parser.add_argument('--test_root', type=str, default="real_data/", help='test examples root dir')
130+
parser.add_argument('-pad', type=bool, default=False, help='use mlp for feature vector from direction vector')
131+
parser.add_argument('--use_depth_posenc', type=bool, default=False, help='use mlp for feature vector from direction vector')
132+
133+
134+
args = parser.parse_args()
135+
return args
136+

configs/render.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
no_load_opt = True
2+
no_load_scheduler = True
3+
distributed = False
4+
loss_mode = vgg19
5+
train_dataset = tiktok
6+
eval_dataset = jamie
7+
eval_mode = True
8+
9+
use_depth_for_decoding = True
10+
adaptive_pts_radius = True
11+
train_raft = False
12+
visualize_rgbda_layers = False
13+
14+
ckpt_path = pretrained/model_250000.pth
15+
16+
17+
model_dir = reshader
18+
ckpt = ""
19+
use_depth_posenc = True
20+
dscl = 2

core/__init__.py

Whitespace-only changes.

core/depth_layering.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from sklearn.cluster import AgglomerativeClustering
18+
19+
20+
def get_depth_bins(depth=None, disparity=None, num_bins=None):
21+
"""
22+
:param depth: [1, 1, H, W]
23+
:param disparity: [1, 1, H, W]
24+
:return: depth_bins
25+
"""
26+
27+
assert (disparity is not None) or (depth is not None)
28+
if disparity is None:
29+
assert depth.min() > 1e-2
30+
disparity = 1. / depth
31+
32+
if depth is None:
33+
depth = 1. / torch.clamp(disparity, min=1e-2)
34+
35+
assert depth.shape[:2] == (1, 1) and disparity.shape[:2] == (1, 1)
36+
disparity_max = disparity.max().item()
37+
disparity_min = disparity.min().item()
38+
disparity_feat = disparity[:, :, ::10, ::10].reshape(-1, 1).cpu().numpy()
39+
disparity_feat = (disparity_feat - disparity_min) / (disparity_max - disparity_min)
40+
if num_bins is None:
41+
n_clusters = None
42+
distance_threshold = 5
43+
else:
44+
n_clusters = num_bins
45+
distance_threshold = None
46+
result = AgglomerativeClustering(n_clusters=n_clusters, distance_threshold=distance_threshold).fit(disparity_feat)
47+
num_bins = result.n_clusters_ if n_clusters is None else n_clusters
48+
depth_bins = [depth.min().item()]
49+
for i in range(num_bins):
50+
th = (disparity_feat[result.labels_ == i]).min()
51+
th = th * (disparity_max - disparity_min) + disparity_min
52+
depth_bins.append(1. / th)
53+
54+
depth_bins = sorted(depth_bins)
55+
depth_bins[0] = depth.min() - 1e-6
56+
depth_bins[-1] = depth.max() + 1e-6
57+
return depth_bins
58+
59+

0 commit comments

Comments
 (0)