Skip to content

Commit d747343

Browse files
committed
initial commit
0 parents  commit d747343

11 files changed

+1490
-0
lines changed

.gitignore

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
outputs/
2+
src/models/
3+
data/
4+
logs/
5+
.idea/
6+
dgx/scripts/
7+
.ipynb_checkpoints/
8+
src/*jpg
9+
notebooks/.ipynb_checkpoints/*
10+
exps/unit_local/configs/
11+
exps/unit_local/backup/
12+
src/yaml_generator.py
13+
*.tar.gz
14+
*ipynb
15+
*.zip
16+
*.pkl
17+
*.pyc

Dockerfile

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
FROM nvidia/cuda:8.0-cudnn7-runtime-ubuntu16.04
2+
# Set anaconda path
3+
ENV ANACONDA /opt/anaconda
4+
ENV PATH $ANACONDA/bin:$PATH
5+
# Download anaconda and install it
6+
RUN apt-get update && apt-get install -y wget build-essential
7+
RUN apt-get update && apt-get install -y libopencv-dev python-opencv
8+
RUN apt-get update && apt-get install -y --no-install-recommends \
9+
build-essential \
10+
cmake \
11+
git \
12+
curl \
13+
ca-certificates \
14+
libjpeg-dev \
15+
libpng-dev
16+
RUN wget https://repo.continuum.io/archive/Anaconda2-5.0.1-Linux-x86_64.sh -P /tmp
17+
RUN bash /tmp/Anaconda2-5.0.1-Linux-x86_64.sh -b -p $ANACONDA
18+
RUN rm /tmp/Anaconda2-5.0.1-Linux-x86_64.sh -rf
19+
RUN conda install -y pytorch torchvision cuda80 -c pytorch
20+
RUN conda install -y -c anaconda pip
21+
RUN conda install -y -c menpo opencv
22+
RUN conda install -y -c anaconda yaml
23+
RUN pip install tensorboard
24+
25+
26+

LICENSE.md

+177
Large diffs are not rendered by default.

__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
3+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4+
"""

configs/edges2shoes.yaml

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2+
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3+
train:
4+
# logger options
5+
image_save_iter: 10 # How often do you want to save output images during training
6+
image_display_iter: 10 # How often do you want to display output images during training
7+
display_size: 8 # How many images do you want to display each time
8+
snapshot_save_iter: 10 # How often do you want to save trained models
9+
log_iter: 1 # How often do you want to log the training stats
10+
11+
# optimization options
12+
max_iter: 1000000 # maximum number of training iterations
13+
batch_size: 1 # batch size
14+
weight_decay: 0.0001 # weight decay
15+
beta1: 0.5 # Adam parameter
16+
beta2: 0.999 # Adam parameter
17+
init: kaiming # initialization [gaussian/kaiming/xavier]
18+
lr: 0.0001 # initial learning rate
19+
lr_policy: step # learning rate scheduler
20+
step_size: 100000 # how often to decay learning rate
21+
gamma: 0.5 # how much to decay learning rate
22+
gan_w: 1 # weight of adversarial loss
23+
recon_x_a_w: 10 # weight of image reconstruction loss
24+
recon_x_b_w: 10 # weight of image reconstruction loss
25+
recon_s_a_w: 1 # weight of style reconstruction loss
26+
recon_s_b_w: 1 # weight of style reconstruction loss
27+
recon_c_w: 1 # weight of content reconstruction loss
28+
recon_x_a_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
29+
recon_x_b_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
30+
vgg_w: 0
31+
32+
# model options
33+
gen:
34+
dim: 64 # number of filters in the bottommost layer
35+
mlp_dim: 512 # number of filters in MLP
36+
style_dim: 8 # length of style code
37+
activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
38+
style_norm: none
39+
upsample_norm: ln
40+
n_downsample: 2 # number of downsampling layers in content encoder
41+
n_res: 4 # number of residual blocks in content encoder/decoder
42+
pad_type: zero
43+
downsample_style: 0
44+
dis:
45+
dim: 64 # number of filters in the bottommost layer
46+
norm: none # normalization layer [none/bn/in/ln]
47+
activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
48+
n_layer: 4 # number of layers in D
49+
gan_type: lsgan # GAN loss [lsgan/nsgan]
50+
num_scales: 3 # number of scales
51+
pad_type: zero
52+
53+
# data options
54+
input_dim_a: 1 # number of image channels [1/3]
55+
input_dim_b: 3 # number of image channels [1/3]
56+
num_workers: 8 # number of data loading threads
57+
new_size: 256 # first resize the shortest image side to this size
58+
crop_image_height: 256 # random crop image of this height
59+
crop_image_width: 256 # random crop image of this width
60+
root_a: /cosmo/datasets/edges2shoes/ # dataset folder location
61+
train_list_a: trainA.txt # image list
62+
test_list_a: testA.txt # image list
63+
root_b: /cosmo/datasets/edges2shoes/ # dataset folder location
64+
train_list_b: trainB.txt # image list
65+
test_list_b: testB.txt # image listprojects/MUNIT/outputs/inception_models/all_4_dogs/all_3_cat_best.pt

data.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch.utils.data as data
2+
import os.path
3+
4+
def default_loader(path):
5+
return Image.open(path).convert('RGB')
6+
7+
8+
def default_flist_reader(flist):
9+
"""
10+
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
11+
"""
12+
imlist = []
13+
with open(flist, 'r') as rf:
14+
for line in rf.readlines():
15+
impath = line.strip()
16+
imlist.append(impath)
17+
18+
return imlist
19+
20+
21+
class ImageFilelist(data.Dataset):
22+
def __init__(self, root, flist, transform=None,
23+
flist_reader=default_flist_reader, loader=default_loader):
24+
self.root = root
25+
self.imlist = flist_reader(os.path.join(self.root, flist))
26+
self.transform = transform
27+
self.loader = loader
28+
29+
def __getitem__(self, index):
30+
impath = self.imlist[index]
31+
img = self.loader(os.path.join(self.root, impath))
32+
if self.transform is not None:
33+
img = self.transform(img)
34+
35+
return img
36+
37+
def __len__(self):
38+
return len(self.imlist)
39+
40+
41+
class ImageLabelFilelist(data.Dataset):
42+
def __init__(self, root, flist, transform=None,
43+
flist_reader=default_flist_reader, loader=default_loader):
44+
self.root = root
45+
self.imlist = flist_reader(os.path.join(self.root, flist))
46+
self.transform = transform
47+
self.loader = loader
48+
self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist])))
49+
self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
50+
self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist]
51+
52+
def __getitem__(self, index):
53+
impath, label = self.imgs[index]
54+
img = self.loader(os.path.join(self.root, impath))
55+
if self.transform is not None:
56+
img = self.transform(img)
57+
return img, label
58+
59+
def __len__(self):
60+
return len(self.imgs)
61+
62+
###############################################################################
63+
# Code from
64+
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
65+
# Modified the original code so that it also loads images from the current
66+
# directory as well as the subdirectories
67+
###############################################################################
68+
69+
import torch.utils.data as data
70+
71+
from PIL import Image
72+
import os
73+
import os.path
74+
75+
IMG_EXTENSIONS = [
76+
'.jpg', '.JPG', '.jpeg', '.JPEG',
77+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
78+
]
79+
80+
81+
def is_image_file(filename):
82+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
83+
84+
85+
def make_dataset(dir):
86+
images = []
87+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
88+
89+
for root, _, fnames in sorted(os.walk(dir)):
90+
for fname in fnames:
91+
if is_image_file(fname):
92+
path = os.path.join(root, fname)
93+
images.append(path)
94+
95+
return images
96+
97+
class ImageFolder(data.Dataset):
98+
99+
def __init__(self, root, transform=None, return_paths=False,
100+
loader=default_loader):
101+
imgs = sorted(make_dataset(root))
102+
if len(imgs) == 0:
103+
raise(RuntimeError("Found 0 images in: " + root + "\n"
104+
"Supported image extensions are: " +
105+
",".join(IMG_EXTENSIONS)))
106+
107+
self.root = root
108+
self.imgs = imgs
109+
self.transform = transform
110+
self.return_paths = return_paths
111+
self.loader = loader
112+
113+
def __getitem__(self, index):
114+
path = self.imgs[index]
115+
img = self.loader(path)
116+
if self.transform is not None:
117+
img = self.transform(img)
118+
if self.return_paths:
119+
return img, path
120+
else:
121+
return img
122+
123+
def __len__(self):
124+
return len(self.imgs)

0 commit comments

Comments
 (0)