-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper_functions.py
79 lines (63 loc) · 2.51 KB
/
helper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Helper functions
import torch
import pickle
from torchvision import utils
import matplotlib.pyplot as plt
import pyro
from sklearn.manifold import TSNE
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
def show_batch(images,nrow=4,npadding=10):
"""Visualize a torch tensor of shape: (batch x ch x width x height) """
batch, ch, width, height = images.shape
if(images.device != "cpu"):
images=images.cpu()
grid = utils.make_grid(images,nrow, npadding, normalize=True, range=None, scale_each=True, pad_value=1)
plt.imshow(grid.detach().numpy().transpose((1, 2, 0)))
def show_2_batch(images1,images2,nrow=4,npadding=10):
"""Visualize a torch tensor of shape: (batch x ch x width x height) """
assert(images1.shape == images2.shape)
if(images1.device != "cpu"):
images1=images1.cpu()
if(images1.device != "cpu"):
images2=images2.cpu()
tmp = torch.cat((images1,images2),dim=0)
grid = utils.make_grid(tmp,nrow, npadding, normalize=True, range=None, scale_each=True, pad_value=1)
plt.imshow(grid.detach().numpy().transpose((1, 2, 0)))
def train(svi, loader, use_cuda=False):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for x, _ in loader:
# if on GPU put mini-batch into CUDA memory
if use_cuda:
x = x.cuda()
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
# return epoch loss
return epoch_loss / len(loader.dataset)
def evaluate(svi, loader, use_cuda=False):
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
for x, _ in loader:
# if on GPU put mini-batch into CUDA memory
if use_cuda:
x = x.cuda()
# compute ELBO estimate and accumulate loss
test_loss += svi.evaluate_loss(x)
return test_loss / len(loader.dataset)
def save_obj(obj,root_dir,name):
with open(root_dir + name + '.pkl', 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(root_dir,name):
with open(root_dir + name + '.pkl', 'rb') as f:
return pickle.load(f)
def save_model(model, root_dir, name):
full_file_path= root_dir + name + '.pkl'
torch.save(model.state_dict(),full_file_path)
def load_model(model, root_dir, name):
full_file_path= root_dir + name + '.pkl'
model.load_state_dict(torch.load(full_file_path))