Skip to content

Commit

Permalink
Revisions
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisyjlin committed Mar 21, 2019
1 parent a5bfd37 commit 46b8078
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This repository contains the PyTorch implementation of the ECCV 2018 paper "Gene
pip3 install -r requirements.txt
```

The training procedure takes 5.5GB memory on a single GPU.
The training procedure described in paper takes 5.5GB memory on a single GPU.

## Usage

Expand Down
14 changes: 8 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@ def parse():
load_nimg = args.test_nimg
if load_nimg is None: # Use the lastest model
load_nimg = max(int(path.split('.')[0]) for path in listdir(join(checkpoint_path)) if path.split('.')[0].isdigit())
print('Loading generator from nimg {:06d}'.format(load_nimg))
print('Loading generator from nimg {:07d}'.format(load_nimg))
G.load_state_dict(torch.load(
join(checkpoint_path, '{:d}.G.pth'.format(load_nimg)),
map_location=lambda storage, loc: storage
))

G.eval()
for batch_idx, (reals, labels) in enumerate(tqdm(test_data)):
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
target_labels = 1 - labels
with torch.no_grad():
with torch.no_grad():
for batch_idx, (reals, labels) in enumerate(tqdm(test_data)):
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
target_labels = 1 - labels

# Modify images
samples, masks = G(reals, target_labels)

Expand All @@ -87,8 +88,9 @@ def parse():
for idx, image_out in enumerate(images_out):
vutils.save_image(
image_out,
join(test_path, '{:06d}.jpg'.format(batch_idx*args.batch_size+idx+200000)),
join(test_path, '{:07d}.jpg'.format(batch_idx*args.batch_size+idx+200000)),
nrow=3,
padding=0,
normalize=True,
range=(-1.,1.)
)
16 changes: 8 additions & 8 deletions sagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_nonlinear(name):
if name == 'relu':
return nn.ReLU(inplace=True)
if name == 'lrelu':
return nn.LeakyReLU(inplace=True)
return nn.LeakyReLU(0.2, inplace=True)
if name == 'sigmoid':
return nn.Sigmoid()
if name == 'tanh':
Expand All @@ -39,7 +39,7 @@ def __init__(self, n_in, n_out):
)

def forward(self, x):
return self.layers(x)
return self.layers(x) + x

class _Generator(nn.Module):
def __init__(self, input_channels, output_channels, last_nonlinear):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, input_channels, output_channels, last_nonlinear):

def forward(self, x, a=None):
if a is not None:
assert len(a.size()) == 2 and x.size(0) == a.size(0)
assert a.dim() == 2 and x.size(0) == a.size(0)
a = a.type(x.dtype)
a = a.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3))
x = torch.cat((x, a), dim=1)
Expand All @@ -90,21 +90,21 @@ def forward(self, x, a=None):
return y

class Generator(nn.Module):
def __init__(self, input_channels):
def __init__(self):
super(Generator, self).__init__()
self.AMN = _Generator(input_channels + 1, input_channels, 'tanh')
self.SAN = _Generator(input_channels, 1, 'sigmoid')
self.AMN = _Generator(4, 3, 'tanh')
self.SAN = _Generator(3, 1, 'sigmoid')
def forward(self, x, a):
y = self.AMN(x, a)
m = self.SAN(x)
y_ = y * m + x * (1-m)
return y_, m

class Discriminator(nn.Module):
def __init__(self, input_channels):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 32, 4, 2, 1),
nn.Conv2d(3, 32, 4, 2, 1),
get_nonlinear('lrelu'),
nn.Conv2d(32, 64, 4, 2, 1),
get_nonlinear('lrelu'),
Expand Down
18 changes: 9 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ def parse():
break
del test_dset
del test_data
vutils.save_image(fixed_reals, join(sample_path, '{:06d}_real.jpg'.format(0)), nrow=8, normalize=True, range=(-1., 1.))
vutils.save_image(fixed_reals, join(sample_path, '{:07d}_real.jpg'.format(0)), nrow=8, padding=0, normalize=True, range=(-1., 1.))

# Models
G = Generator(3)
G = Generator()
G.apply(init_weights)
G.to(device)

D = Discriminator(3)
D = Discriminator()
D.apply(init_weights)
D.to(device)

Expand Down Expand Up @@ -225,12 +225,12 @@ def parse():
G.eval()
with torch.no_grad():
samples, masks = G(fixed_reals, fixed_target_labels)
vutils.save_image(samples, join(sample_path, '{:06d}_fake.jpg'.format(cur_nimg)), nrow=8, normalize=True, range=(-1., 1.))
vutils.save_image(masks.repeat(1, 3, 1, 1), join(sample_path, '{:06d}_mask.jpg'.format(cur_nimg)), nrow=8)
vutils.save_image(samples, join(sample_path, '{:07d}_fake.jpg'.format(cur_nimg)), nrow=8, padding=0, normalize=True, range=(-1., 1.))
vutils.save_image(masks.repeat(1, 3, 1, 1), join(sample_path, '{:07d}_mask.jpg'.format(cur_nimg)), nrow=8, padding=0)

# Model checkpoints
if cur_tick % args.save_ticks == 0 or done:
torch.save(G.state_dict(), join(checkpoint_path, '{:06}.G.pth'.format(cur_nimg)))
torch.save(D.state_dict(), join(checkpoint_path, '{:06}.D.pth'.format(cur_nimg)))
torch.save(G_opt.state_dict(), join(checkpoint_path, '{:06}.G_opt.pth'.format(cur_nimg)))
torch.save(D_opt.state_dict(), join(checkpoint_path, '{:06}.D_opt.pth'.format(cur_nimg)))
torch.save(G.state_dict(), join(checkpoint_path, '{:07}.G.pth'.format(cur_nimg)))
torch.save(D.state_dict(), join(checkpoint_path, '{:07}.D.pth'.format(cur_nimg)))
torch.save(G_opt.state_dict(), join(checkpoint_path, '{:07}.G_opt.pth'.format(cur_nimg)))
torch.save(D_opt.state_dict(), join(checkpoint_path, '{:07}.D_opt.pth'.format(cur_nimg)))

0 comments on commit 46b8078

Please sign in to comment.