Skip to content

Commit

Permalink
Fix device issue
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisyjlin committed Jul 6, 2019
1 parent 7d3d7cd commit 3d7ac55
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def parse():
test_data = data.DataLoader(test_dset, args.num_samples)
for fixed_reals, fixed_labels in test_data:
# Get the first batch of images from the testing set
fixed_reals, fixed_labels = fixed_reals.to(device), fixed_labels.to(device).type_as(fixed_reals)
fixed_reals, fixed_labels = fixed_reals.to(device), fixed_labels.type_as(fixed_reals).to(device)
fixed_target_labels = 1 - fixed_labels
break
del test_dset
Expand Down Expand Up @@ -145,7 +145,7 @@ def parse():
trainable(D, True)

reals, labels = next(train_data)
reals, labels = reals.to(device), labels.to(device).type_as(reals)
reals, labels = reals.to(device), labels.type_as(reals).to(device)
target_labels = 1 - labels

fakes, _ = G(reals, target_labels)
Expand Down Expand Up @@ -180,7 +180,7 @@ def parse():
trainable(D, False)

reals, labels = next(train_data)
reals, labels = reals.to(device), labels.to(device).type_as(reals)
reals, labels = reals.to(device), labels.type_as(reals).to(device)
target_labels = 1 - labels

fakes, _ = G(reals, target_labels)
Expand Down Expand Up @@ -233,4 +233,4 @@ def parse():
torch.save(G.state_dict(), join(checkpoint_path, '{:07}.G.pth'.format(cur_nimg)))
torch.save(D.state_dict(), join(checkpoint_path, '{:07}.D.pth'.format(cur_nimg)))
torch.save(G_opt.state_dict(), join(checkpoint_path, '{:07}.G_opt.pth'.format(cur_nimg)))
torch.save(D_opt.state_dict(), join(checkpoint_path, '{:07}.D_opt.pth'.format(cur_nimg)))
torch.save(D_opt.state_dict(), join(checkpoint_path, '{:07}.D_opt.pth'.format(cur_nimg)))

0 comments on commit 3d7ac55

Please sign in to comment.