-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_ae.py
54 lines (44 loc) · 1.67 KB
/
test_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import wandb
import math
import xarray as xr
import asyncio
import submitit
import pickle
import sys
import hydra
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
# import matplotlib.pyplot as plt
import numpy as np
# import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from tqdm.auto import tqdm
import glob, os, shutil
import random
import sys
from ContModeling.modeling import test_mat_autoencoder
from ContModeling.helper_classes import MatData
@hydra.main(config_path=".", config_name="mat_autoencoder_config")
def main(cfg: DictConfig):
best_fold = int(cfg.best_fold)
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_idx_path = f"{cfg.output_dir}/{cfg.experiment_name}/test_idx.npy"
test_idx = np.load(test_idx_path)
targets = list(cfg.targets)
dataset = MatData(cfg.dataset_path, targets, threshold=0)
test_dataset = Subset(dataset, test_idx)
print("Testing model.\n", OmegaConf.to_yaml(cfg))
results_dir = os.path.join(cfg.output_dir, cfg.experiment_name)
os.makedirs(results_dir, exist_ok=True)
recon_mat_dir = os.path.join(results_dir, cfg.reconstructed_dir)
os.makedirs(recon_mat_dir, exist_ok=True)
model_params_dir = os.path.join(results_dir, cfg.model_weight_dir)
os.makedirs(model_params_dir, exist_ok=True)
test_mat_autoencoder(best_fold = best_fold, test_dataset =test_dataset, cfg = cfg, model_params_dir = model_params_dir,
recon_mat_dir = recon_mat_dir, device = device)
if __name__ == "__main__":
main()