Skip to content

Commit 968d6be

Browse files
authored
Merge pull request #198 from AmoghSInamdar/addAI
Add AI support and bug fixes
2 parents b6eca7a + e6ff1fd commit 968d6be

25 files changed

+21417
-422
lines changed

logictools/AI/astar_heuristics.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from Levenshtein import distance
66
import numpy as np
77

8-
from logictools.AI.neural_embedding_heuristic import NeuralEmbeddingHeuristic
8+
from logictools.AI.neural_heuristic.neural_embedding_heuristic import NeuralEmbeddingHeuristic
99

1010

11-
# all heuristics expect Tuple<expr: str, law:str> as inputs. Change typing to a StepNode : {expr: str, law: str} object
11+
# all heuristics expect Tuple<expr: str, law:str> as inputs. Change typing
12+
# to a StepNode : {expr: str, law: str} object
1213

1314
def random_weight(n1, n2):
1415
return random() * 10
@@ -19,23 +20,28 @@ def levenshtein_distance(n1, n2):
1920

2021

2122
def len_distance(n1, n2):
22-
return abs(len(n1[0])-len(n2[0]))
23+
return abs(len(n1[0]) - len(n2[0]))
2324

2425

2526
def unitary_distance(n1, n2):
2627
return 1
2728

2829

2930
def variable_mismatch(n1, n2): # vars in n1 but not in n2 and vice versa
30-
cfunc = lambda x: 97 <= ord(x) <= 122 and x != 'v'
31+
def cfunc(x): return 97 <= ord(x) <= 122 and x != 'v'
3132
n1v, n2v = set(filter(cfunc, n1[0])), set(filter(cfunc, n2[0]))
3233
return len((n1v | n2v) - (n1v & n2v))
3334

3435

3536
class RuleDists:
3637

3738
def __init__(self):
38-
self.all_dists = list(filter(lambda r: not r[0].startswith("__"), inspect.getmembers(self, predicate=inspect.ismethod)))
39+
self.all_dists = list(
40+
filter(
41+
lambda r: not r[0].startswith("__"),
42+
inspect.getmembers(
43+
self,
44+
predicate=inspect.ismethod)))
3945
self.all_dists = [r[1] for r in self.all_dists]
4046

4147
def start_dist(self, n1, n2, d=1):
@@ -123,7 +129,8 @@ def load(self, weight_file):
123129
for l in lines[2:]:
124130
try:
125131
heur, val, file = l.split(": ")
126-
self.heuristics.append(getattr(NeuralEmbeddingHeuristic(file[:-1], is_state_dict=True), heur))
132+
self.heuristics.append(
133+
getattr(NeuralEmbeddingHeuristic(file[:-1], is_state_dict=True), heur))
127134
except ValueError:
128135
heur, val = l.split(": ")
129136
if heur in globals():
@@ -134,10 +141,11 @@ def load(self, weight_file):
134141

135142
def save(self, out_file):
136143
with open(out_file, "w") as f:
137-
f.write(str(self.params)+"\n\n")
144+
f.write(str(self.params) + "\n\n")
138145
for i, h in enumerate(self.heuristics):
139146
if self.model_files[i]:
140-
f.write(f"{h.__name__}: {self.weights[i]}: {self.model_files[i]}\n")
147+
f.write(
148+
f"{h.__name__}: {self.weights[i]}: {self.model_files[i]}\n")
141149
else:
142150
f.write(f"{h.__name__}: {self.weights[i]}\n")
143151

@@ -147,4 +155,3 @@ def save(self, out_file):
147155
gh = GeneHeuristic()
148156
gh.load("astar_heuristic_weights.txt")
149157
print(gh.gene_meta_dist(n1, n2))
150-

logictools/AI/astar_search.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def __init__(self, data, fscore=inf, gscore=inf):
4242
def __lt__(self, other):
4343
return self.fscore < other.fscore
4444

45-
class NodeDict(dict): # can't replace with defaultdict because it doesn't accept args in lambda :(
45+
class NodeDict(
46+
dict): # can't replace with defaultdict because it doesn't accept args in lambda :(
4647

4748
def __missing__(self, key):
4849
value = SearchNode(key)
@@ -56,7 +57,9 @@ def __missing__(self, key):
5657
return True, [start]
5758

5859
search_dict = NodeDict()
59-
start_node = search_dict[start] = SearchNode(start, fscore=goal_heuristic(start, goal, *args, **kwargs), gscore=.0)
60+
start_node = search_dict[start] = SearchNode(
61+
start, fscore=goal_heuristic(
62+
start, goal, *args, **kwargs), gscore=.0)
6063
start_node.depth = 0
6164
open_set = []
6265
heappush(open_set, start_node)
@@ -77,15 +80,18 @@ def __missing__(self, key):
7780
current_node = current_node.prev
7881
return True, list(reversed(rev_sol))
7982

80-
for neighbor in map(lambda n: search_dict[n], frontier_func(current_node.data)):
83+
for neighbor in map(
84+
lambda n: search_dict[n], frontier_func(current_node.data)):
8185
if neighbor.completed:
8286
continue
83-
tentative_gscore = current_node.gscore + neighbor_dist(current_node.data, neighbor.data)
87+
tentative_gscore = current_node.gscore + \
88+
neighbor_dist(current_node.data, neighbor.data)
8489
if tentative_gscore < neighbor.gscore:
8590
neighbor.prev = current_node
8691
neighbor.depth = current_node.depth + 1
8792
neighbor.gscore = tentative_gscore
88-
neighbor.fscore = tentative_gscore + goal_heuristic(neighbor.data, goal)
93+
neighbor.fscore = tentative_gscore + \
94+
goal_heuristic(neighbor.data, goal)
8995
if neighbor.out_of_openset:
9096
neighbor.out_of_openset = False
9197
heappush(open_set, neighbor)
@@ -112,18 +118,35 @@ def __missing__(self, key):
112118
questions = json.load(f)['questions']
113119

114120
import logictools.expression_parser as ep
121+
115122
def frontier_func(x):
116-
fr = ep.get_frontier(x[0], simplify_paren=True, include_paren=False, allowed_ops=lrt.search_operations)
123+
fr = ep.get_frontier(
124+
x[0],
125+
simplify_paren=True,
126+
include_paren=False,
127+
allowed_ops=lrt.search_operations)
117128
print(fr)
118129
return fr
130+
119131
def goal_func(x, target):
120132
return x[0] == target[0]
121133

122134
for q in questions[4:5]:
123135
q["premise"] = "(qvp)^(qv~q)"
124-
gp = astar_search(q['premise'], q['target'], levenshtein_distance, frontier_func, goal_func)
136+
gp = astar_search(
137+
q['premise'],
138+
q['target'],
139+
levenshtein_distance,
140+
frontier_func,
141+
goal_func)
125142
print(gp)
126143
gh = GeneHeuristic()
127144
gh.load("astar_heuristic_weights.txt")
128-
gp2 = astar_search("~(pvq)", "~p^~q", gh.gene_meta_dist, frontier_func, goal_func, max_timeout=1)
145+
gp2 = astar_search(
146+
"~(pvq)",
147+
"~p^~q",
148+
gh.gene_meta_dist,
149+
frontier_func,
150+
goal_func,
151+
max_timeout=1)
129152
print(gp2)

logictools/AI/evaluate_astar_heuristic.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def get_questions():
1414
return questions
1515

1616

17-
def evaluate_against_question_bank(heuristic, results_file="heuristic_test_results.txt", max_timeout=5):
17+
def evaluate_against_question_bank(
18+
heuristic, results_file="heuristic_test_results.txt", max_timeout=5):
1819
with open(results_file, "w") as rf:
1920
est = asctime(localtime())
2021
rf.write(
@@ -23,7 +24,10 @@ def evaluate_against_question_bank(heuristic, results_file="heuristic_test_resul
2324
)
2425

2526
def frontier_func(x):
26-
fr = get_frontier(x[0], include_paren=False, allowed_ops=search_operations)
27+
fr = get_frontier(
28+
x[0],
29+
include_paren=False,
30+
allowed_ops=search_operations)
2731
return fr
2832

2933
def goal_func(x, target):
@@ -35,19 +39,29 @@ def goal_func(x, target):
3539
start = q['premise']
3640
goal = q['target']
3741
start_time = time()
38-
is_solved, result = astar_search(start, goal, heuristic, frontier_func, goal_func, max_timeout=max_timeout)
42+
is_solved, result = astar_search(
43+
start, goal, heuristic, frontier_func, goal_func, max_timeout=max_timeout)
3944
end_time = time()
4045
with open(results_file, "a") as rf:
41-
info_str = "{}. Premise: {}, Target: {}. ".format(i + 1, start, goal)
46+
info_str = "{}. Premise: {}, Target: {}. ".format(
47+
i + 1, start, goal)
4248
if is_solved:
4349
solve_time = end_time - start_time
44-
rf.write(info_str + f"Solved in {solve_time:.4f} seconds. Solution: {result}\n.")
45-
print(info_str + f"Solved in {solve_time:.4f} seconds. Solution: {result}")
50+
rf.write(
51+
info_str +
52+
f"Solved in {solve_time:.4f} seconds. Solution: {result}\n.")
53+
print(
54+
info_str +
55+
f"Solved in {solve_time:.4f} seconds. Solution: {result}")
4656
num_solved += 1
4757
else:
4858
if result:
49-
rf.write(info_str + f" Timeout occurred. Path to best node: {result}\n.")
50-
print(info_str + f" Timeout occurred. Path to best node: {result}")
59+
rf.write(
60+
info_str +
61+
f" Timeout occurred. Path to best node: {result}\n.")
62+
print(
63+
info_str +
64+
f" Timeout occurred. Path to best node: {result}")
5165
else:
5266
rf.write(info_str + " Error occurred.")
5367
print(info_str + " Error occurred.")
@@ -56,8 +70,26 @@ def goal_func(x, target):
5670
print("\nSolved {}/{} questions.".format(num_solved, len(questions)))
5771

5872

73+
def frontier_func(x):
74+
fr = get_frontier(x[0], include_paren=False, allowed_ops=search_operations)
75+
return fr
76+
77+
78+
def goal_func(x, target):
79+
return x[0] == target[0]
80+
81+
82+
def get_heuristic_score(heuristic, questions, max_timeout=5):
83+
return sum([is_solved for is_solved, result in
84+
[astar_search(q["premise"], q["target"], heuristic, frontier_func, goal_func, max_timeout=max_timeout)
85+
for q in questions]])
86+
87+
5988
if __name__ == "__main__":
6089
gh = GeneHeuristic()
6190
gh.load("astar_heuristic_weights.txt")
62-
evaluate_against_question_bank(gh.gene_meta_dist)
91+
# evaluate_against_question_bank(gh.gene_meta_dist)
6392

93+
questions = get_questions()
94+
score = get_heuristic_score(gh.gene_meta_dist, questions, max_timeout=5)
95+
print(f"Score: {score}")

0 commit comments

Comments
 (0)