Skip to content

Commit

Permalink
Revisions
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisyjlin committed Mar 22, 2019
1 parent 46b8078 commit 9431701
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, data_path, attr_path, image_size, mode, selected_attrs):
transforms.CenterCrop(170),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.length = len(self.images)
Expand Down
8 changes: 4 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def parse():
test_data = data.DataLoader(test_dset, args.batch_size)

# Model
G = Generator(3)
G = Generator()
G.to(device)

# Load from checkpoints
Expand All @@ -66,14 +66,14 @@ def parse():
load_nimg = max(int(path.split('.')[0]) for path in listdir(join(checkpoint_path)) if path.split('.')[0].isdigit())
print('Loading generator from nimg {:07d}'.format(load_nimg))
G.load_state_dict(torch.load(
join(checkpoint_path, '{:d}.G.pth'.format(load_nimg)),
join(checkpoint_path, '{:07d}.G.pth'.format(load_nimg)),
map_location=lambda storage, loc: storage
))

G.eval()
with torch.no_grad():
for batch_idx, (reals, labels) in enumerate(tqdm(test_data)):
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
reals, labels = reals.to(device), labels.to(device).type_as(reals)
target_labels = 1 - labels

# Modify images
Expand All @@ -88,7 +88,7 @@ def parse():
for idx, image_out in enumerate(images_out):
vutils.save_image(
image_out,
join(test_path, '{:07d}.jpg'.format(batch_idx*args.batch_size+idx+200000)),
join(test_path, '{:06d}.jpg'.format(batch_idx*args.batch_size+idx+200000)),
nrow=3,
padding=0,
normalize=True,
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def parse():
test_data = data.DataLoader(test_dset, args.num_samples)
for fixed_reals, fixed_labels in test_data:
# Get the first batch of images from the testing set
fixed_reals, fixed_labels = fixed_reals.to(device), fixed_labels.to(device).type(fixed_reals.dtype)
fixed_reals, fixed_labels = fixed_reals.to(device), fixed_labels.to(device).type_as(fixed_reals)
fixed_target_labels = 1 - fixed_labels
break
del test_dset
Expand Down Expand Up @@ -145,7 +145,7 @@ def parse():
trainable(D, True)

reals, labels = next(train_data)
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
reals, labels = reals.to(device), labels.to(device).type_as(reals)
target_labels = 1 - labels

fakes, _ = G(reals, target_labels)
Expand Down Expand Up @@ -180,7 +180,7 @@ def parse():
trainable(D, False)

reals, labels = next(train_data)
reals, labels = reals.to(device), labels.to(device).type(reals.dtype)
reals, labels = reals.to(device), labels.to(device).type_as(reals)
target_labels = 1 - labels

fakes, _ = G(reals, target_labels)
Expand Down

0 comments on commit 9431701

Please sign in to comment.