-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmain.py
105 lines (82 loc) · 3.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from nam.config import defaults
from nam.data import FoldedDataset
from nam.data import NAMDataset
from nam.models import NAM
from nam.models import get_num_units
from nam.trainer import LitNAM
from nam.types import Config
from nam.utils import parse_args
from nam.utils import plot_mean_feature_importance
from nam.utils import plot_nams
def get_config() -> Config:
args = parse_args()
config = defaults()
config.update(**vars(args))
return config
def main():
config = get_config()
pl.seed_everything(config.seed)
print(config)
exit()
if config.cross_val:
dataset = FoldedDataset(
config,
data_path=config.data_path,
features_columns=["income_2", "WP1219", "WP1220", "year", "weo_gdpc_con_ppp"],
targets_column="WP16",
weights_column="wgt",
)
dataloaders = dataset.train_dataloaders()
model = NAM(
config=config,
name=config.experiment_name,
num_inputs=len(dataset[0][0]),
num_units=get_num_units(config, dataset.features),
)
for fold, (trainloader, valloader) in enumerate(dataloaders):
# Folder hack
tb_logger = TensorBoardLogger(save_dir=config.logdir, name=f'{model.name}', version=f'fold_{fold + 1}')
checkpoint_callback = ModelCheckpoint(filename=tb_logger.log_dir + "/{epoch:02d}-{val_loss:.4f}",
monitor='val_loss',
save_top_k=config.save_top_k,
mode='min')
litmodel = LitNAM(config, model)
trainer = pl.Trainer(logger=tb_logger,
max_epochs=config.num_epochs,
checkpoint_callback=checkpoint_callback)
trainer.fit(litmodel, train_dataloader=trainloader, val_dataloaders=valloader)
plot_mean_feature_importance(litmodel.model, dataset)
plot_nams(litmodel.model, dataset, num_cols=1)
plt.show()
else:
dataset = NAMDataset(
config,
data_path=config.data_path,
features_columns=["income_2", "WP1219", "WP1220", "year", "weo_gdpc_con_ppp"],
targets_column="WP16",
weights_column="wgt",
)
trainloader, valloader, testloader = dataset.get_dataloaders()
model = NAM(
config=config,
name=config.experiment_name,
num_inputs=len(dataset[0][0]),
num_units=get_num_units(config, dataset.features),
)
# Folder hack
tb_logger = TensorBoardLogger(save_dir=config.logdir, name=f'{model.name}', version=f'0')
checkpoint_callback = ModelCheckpoint(filename=tb_logger.log_dir + "/{epoch:02d}-{val_loss:.4f}",
monitor='val_loss',
save_top_k=config.save_top_k,
mode='min')
litmodel = LitNAM(config, model)
trainer = pl.Trainer(logger=tb_logger, max_epochs=config.num_epochs, checkpoint_callback=checkpoint_callback)
trainer.fit(litmodel, train_dataloader=trainloader, val_dataloaders=valloader)
plot_mean_feature_importance(litmodel.model, dataset)
plot_nams(litmodel.model, dataset, num_cols=1)
plt.show()
if __name__ == "__main__":
main()