-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c1651d8
commit a5bfd37
Showing
6 changed files
with
549 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,28 @@ | ||
# SpatialAttentionGAN | ||
PyTorch implementation of "Generative Adversarial Network with Spatial Attention for Face Attribute Editing" | ||
# Spatial Attention Generative Adversarial Network | ||
|
||
This repository contains the PyTorch implementation of the ECCV 2018 paper "Generative Adversarial Network with Spatial Attention for Face Attribute Editing" ([pdf](http://openaccess.thecvf.com/content_ECCV_2018/papers/Gang_Zhang_Generative_Adversarial_Network_ECCV_2018_paper.pdf)). | ||
|
||
## Requirements | ||
|
||
* Python 3.5 | ||
* PyTorch 1.0.0 | ||
|
||
```bash | ||
pip3 install -r requirements.txt | ||
``` | ||
|
||
The training procedure takes 5.5GB memory on a single GPU. | ||
|
||
## Usage | ||
|
||
Train a model with a target attribute | ||
|
||
```bash | ||
python3 train.py --experiment-name celeba_128_eyeglasses --target-attr Eyeglasses --gpu | ||
``` | ||
|
||
Generate images from trained models | ||
|
||
```bash | ||
python3 generate.py --experiment-name celeba_128_eyeglasses --gpu | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""Custom datasets for CelebA and CelebA-HQ.""" | ||
|
||
import numpy as np | ||
import os | ||
from skimage import io | ||
|
||
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
|
||
|
||
class CelebA(data.Dataset): | ||
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs): | ||
super(CelebA, self).__init__() | ||
self.data_path = data_path | ||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split() | ||
atts = [att_list.index(att) + 1 for att in selected_attrs] | ||
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str) | ||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int) | ||
if len(labels.shape) == 1: | ||
labels = labels[:, None] | ||
|
||
if mode == 'train': | ||
self.images = images[:200000] | ||
self.labels = labels[:200000] | ||
if mode == 'test': | ||
self.images = images[200000:] | ||
self.labels = labels[200000:] | ||
|
||
self.tf = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.CenterCrop(170), | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
self.length = len(self.images) | ||
def __getitem__(self, index): | ||
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index]))) | ||
att = torch.tensor((self.labels[index] + 1) // 2) | ||
return img, att | ||
def __len__(self): | ||
return self.length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""Generate images from trained models""" | ||
|
||
import argparse | ||
import json | ||
import os | ||
from os import listdir | ||
from os.path import join | ||
from tqdm import tqdm | ||
|
||
import torch | ||
import torch.utils.data as data | ||
import torchvision.utils as vutils | ||
|
||
from data import CelebA | ||
from sagan import Generator | ||
|
||
|
||
def parse(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--data-path', type=str, default=argparse.SUPPRESS) | ||
parser.add_argument('--attr-path', type=str, default=argparse.SUPPRESS) | ||
parser.add_argument('--batch-size', type=int, default=argparse.SUPPRESS) | ||
parser.add_argument('--test-nimg', type=int, default=None) | ||
parser.add_argument('--experiment-name', type=str, required=True) | ||
parser.add_argument('--gpu', action='store_true') | ||
return parser.parse_args() | ||
|
||
if __name__ == '__main__': | ||
# Arguments | ||
args = parse() | ||
print(args) | ||
|
||
# Load training setting | ||
with open(join('results', args.experiment_name, 'setting.json'), 'r', encoding='utf-8') as f: | ||
setting = json.load(f) | ||
for key, value in vars(args).items(): | ||
setting[key] = value | ||
args = argparse.Namespace(**setting) | ||
print(args) | ||
|
||
# Device | ||
device = torch.device('cuda') if args.gpu and torch.cuda.is_available() else torch.device('cpu') | ||
|
||
# Paths | ||
checkpoint_path = join('results', args.experiment_name, 'checkpoint') | ||
test_path = join('results', args.experiment_name, 'test') | ||
os.makedirs(test_path, exist_ok=True) | ||
|
||
# Data | ||
selected_attrs = [args.target_attr] | ||
test_dset = CelebA(args.data_path, args.attr_path, args.image_size, 'test', selected_attrs) | ||
test_data = data.DataLoader(test_dset, args.batch_size) | ||
|
||
# Model | ||
G = Generator(3) | ||
G.to(device) | ||
|
||
# Load from checkpoints | ||
load_nimg = args.test_nimg | ||
if load_nimg is None: # Use the lastest model | ||
load_nimg = max(int(path.split('.')[0]) for path in listdir(join(checkpoint_path)) if path.split('.')[0].isdigit()) | ||
print('Loading generator from nimg {:06d}'.format(load_nimg)) | ||
G.load_state_dict(torch.load( | ||
join(checkpoint_path, '{:d}.G.pth'.format(load_nimg)), | ||
map_location=lambda storage, loc: storage | ||
)) | ||
|
||
G.eval() | ||
for batch_idx, (reals, labels) in enumerate(tqdm(test_data)): | ||
reals, labels = reals.to(device), labels.to(device).type(reals.dtype) | ||
target_labels = 1 - labels | ||
with torch.no_grad(): | ||
# Modify images | ||
samples, masks = G(reals, target_labels) | ||
|
||
# Put images together | ||
masks = masks.repeat(1, 3, 1, 1) * 2 - 1 | ||
images_out = torch.stack((reals, samples, masks)) # 3, N, 3, S, S | ||
images_out = images_out.transpose(0, 1) # N, 3, 3, S, S | ||
|
||
# Save images separately | ||
for idx, image_out in enumerate(images_out): | ||
vutils.save_image( | ||
image_out, | ||
join(test_path, '{:06d}.jpg'.format(batch_idx*args.batch_size+idx+200000)), | ||
nrow=3, | ||
normalize=True, | ||
range=(-1.,1.) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
tensorboardX | ||
torch | ||
torchvision | ||
torchsummary | ||
tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""Models of Spatial Attention Generative Adversarial Network""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def get_norm(name, nc): | ||
if name == 'batchnorm': | ||
return nn.BatchNorm2d(nc) | ||
if name == 'instancenorm': | ||
return nn.InstanceNorm2d(nc) | ||
raise ValueError('Unsupported normalization layer: {:s}'.format(name)) | ||
|
||
def get_nonlinear(name): | ||
if name == 'relu': | ||
return nn.ReLU(inplace=True) | ||
if name == 'lrelu': | ||
return nn.LeakyReLU(inplace=True) | ||
if name == 'sigmoid': | ||
return nn.Sigmoid() | ||
if name == 'tanh': | ||
return nn.Tanh() | ||
raise ValueError('Unsupported activation layer: {:s}'.format(name)) | ||
|
||
class ResBlk(nn.Module): | ||
def __init__(self, n_in, n_out): | ||
super(ResBlk, self).__init__() | ||
self.layers = nn.Sequential( | ||
nn.Conv2d(n_in, n_out, 3, 1, 1), | ||
get_norm('batchnorm', n_out), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(n_out, n_out, 3, 1, 1), | ||
get_norm('batchnorm', n_out), | ||
) | ||
|
||
def forward(self, x): | ||
return self.layers(x) | ||
|
||
class _Generator(nn.Module): | ||
def __init__(self, input_channels, output_channels, last_nonlinear): | ||
super(_Generator, self).__init__() | ||
self.conv = nn.Sequential( | ||
nn.Conv2d(input_channels, 32, 7, 1, 3), # n_in, n_out, kernel_size, stride, padding | ||
get_norm('instancenorm', 32), | ||
get_nonlinear('relu'), | ||
nn.Conv2d(32, 64, 4, 2, 1), | ||
get_norm('instancenorm', 64), | ||
get_nonlinear('relu'), | ||
nn.Conv2d(64, 128, 4, 2, 1), | ||
get_norm('instancenorm', 128), | ||
get_nonlinear('relu'), | ||
nn.Conv2d(128, 256, 4, 2, 1), | ||
get_norm('instancenorm', 256), | ||
get_nonlinear('relu'), | ||
) | ||
self.resblk = nn.Sequential( | ||
ResBlk(256, 256), | ||
ResBlk(256, 256), | ||
ResBlk(256, 256), | ||
ResBlk(256, 256), | ||
) | ||
self.deconv = nn.Sequential( | ||
nn.ConvTranspose2d(256, 128, 4, 2, 1), | ||
get_norm('instancenorm', 128), | ||
get_nonlinear('relu'), | ||
nn.ConvTranspose2d(128, 64, 4, 2, 1), | ||
get_norm('instancenorm', 64), | ||
get_nonlinear('relu'), | ||
nn.ConvTranspose2d(64, 32, 4, 2, 1), | ||
get_norm('instancenorm', 32), | ||
get_nonlinear('relu'), | ||
nn.ConvTranspose2d(32, output_channels, 7, 1, 3), | ||
get_nonlinear(last_nonlinear), | ||
) | ||
|
||
def forward(self, x, a=None): | ||
if a is not None: | ||
assert len(a.size()) == 2 and x.size(0) == a.size(0) | ||
a = a.type(x.dtype) | ||
a = a.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3)) | ||
x = torch.cat((x, a), dim=1) | ||
h = self.conv(x) | ||
h = self.resblk(h) | ||
y = self.deconv(h) | ||
return y | ||
|
||
class Generator(nn.Module): | ||
def __init__(self, input_channels): | ||
super(Generator, self).__init__() | ||
self.AMN = _Generator(input_channels + 1, input_channels, 'tanh') | ||
self.SAN = _Generator(input_channels, 1, 'sigmoid') | ||
def forward(self, x, a): | ||
y = self.AMN(x, a) | ||
m = self.SAN(x) | ||
y_ = y * m + x * (1-m) | ||
return y_, m | ||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, input_channels): | ||
super(Discriminator, self).__init__() | ||
self.conv = nn.Sequential( | ||
nn.Conv2d(input_channels, 32, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
nn.Conv2d(32, 64, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
nn.Conv2d(64, 128, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
nn.Conv2d(128, 256, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
nn.Conv2d(256, 512, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
nn.Conv2d(512, 1024, 4, 2, 1), | ||
get_nonlinear('lrelu'), | ||
) | ||
self.src = nn.Conv2d(1024, 1, 3, 1, 1) | ||
self.cls = nn.Sequential( | ||
nn.Conv2d(1024, 1, 2, 1, 0), | ||
get_nonlinear('sigmoid'), | ||
) | ||
|
||
def forward(self, x): | ||
h = self.conv(x) | ||
return self.src(h), self.cls(h).squeeze().unsqueeze(1) | ||
|
||
if __name__ == '__main__': | ||
from torchsummary import summary | ||
AMN = Generator(4, 3, 'tanh') | ||
summary(AMN, (4, 128, 128), device='cpu') | ||
SAN = Generator(3, 1, 'sigmoid') | ||
summary(SAN, (3, 128, 128), device='cpu') | ||
D = Discriminator(3) | ||
summary(D, (3, 128, 128), device='cpu') |
Oops, something went wrong.