Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisyjlin committed Mar 19, 2019
1 parent c1651d8 commit a5bfd37
Show file tree
Hide file tree
Showing 6 changed files with 549 additions and 2 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
# SpatialAttentionGAN
PyTorch implementation of "Generative Adversarial Network with Spatial Attention for Face Attribute Editing"
# Spatial Attention Generative Adversarial Network

This repository contains the PyTorch implementation of the ECCV 2018 paper "Generative Adversarial Network with Spatial Attention for Face Attribute Editing" ([pdf](http://openaccess.thecvf.com/content_ECCV_2018/papers/Gang_Zhang_Generative_Adversarial_Network_ECCV_2018_paper.pdf)).

## Requirements

* Python 3.5
* PyTorch 1.0.0

```bash
pip3 install -r requirements.txt
```

The training procedure takes 5.5GB memory on a single GPU.

## Usage

Train a model with a target attribute

```bash
python3 train.py --experiment-name celeba_128_eyeglasses --target-attr Eyeglasses --gpu
```

Generate images from trained models

```bash
python3 generate.py --experiment-name celeba_128_eyeglasses --gpu
```
49 changes: 49 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]>
#
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""Custom datasets for CelebA and CelebA-HQ."""

import numpy as np
import os
from skimage import io

import torch
import torch.utils.data as data
import torchvision.transforms as transforms


class CelebA(data.Dataset):
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs):
super(CelebA, self).__init__()
self.data_path = data_path
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
if len(labels.shape) == 1:
labels = labels[:, None]

if mode == 'train':
self.images = images[:200000]
self.labels = labels[:200000]
if mode == 'test':
self.images = images[200000:]
self.labels = labels[200000:]

self.tf = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop(170),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

self.length = len(self.images)
def __getitem__(self, index):
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index])))
att = torch.tensor((self.labels[index] + 1) // 2)
return img, att
def __len__(self):
return self.length
94 changes: 94 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]>
#
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""Generate images from trained models"""

import argparse
import json
import os
from os import listdir
from os.path import join
from tqdm import tqdm

import torch
import torch.utils.data as data
import torchvision.utils as vutils

from data import CelebA
from sagan import Generator


def parse():
parser = argparse.ArgumentParser()
parser.add_argument('--data-path', type=str, default=argparse.SUPPRESS)
parser.add_argument('--attr-path', type=str, default=argparse.SUPPRESS)
parser.add_argument('--batch-size', type=int, default=argparse.SUPPRESS)
parser.add_argument('--test-nimg', type=int, default=None)
parser.add_argument('--experiment-name', type=str, required=True)
parser.add_argument('--gpu', action='store_true')
return parser.parse_args()

if __name__ == '__main__':
# Arguments
args = parse()
print(args)

# Load training setting
with open(join('results', args.experiment_name, 'setting.json'), 'r', encoding='utf-8') as f:
setting = json.load(f)
for key, value in vars(args).items():
setting[key] = value
args = argparse.Namespace(**setting)
print(args)

# Device
device = torch.device('cuda') if args.gpu and torch.cuda.is_available() else torch.device('cpu')

# Paths
checkpoint_path = join('results', args.experiment_name, 'checkpoint')
test_path = join('results', args.experiment_name, 'test')
os.makedirs(test_path, exist_ok=True)

# Data
selected_attrs = [args.target_attr]
test_dset = CelebA(args.data_path, args.attr_path, args.image_size, 'test', selected_attrs)
test_data = data.DataLoader(test_dset, args.batch_size)

# Model
G = Generator(3)
G.to(device)

# Load from checkpoints
load_nimg = args.test_nimg
if load_nimg is None: # Use the lastest model
load_nimg = max(int(path.split('.')[0]) for path in listdir(join(checkpoint_path)) if path.split('.')[0].isdigit())
print('Loading generator from nimg {:06d}'.format(load_nimg))
G.load_state_dict(torch.load(
join(checkpoint_path, '{:d}.G.pth'.format(load_nimg)),
map_location=lambda storage, loc: storage
))

G.eval()
for batch_idx, (reals, labels) in enumerate(tqdm(test_data)):
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
target_labels = 1 - labels
with torch.no_grad():
# Modify images
samples, masks = G(reals, target_labels)

# Put images together
masks = masks.repeat(1, 3, 1, 1) * 2 - 1
images_out = torch.stack((reals, samples, masks)) # 3, N, 3, S, S
images_out = images_out.transpose(0, 1) # N, 3, 3, S, S

# Save images separately
for idx, image_out in enumerate(images_out):
vutils.save_image(
image_out,
join(test_path, '{:06d}.jpg'.format(batch_idx*args.batch_size+idx+200000)),
nrow=3,
normalize=True,
range=(-1.,1.)
)
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
tensorboardX
torch
torchvision
torchsummary
tqdm
137 changes: 137 additions & 0 deletions sagan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2019 Elvis Yu-Jing Lin <[email protected]>
#
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""Models of Spatial Attention Generative Adversarial Network"""

import torch
import torch.nn as nn


def get_norm(name, nc):
if name == 'batchnorm':
return nn.BatchNorm2d(nc)
if name == 'instancenorm':
return nn.InstanceNorm2d(nc)
raise ValueError('Unsupported normalization layer: {:s}'.format(name))

def get_nonlinear(name):
if name == 'relu':
return nn.ReLU(inplace=True)
if name == 'lrelu':
return nn.LeakyReLU(inplace=True)
if name == 'sigmoid':
return nn.Sigmoid()
if name == 'tanh':
return nn.Tanh()
raise ValueError('Unsupported activation layer: {:s}'.format(name))

class ResBlk(nn.Module):
def __init__(self, n_in, n_out):
super(ResBlk, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(n_in, n_out, 3, 1, 1),
get_norm('batchnorm', n_out),
nn.ReLU(inplace=True),
nn.Conv2d(n_out, n_out, 3, 1, 1),
get_norm('batchnorm', n_out),
)

def forward(self, x):
return self.layers(x)

class _Generator(nn.Module):
def __init__(self, input_channels, output_channels, last_nonlinear):
super(_Generator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 32, 7, 1, 3), # n_in, n_out, kernel_size, stride, padding
get_norm('instancenorm', 32),
get_nonlinear('relu'),
nn.Conv2d(32, 64, 4, 2, 1),
get_norm('instancenorm', 64),
get_nonlinear('relu'),
nn.Conv2d(64, 128, 4, 2, 1),
get_norm('instancenorm', 128),
get_nonlinear('relu'),
nn.Conv2d(128, 256, 4, 2, 1),
get_norm('instancenorm', 256),
get_nonlinear('relu'),
)
self.resblk = nn.Sequential(
ResBlk(256, 256),
ResBlk(256, 256),
ResBlk(256, 256),
ResBlk(256, 256),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
get_norm('instancenorm', 128),
get_nonlinear('relu'),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
get_norm('instancenorm', 64),
get_nonlinear('relu'),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
get_norm('instancenorm', 32),
get_nonlinear('relu'),
nn.ConvTranspose2d(32, output_channels, 7, 1, 3),
get_nonlinear(last_nonlinear),
)

def forward(self, x, a=None):
if a is not None:
assert len(a.size()) == 2 and x.size(0) == a.size(0)
a = a.type(x.dtype)
a = a.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3))
x = torch.cat((x, a), dim=1)
h = self.conv(x)
h = self.resblk(h)
y = self.deconv(h)
return y

class Generator(nn.Module):
def __init__(self, input_channels):
super(Generator, self).__init__()
self.AMN = _Generator(input_channels + 1, input_channels, 'tanh')
self.SAN = _Generator(input_channels, 1, 'sigmoid')
def forward(self, x, a):
y = self.AMN(x, a)
m = self.SAN(x)
y_ = y * m + x * (1-m)
return y_, m

class Discriminator(nn.Module):
def __init__(self, input_channels):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 32, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(32, 64, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(64, 128, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(128, 256, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(256, 512, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(512, 1024, 4, 2, 1),
get_nonlinear('lrelu'),
)
self.src = nn.Conv2d(1024, 1, 3, 1, 1)
self.cls = nn.Sequential(
nn.Conv2d(1024, 1, 2, 1, 0),
get_nonlinear('sigmoid'),
)

def forward(self, x):
h = self.conv(x)
return self.src(h), self.cls(h).squeeze().unsqueeze(1)

if __name__ == '__main__':
from torchsummary import summary
AMN = Generator(4, 3, 'tanh')
summary(AMN, (4, 128, 128), device='cpu')
SAN = Generator(3, 1, 'sigmoid')
summary(SAN, (3, 128, 128), device='cpu')
D = Discriminator(3)
summary(D, (3, 128, 128), device='cpu')
Loading

0 comments on commit a5bfd37

Please sign in to comment.