-
Notifications
You must be signed in to change notification settings - Fork 1
/
structure_utils.py
168 lines (132 loc) · 5.39 KB
/
structure_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
LEFT_NAME = 0
RIGHT_NAME = 1
K = 2
BINARY_FUNC = 3
ACTIVATION = 4
NAME = 5
# k values are wrong right now
GRU_STRUCTURE = [['x', 'h', 0, 'add', 'sigmoid', 'z1'],
['x', 'h', 1, 'add', 'sigmoid', 'r'],
['x', 'h', 1, 'add', 'sigmoid', 'z2'],
['r', 'h', 2, 'mul', 'identity', 'r*h'],
['x', 'r*h', 3, 'add', 'tanh', 'h_tilde'],
['z1', '0', 2, 'add', 'minus', '1-z'],
['1-z', 'h_tilde', 3, 'mul', 'identity', '(1-z)*h_tilde'],
['z2', 'h', 3, 'mul', 'identity', 'z*h'],
['z*h', '(1-z)*h_tilde', 3, 'add', 'identity', 'ht']]
class TreeNode:
"""Simple class to store nodes of a structure tree.
A structure tree just describes the structure of the update equations, and
doesn't record vector information.
"""
def __init__(self):
self.left = None
self.right = None
self.k = None
self.binary_func = None
self.activation = None
self.name = None
def __eq__(self, tree2):
"""Returns True if the two trees have the same basic attributes.
These attributes are children, name, binary function, and activation
function.
Input:
tree2 - The tree to compare equality with
"""
return (self.left == tree2.left and self.right == tree2.right and
self.name == tree2.name and self.binary_func == tree2.binary_func and
self.activation == tree2.activation)
def is_leaf(self):
return self.left is None and self.right is None
def leaf_nodes():
"""Returns a dict of the basic leaf vectors that can be used by the equations.
Here they are x, h, and the zero vector.
"""
x = TreeNode()
x.name = 'x'
h = TreeNode()
h.name = 'h'
zero = TreeNode()
zero.name = '0'
return {'x': x, 'h': h, '0': zero}
def structure2tree(structure):
"""Returns the tree representation of the given structure.
Input:
structure - List of nodes, where each node is a list of
[leftname, rightname, k, binary_func, activation, name]. Child nodes
must appear before their parents.
Returns:
node - TreeNode of the root of the structure tree.
"""
nodes = leaf_nodes()
for n in structure:
node = TreeNode()
if n[LEFT_NAME] not in nodes or n[RIGHT_NAME] not in nodes:
print(n[LEFT_NAME])
print(n[RIGHT_NAME])
raise ValueError('Children must appear before their parents in \
the structure')
node.left = nodes[n[LEFT_NAME]]
node.right = nodes[n[RIGHT_NAME]]
node.k = n[K]
node.binary_func = n[BINARY_FUNC]
node.activation = n[ACTIVATION]
node.name = n[NAME]
nodes[node.name] = node
return node
def trees_are_isomorphic(tree1, tree2):
"""Returns True if the two trees are isomorphic."""
if tree1 is None:
return tree2 is None
if tree2 is None:
return tree1 is None
# Check basic properties of the current node. # This doesn't actually matter for the structure
# if tree1.binary_func != tree2.binary_func or tree1.activation != tree2.activation:
# differences += 1
# Check if they're the same leaf node
# if (tree1.is_leaf() or tree2.is_leaf()) and tree1 != tree2:
# return False
# Recursive case
return ((trees_are_isomorphic(tree1.left, tree2.left) and
trees_are_isomorphic(tree1.right, tree2.right)) or
(trees_are_isomorphic(tree1.left, tree2.right) and
trees_are_isomorphic(tree1.right, tree2.left)))
def structures_are_equal(structure1, structure2):
"""Returns true if the two tree structures are isomorphic.
If the two tree structures are isomorphic, that means they represent the
same underlying equations and for all practical purposes are the same.
Inputs:
structure1 - A tree structure
structure2 - A tree structure
"""
if not (isinstance(structure1, list) and isinstance(structure2, list)):
raise TypeError('structures_are_equal takes type list.')
# TODO: Do different permutations of Lk and Rk count as the same tree?
if len(structure1) != len(structure2):
return False
tree1 = structure2tree(structure1)
tree2 = structure2tree(structure2)
return trees_are_isomorphic(tree1, tree2)
def structure_is_gru(structure):
return structures_are_equal(structure, GRU_STRUCTURE)
def n_differences(structure1, structure2):
if not structures_are_equal(structure1, structure2):
raise ValueError('Structures are not isomorphic')
tree1 = structure2tree(structure1)
tree2 = structure2tree(structure2)
return n_tree_differences(tree1, tree2)
# TODO: Account for Lk, Rk, bk
def n_tree_differences(tree1, tree2):
if tree1 is None and tree2 is None:
return 0
elif tree1 is None or tree2 is None:
return np.Inf
differences = 0
if tree1.binary_func != tree2.binary_func:
differences += 1
if tree1.activation != tree2.activation:
differences += 1
child_diffs = [n_tree_differences(tree1.left, tree2.left) + n_tree_differences(tree1.right, tree2.right),
n_tree_differences(tree1.left, tree2.right) + n_tree_differences(tree1.right, tree2.left)]
return differences + min(child_diffs)