22
22
# from torch_cluster import random_walk
23
23
import functools
24
24
25
+ import torch_scatter
25
26
import common
26
27
import normalization
27
28
import encode_process_decode
28
- import encode_process_decode_max_pooling
29
- import encode_process_decode_lstm
30
- import encode_process_decode_graph_structure_watcher
31
29
32
30
device = torch .device ('cuda' )
33
31
@@ -45,20 +43,14 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
45
43
self ._output_normalizer = normalization .Normalizer (size = 3 , name = 'output_normalizer' )
46
44
self ._node_normalizer = normalization .Normalizer (
47
45
size = 3 + common .NodeType .SIZE , name = 'node_normalizer' )
48
- self ._edge_normalizer = normalization .Normalizer (
46
+ self ._node_dynamic_normalizer = normalization .Normalizer (size = 1 , name = 'node_normalizer' )
47
+ self ._mesh_edge_normalizer = normalization .Normalizer (
49
48
size = 7 , name = 'edge_normalizer' ) # 2D coord + 3D coord + 2*length = 7
49
+ self ._world_edge_normalizer = normalization .Normalizer (size = 4 , name = 'world_edge_normalizer' )
50
50
self ._model_type = params ['model' ].__name__
51
51
52
- # for stochastic message passing
53
- '''
54
- self.random_walk_generation_interval = 399
55
- self.input_count = 0
56
- self.sto_mat = None
57
- self.normalized_adj_mat = None
58
- '''
59
-
60
52
self .core_model_name = core_model_name
61
- self .core_model = self . select_core_model ( core_model_name )
53
+ self .core_model = encode_process_decode
62
54
self .message_passing_steps = message_passing_steps
63
55
self .message_passing_aggregator = message_passing_aggregator
64
56
self ._attention = attention
@@ -70,8 +62,6 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
70
62
self ._ripple_node_selection_random_top_n = ripple_node_selection_random_top_n
71
63
self ._ripple_node_connection = ripple_node_connection
72
64
self ._ripple_node_ncross = ripple_node_ncross
73
- # self.stochastic_message_passing_used = False
74
- if self ._ripple_used :
75
65
self .learned_model = self .core_model .EncodeProcessDecode (
76
66
output_size = params ['size' ],
77
67
latent_size = 128 ,
@@ -93,13 +83,38 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
93
83
message_passing_aggregator = self .message_passing_aggregator , attention = self ._attention ,
94
84
ripple_used = self ._ripple_used )
95
85
96
- def select_core_model (self , core_model_name ):
97
- return {
98
- 'encode_process_decode' : encode_process_decode ,
99
- 'encode_process_decode_graph_structure_watcher' : encode_process_decode_graph_structure_watcher ,
100
- 'encode_process_decode_max_pooling' : encode_process_decode_max_pooling ,
101
- 'encode_process_decode_lstm' : encode_process_decode_lstm ,
102
- }.get (core_model_name , encode_process_decode )
86
+ def unsorted_segment_operation (self , data , segment_ids , num_segments , operation ):
87
+ """
88
+ Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
89
+
90
+ :param data: A tensor whose segments are to be summed.
91
+ :param segment_ids: The segment indices tensor.
92
+ :param num_segments: The number of segments.
93
+ :return: A tensor of same data type as the data argument.
94
+ """
95
+ assert all ([i in data .shape for i in segment_ids .shape ]), "segment_ids.shape should be a prefix of data.shape"
96
+
97
+ # segment_ids is a 1-D tensor repeat it to have the same shape as data
98
+ if len (segment_ids .shape ) == 1 :
99
+ s = torch .prod (torch .tensor (data .shape [1 :])).long ().to (device )
100
+ segment_ids = segment_ids .repeat_interleave (s ).view (segment_ids .shape [0 ], * data .shape [1 :]).to (device )
101
+
102
+ assert data .shape == segment_ids .shape , "data.shape and segment_ids.shape should be equal"
103
+
104
+ shape = [num_segments ] + list (data .shape [1 :])
105
+ result = torch .zeros (* shape )
106
+ if operation == 'sum' :
107
+ result = torch_scatter .scatter_add (data .float (), segment_ids , dim = 0 , dim_size = num_segments )
108
+ elif operation == 'max' :
109
+ result , _ = torch_scatter .scatter_max (data .float (), segment_ids , dim = 0 , dim_size = num_segments )
110
+ elif operation == 'mean' :
111
+ result = torch_scatter .scatter_mean (data .float (), segment_ids , dim = 0 , dim_size = num_segments )
112
+ elif operation == 'min' :
113
+ result , _ = torch_scatter .scatter_min (data .float (), segment_ids , dim = 0 , dim_size = num_segments )
114
+ else :
115
+ raise Exception ('Invalid operation type!' )
116
+ result = result .type (data .dtype )
117
+ return result
103
118
104
119
def _build_graph (self , inputs , is_training ):
105
120
"""Builds input graph."""
@@ -114,29 +129,7 @@ def _build_graph(self, inputs, is_training):
114
129
cells = inputs ['cells' ]
115
130
decomposed_cells = common .triangles_to_edges (cells )
116
131
senders , receivers = decomposed_cells ['two_way_connectivity' ]
117
- '''
118
- Stochastic matrix and adjacency matrix
119
- Reference: a simple and general graph neural network with stochastic message passing
120
- '''
121
- '''
122
- if self.stochastic_message_passing_used and self.input_count % self.random_walk_generation_interval == 0:
123
- start = torch.tensor(range(node_type.shape[0]), device=device)
124
- self.sto_mat = random_walk(receivers, senders, start, walk_length=20)
125
-
126
- adj_index = torch.stack((receivers, senders), dim=0)
127
- adj_index = adj_index.tolist()
128
- adj_mat = torch.sparse_coo_tensor(adj_index, [1] * receivers.shape[0],
129
- (node_type.shape[0], node_type.shape[0]), device=device)
130
- self_loop_mat = torch.diag(torch.tensor([1.0] * node_type.shape[0], device=device))
131
- self_loop_adj_mat = self_loop_mat + adj_mat
132
- adj_mat = torch.sparse.sum(adj_mat, dim=1)
133
- adj_mat = torch.sqrt(adj_mat).to_dense()
134
- square_root_degree_mat = torch.diag(adj_mat)
135
- inversed_square_root_degree_mat = torch.inverse(square_root_degree_mat)
136
- self.normalized_adj_mat = torch.matmul(inversed_square_root_degree_mat, self_loop_adj_mat)
137
- self.normalized_adj_mat = torch.matmul(self.normalized_adj_mat, inversed_square_root_degree_mat)
138
- self.input_count += 1
139
- '''
132
+
140
133
mesh_pos = inputs ['mesh_pos' ]
141
134
relative_world_pos = (torch .index_select (input = world_pos , dim = 0 , index = senders ) -
142
135
torch .index_select (input = world_pos , dim = 0 , index = receivers ))
@@ -150,24 +143,37 @@ def _build_graph(self, inputs, is_training):
150
143
151
144
mesh_edges = self .core_model .EdgeSet (
152
145
name = 'mesh_edges' ,
153
- features = self ._edge_normalizer (edge_features , is_training ),
146
+ features = self ._mesh_edge_normalizer (edge_features , is_training ),
154
147
receivers = receivers ,
155
148
senders = senders )
156
149
157
- if self .core_model == encode_process_decode and self ._ripple_used == True :
158
- return self .core_model .MultiGraphWithPos (node_features = self ._node_normalizer (node_features , is_training ),
159
- edge_sets = [mesh_edges ], target_feature = world_pos ,
160
- mesh_pos = mesh_pos , model_type = self ._model_type )
150
+ if self ._ripple_used :
151
+ num_nodes = node_type .shape [0 ]
152
+ max_node_dynamic = self .unsorted_segment_operation (torch .norm (relative_world_pos , dim = - 1 ), receivers ,
153
+ num_nodes ,
154
+ operation = 'max' ).to (device )
155
+ min_node_dynamic = self .unsorted_segment_operation (torch .norm (relative_world_pos , dim = - 1 ), receivers ,
156
+ num_nodes ,
157
+ operation = 'min' ).to (device )
158
+ node_dynamic = self ._node_dynamic_normalizer (max_node_dynamic - min_node_dynamic )
159
+
160
+ return (self .core_model .MultiGraphWithPos (node_features = node_features ,
161
+ edge_sets = [mesh_edges ], target_feature = world_pos ,
162
+ model_type = self ._model_type ,
163
+ node_dynamic = node_dynamic ))
161
164
else :
162
- return self .core_model .MultiGraph (node_features = self . _node_normalizer ( node_features , is_training ) ,
163
- edge_sets = [mesh_edges ])
165
+ return ( self .core_model .MultiGraph (node_features = node_features ,
166
+ edge_sets = [mesh_edges ]) )
164
167
165
168
def forward (self , inputs , is_training ):
166
169
graph = self ._build_graph (inputs , is_training = is_training )
167
170
if is_training :
168
- return self .learned_model (graph , self ._edge_normalizer , is_training = is_training )
171
+ return self .learned_model (graph ,
172
+ world_edge_normalizer = self ._world_edge_normalizer , is_training = is_training )
169
173
else :
170
- return self ._update (inputs , self .learned_model (graph , self ._edge_normalizer , is_training = is_training ))
174
+ return self ._update (inputs , self .learned_model (graph ,
175
+ world_edge_normalizer = self ._world_edge_normalizer ,
176
+ is_training = is_training ))
171
177
172
178
def _update (self , inputs , per_node_network_output ):
173
179
"""Integrate model outputs."""
@@ -186,13 +192,15 @@ def get_output_normalizer(self):
186
192
def save_model (self , path ):
187
193
torch .save (self .learned_model , path + "_learned_model.pth" )
188
194
torch .save (self ._output_normalizer , path + "_output_normalizer.pth" )
189
- torch .save (self ._edge_normalizer , path + "_edge_normalizer.pth" )
195
+ torch .save (self ._mesh_edge_normalizer , path + "_mesh_edge_normalizer.pth" )
196
+ torch .save (self ._world_edge_normalizer , path + "_world_edge_normalizer.pth" )
190
197
torch .save (self ._node_normalizer , path + "_node_normalizer.pth" )
191
198
192
199
def load_model (self , path ):
193
200
self .learned_model = torch .load (path + "_learned_model.pth" )
194
201
self ._output_normalizer = torch .load (path + "_output_normalizer.pth" )
195
- self ._edge_normalizer = torch .load (path + "_edge_normalizer.pth" )
202
+ self ._mesh_edge_normalizer = torch .load (path + "_mesh_edge_normalizer.pth" )
203
+ self ._world_edge_normalizer = torch .load (path + "_world_edge_normalizer.pth" )
196
204
self ._node_normalizer = torch .load (path + "_node_normalizer.pth" )
197
205
198
206
def evaluate (self ):
0 commit comments