Skip to content

Commit

Permalink
add source code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ha0Tang committed Jul 20, 2019
1 parent 1af47ba commit a0dc5c9
Show file tree
Hide file tree
Showing 65 changed files with 2,034 additions and 0 deletions.
Binary file modified .DS_Store
Binary file not shown.
74 changes: 74 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import importlib
import torch.utils.data
from data.base_data_loader import BaseDataLoader
from data.base_dataset import BaseDataset

def find_dataset_using_name(dataset_name):
# Given the option --dataset_mode [datasetname],
# the file "data/datasetname_dataset.py"
# will be imported.
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

# In the file, the class called DatasetNameDataset() will
# be instantiated. It has to be a subclass of BaseDataset,
# and it is case-insensitive.
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls

if dataset is None:
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
exit(0)

return dataset


def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options


def create_dataset(opt):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset()
instance.initialize(opt)
print("dataset [%s] was created" % (instance.name()))
return instance


def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
data_loader.initialize(opt)
return data_loader


## Wrapper class of Dataset class that performs
## multi-threaded data loading
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = create_dataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))

def load_data(self):
return self

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)

def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
Binary file added data/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/aligned_dataset.cpython-36.pyc
Binary file not shown.
Binary file added data/__pycache__/aligned_dataset.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/base_data_loader.cpython-36.pyc
Binary file not shown.
Binary file added data/__pycache__/base_data_loader.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-36.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/image_folder.cpython-36.pyc
Binary file not shown.
Binary file added data/__pycache__/image_folder.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/single_dataset.cpython-36.pyc
Binary file not shown.
81 changes: 81 additions & 0 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image


class AlignedDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
self.AB_paths = sorted(make_dataset(self.dir_AB))
assert(opt.resize_or_crop == 'resize_and_crop')

def __getitem__(self, index):
AB_path = self.AB_paths[index]
ABCD = Image.open(AB_path).convert('RGB')
w, h = ABCD.size
w2 = int(w / 4)
A = ABCD.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
B = ABCD.crop((w2, 0, w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
C = ABCD.crop((w2+w2, 0, w2+w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
D = ABCD.crop((w2+w2+w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
# A.show()
# B.show()
# C.show()
# D.show()

A = transforms.ToTensor()(A)
B = transforms.ToTensor()(B)
C = transforms.ToTensor()(C)
D = transforms.ToTensor()(D)
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))

A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
C = C[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
D = D[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]

A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
D = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(D)

if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc

if (not self.opt.no_flip) and random.random() < 0.5:
idx = [i for i in range(A.size(2) - 1, -1, -1)]
idx = torch.LongTensor(idx)
A = A.index_select(2, idx)
B = B.index_select(2, idx)

if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)

if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)

return {'A': A, 'B': B, 'C': C, 'D': D,
'A_paths': AB_path, 'B_paths': AB_path}

def __len__(self):
return len(self.AB_paths)

def name(self):
return 'AlignedDataset'
10 changes: 10 additions & 0 deletions data/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class BaseDataLoader():
def __init__(self):
pass

def initialize(self, opt):
self.opt = opt
pass

def load_data():
return None
103 changes: 103 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms


class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
pass

def __len__(self):
return 0


def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'none':
transform_list.append(transforms.Lambda(
lambda img: __adjust(img)))
else:
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())

transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

# just modify the width and height to be multiple of 4
def __adjust(img):
ow, oh = img.size

# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
if ow % mult == 0 and oh % mult == 0:
return img
w = (ow - 1) // mult
w = (w + 1) * mult
h = (oh - 1) // mult
h = (h + 1) * mult

if ow != w or oh != h:
__print_size_warning(ow, oh, w, h)

return img.resize((w, h), Image.BICUBIC)


def __scale_width(img, target_width):
ow, oh = img.size

# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
if (ow == target_width and oh % mult == 0):
return img
w = target_width
target_height = int(target_width * oh / ow)
m = (target_height - 1) // mult
h = (m + 1) * mult

if target_height != h:
__print_size_warning(target_width, target_height, w, h)

return img.resize((w, h), Image.BICUBIC)


def __print_size_warning(ow, oh, w, h):
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True


68 changes: 68 additions & 0 deletions data/image_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################

import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)

return images


def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)
42 changes: 42 additions & 0 deletions data/single_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image


class SingleDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot)

self.A_paths = make_dataset(self.dir_A)

self.A_paths = sorted(self.A_paths)

self.transform = get_transform(opt)

def __getitem__(self, index):
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A = self.transform(A_img)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
else:
input_nc = self.opt.input_nc

if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)

return {'A': A, 'A_paths': A_path}

def __len__(self):
return len(self.A_paths)

def name(self):
return 'SingleImageDataset'
Loading

0 comments on commit a0dc5c9

Please sign in to comment.