-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_model.py
34 lines (29 loc) · 1.29 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import csv
import os
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.celeba import CelebAZipDatasetWithFilter, CelebAZipDataModule
from experiment import VAELightningModule
from models import *
from utils.utils import *
import numpy as np
## Load model
config = get_config(os.path.join(os.getcwd(), 'configs/vae.yaml'))
chk_path = os.path.join(os.getcwd(), f"logs/{config['model_params']['name']}/version_12/checkpoints/last.ckpt")
# chkpt = torch.load(chk_path, map_location=torch.device('cpu'))
model = VAELightningModule.load_from_checkpoint(checkpoint_path=chk_path,
map_location=torch.device('cpu'),
vae_model=vae_models[config['model_params']['name']](
**config['model_params']),
params=config['exp_params'])
## For generating embeddings
data = CelebAZipDataModule(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0)
data.setup()
dl = data.train_dataloader()
inputs, classes = next(iter(dl))
times = 1
for step, (inputs, classes) in enumerate(dl):
if step == times:
break
f = model.forward(inputs)