-
Notifications
You must be signed in to change notification settings - Fork 15
/
runner.py
110 lines (93 loc) · 2.87 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from robothor_challenge.challenge import RobothorChallenge
import os
import argparse
import importlib
import gzip
import json
import logging
logging.getLogger().setLevel(logging.INFO)
def main():
parser = argparse.ArgumentParser(description="Inference script for RoboThor ObjectNav challenge.")
parser.add_argument(
"--agent", "-a",
default="agents.random_agent",
help="Relative module for agent definition.",
)
parser.add_argument(
"--cfg", "-c",
default="challenge_config.yaml",
help="Filepath to challenge config.",
)
parser.add_argument(
"--dataset-dir", "-d",
default="dataset",
help="Filepath to challenge dataset.",
)
parser.add_argument(
"--output", "-o",
default="metrics.json.gz",
help="Filepath to output results to.",
)
parser.add_argument(
"--submission",
action="store_true")
parser.add_argument(
"--debug",
action="store_true")
parser.add_argument(
"--train",
action="store_true")
parser.add_argument(
"--val",
action="store_true")
parser.add_argument(
"--test",
action="store_true")
parser.add_argument(
"--nprocesses", "-n",
default=1,
type=int,
help="Number of parallel processes used to compute inference.",
)
args = parser.parse_args()
if args.submission:
args.debug = False
args.train = False
args.val = True
args.test = True
agent = importlib.import_module(args.agent)
agent_class, agent_kwargs, render_depth = agent.build()
r = RobothorChallenge(args.cfg, agent_class, agent_kwargs, render_depth=render_depth)
challenge_metrics = {}
if args.debug:
debug_episodes, debug_dataset = r.load_split(args.dataset_dir, "debug")
challenge_metrics["debug"] = r.inference(
debug_episodes,
nprocesses=args.nprocesses,
test=False
)
if args.train:
train_episodes, train_dataset = r.load_split(args.dataset_dir, "train")
challenge_metrics["train"] = r.inference(
train_episodes,
nprocesses=args.nprocesses,
test=False
)
if args.val:
val_episodes, val_dataset = r.load_split(args.dataset_dir, "val")
challenge_metrics["val"] = r.inference(
val_episodes,
nprocesses=args.nprocesses,
test=False
)
if args.test:
test_episodes, test_dataset = r.load_split(args.dataset_dir, "test")
challenge_metrics["test"] = r.inference(
test_episodes,
nprocesses=args.nprocesses,
test=True
)
with gzip.open(args.output, "wt", encoding="utf-8") as zipfile:
json.dump(challenge_metrics, zipfile)
if __name__ == "__main__":
main()