-
Notifications
You must be signed in to change notification settings - Fork 0
/
datacollect.py
137 lines (119 loc) · 6.73 KB
/
datacollect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import sys
import argparse
import boirlscenarios.constants as constants
from boirlscenarios.configurations import Configurations
import os
from boirlscenarios.irlobject import IRLObject
"""
MAIN SCRIPT TO COLLECT EXPERT TRAJECTORIES
"""
def check_kernels(value):
if not constants.isValidKernel(value):
raise argparse.ArgumentTypeError("%s is an invalid kernel" % value)
return value
def check_envs(value):
if not constants.isValidEnv(value):
raise argparse.ArgumentTypeError("%s is an invalid env" % value)
return value
def main(argv):
env = ""
parser = argparse.ArgumentParser(description="Description for my parser")
parser.add_argument("-e", "--env", help="gridworld2d, gridworld3d, vborlange or rborlange", required=True,
default="gridworld2d",
type=check_envs)
argument = parser.parse_args()
status = False
if argument.env:
print("You have used '-e' or '--env' with argument: {0}".format(argument.env))
env = argument.env
status = True
if not status:
print("Maybe you want to use -e as argument?")
#Create new trajectories for all environments except fetch and maze
# Fetch and Maze requires training of an expert policy to generate expert trajectories.
# So we simply load a pre-existing trajectories
if not (env == constants.FETCH or env == constants.MAZE):
# Create an environment object
configs = Configurations(None, env)
if env == constants.GRIDWORLD2D:
from mdp.gridworld2d import GridWorld2D
O = GridWorld2D(horizon=configs.getLTrajs())
elif env == constants.GRIDWORLD3D:
from mdp.gridworld3d import GridWorld3D
O = GridWorld3D(horizon=configs.getLTrajs())
elif env == constants.VIRTBORLANGE or env == constants.REALBORLANGE:
from mdp.borlangeworld import BorlangeWorld
O = BorlangeWorld(destination=7622, horizon=configs.getLTrajs(), discount=configs.getDiscounts(),
loadres=True)
# Get full trajectories
if env == constants.REALBORLANGE:
fullTrajs, fullSpos = O.gather_real_trajectories()
elif env == constants.VIRTBORLANGE:
print("Warning! This might take a while")
fullTrajs, fullSpos, _ = O.generate_trajectories(n_trajectories=O.nodummy_states.shape[0],
startpos=np.random.permutation(
np.random.permutation(O.nodummy_states)))
elif env == constants.GRIDWORLD3D or env == constants.GRIDWORLD2D:
fullTrajs, fullSpos, _ = O.generate_trajectories(n_trajectories=50, startpos=np.random.randint(0, 6, 50))
else:
raise AssertionError("Invalid Environment: %s") % env
# Get trajectories for creating latent space
n_full_trajectories, l_trajectories, _ = np.shape(fullTrajs)
indices = np.random.permutation(np.arange(n_full_trajectories))[0:configs.getNTrajs()]
trajectories = fullTrajs[indices]
start_pos = fullSpos[indices]
# Save them all in the Data directory
np.save(os.path.join(configs.getTrajectoryDir(), "full_opt_trajectories.npy"), fullTrajs)
np.save(os.path.join(configs.getTrajectoryDir(), "full_start_pos.npy"), fullSpos)
np.save(os.path.join(configs.getTrajectoryDir(), "train_trajectories.npy"), trajectories)
np.save(os.path.join(configs.getTrajectoryDir(), "feature_indices.npy"), indices)
np.save(os.path.join(configs.getTrajectoryDir(), "features.npy"), O.features)
else:
print("We will reuse expert demonstrations stored in the Data folder for %s environment." % env)
# Store reward function parameters within bounds to check rho-projection space
if env == constants.GRIDWORLD2D:
irlobj = IRLObject(None, env)
W1 = np.linspace(irlobj.bounds[0]['domain'][0], irlobj.bounds[0]['domain'][1], 500)
W2 = np.linspace(irlobj.bounds[1]['domain'][0], irlobj.bounds[1]['domain'][1], 500)
w1, w2 = np.meshgrid(W1, W2)
allw = np.hstack((w1.reshape(500 * 500, 1), w2.reshape(500 * 500, 1)))
allw = np.append(allw, irlobj.gtheta, axis=0)
allw = np.append(allw, -1 * irlobj.gtheta, axis=0)
np.save(os.path.join(irlobj.configurations.getTrajectoryDir(), "allw.npy"), allw)
elif env == constants.GRIDWORLD3D:
irlobj = IRLObject(None, env)
W1 = np.linspace(irlobj.bounds[0]['domain'][0], irlobj.bounds[0]['domain'][1], 500)
W2 = np.linspace(irlobj.bounds[1]['domain'][0], irlobj.bounds[1]['domain'][1], 500)
w1, w2 = np.meshgrid(W1, W2)
tempw = np.hstack((w1.reshape(500 * 500, 1), w2.reshape(500 * 500, 1),np.zeros((500*500,1))))
tempw = np.append(tempw, irlobj.gtheta, axis=0)
tempw = np.append(tempw, -1 * irlobj.gtheta, axis=0)
allw = np.vstack((tempw, tempw))
allw = np.vstack((allw, tempw))
allw[int(allw.shape[0] / 3):2 * int(allw.shape[0] / 3), 2] = -1
allw[2 * int(allw.shape[0] / 3):, 2] = 2
np.save(os.path.join(irlobj.configurations.getTrajectoryDir(), "allw.npy"), allw)
elif env == constants.VIRTBORLANGE or env == constants.REALBORLANGE:
irlobj = IRLObject(None, env)
W1 = np.linspace(irlobj.bounds[0]['domain'][0], irlobj.bounds[0]['domain'][1], 100)
W2 = np.linspace(irlobj.bounds[1]['domain'][0], irlobj.bounds[1]['domain'][1], 100)
W3 = np.linspace(irlobj.bounds[2]['domain'][0], irlobj.bounds[2]['domain'][1], 100)
w1, w2, w3 = np.meshgrid(W1, W2, W3)
allw = np.hstack(
(w1.reshape(100 * 100 * 100, 1), w2.reshape(100 * 100 * 100, 1), w3.reshape(100 * 100 * 100, 1)))
if env == constants.VIRTBORLANGE:
allw = np.append(allw, irlobj.gtheta, axis=0)
np.save(os.path.join(irlobj.configurations.getTrajectoryDir(), "allw.npy"), allw)
elif env == constants.MAZE or env == constants.FETCH:
irlobj = IRLObject(None, env)
W1 = np.linspace(irlobj.bounds[0]['domain'][0], irlobj.bounds[0]['domain'][1], 500)
W2 = np.linspace(irlobj.bounds[1]['domain'][0], irlobj.bounds[1]['domain'][1], 500)
w1, w2 = np.meshgrid(W1, W2)
allw = np.hstack((w1.reshape(500 * 500, 1), w2.reshape(500 * 500, 1)))
allw = np.append(allw, irlobj.gtheta, axis=0)
np.save(os.path.join(irlobj.configurations.getTrajectoryDir(), "allw.npy"), allw)
else:
raise AssertionError("Invalid Environment: %s") % env
if __name__ == "__main__":
main(sys.argv[1:])