Skip to content

Commit

Permalink
add pinnsformer
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhaixu2016 committed Oct 18, 2024
1 parent 02a8f53 commit fef4811
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 93 deletions.
25 changes: 19 additions & 6 deletions 1d_reaction_point_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,24 @@
res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)

if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, 0:1], res[:, 1:2]
x_left, t_left = b_left[:, 0:1], b_left[:, 1:2]
x_right, t_right = b_right[:, 0:1], b_right[:, 1:2]
x_upper, t_upper = b_upper[:, 0:1], b_upper[:, 1:2]
x_lower, t_lower = b_lower[:, 0:1], b_lower[:, 1:2]
x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]


def init_weights(m):
Expand All @@ -50,6 +57,9 @@ def init_weights(m):
elif args.model == 'QRes':
model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
model.apply(init_weights)
else:
model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
Expand Down Expand Up @@ -96,8 +106,11 @@ def closure():
torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_point.pt')

# Visualize
if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res_test = make_time_sequence(res_test, num_step=5, step=1e-4)

res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]

with torch.no_grad():
pred = model(x_test, t_test)[:, 0:1]
Expand Down
38 changes: 25 additions & 13 deletions 1d_reaction_region_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@
res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)

if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, 0:1], res[:, 1:2]
x_left, t_left = b_left[:, 0:1], b_left[:, 1:2]
x_right, t_right = b_right[:, 0:1], b_right[:, 1:2]
x_upper, t_upper = b_upper[:, 0:1], b_upper[:, 1:2]
x_lower, t_lower = b_lower[:, 0:1], b_lower[:, 1:2]
x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]


def init_weights(m):
Expand All @@ -53,6 +60,9 @@ def init_weights(m):
elif args.model == 'QRes':
model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device)
model.apply(init_weights)
elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
model.apply(init_weights)
else:
model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
Expand All @@ -75,16 +85,15 @@ def init_weights(m):

###### Region Optimization with Monte Carlo Approximation ######
def closure():
B, C = x_res.shape
x_res_region_sample_list = []
t_res_region_sample_list = []
for i in range(sample_num):
x_region_sample = (torch.rand(B, C).to(x_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
t_region_sample = (torch.rand(B, C).to(t_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
x_res_region_sample_list.append(x_res + x_region_sample)
t_res_region_sample_list.append(t_res + t_region_sample)
x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0)
Expand Down Expand Up @@ -139,8 +148,11 @@ def closure():
torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_region.pt')

# Visualize
if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res_test = make_time_sequence(res_test, num_step=5, step=1e-4)

res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]

with torch.no_grad():
pred = model(x_test, t_test)[:, 0:1]
Expand Down
29 changes: 22 additions & 7 deletions 1d_wave_point_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,41 @@
res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101)
res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)

if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, 0:1], res[:, 1:2]
x_left, t_left = b_left[:, 0:1], b_left[:, 1:2]
x_right, t_right = b_right[:, 0:1], b_right[:, 1:2]
x_upper, t_upper = b_upper[:, 0:1], b_upper[:, 1:2]
x_lower, t_lower = b_lower[:, 0:1], b_lower[:, 1:2]
x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]


def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)


if args.model == 'KAN':
model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \
model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
noise_scale_base=0.25, device=device).to(device)
elif args.model == 'QRes':
model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
model.apply(init_weights)
else:
model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
Expand Down Expand Up @@ -111,8 +123,11 @@ def closure():
torch.save(model.state_dict(), f'./results/1dwave_{args.model}_point.pt')

# Visualize
if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res_test = make_time_sequence(res_test, num_step=5, step=1e-4)

res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]

with torch.no_grad():
pred = model(x_test, t_test)[:, 0:1]
Expand Down
40 changes: 26 additions & 14 deletions 1d_wave_region_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@
res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101)
res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)

if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, 0:1], res[:, 1:2]
x_left, t_left = b_left[:, 0:1], b_left[:, 1:2]
x_right, t_right = b_right[:, 0:1], b_right[:, 1:2]
x_upper, t_upper = b_upper[:, 0:1], b_upper[:, 1:2]
x_lower, t_lower = b_lower[:, 0:1], b_lower[:, 1:2]
x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]


def init_weights(m):
Expand All @@ -47,11 +54,14 @@ def init_weights(m):


if args.model == 'KAN':
model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \
model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
noise_scale_base=0.25, device=device).to(device)
elif args.model == 'QRes':
model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device)
model.apply(init_weights)
elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
model.apply(init_weights)
else:
model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
Expand All @@ -77,16 +87,15 @@ def init_weights(m):

###### Region Optimization with Monte Carlo Approximation ######
def closure():
B, C = x_res.shape
x_res_region_sample_list = []
t_res_region_sample_list = []
for i in range(sample_num):
x_region_sample = (torch.rand(B, C).to(x_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
t_region_sample = (torch.rand(B, C).to(t_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance,
a_min=0,
a_max=0.01)
x_res_region_sample_list.append(x_res + x_region_sample)
t_res_region_sample_list.append(t_res + t_region_sample)
x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0)
Expand Down Expand Up @@ -153,8 +162,11 @@ def closure():
torch.save(model.state_dict(), f'./results/1dwave_{args.model}_region.pt')

# Visualize PINNs
if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res_test = make_time_sequence(res_test, num_step=5, step=1e-4)

res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]

with torch.no_grad():
pred = model(x_test, t_test)[:, 0:1]
Expand Down
27 changes: 20 additions & 7 deletions convection_point_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,24 @@
res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)

if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, 0:1], res[:, 1:2]
x_left, t_left = b_left[:, 0:1], b_left[:, 1:2]
x_right, t_right = b_right[:, 0:1], b_right[:, 1:2]
x_upper, t_upper = b_upper[:, 0:1], b_upper[:, 1:2]
x_lower, t_lower = b_lower[:, 0:1], b_lower[:, 1:2]
x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]


def init_weights(m):
Expand All @@ -44,11 +51,14 @@ def init_weights(m):


if args.model == 'KAN':
model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \
model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
noise_scale_base=0.25, device=device).to(device)
elif args.model == 'QRes':
model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
model.apply(init_weights)
else:
model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
model.apply(init_weights)
Expand Down Expand Up @@ -96,8 +106,11 @@ def closure():
torch.save(model.state_dict(), f'./results/1dconvection_{args.model}_point.pt')

# Visualize
if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
res_test = make_time_sequence(res_test, num_step=5, step=1e-4)

res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]

with torch.no_grad():
pred = model(x_test, t_test)[:, 0:1]
Expand Down
Loading

0 comments on commit fef4811

Please sign in to comment.