Skip to content

Commit eed20dd

Browse files
author
Mark
committed
complete deforming plate task; add hpc run script
1 parent b21ed8d commit eed20dd

File tree

208 files changed

+1775
-1314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

208 files changed

+1775
-1314
lines changed

all_job.sh

Lines changed: 406 additions & 0 deletions
Large diffs are not rendered by default.

cfd_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
import torch.nn as nn
2424
import torch.nn.functional as F
2525
import encode_process_decode
26-
import encode_process_decode_max_pooling
27-
import encode_process_decode_lstm
28-
import encode_process_decode_graph_structure_watcher
2926

3027
device = torch.device('cuda')
3128

cloth_model.py

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
# from torch_cluster import random_walk
2323
import functools
2424

25+
import torch_scatter
2526
import common
2627
import normalization
2728
import encode_process_decode
28-
import encode_process_decode_max_pooling
29-
import encode_process_decode_lstm
30-
import encode_process_decode_graph_structure_watcher
3129

3230
device = torch.device('cuda')
3331

@@ -45,20 +43,14 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
4543
self._output_normalizer = normalization.Normalizer(size=3, name='output_normalizer')
4644
self._node_normalizer = normalization.Normalizer(
4745
size=3 + common.NodeType.SIZE, name='node_normalizer')
48-
self._edge_normalizer = normalization.Normalizer(
46+
self._node_dynamic_normalizer = normalization.Normalizer(size=1, name='node_normalizer')
47+
self._mesh_edge_normalizer = normalization.Normalizer(
4948
size=7, name='edge_normalizer') # 2D coord + 3D coord + 2*length = 7
49+
self._world_edge_normalizer = normalization.Normalizer(size=4, name='world_edge_normalizer')
5050
self._model_type = params['model'].__name__
5151

52-
# for stochastic message passing
53-
'''
54-
self.random_walk_generation_interval = 399
55-
self.input_count = 0
56-
self.sto_mat = None
57-
self.normalized_adj_mat = None
58-
'''
59-
6052
self.core_model_name = core_model_name
61-
self.core_model = self.select_core_model(core_model_name)
53+
self.core_model = encode_process_decode
6254
self.message_passing_steps = message_passing_steps
6355
self.message_passing_aggregator = message_passing_aggregator
6456
self._attention = attention
@@ -70,8 +62,6 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
7062
self._ripple_node_selection_random_top_n = ripple_node_selection_random_top_n
7163
self._ripple_node_connection = ripple_node_connection
7264
self._ripple_node_ncross = ripple_node_ncross
73-
# self.stochastic_message_passing_used = False
74-
if self._ripple_used:
7565
self.learned_model = self.core_model.EncodeProcessDecode(
7666
output_size=params['size'],
7767
latent_size=128,
@@ -93,13 +83,38 @@ def __init__(self, params, core_model_name=encode_process_decode, message_passin
9383
message_passing_aggregator=self.message_passing_aggregator, attention=self._attention,
9484
ripple_used=self._ripple_used)
9585

96-
def select_core_model(self, core_model_name):
97-
return {
98-
'encode_process_decode': encode_process_decode,
99-
'encode_process_decode_graph_structure_watcher': encode_process_decode_graph_structure_watcher,
100-
'encode_process_decode_max_pooling': encode_process_decode_max_pooling,
101-
'encode_process_decode_lstm': encode_process_decode_lstm,
102-
}.get(core_model_name, encode_process_decode)
86+
def unsorted_segment_operation(self, data, segment_ids, num_segments, operation):
87+
"""
88+
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
89+
90+
:param data: A tensor whose segments are to be summed.
91+
:param segment_ids: The segment indices tensor.
92+
:param num_segments: The number of segments.
93+
:return: A tensor of same data type as the data argument.
94+
"""
95+
assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
96+
97+
# segment_ids is a 1-D tensor repeat it to have the same shape as data
98+
if len(segment_ids.shape) == 1:
99+
s = torch.prod(torch.tensor(data.shape[1:])).long().to(device)
100+
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]).to(device)
101+
102+
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
103+
104+
shape = [num_segments] + list(data.shape[1:])
105+
result = torch.zeros(*shape)
106+
if operation == 'sum':
107+
result = torch_scatter.scatter_add(data.float(), segment_ids, dim=0, dim_size=num_segments)
108+
elif operation == 'max':
109+
result, _ = torch_scatter.scatter_max(data.float(), segment_ids, dim=0, dim_size=num_segments)
110+
elif operation == 'mean':
111+
result = torch_scatter.scatter_mean(data.float(), segment_ids, dim=0, dim_size=num_segments)
112+
elif operation == 'min':
113+
result, _ = torch_scatter.scatter_min(data.float(), segment_ids, dim=0, dim_size=num_segments)
114+
else:
115+
raise Exception('Invalid operation type!')
116+
result = result.type(data.dtype)
117+
return result
103118

104119
def _build_graph(self, inputs, is_training):
105120
"""Builds input graph."""
@@ -114,29 +129,7 @@ def _build_graph(self, inputs, is_training):
114129
cells = inputs['cells']
115130
decomposed_cells = common.triangles_to_edges(cells)
116131
senders, receivers = decomposed_cells['two_way_connectivity']
117-
'''
118-
Stochastic matrix and adjacency matrix
119-
Reference: a simple and general graph neural network with stochastic message passing
120-
'''
121-
'''
122-
if self.stochastic_message_passing_used and self.input_count % self.random_walk_generation_interval == 0:
123-
start = torch.tensor(range(node_type.shape[0]), device=device)
124-
self.sto_mat = random_walk(receivers, senders, start, walk_length=20)
125-
126-
adj_index = torch.stack((receivers, senders), dim=0)
127-
adj_index = adj_index.tolist()
128-
adj_mat = torch.sparse_coo_tensor(adj_index, [1] * receivers.shape[0],
129-
(node_type.shape[0], node_type.shape[0]), device=device)
130-
self_loop_mat = torch.diag(torch.tensor([1.0] * node_type.shape[0], device=device))
131-
self_loop_adj_mat = self_loop_mat + adj_mat
132-
adj_mat = torch.sparse.sum(adj_mat, dim=1)
133-
adj_mat = torch.sqrt(adj_mat).to_dense()
134-
square_root_degree_mat = torch.diag(adj_mat)
135-
inversed_square_root_degree_mat = torch.inverse(square_root_degree_mat)
136-
self.normalized_adj_mat = torch.matmul(inversed_square_root_degree_mat, self_loop_adj_mat)
137-
self.normalized_adj_mat = torch.matmul(self.normalized_adj_mat, inversed_square_root_degree_mat)
138-
self.input_count += 1
139-
'''
132+
140133
mesh_pos = inputs['mesh_pos']
141134
relative_world_pos = (torch.index_select(input=world_pos, dim=0, index=senders) -
142135
torch.index_select(input=world_pos, dim=0, index=receivers))
@@ -150,24 +143,37 @@ def _build_graph(self, inputs, is_training):
150143

151144
mesh_edges = self.core_model.EdgeSet(
152145
name='mesh_edges',
153-
features=self._edge_normalizer(edge_features, is_training),
146+
features=self._mesh_edge_normalizer(edge_features, is_training),
154147
receivers=receivers,
155148
senders=senders)
156149

157-
if self.core_model == encode_process_decode and self._ripple_used == True:
158-
return self.core_model.MultiGraphWithPos(node_features=self._node_normalizer(node_features, is_training),
159-
edge_sets=[mesh_edges], target_feature=world_pos,
160-
mesh_pos=mesh_pos, model_type=self._model_type)
150+
if self._ripple_used:
151+
num_nodes = node_type.shape[0]
152+
max_node_dynamic = self.unsorted_segment_operation(torch.norm(relative_world_pos, dim=-1), receivers,
153+
num_nodes,
154+
operation='max').to(device)
155+
min_node_dynamic = self.unsorted_segment_operation(torch.norm(relative_world_pos, dim=-1), receivers,
156+
num_nodes,
157+
operation='min').to(device)
158+
node_dynamic = self._node_dynamic_normalizer(max_node_dynamic - min_node_dynamic)
159+
160+
return (self.core_model.MultiGraphWithPos(node_features=node_features,
161+
edge_sets=[mesh_edges], target_feature=world_pos,
162+
model_type=self._model_type,
163+
node_dynamic=node_dynamic))
161164
else:
162-
return self.core_model.MultiGraph(node_features=self._node_normalizer(node_features, is_training),
163-
edge_sets=[mesh_edges])
165+
return (self.core_model.MultiGraph(node_features=node_features,
166+
edge_sets=[mesh_edges]))
164167

165168
def forward(self, inputs, is_training):
166169
graph = self._build_graph(inputs, is_training=is_training)
167170
if is_training:
168-
return self.learned_model(graph, self._edge_normalizer, is_training=is_training)
171+
return self.learned_model(graph,
172+
world_edge_normalizer=self._world_edge_normalizer, is_training=is_training)
169173
else:
170-
return self._update(inputs, self.learned_model(graph, self._edge_normalizer, is_training=is_training))
174+
return self._update(inputs, self.learned_model(graph,
175+
world_edge_normalizer=self._world_edge_normalizer,
176+
is_training=is_training))
171177

172178
def _update(self, inputs, per_node_network_output):
173179
"""Integrate model outputs."""
@@ -186,13 +192,15 @@ def get_output_normalizer(self):
186192
def save_model(self, path):
187193
torch.save(self.learned_model, path + "_learned_model.pth")
188194
torch.save(self._output_normalizer, path + "_output_normalizer.pth")
189-
torch.save(self._edge_normalizer, path + "_edge_normalizer.pth")
195+
torch.save(self._mesh_edge_normalizer, path + "_mesh_edge_normalizer.pth")
196+
torch.save(self._world_edge_normalizer, path + "_world_edge_normalizer.pth")
190197
torch.save(self._node_normalizer, path + "_node_normalizer.pth")
191198

192199
def load_model(self, path):
193200
self.learned_model = torch.load(path + "_learned_model.pth")
194201
self._output_normalizer = torch.load(path + "_output_normalizer.pth")
195-
self._edge_normalizer = torch.load(path + "_edge_normalizer.pth")
202+
self._mesh_edge_normalizer = torch.load(path + "_mesh_edge_normalizer.pth")
203+
self._world_edge_normalizer = torch.load(path + "_world_edge_normalizer.pth")
196204
self._node_normalizer = torch.load(path + "_node_normalizer.pth")
197205

198206
def evaluate(self):

cloth_noripple_max_attention_10.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=10
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=True --ripple_used=False --use_prev_config=True

cloth_noripple_max_attention_15.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=15
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=True --ripple_used=False --use_prev_config=True

cloth_noripple_max_attention_3.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=3
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=True --ripple_used=False --use_prev_config=True

cloth_noripple_max_attention_5.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=5
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=True --ripple_used=False --use_prev_config=True

cloth_noripple_max_attention_7.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=7
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=True --ripple_used=False --use_prev_config=True

cloth_noripple_max_noattention_10.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=10
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=False --ripple_used=False --use_prev_config=True

cloth_noripple_max_noattention_15.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=deform_noripple_sum_noattention
3+
#SBATCH --partition=gpu_8
4+
5+
i=15
6+
srun --exclusive -N1 -p gpu_8 --gres=gpu python run_model.py --model=cloth --mode=all --rollout_split=valid --dataset=flag_simple --epochs=25 --trajectories=1000 --num_rollouts=100 --core_model=encode_process_decode --message_passing_aggregator=max --message_passing_steps=${i} --attention=False --ripple_used=False --use_prev_config=True

0 commit comments

Comments
 (0)