-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ha0Tang
committed
Jul 20, 2019
1 parent
1af47ba
commit a0dc5c9
Showing
65 changed files
with
2,034 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import importlib | ||
import torch.utils.data | ||
from data.base_data_loader import BaseDataLoader | ||
from data.base_dataset import BaseDataset | ||
|
||
def find_dataset_using_name(dataset_name): | ||
# Given the option --dataset_mode [datasetname], | ||
# the file "data/datasetname_dataset.py" | ||
# will be imported. | ||
dataset_filename = "data." + dataset_name + "_dataset" | ||
datasetlib = importlib.import_module(dataset_filename) | ||
|
||
# In the file, the class called DatasetNameDataset() will | ||
# be instantiated. It has to be a subclass of BaseDataset, | ||
# and it is case-insensitive. | ||
dataset = None | ||
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | ||
for name, cls in datasetlib.__dict__.items(): | ||
if name.lower() == target_dataset_name.lower() \ | ||
and issubclass(cls, BaseDataset): | ||
dataset = cls | ||
|
||
if dataset is None: | ||
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | ||
exit(0) | ||
|
||
return dataset | ||
|
||
|
||
def get_option_setter(dataset_name): | ||
dataset_class = find_dataset_using_name(dataset_name) | ||
return dataset_class.modify_commandline_options | ||
|
||
|
||
def create_dataset(opt): | ||
dataset = find_dataset_using_name(opt.dataset_mode) | ||
instance = dataset() | ||
instance.initialize(opt) | ||
print("dataset [%s] was created" % (instance.name())) | ||
return instance | ||
|
||
|
||
def CreateDataLoader(opt): | ||
data_loader = CustomDatasetDataLoader() | ||
data_loader.initialize(opt) | ||
return data_loader | ||
|
||
|
||
## Wrapper class of Dataset class that performs | ||
## multi-threaded data loading | ||
class CustomDatasetDataLoader(BaseDataLoader): | ||
def name(self): | ||
return 'CustomDatasetDataLoader' | ||
|
||
def initialize(self, opt): | ||
BaseDataLoader.initialize(self, opt) | ||
self.dataset = create_dataset(opt) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batchSize, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.nThreads)) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
for i, data in enumerate(self.dataloader): | ||
if i * self.opt.batchSize >= self.opt.max_dataset_size: | ||
break | ||
yield data |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os.path | ||
import random | ||
import torchvision.transforms as transforms | ||
import torch | ||
from data.base_dataset import BaseDataset | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
|
||
|
||
class AlignedDataset(BaseDataset): | ||
@staticmethod | ||
def modify_commandline_options(parser, is_train): | ||
return parser | ||
|
||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
self.dir_AB = os.path.join(opt.dataroot, opt.phase) | ||
self.AB_paths = sorted(make_dataset(self.dir_AB)) | ||
assert(opt.resize_or_crop == 'resize_and_crop') | ||
|
||
def __getitem__(self, index): | ||
AB_path = self.AB_paths[index] | ||
ABCD = Image.open(AB_path).convert('RGB') | ||
w, h = ABCD.size | ||
w2 = int(w / 4) | ||
A = ABCD.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) | ||
B = ABCD.crop((w2, 0, w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) | ||
C = ABCD.crop((w2+w2, 0, w2+w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) | ||
D = ABCD.crop((w2+w2+w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) | ||
# A.show() | ||
# B.show() | ||
# C.show() | ||
# D.show() | ||
|
||
A = transforms.ToTensor()(A) | ||
B = transforms.ToTensor()(B) | ||
C = transforms.ToTensor()(C) | ||
D = transforms.ToTensor()(D) | ||
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) | ||
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) | ||
|
||
A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] | ||
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] | ||
C = C[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] | ||
D = D[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] | ||
|
||
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) | ||
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) | ||
C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) | ||
D = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(D) | ||
|
||
if self.opt.which_direction == 'BtoA': | ||
input_nc = self.opt.output_nc | ||
output_nc = self.opt.input_nc | ||
else: | ||
input_nc = self.opt.input_nc | ||
output_nc = self.opt.output_nc | ||
|
||
if (not self.opt.no_flip) and random.random() < 0.5: | ||
idx = [i for i in range(A.size(2) - 1, -1, -1)] | ||
idx = torch.LongTensor(idx) | ||
A = A.index_select(2, idx) | ||
B = B.index_select(2, idx) | ||
|
||
if input_nc == 1: # RGB to gray | ||
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 | ||
A = tmp.unsqueeze(0) | ||
|
||
if output_nc == 1: # RGB to gray | ||
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 | ||
B = tmp.unsqueeze(0) | ||
|
||
return {'A': A, 'B': B, 'C': C, 'D': D, | ||
'A_paths': AB_path, 'B_paths': AB_path} | ||
|
||
def __len__(self): | ||
return len(self.AB_paths) | ||
|
||
def name(self): | ||
return 'AlignedDataset' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class BaseDataLoader(): | ||
def __init__(self): | ||
pass | ||
|
||
def initialize(self, opt): | ||
self.opt = opt | ||
pass | ||
|
||
def load_data(): | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
|
||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
@staticmethod | ||
def modify_commandline_options(parser, is_train): | ||
return parser | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
||
def __len__(self): | ||
return 0 | ||
|
||
|
||
def get_transform(opt): | ||
transform_list = [] | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
osize = [opt.loadSize, opt.loadSize] | ||
transform_list.append(transforms.Resize(osize, Image.BICUBIC)) | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
elif opt.resize_or_crop == 'crop': | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
elif opt.resize_or_crop == 'scale_width': | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __scale_width(img, opt.fineSize))) | ||
elif opt.resize_or_crop == 'scale_width_and_crop': | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __scale_width(img, opt.loadSize))) | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
elif opt.resize_or_crop == 'none': | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __adjust(img))) | ||
else: | ||
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append(transforms.RandomHorizontalFlip()) | ||
|
||
transform_list += [transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
# just modify the width and height to be multiple of 4 | ||
def __adjust(img): | ||
ow, oh = img.size | ||
|
||
# the size needs to be a multiple of this number, | ||
# because going through generator network may change img size | ||
# and eventually cause size mismatch error | ||
mult = 4 | ||
if ow % mult == 0 and oh % mult == 0: | ||
return img | ||
w = (ow - 1) // mult | ||
w = (w + 1) * mult | ||
h = (oh - 1) // mult | ||
h = (h + 1) * mult | ||
|
||
if ow != w or oh != h: | ||
__print_size_warning(ow, oh, w, h) | ||
|
||
return img.resize((w, h), Image.BICUBIC) | ||
|
||
|
||
def __scale_width(img, target_width): | ||
ow, oh = img.size | ||
|
||
# the size needs to be a multiple of this number, | ||
# because going through generator network may change img size | ||
# and eventually cause size mismatch error | ||
mult = 4 | ||
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult | ||
if (ow == target_width and oh % mult == 0): | ||
return img | ||
w = target_width | ||
target_height = int(target_width * oh / ow) | ||
m = (target_height - 1) // mult | ||
h = (m + 1) * mult | ||
|
||
if target_height != h: | ||
__print_size_warning(target_width, target_height, w, h) | ||
|
||
return img.resize((w, h), Image.BICUBIC) | ||
|
||
|
||
def __print_size_warning(ow, oh, w, h): | ||
if not hasattr(__print_size_warning, 'has_printed'): | ||
print("The image size needs to be a multiple of 4. " | ||
"The loaded image size was (%d, %d), so it was adjusted to " | ||
"(%d, %d). This adjustment will be done to all images " | ||
"whose sizes are not multiples of 4" % (ow, oh, w, h)) | ||
__print_size_warning.has_printed = True | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
############################################################################### | ||
# Code from | ||
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py | ||
# Modified the original code so that it also loads images from the current | ||
# directory as well as the subdirectories | ||
############################################################################### | ||
|
||
import torch.utils.data as data | ||
|
||
from PIL import Image | ||
import os | ||
import os.path | ||
|
||
IMG_EXTENSIONS = [ | ||
'.jpg', '.JPG', '.jpeg', '.JPEG', | ||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', | ||
] | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | ||
|
||
|
||
def make_dataset(dir): | ||
images = [] | ||
assert os.path.isdir(dir), '%s is not a valid directory' % dir | ||
|
||
for root, _, fnames in sorted(os.walk(dir)): | ||
for fname in fnames: | ||
if is_image_file(fname): | ||
path = os.path.join(root, fname) | ||
images.append(path) | ||
|
||
return images | ||
|
||
|
||
def default_loader(path): | ||
return Image.open(path).convert('RGB') | ||
|
||
|
||
class ImageFolder(data.Dataset): | ||
|
||
def __init__(self, root, transform=None, return_paths=False, | ||
loader=default_loader): | ||
imgs = make_dataset(root) | ||
if len(imgs) == 0: | ||
raise(RuntimeError("Found 0 images in: " + root + "\n" | ||
"Supported image extensions are: " + | ||
",".join(IMG_EXTENSIONS))) | ||
|
||
self.root = root | ||
self.imgs = imgs | ||
self.transform = transform | ||
self.return_paths = return_paths | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
path = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.return_paths: | ||
return img, path | ||
else: | ||
return img | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os.path | ||
from data.base_dataset import BaseDataset, get_transform | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
|
||
|
||
class SingleDataset(BaseDataset): | ||
@staticmethod | ||
def modify_commandline_options(parser, is_train): | ||
return parser | ||
|
||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
self.dir_A = os.path.join(opt.dataroot) | ||
|
||
self.A_paths = make_dataset(self.dir_A) | ||
|
||
self.A_paths = sorted(self.A_paths) | ||
|
||
self.transform = get_transform(opt) | ||
|
||
def __getitem__(self, index): | ||
A_path = self.A_paths[index] | ||
A_img = Image.open(A_path).convert('RGB') | ||
A = self.transform(A_img) | ||
if self.opt.which_direction == 'BtoA': | ||
input_nc = self.opt.output_nc | ||
else: | ||
input_nc = self.opt.input_nc | ||
|
||
if input_nc == 1: # RGB to gray | ||
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 | ||
A = tmp.unsqueeze(0) | ||
|
||
return {'A': A, 'A_paths': A_path} | ||
|
||
def __len__(self): | ||
return len(self.A_paths) | ||
|
||
def name(self): | ||
return 'SingleImageDataset' |
Oops, something went wrong.