Skip to content

Commit bb977e7

Browse files
authored
Merge pull request #4 from convince-project/option_learning
Option Learning Functionality (and other mods)
2 parents 4a2f9d7 + f09479a commit bb977e7

File tree

107 files changed

+515909
-18
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+515909
-18
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ COVERAGE-PLAN requires the following dependencies:
1414
* [Numpy](https://numpy.org/) (Tested with 26.4)
1515
* [Sympy](https://www.sympy.org/en/index.html) (Tested with 1.12)
1616
* [Pyeda](https://pyeda.readthedocs.io/en/latest/) (Tested with 0.29.0)
17+
* [PyAgrum](https://pyagrum.readthedocs.io/en/1.15.1/index.html) (Tested with 1.14.1)
18+
* [PyMongo](https://pymongo.readthedocs.io/en/stable/index.html) (Tested with 4.8.0)
19+
* [Pandas](https://pandas.pydata.org/) (Tested with 2.2.1)
1720
* [Stormpy](https://moves-rwth.github.io/stormpy/index.html) (Tested with 1.8.0)
21+
* [MongoDB](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-ubuntu/) (Tested with 7.0.12) - only required for unit tests.
1822

19-
The first three dependencies can be installed via:
23+
The first six dependencies can be installed via:
2024
```
2125
pip install -r requirements.txt
2226
```
2327

28+
`MongoDB` can be installed using the [official instructions](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-ubuntu/).
29+
2430
Installing `stormpy` is more involved. Please see below.
2531

2632
### Installing Stormpy
@@ -111,6 +117,5 @@ If you want to clean the documentation, you can run:
111117

112118
```bash
113119
cd docs
114-
rm -r source/API
115120
make clean
116121
```

bin/fake_museum_planning.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#!/usr/bin/env python3
2+
""" A script to run REFINE-PLAN on the fake museum simulation example
3+
4+
Author: Charlie Street
5+
Owner: Charlie Street
6+
"""
7+
8+
from refine_plan.models.condition import Label, EqCondition, AndCondition, OrCondition
9+
from refine_plan.learning.option_learning import mongodb_to_yaml, learn_dbns
10+
from refine_plan.algorithms.semi_mdp_solver import synthesise_policy
11+
from refine_plan.models.state_factor import StateFactor
12+
from refine_plan.models.dbn_option import DBNOption
13+
from refine_plan.models.semi_mdp import SemiMDP
14+
from refine_plan.models.state import State
15+
import sys
16+
17+
# Global map setup
18+
19+
GRAPH = {
20+
"v1": {"e12": "v2", "e13": "v3", "e14": "v4"},
21+
"v2": {"e12": "v1", "e23": "v3", "e25": "v5", "e26": "v6"},
22+
"v3": {
23+
"e13": "v1",
24+
"e23": "v2",
25+
"e34": "v4",
26+
"e35": "v5",
27+
"e36": "v6",
28+
"e37": "v7",
29+
},
30+
"v4": {"e14": "v1", "e34": "v3", "e46": "v6", "e47": "v7"},
31+
"v5": {"e25": "v2", "e35": "v3", "e56": "v6", "e58": "v8"},
32+
"v6": {
33+
"e26": "v2",
34+
"e36": "v3",
35+
"e46": "v4",
36+
"e56": "v5",
37+
"e67": "v7",
38+
"e68": "v8",
39+
},
40+
"v7": {
41+
"e37": "v3",
42+
"e47": "v4",
43+
"e67": "v6",
44+
"e78": "v8",
45+
},
46+
"v8": {"e58": "v5", "e68": "v6", "e78": "v7"},
47+
}
48+
49+
CORRESPONDING_DOOR = {
50+
"e12": None,
51+
"e14": None,
52+
"e58": "v5",
53+
"e78": "v7",
54+
"e13": None,
55+
"e36": "v3",
56+
"e68": "v6",
57+
"e25": "v2",
58+
"e47": "v4",
59+
"e26": "v2",
60+
"e35": "v3",
61+
"e46": "v4",
62+
"e37": "v3",
63+
"e23": None,
64+
"e34": None,
65+
"e56": None,
66+
"e67": None,
67+
}
68+
69+
# Problem Setup
70+
INITIAL_LOC = "v1"
71+
GOAL_LOC = "v8"
72+
73+
74+
def _get_enabled_cond(sf_list, option):
75+
"""Get the enabled condition for an option.
76+
77+
Args:
78+
sf_list: The list of state factors
79+
option: The option we want the condition for
80+
81+
Returns:
82+
The enabled condition for the option
83+
"""
84+
sf_dict = {sf.get_name(): sf for sf in sf_list}
85+
86+
door_locs = ["v{}".format(i) for i in range(2, 8)]
87+
88+
if option == "check_door" or option == "open_door":
89+
enabled_cond = OrCondition()
90+
door_status = "unknown" if option == "check_door" else "closed"
91+
for door in door_locs:
92+
enabled_cond.add_cond(
93+
AndCondition(
94+
EqCondition(sf_dict["location"], door),
95+
EqCondition(sf_dict["{}_door".format(door)], door_status),
96+
)
97+
)
98+
return enabled_cond
99+
else: # edge navigation option
100+
enabled_cond = OrCondition()
101+
for node in GRAPH:
102+
if option in GRAPH[node]:
103+
enabled_cond.add_cond(EqCondition(sf_dict["location"], node))
104+
door = CORRESPONDING_DOOR[option]
105+
if door != None:
106+
enabled_cond = AndCondition(
107+
enabled_cond, EqCondition(sf_dict["{}_door".format(door)], "open")
108+
)
109+
return enabled_cond
110+
111+
112+
def write_mongodb_to_yaml(mongo_connection_str):
113+
"""Learn the DBNOptions from the database.
114+
115+
Args:
116+
mongo_connection_str: The MongoDB conenction string"""
117+
118+
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)])
119+
door_sfs = [
120+
StateFactor("v2_door", ["unknown", "closed", "open"]),
121+
StateFactor("v3_door", ["unknown", "closed", "open"]),
122+
StateFactor("v4_door", ["unknown", "closed", "open"]),
123+
StateFactor("v5_door", ["unknown", "closed", "open"]),
124+
StateFactor("v6_door", ["unknown", "closed", "open"]),
125+
StateFactor("v7_door", ["unknown", "closed", "open"]),
126+
]
127+
128+
print("Writing mongo database to yaml file")
129+
mongodb_to_yaml(
130+
mongo_connection_str,
131+
"refine-plan",
132+
"fake-museum-data",
133+
[loc_sf] + door_sfs,
134+
"../data/fake_museum/dataset.yaml",
135+
)
136+
137+
138+
def learn_options():
139+
"""Learn the options from the YAML file."""
140+
dataset_path = "../data/fake_museum/dataset.yaml"
141+
output_dir = "../data/fake_museum/"
142+
143+
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)])
144+
door_sfs = [
145+
StateFactor("v2_door", ["unknown", "closed", "open"]),
146+
StateFactor("v3_door", ["unknown", "closed", "open"]),
147+
StateFactor("v4_door", ["unknown", "closed", "open"]),
148+
StateFactor("v5_door", ["unknown", "closed", "open"]),
149+
StateFactor("v6_door", ["unknown", "closed", "open"]),
150+
StateFactor("v7_door", ["unknown", "closed", "open"]),
151+
]
152+
153+
learn_dbns(dataset_path, output_dir, [loc_sf] + door_sfs)
154+
155+
156+
def run_planner():
157+
"""Run refine-plan and synthesise a BT.
158+
159+
Returns:
160+
The refined BT
161+
"""
162+
163+
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)])
164+
door_sfs = [
165+
StateFactor("v2_door", ["unknown", "closed", "open"]),
166+
StateFactor("v3_door", ["unknown", "closed", "open"]),
167+
StateFactor("v4_door", ["unknown", "closed", "open"]),
168+
StateFactor("v5_door", ["unknown", "closed", "open"]),
169+
StateFactor("v6_door", ["unknown", "closed", "open"]),
170+
StateFactor("v7_door", ["unknown", "closed", "open"]),
171+
]
172+
sf_list = [loc_sf] + door_sfs
173+
174+
labels = [Label("goal", EqCondition(loc_sf, "v8"))]
175+
176+
option_names = [
177+
"e12",
178+
"e14",
179+
"e58",
180+
"e78",
181+
"e13",
182+
"e36",
183+
"e68",
184+
"e25",
185+
"e47",
186+
"e26",
187+
"e35",
188+
"e46",
189+
"e37",
190+
"e23",
191+
"e34",
192+
"e56",
193+
"e67",
194+
"check_door",
195+
"open_door",
196+
]
197+
198+
assert len(set(option_names)) == 19 # Quick safety check
199+
200+
init_state_dict = {sf: "unknown" for sf in door_sfs}
201+
init_state_dict[loc_sf] = "v1"
202+
init_state = State(init_state_dict)
203+
204+
option_list = []
205+
for option in option_names:
206+
print("Reading in option: {}".format(option))
207+
t_path = "../data/fake_museum/{}_transition.bifxml".format(option)
208+
r_path = "../data/fake_museum/{}_reward.bifxml".format(option)
209+
option_list.append(
210+
DBNOption(
211+
option, t_path, r_path, sf_list, _get_enabled_cond(sf_list, option)
212+
)
213+
)
214+
215+
print("Creating MDP...")
216+
semi_mdp = SemiMDP(sf_list, option_list, labels, initial_state=init_state)
217+
print("Synthesising Policy...")
218+
policy = synthesise_policy(semi_mdp, prism_prop='Rmin=?[F "goal"]')
219+
policy.write_policy("../data/fake_museum/fake_museum_refined_policy.yaml")
220+
221+
222+
if __name__ == "__main__":
223+
224+
# write_mongodb_to_yaml(sys.argv[1])
225+
# learn_options()
226+
run_planner()

bin/plot_fake_museum_results.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#!/usr/bin/env python
2+
""" Script for plotting the fake museum REFINE-PLAN results.
3+
4+
Author: Charlie Street
5+
"""
6+
7+
from scipy.stats import mannwhitneyu
8+
from pymongo import MongoClient
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import matplotlib
12+
import sys
13+
14+
plt.rcParams["pdf.fonttype"] = 42
15+
matplotlib.rcParams.update({"font.size": 40})
16+
17+
18+
def read_results_for_method(collection, sf_names):
19+
"""Read the mongo results for a single method (i.e. a collection).
20+
21+
Args:
22+
collection: The MongoDB collection
23+
sf_names: A list of state factor names
24+
25+
Returns:
26+
results: The list of run durations
27+
"""
28+
29+
# Group docs together by run_id
30+
docs_per_run = {}
31+
for doc in collection.find({}):
32+
if doc["run_id"] not in docs_per_run:
33+
docs_per_run[doc["run_id"]] = []
34+
docs_per_run[doc["run_id"]].append(doc)
35+
36+
# Sanity check each run
37+
results = []
38+
for run_id in docs_per_run:
39+
total_duration = 0.0
40+
in_order = sorted(docs_per_run[run_id], key=lambda d: d["date_started"])
41+
assert in_order[0]["location0"] == "v1"
42+
assert in_order[-1]["locationt"] == "v8"
43+
44+
for i in range(len(in_order) - 1):
45+
total_duration += in_order[i]["duration"]
46+
for sf in sf_names:
47+
assert (
48+
in_order[i]["{}t".format(sf)] == in_order[i + 1]["{}0".format(sf)]
49+
)
50+
total_duration += in_order[-1]["duration"]
51+
results.append(total_duration)
52+
53+
assert len(results) == 100
54+
return results
55+
56+
57+
def print_stats(init_results, refined_results):
58+
"""Print the statistics for the initial and refined results.
59+
60+
Args:
61+
init_results: The durations for the initial behaviour
62+
refined_results: The durations for the refined behaviour
63+
"""
64+
print(
65+
"INITIAL BEHAVIOUR: AVG COST: {}; VARIANCE: {}".format(
66+
np.mean(init_results), np.var(init_results)
67+
)
68+
)
69+
print(
70+
"REFINED BEHAVIOUR: AVG COST: {}; VARIANCE: {}".format(
71+
np.mean(refined_results), np.var(refined_results)
72+
)
73+
)
74+
p = mannwhitneyu(
75+
refined_results,
76+
init_results,
77+
alternative="less",
78+
)[1]
79+
print(
80+
"REFINED BT BETTER THAN INITIAL BT: p = {}, stat sig better = {}".format(
81+
p, p < 0.05
82+
)
83+
)
84+
85+
86+
def set_box_colors(bp):
87+
plt.setp(bp["boxes"][0], color="tab:blue", linewidth=8.0)
88+
plt.setp(bp["caps"][0], color="tab:blue", linewidth=8.0)
89+
plt.setp(bp["caps"][1], color="tab:blue", linewidth=8.0)
90+
plt.setp(bp["whiskers"][0], color="tab:blue", linewidth=8.0)
91+
plt.setp(bp["whiskers"][1], color="tab:blue", linewidth=8.0)
92+
plt.setp(bp["fliers"][0], color="tab:blue")
93+
plt.setp(bp["medians"][0], color="tab:blue", linewidth=8.0)
94+
95+
plt.setp(bp["boxes"][1], color="tab:red", linewidth=8.0)
96+
plt.setp(bp["caps"][2], color="tab:red", linewidth=8.0)
97+
plt.setp(bp["caps"][3], color="tab:red", linewidth=8.0)
98+
plt.setp(bp["whiskers"][2], color="tab:red", linewidth=8.0)
99+
plt.setp(bp["whiskers"][3], color="tab:red", linewidth=8.0)
100+
plt.setp(bp["medians"][1], color="tab:red", linewidth=8.0)
101+
102+
103+
def plot_box_plot(init_results, refined_results):
104+
"""Plot a box plot showing the initial and refined results.
105+
106+
Args:
107+
init_results: The durations for the initial behaviour
108+
refined_results: The durations for the refined behaviour
109+
"""
110+
111+
box = plt.boxplot(
112+
[init_results, refined_results],
113+
whis=[0, 100],
114+
positions=[1, 2],
115+
widths=0.6,
116+
)
117+
set_box_colors(box)
118+
119+
plt.tick_params(
120+
axis="x", # changes apply to the x-axis
121+
which="both", # both major and minor ticks are affected
122+
bottom=True, # ticks along the bottom edge are off
123+
top=False, # ticks along the top edge are off
124+
labelbottom=True, # labels along the bottom edge are offcd
125+
labelsize=40,
126+
)
127+
plt.ylabel("Time to Reach Goal (s)")
128+
129+
plt.xticks([1, 2], ["Initial BT", "Refined BT"])
130+
131+
plt.show()
132+
133+
134+
if __name__ == "__main__":
135+
136+
sf_names = ["v{}_door".format(v) for v in range(2, 7)]
137+
sf_names = ["location"] + sf_names
138+
client = MongoClient(sys.argv[1])
139+
db = client["refine-plan"]
140+
init_results = read_results_for_method(db["fake-museum-initial"], sf_names)
141+
refined_results = read_results_for_method(db["fake-museum-refined"], sf_names)
142+
print_stats(init_results, refined_results)
143+
plot_box_plot(init_results, refined_results)

0 commit comments

Comments
 (0)