diff --git a/.DS_Store b/.DS_Store index 1de80bb..502330b 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/data/__init__.py b/data/__init__.py new file mode 100755 index 0000000..b44ca1f --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,74 @@ +import importlib +import torch.utils.data +from data.base_data_loader import BaseDataLoader +from data.base_dataset import BaseDataset + +def find_dataset_using_name(dataset_name): + # Given the option --dataset_mode [datasetname], + # the file "data/datasetname_dataset.py" + # will be imported. + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # In the file, the class called DatasetNameDataset() will + # be instantiated. It has to be a subclass of BaseDataset, + # and it is case-insensitive. + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + exit(0) + + return dataset + + +def get_option_setter(dataset_name): + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + dataset = find_dataset_using_name(opt.dataset_mode) + instance = dataset() + instance.initialize(opt) + print("dataset [%s] was created" % (instance.name())) + return instance + + +def CreateDataLoader(opt): + data_loader = CustomDatasetDataLoader() + data_loader.initialize(opt) + return data_loader + + +## Wrapper class of Dataset class that performs +## multi-threaded data loading +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = create_dataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i * self.opt.batchSize >= self.opt.max_dataset_size: + break + yield data diff --git a/data/__pycache__/__init__.cpython-36.pyc b/data/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..0057abd Binary files /dev/null and b/data/__pycache__/__init__.cpython-36.pyc differ diff --git a/data/__pycache__/__init__.cpython-37.pyc b/data/__pycache__/__init__.cpython-37.pyc new file mode 100755 index 0000000..02825d7 Binary files /dev/null and b/data/__pycache__/__init__.cpython-37.pyc differ diff --git a/data/__pycache__/aligned_dataset.cpython-36.pyc b/data/__pycache__/aligned_dataset.cpython-36.pyc new file mode 100644 index 0000000..9c99224 Binary files /dev/null and b/data/__pycache__/aligned_dataset.cpython-36.pyc differ diff --git a/data/__pycache__/aligned_dataset.cpython-37.pyc b/data/__pycache__/aligned_dataset.cpython-37.pyc new file mode 100755 index 0000000..41f6128 Binary files /dev/null and b/data/__pycache__/aligned_dataset.cpython-37.pyc differ diff --git a/data/__pycache__/base_data_loader.cpython-36.pyc b/data/__pycache__/base_data_loader.cpython-36.pyc new file mode 100644 index 0000000..ca97733 Binary files /dev/null and b/data/__pycache__/base_data_loader.cpython-36.pyc differ diff --git a/data/__pycache__/base_data_loader.cpython-37.pyc b/data/__pycache__/base_data_loader.cpython-37.pyc new file mode 100755 index 0000000..00a1ef4 Binary files /dev/null and b/data/__pycache__/base_data_loader.cpython-37.pyc differ diff --git a/data/__pycache__/base_dataset.cpython-36.pyc b/data/__pycache__/base_dataset.cpython-36.pyc new file mode 100644 index 0000000..6ad3388 Binary files /dev/null and b/data/__pycache__/base_dataset.cpython-36.pyc differ diff --git a/data/__pycache__/base_dataset.cpython-37.pyc b/data/__pycache__/base_dataset.cpython-37.pyc new file mode 100755 index 0000000..7127dfa Binary files /dev/null and b/data/__pycache__/base_dataset.cpython-37.pyc differ diff --git a/data/__pycache__/image_folder.cpython-36.pyc b/data/__pycache__/image_folder.cpython-36.pyc new file mode 100644 index 0000000..8f78433 Binary files /dev/null and b/data/__pycache__/image_folder.cpython-36.pyc differ diff --git a/data/__pycache__/image_folder.cpython-37.pyc b/data/__pycache__/image_folder.cpython-37.pyc new file mode 100755 index 0000000..28318dd Binary files /dev/null and b/data/__pycache__/image_folder.cpython-37.pyc differ diff --git a/data/__pycache__/single_dataset.cpython-36.pyc b/data/__pycache__/single_dataset.cpython-36.pyc new file mode 100755 index 0000000..1559f3c Binary files /dev/null and b/data/__pycache__/single_dataset.cpython-36.pyc differ diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100755 index 0000000..6e4c66e --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,81 @@ +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class AlignedDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + self.AB_paths = sorted(make_dataset(self.dir_AB)) + assert(opt.resize_or_crop == 'resize_and_crop') + + def __getitem__(self, index): + AB_path = self.AB_paths[index] + ABCD = Image.open(AB_path).convert('RGB') + w, h = ABCD.size + w2 = int(w / 4) + A = ABCD.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) + B = ABCD.crop((w2, 0, w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) + C = ABCD.crop((w2+w2, 0, w2+w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) + D = ABCD.crop((w2+w2+w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) + # A.show() + # B.show() + # C.show() + # D.show() + + A = transforms.ToTensor()(A) + B = transforms.ToTensor()(B) + C = transforms.ToTensor()(C) + D = transforms.ToTensor()(D) + w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) + + A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] + B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] + C = C[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] + D = D[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] + + A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) + B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) + C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) + D = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(D) + + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(2, idx) + B = B.index_select(2, idx) + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + + return {'A': A, 'B': B, 'C': C, 'D': D, + 'A_paths': AB_path, 'B_paths': AB_path} + + def __len__(self): + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100755 index 0000000..ae5a168 --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,10 @@ +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100755 index 0000000..1a82083 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,103 @@ +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms + + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + pass + + def __len__(self): + return 0 + + +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Resize(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'none': + transform_list.append(transforms.Lambda( + lambda img: __adjust(img))) + else: + raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +# just modify the width and height to be multiple of 4 +def __adjust(img): + ow, oh = img.size + + # the size needs to be a multiple of this number, + # because going through generator network may change img size + # and eventually cause size mismatch error + mult = 4 + if ow % mult == 0 and oh % mult == 0: + return img + w = (ow - 1) // mult + w = (w + 1) * mult + h = (oh - 1) // mult + h = (h + 1) * mult + + if ow != w or oh != h: + __print_size_warning(ow, oh, w, h) + + return img.resize((w, h), Image.BICUBIC) + + +def __scale_width(img, target_width): + ow, oh = img.size + + # the size needs to be a multiple of this number, + # because going through generator network may change img size + # and eventually cause size mismatch error + mult = 4 + assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult + if (ow == target_width and oh % mult == 0): + return img + w = target_width + target_height = int(target_width * oh / ow) + m = (target_height - 1) // mult + h = (m + 1) * mult + + if target_height != h: + __print_size_warning(target_width, target_height, w, h) + + return img.resize((w, h), Image.BICUBIC) + + +def __print_size_warning(ow, oh, w, h): + if not hasattr(__print_size_warning, 'has_printed'): + print("The image size needs to be a multiple of 4. " + "The loaded image size was (%d, %d), so it was adjusted to " + "(%d, %d). This adjustment will be done to all images " + "whose sizes are not multiples of 4" % (ow, oh, w, h)) + __print_size_warning.has_printed = True + + diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100755 index 0000000..898200b --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,68 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/data/single_dataset.py b/data/single_dataset.py new file mode 100755 index 0000000..c9f515b --- /dev/null +++ b/data/single_dataset.py @@ -0,0 +1,42 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class SingleDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot) + + self.A_paths = make_dataset(self.dir_A) + + self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index] + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SingleImageDataset' diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100755 index 0000000..06938b7 --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,62 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import random + + +class UnalignedDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + if self.opt.serial_batches: + index_B = index % self.B_size + else: + index_B = random.randint(0, self.B_size - 1) + B_path = self.B_paths[index_B] + # print('(A, B) = (%d, %d)' % (index_A, index_B)) + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + A = self.transform(A_img) + B = self.transform(B_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return max(self.A_size, self.B_size) + + def name(self): + return 'UnalignedDataset' diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000..4d92091 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,39 @@ +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of BaseModel, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + model = find_model_using_name(opt.model) + instance = model() + instance.initialize(opt) + print("model [%s] was created" % (instance.name())) + return instance diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..d510803 Binary files /dev/null and b/models/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc new file mode 100755 index 0000000..5472b96 Binary files /dev/null and b/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/__pycache__/base_model.cpython-36.pyc b/models/__pycache__/base_model.cpython-36.pyc new file mode 100644 index 0000000..5a18922 Binary files /dev/null and b/models/__pycache__/base_model.cpython-36.pyc differ diff --git a/models/__pycache__/base_model.cpython-37.pyc b/models/__pycache__/base_model.cpython-37.pyc new file mode 100755 index 0000000..d713f6b Binary files /dev/null and b/models/__pycache__/base_model.cpython-37.pyc differ diff --git a/models/__pycache__/cycle_gan_model.cpython-36.pyc b/models/__pycache__/cycle_gan_model.cpython-36.pyc new file mode 100755 index 0000000..3e899c9 Binary files /dev/null and b/models/__pycache__/cycle_gan_model.cpython-36.pyc differ diff --git a/models/__pycache__/networks.cpython-36.pyc b/models/__pycache__/networks.cpython-36.pyc new file mode 100644 index 0000000..b435875 Binary files /dev/null and b/models/__pycache__/networks.cpython-36.pyc differ diff --git a/models/__pycache__/networks.cpython-37.pyc b/models/__pycache__/networks.cpython-37.pyc new file mode 100755 index 0000000..1e77d38 Binary files /dev/null and b/models/__pycache__/networks.cpython-37.pyc differ diff --git a/models/__pycache__/pix2pix_model.cpython-36.pyc b/models/__pycache__/pix2pix_model.cpython-36.pyc new file mode 100644 index 0000000..866164b Binary files /dev/null and b/models/__pycache__/pix2pix_model.cpython-36.pyc differ diff --git a/models/__pycache__/pix2pix_model.cpython-37.pyc b/models/__pycache__/pix2pix_model.cpython-37.pyc new file mode 100755 index 0000000..4ae017c Binary files /dev/null and b/models/__pycache__/pix2pix_model.cpython-37.pyc differ diff --git a/models/__pycache__/test_model.cpython-36.pyc b/models/__pycache__/test_model.cpython-36.pyc new file mode 100755 index 0000000..492ced1 Binary files /dev/null and b/models/__pycache__/test_model.cpython-36.pyc differ diff --git a/models/base_model.py b/models/base_model.py new file mode 100755 index 0000000..461727c --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,159 @@ +import os +import torch +from collections import OrderedDict +from . import networks + + +class BaseModel(): + + # modify parser to add command line options, + # and also change the default values if needed + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + if opt.resize_or_crop != 'scale_width': + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.image_paths = [] + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # load and print networks; create schedulers + def setup(self, opt, parser=None): + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + self.load_networks(opt.which_epoch) + self.print_networks(opt.verbose) + + # make models eval mode during test time + def eval(self): + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + # used in test time, wrapping `forward` in no_grad() so we don't save + # intermediate steps for backprop + def test(self): + with torch.no_grad(): + self.forward() + + # get image paths + def get_image_paths(self): + return self.image_paths + + def optimize_parameters(self): + pass + + # update learning rate (called once every epoch) + def update_learning_rate(self): + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + # return visualization images. train.py will display these images, and save the images to a html + def get_current_visuals(self): + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + # return traning losses/errors. train.py will print out these errors as debugging information + def get_current_losses(self): + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + # float(...) works for both scalar tensor and float number + errors_ret[name] = float(getattr(self, 'loss_' + name)) + return errors_ret + + # save models to the disk + def save_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (which_epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + # load models from the disk + def load_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (which_epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + # print network information + def print_networks(self, verbose): + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + # set requies_grad=Fasle to avoid computation + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad \ No newline at end of file diff --git a/models/gesturegan_twocycle_model.py b/models/gesturegan_twocycle_model.py new file mode 100755 index 0000000..c8cd3ae --- /dev/null +++ b/models/gesturegan_twocycle_model.py @@ -0,0 +1,219 @@ +import torch +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import itertools + +class GestureGANTwoCycleModel(BaseModel): + def name(self): + return 'GestureGANTwoCycleModel' + + @staticmethod + def modify_commandline_options(parser, is_train=True): + + # changing the default values to match the pix2pix paper + # (https://phillipi.github.io/pix2pix/) + # parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') + parser.set_defaults(pool_size=0, no_lsgan=True, norm='instance') + parser.set_defaults(dataset_mode='aligned') + parser.set_defaults(which_model_netG='resnet_9blocks') + parser.add_argument('--REGULARIZATION', type=float, default=1e-6) + if is_train: + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + parser.add_argument('--cyc_L1', type=float, default=100.0, help='weight for L1 loss') + parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + parser.add_argument('--lambda_identity', type=float, default=5.0, help='weight for identity loss') + + return parser + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + # specify the training losses you want to print out. The program will call base_model.get_current_losses + # self.loss_names = ['G_GAN_D1', 'Gi_L1', 'G' , 'D1_real', 'D1_fake','D1'] + self.loss_names = ['G_GAN_D1', 'G_GAN_D2', 'G_L1', 'G_VGG', 'reg', 'G','D1','D2'] + # specify the images you want to save/display. The program will call base_model.get_current_visuals + self.visual_names = ['real_A', 'real_D', 'fake_B', 'real_B', 'real_C', 'recovery_A'] + # self.visual_names = ['fake_B', 'fake_D'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks + if self.isTrain: + self.model_names = ['Gi','D1','D2'] + else: # during test time, only load Gs + self.model_names = ['Gi'] + # load/define networks + self.netGi = networks.define_G(6, 3, opt.ngf, + opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD1 = networks.define_D(6, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) + self.netD2 = networks.define_D(9, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) + + + if self.isTrain: + self.fake_AB_pool = ImagePool(opt.pool_size) + + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + self.criterionVGG = networks.VGGLoss(self.gpu_ids) + + # initialize optimizers + self.optimizers = [] + # self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + # lr=opt.lr, betas=(opt.beta1, 0.999)) + # self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + # lr=opt.lr, betas=(opt.beta1, 0.999)) + + self.optimizer_G = torch.optim.Adam(self.netGi.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(),self.netD2.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.real_C = input['C'].to(self.device) + self.real_D = input['D'].to(self.device) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + combine_realA_realD=torch.cat((self.real_A, self.real_D), 1) + # combine_ACD=torch.cat((self.real_A, self.real_D), 1) + self.fake_B = self.netGi(combine_realA_realD) + combine_fakeB_realC=torch.cat((self.fake_B, self.real_C), 1) + self.recovery_A = self.netGi(combine_fakeB_realC) + + combine_realB_real_C=torch.cat((self.real_B, self.real_C), 1) + self.fake_A = self.netGi(combine_realB_real_C) + combine_fakeA_realD=torch.cat((self.fake_A, self.real_D), 1) + self.recovery_B = self.netGi(combine_fakeA_realD) + + + combine_realA_realC=torch.cat((self.real_A, self.real_C), 1) + self.identity_A = self.netGi(combine_realA_realC) + combine_realB_realD=torch.cat((self.real_B, self.real_D), 1) + self.identity_B = self.netGi(combine_realB_realD) + + def backward_D1(self): + # Fake + # stop backprop to the generator by detaching fake_B + realA_fakeB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + pred_D1_realA_fakeB = self.netD1(realA_fakeB.detach()) + self.loss_D1_realA_fakeB = self.criterionGAN(pred_D1_realA_fakeB, False) + + # Real + realA_realB = torch.cat((self.real_A, self.real_B), 1) + pred_D1_realA_realB = self.netD1(realA_realB) + self.loss_D1_realA_realB = self.criterionGAN(pred_D1_realA_realB, True) + + # Combined loss + self.loss_D1 = (self.loss_D1_realA_fakeB + self.loss_D1_realA_realB) * 0.5 + + + realB_fakeA = self.fake_AB_pool.query(torch.cat((self.real_B, self.fake_A), 1)) + pred_D1_realB_fakeA = self.netD1(realB_fakeA.detach()) + self.loss_D1_realB_fakeA = self.criterionGAN(pred_D1_realB_fakeA, False) + + # Combined loss + self.loss_D1 = (self.loss_D1_realB_fakeA + self.loss_D1_realA_realB) * 0.5 + self.loss_D1 + + self.loss_D1.backward() + + def backward_D2(self): + # Fake + # stop backprop to the generator by detaching fake_B + realA_realD_fakeB = self.fake_AB_pool.query(torch.cat((self.real_A, self.real_D, self.fake_B), 1)) + pred_D2_realA_realD_fakeB = self.netD2(realA_realD_fakeB.detach()) + self.loss_D2_realA_realD_fakeB = self.criterionGAN(pred_D2_realA_realD_fakeB, False) + + # Real + realA_realD_realB = torch.cat((self.real_A, self.real_D, self.real_B), 1) + pred_D2_realA_realD_realB = self.netD2(realA_realD_realB) + self.loss_D2_realA_realD_realB = self.criterionGAN(pred_D2_realA_realD_realB, True) + + # Combined loss + self.loss_D2 = (self.loss_D2_realA_realD_fakeB + self.loss_D2_realA_realD_realB) * 0.5 + + realB_realC_fakeA = self.fake_AB_pool.query(torch.cat((self.real_B, self.real_C, self.fake_A), 1)) + pred_D2_realB_realC_fakeA = self.netD2(realB_realC_fakeA.detach()) + self.loss_D2_realB_realC_fakeA = self.criterionGAN(pred_D2_realB_realC_fakeA, False) + + # Real + realB_realC_realA = torch.cat((self.real_B, self.real_C, self.real_A), 1) + pred_D2_realB_realC_realA = self.netD2(realB_realC_realA) + self.loss_D2_realB_realC_realA = self.criterionGAN(pred_D2_realB_realC_realA, True) + + # Combined loss + self.loss_D2 = (self.loss_D2_realB_realC_fakeA + self.loss_D2_realB_realC_realA) * 0.5 + self.loss_D2 + + self.loss_D2.backward() + + + def backward_G(self): + # First, G(A) should fake the discriminator + realA_fakeB = torch.cat((self.real_A, self.fake_B), 1) + pred_D1_realA_fakeB = self.netD1(realA_fakeB) + self.loss_G_GAN_D1 = self.criterionGAN(pred_D1_realA_fakeB, True) + + realB_fakeA = torch.cat((self.real_B, self.fake_A), 1) + pred_D1_realB_fakeA = self.netD1(realB_fakeA) + self.loss_G_GAN_D1 = self.criterionGAN(pred_D1_realB_fakeA, True) + self.loss_G_GAN_D1 + + realA_realD_fakeB = torch.cat((self.real_A, self.real_D, self.fake_B), 1) + pred_D2_realA_realD_fakeB = self.netD2(realA_realD_fakeB) + self.loss_G_GAN_D2 = self.criterionGAN(pred_D2_realA_realD_fakeB, True) + + realB_realC_fakeA = torch.cat((self.real_B, self.real_C, self.fake_A), 1) + pred_D2_realB_realC_fakeA = self.netD2(realB_realC_fakeA) + self.loss_G_GAN_D2 = self.criterionGAN(pred_D2_realB_realC_fakeA, True) + self.loss_G_GAN_D2 + + self.fake_B_red = self.fake_B[:,0:1,:,:] + self.fake_B_green = self.fake_B[:,1:2,:,:] + self.fake_B_blue = self.fake_B[:,2:3,:,:] + # print(self.fake_A_red.size()) + self.real_B_red = self.real_B[:,0:1,:,:] + self.real_B_green = self.real_B[:,1:2,:,:] + self.real_B_blue = self.real_B[:,2:3,:,:] + + self.fake_A_red = self.fake_A[:,0:1,:,:] + self.fake_A_green = self.fake_A[:,1:2,:,:] + self.fake_A_blue = self.fake_A[:,2:3,:,:] + # print(self.fake_A_red.size()) + self.real_A_red = self.real_A[:,0:1,:,:] + self.real_A_green = self.real_A[:,1:2,:,:] + self.real_A_blue = self.real_A[:,2:3,:,:] + + # second, G(A)=B + self.loss_G_L1 = (self.criterionL1(self.fake_B_red, self.real_B_red) + self.criterionL1(self.fake_B_green, self.real_B_green) + self.criterionL1(self.fake_B_blue, self.real_B_blue)) * self.opt.lambda_L1 + self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + self.criterionL1(self.recovery_A, self.real_A) * self.opt.cyc_L1 + self.criterionL1(self.identity_A, self.real_A) * self.opt.lambda_identity + (self.criterionL1(self.fake_A_red, self.real_A_red) + self.criterionL1(self.fake_A_green, self.real_A_green) + self.criterionL1(self.fake_A_blue, self.real_A_blue)) * self.opt.lambda_L1 + self.criterionL1(self.fake_A, self.real_A) * self.opt.lambda_L1 + self.criterionL1(self.recovery_B, self.real_B) * self.opt.cyc_L1 + self.criterionL1(self.identity_B, self.real_B) * self.opt.lambda_identity + + self.loss_G_VGG = self.criterionVGG(self.fake_B, self.real_B) * self.opt.lambda_feat + self.criterionVGG(self.fake_A, self.real_A) * self.opt.lambda_feat + + self.loss_reg = self.opt.REGULARIZATION * (torch.sum(torch.abs(self.fake_B[:, :, :, :-1] - self.fake_B[:, :, :, 1:])) + torch.sum(torch.abs(self.fake_B[:, :, :-1, :] - self.fake_B[:, :, 1:, :]))) + self.opt.REGULARIZATION * (torch.sum(torch.abs(self.fake_A[:, :, :, :-1] - self.fake_A[:, :, :, 1:])) + torch.sum(torch.abs(self.fake_A[:, :, :-1, :] - self.fake_A[:, :, 1:, :]))) + + self.loss_G = self.loss_G_GAN_D1 + self.loss_G_GAN_D2 + self.loss_G_L1 + self.loss_G_VGG + self.loss_reg + + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + # update D + self.set_requires_grad([self.netD1, self.netD2], True) + self.optimizer_D.zero_grad() + self.backward_D1() + self.backward_D2() + self.optimizer_D.step() + + # update G + self.set_requires_grad([self.netD1, self.netD2], False) + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() diff --git a/models/networks.py b/models/networks.py new file mode 100755 index 0000000..ad8bc06 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + +############################################################################### +# Helper Functions +############################################################################### + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + init_weights(net, init_type, gain=init_gain) + return net + +class VGGLoss(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + netG = None + norm_layer = get_norm_layer(norm_type=norm) + + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif which_model_netG == 'unet_128': + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif which_model_netG == 'unet_256': + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + return init_net(netG, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + netD = None + norm_layer = get_norm_layer(norm_type=norm) + + if which_model_netD == 'basic': + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) + return init_net(netD, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(input) + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, + bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetGenerator, self).__init__() + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + return self.model(input) + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + return self.model(input) + + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + return self.net(input) + +from torchvision import models +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out \ No newline at end of file diff --git a/models/test_model.py b/models/test_model.py new file mode 100755 index 0000000..4d73f70 --- /dev/null +++ b/models/test_model.py @@ -0,0 +1,46 @@ +from .base_model import BaseModel +from . import networks +from .cycle_gan_model import CycleGANModel + + +class TestModel(BaseModel): + def name(self): + return 'TestModel' + + @staticmethod + def modify_commandline_options(parser, is_train=True): + assert not is_train, 'TestModel cannot be used in train mode' + parser = CycleGANModel.modify_commandline_options(parser, is_train=False) + parser.set_defaults(dataset_mode='single') + + parser.add_argument('--model_suffix', type=str, default='', + help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will' + ' be loaded as the generator of TestModel') + + return parser + + def initialize(self, opt): + assert(not opt.isTrain) + BaseModel.initialize(self, opt) + + # specify the training losses you want to print out. The program will call base_model.get_current_losses + self.loss_names = [] + # specify the images you want to save/display. The program will call base_model.get_current_visuals + self.visual_names = ['real_A', 'fake_B'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks + self.model_names = ['G' + opt.model_suffix] + + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + # assigns the model to self.netG_[suffix] so that it can be loaded + # please see BaseModel.load_networks + setattr(self, 'netG' + opt.model_suffix, self.netG) + + def set_input(self, input): + # we need to use single_dataset mode + self.real_A = input['A'].to(self.device) + self.image_paths = input['A_paths'] + + def forward(self): + self.fake_B = self.netG(self.real_A) \ No newline at end of file diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/options/__pycache__/__init__.cpython-36.pyc b/options/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..8169d49 Binary files /dev/null and b/options/__pycache__/__init__.cpython-36.pyc differ diff --git a/options/__pycache__/__init__.cpython-37.pyc b/options/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..d9bd8e7 Binary files /dev/null and b/options/__pycache__/__init__.cpython-37.pyc differ diff --git a/options/__pycache__/base_options.cpython-36.pyc b/options/__pycache__/base_options.cpython-36.pyc new file mode 100644 index 0000000..1771c6e Binary files /dev/null and b/options/__pycache__/base_options.cpython-36.pyc differ diff --git a/options/__pycache__/base_options.cpython-37.pyc b/options/__pycache__/base_options.cpython-37.pyc new file mode 100644 index 0000000..1340ef6 Binary files /dev/null and b/options/__pycache__/base_options.cpython-37.pyc differ diff --git a/options/__pycache__/test_options.cpython-36.pyc b/options/__pycache__/test_options.cpython-36.pyc new file mode 100644 index 0000000..3c252db Binary files /dev/null and b/options/__pycache__/test_options.cpython-36.pyc differ diff --git a/options/__pycache__/test_options.cpython-37.pyc b/options/__pycache__/test_options.cpython-37.pyc new file mode 100644 index 0000000..1cad787 Binary files /dev/null and b/options/__pycache__/test_options.cpython-37.pyc differ diff --git a/options/__pycache__/train_options.cpython-36.pyc b/options/__pycache__/train_options.cpython-36.pyc new file mode 100644 index 0000000..3ba3138 Binary files /dev/null and b/options/__pycache__/train_options.cpython-36.pyc differ diff --git a/options/__pycache__/train_options.cpython-37.pyc b/options/__pycache__/train_options.cpython-37.pyc new file mode 100644 index 0000000..68185f1 Binary files /dev/null and b/options/__pycache__/train_options.cpython-37.pyc differ diff --git a/options/base_options.py b/options/base_options.py new file mode 100755 index 0000000..3e8fc64 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,119 @@ +import argparse +import os +from util import util +import torch +import models +import data + + +class BaseOptions(): + def __init__(self): + self.initialized = False + + def initialize(self, parser): + parser.add_argument('--dataroot', type=str, default='/home/csdept/projects/pytorch-CycleGAN-and-pix2pix_sg2/datasets/dayton', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--batchSize', type=int, default=4, help='input batch size') + parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') + parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') + parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') + parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') + parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG') + parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--name', type=str, default='ego2top_x_sg2_1', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') + parser.add_argument('--model', type=str, default='pix2pix', + help='chooses which model to use. cycle_gan, pix2pix, test') + parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size') + parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') + parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') + parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with the new defaults + + # modify dataset-related parser options + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + self.parser = parser + + return parser.parse_args() + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100755 index 0000000..e114ec7 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,21 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--how_many', type=int, default=10000000, help='how many test images to run') + + parser.set_defaults(model='pix2pix') + # To avoid cropping, the loadSize should be the same as fineSize + parser.set_defaults(loadSize=parser.get_default('fineSize')) + + self.isTrain = False + return parser diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000..fdfb4d6 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,28 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--update_html_freq', type=int, default=100, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') + parser.add_argument('--niter_decay', type=int, default=15, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + + self.isTrain = True + return parser diff --git a/test.py b/test.py new file mode 100755 index 0000000..3df0965 --- /dev/null +++ b/test.py @@ -0,0 +1,41 @@ +import os +from options.test_options import TestOptions +from data import CreateDataLoader +from models import create_model +from util.visualizer import save_images +from util import html + + +if __name__ == '__main__': + opt = TestOptions().parse() + opt.nThreads = 1 # test code only supports nThreads = 1 + opt.batchSize = 1 # test code only supports batchSize = 1 + opt.serial_batches = True # no shuffle + opt.no_flip = True # no flip + opt.display_id = -1 # no visdom display + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + model = create_model(opt) + model.setup(opt) + # create website + web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) + # test + + # Set eval mode. + # This only affects layers like batch norm and drop out. But we do use batch norm in pix2pix. + if opt.eval: + model.eval() + + for i, data in enumerate(dataset): + if i >= opt.how_many: + break + model.set_input(data) + model.test() + visuals = model.get_current_visuals() + img_path = model.get_image_paths() + if i % 5 == 0: + print('processing (%04d)-th image... %s' % (i, img_path)) + save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) + + webpage.save() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100755 index 0000000..0877a35 --- /dev/null +++ b/train.py @@ -0,0 +1,59 @@ +import time +from options.train_options import TrainOptions +from data import CreateDataLoader +from models import create_model +from util.visualizer import Visualizer + +if __name__ == '__main__': + opt = TrainOptions().parse() + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + dataset_size = len(data_loader) + print('#training images = %d' % dataset_size) + + model = create_model(opt) + model.setup(opt) + visualizer = Visualizer(opt) + total_steps = 0 + + for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + iter_data_time = time.time() + epoch_iter = 0 + + for i, data in enumerate(dataset): + iter_start_time = time.time() + if total_steps % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time + visualizer.reset() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + model.set_input(data) + model.optimize_parameters() + + if total_steps % opt.display_freq == 0: + save_result = total_steps % opt.update_html_freq == 0 + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + + if total_steps % opt.print_freq == 0: + losses = model.get_current_losses() + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) + if opt.display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) + + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, total_steps)) + model.save_networks('latest') + + iter_data_time = time.time() + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % + (epoch, total_steps)) + model.save_networks('latest') + model.save_networks(epoch) + + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + model.update_learning_rate() diff --git a/util/__init__.py b/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/util/__pycache__/__init__.cpython-36.pyc b/util/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..4deaeba Binary files /dev/null and b/util/__pycache__/__init__.cpython-36.pyc differ diff --git a/util/__pycache__/__init__.cpython-37.pyc b/util/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..c32e50c Binary files /dev/null and b/util/__pycache__/__init__.cpython-37.pyc differ diff --git a/util/__pycache__/html.cpython-36.pyc b/util/__pycache__/html.cpython-36.pyc new file mode 100644 index 0000000..fc2abd9 Binary files /dev/null and b/util/__pycache__/html.cpython-36.pyc differ diff --git a/util/__pycache__/html.cpython-37.pyc b/util/__pycache__/html.cpython-37.pyc new file mode 100755 index 0000000..0cc8707 Binary files /dev/null and b/util/__pycache__/html.cpython-37.pyc differ diff --git a/util/__pycache__/image_pool.cpython-36.pyc b/util/__pycache__/image_pool.cpython-36.pyc new file mode 100644 index 0000000..20a6e57 Binary files /dev/null and b/util/__pycache__/image_pool.cpython-36.pyc differ diff --git a/util/__pycache__/image_pool.cpython-37.pyc b/util/__pycache__/image_pool.cpython-37.pyc new file mode 100755 index 0000000..bf96847 Binary files /dev/null and b/util/__pycache__/image_pool.cpython-37.pyc differ diff --git a/util/__pycache__/util.cpython-36.pyc b/util/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000..6880cdf Binary files /dev/null and b/util/__pycache__/util.cpython-36.pyc differ diff --git a/util/__pycache__/util.cpython-37.pyc b/util/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000..8c1cfe8 Binary files /dev/null and b/util/__pycache__/util.cpython-37.pyc differ diff --git a/util/__pycache__/visualizer.cpython-36.pyc b/util/__pycache__/visualizer.cpython-36.pyc new file mode 100644 index 0000000..d32cc87 Binary files /dev/null and b/util/__pycache__/visualizer.cpython-36.pyc differ diff --git a/util/__pycache__/visualizer.cpython-37.pyc b/util/__pycache__/visualizer.cpython-37.pyc new file mode 100755 index 0000000..9fcd740 Binary files /dev/null and b/util/__pycache__/visualizer.cpython-37.pyc differ diff --git a/util/get_data.py b/util/get_data.py new file mode 100755 index 0000000..6325605 --- /dev/null +++ b/util/get_data.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """ + + Download CycleGAN or Pix2Pix Data. + + Args: + technique : str + One of: 'cyclegan' or 'pix2pix'. + verbose : bool + If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Args: + save_path : str + A directory to save the data to. + dataset : str, optional + A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full : str + The absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/util/html.py b/util/html.py new file mode 100755 index 0000000..c7956f1 --- /dev/null +++ b/util/html.py @@ -0,0 +1,64 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100755 index 0000000..52413e0 --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import torch + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = torch.cat(return_images, 0) + return return_images diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..ba7b083 --- /dev/null +++ b/util/util.py @@ -0,0 +1,60 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +# Converts a Tensor into an image array (numpy) +# |imtype|: the desired type of the converted numpy array +def tensor2im(input_image, imtype=np.uint8): + if isinstance(input_image, torch.Tensor): + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100755 index 0000000..69cf926 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,163 @@ +import numpy as np +import os +import ntpath +import time +from . import util +from . import html +from scipy.misc import imresize + + +# save image to the disk +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + h, w, _ = im.shape + if aspect_ratio > 1.0: + im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') + if aspect_ratio < 1.0: + im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + util.save_image(im, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + def __init__(self, opt): + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.opt = opt + self.saved = False + if self.display_id > 0: + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True, use_incoming_socket=False) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + self.saved = False + + def throw_visdom_connection_error(self): + print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') + exit(1) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser + ncols = self.ncols + if ncols > 0: + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + # pane col = image row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except ConnectionError: + self.throw_visdom_connection_error() + + else: + idx = 1 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + # losses: dictionary of error labels and values + def plot_current_losses(self, epoch, counter_ratio, opt, losses): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except ConnectionError: + self.throw_visdom_connection_error() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, i, losses, t, t_data): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message)