forked from andreped/DSS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
43 lines (36 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from argparse import ArgumentParser
import sys
import os
def main():
parser = ArgumentParser()
parser.add_argument('-t', '--task', type=str, nargs='?', default="train",
help="which task to perform - either 'train' or 'eval'.")
parser.add_argument('-v', '--verbose', type=int, nargs='?', default=1,
help="sets the verbose level.")
parser.add_argument('-bs', '--batch_size', type=int, nargs='?', default=8,
help="set which batch size to use for training.")
parser.add_argument('-lr', '--learning_rate', type=float, nargs='?', default=0.001,
help="set which learning rate to use for training.")
parser.add_argument('-ep', '--epochs', type=int, nargs='?', default=500,
help="number of epochs to train.")
parser.add_argument('-pa', '--patience', type=int, nargs='?', default=10,
help="number of epochs to wait (patience) for early stopping.")
parser.add_argument('-a', '--arch', type=str, nargs='?', default="rnn",
help="which architecture to use.")
parser.add_argument('-ls', '--loss', type=str, nargs='?', default="cce",
help="which loss function to use. Supportes losses are: {'cce', 'focal'}.")
args = parser.parse_known_args(sys.argv[1:])[0]
print(args)
# setup folders
os.makedirs("output/models/", exist_ok=True)
os.makedirs("output/history/", exist_ok=True)
os.makedirs("output/datasets/", exist_ok=True)
if args.task == "train":
from dss.train import Trainer
Trainer(args).fit()
elif args.task == "deploy":
raise NotImplementedError
else:
raise ValueError("Unknown task specified. Available tasks include {'train', 'eval'}.")
if __name__ == "__main__":
main()