Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhaixu2016 committed May 2, 2023
1 parent 872c195 commit 4c310ef
Show file tree
Hide file tree
Showing 28 changed files with 1,133 additions and 289 deletions.
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ To tackle both the approximation and computation complexities in PDE-governed ta

## LSM vs. Previous Methods

Different from previous methods, instead of learning a single operator, inspired by classical spectral methods in numerical analysis, LSM composes complex mappings into multiple basis operators. Along with the latent space projection, LSM presents favorable approximation and convergence properties.
Different from previous methods that learn a single operator directly, inspired by classical spectral methods in numerical analysis, LSM composes complex mappings into multiple basis operators. Along with the latent space projection, LSM presents favorable approximation and convergence properties.

<p align="center">
<img src=".\fig\compare.png" height = "200" alt="" align=center />
Expand All @@ -32,7 +32,8 @@ Different from previous methods, instead of learning a single operator, inspired
pip install -r requirements.txt
```

2. Prepare Data. You can obtain the datasets from the following links.
2. Prepare Data. You can obtain experimental datasets from the following links.


| Dataset | Task | Geometry | Link |
| ------------- | --------------------------------------- | --------------- | ------------------------------------------------------------ |
Expand All @@ -44,7 +45,7 @@ pip install -r requirements.txt
| AirFoil | Estimate airflow velocity around airfoil | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |
| Pipe | Estimate fluid velocity in a pipe | Structured Mesh | [[Google Cloud]](https://drive.google.com/drive/folders/1YBuaoTdOSr_qzaow-G-iwvbUI7fiUzu8) |

2. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
3. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:

```bash
bash scripts/elas_lsm.sh # for Elasticity-P
Expand All @@ -56,14 +57,22 @@ bash scripts/airfoil_lsm.sh # for Airfoil
bash scripts/pipe_lsm.sh # for Pipe
```

4. Develop your own model. Here are the instructions:

- You can add your model file under the folder `./models/`.
- Add the model into the `./model_dict.py`.
- Add a script file under the folder `./scripts/` and change the argument `--model`.

Note: For clearness and easy comparison, we also include the FNO in this repository.

## Results

We extensively experiment on seven benchmarks and compare LSM with 13 baselines. LSM achieves the consistent state-of-the-art in both solid and fluid physics (11.5% averaged error reduction).

<p align="center">
<img src=".\fig\main_results.png" height = "350" alt="" align=center />
<br><br>
<b>Table 1.</b> Model perfromance on seven benchmarks. MSE is recorded.
<b>Table 1.</b> Model performance on seven benchmarks. MSE is recorded.
</p>

## Showcases
Expand Down
72 changes: 36 additions & 36 deletions exp_airfoils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
from utilities3 import *
from adam import Adam
from utils.utilities3 import *
from utils.adam import Adam
from utils.params import get_args
from model_dict import get_model
import math
import os
from models.LSM_2D import LSM2d

torch.manual_seed(0)
np.random.seed(0)
Expand All @@ -15,33 +16,36 @@
################################################################
# configs
################################################################
INPUT_X = '/home/wuhaixu/airfoil/naca/NACA_Cylinder_X.npy'
INPUT_Y = '/home/wuhaixu/airfoil/naca/NACA_Cylinder_Y.npy'
OUTPUT_Sigma = '/home/wuhaixu/airfoil/naca/NACA_Cylinder_Q.npy'

ntrain = 1000
ntest = 200
in_channels = 2
out_channels = 1
r1 = 1
r2 = 1
s1 = int(((221 - 1) / r1) + 1)
s2 = int(((51 - 1) / r2) + 1)

batch_size = 20
learning_rate = 0.001
epochs = 501
step_size = 100
gamma = 0.5

num_basis = 12
num_token = 4
width = 32
patch_size = [14, 4]
padding = [13, 3]

model_save_path = './checkpoints/airfoil'
model_save_name = 'airfoil_lsm.pt'
args = get_args()

INPUT_X = os.path.join(args.data_path, './naca/NACA_Cylinder_X.npy')
INPUT_Y = os.path.join(args.data_path, './naca/NACA_Cylinder_Y.npy')
OUTPUT_Sigma = os.path.join(args.data_path, './naca/NACA_Cylinder_Q.npy')

ntrain = args.ntrain
ntest = args.ntest
N = args.ntotal
in_channels = args.in_dim
out_channels = args.out_dim
r1 = args.h_down
r2 = args.w_down
s1 = int(((args.h - 1) / r1) + 1)
s2 = int(((args.w - 1) / r2) + 1)

batch_size = args.batch_size
learning_rate = args.learning_rate
epochs = args.epochs
step_size = args.step_size
gamma = args.gamma

model_save_path = args.model_save_path
model_save_name = args.model_save_name

################################################################
# models
################################################################
model = get_model(args)
print(count_params(model))

################################################################
# load data and data normalization
Expand All @@ -68,12 +72,6 @@
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
shuffle=False)

################################################################
# models
################################################################
model = LSM2d(in_channels, out_channels, width, patch_size, num_basis, num_token, padding).cuda()
print(count_params(model))

################################################################
# training and evaluation
################################################################
Expand Down Expand Up @@ -119,6 +117,8 @@
if ep % step_size == 0:
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
print('save model')
torch.save(model.state_dict(), os.path.join(model_save_path, model_save_name))
ind = -1
X = x[ind, :, :, 0].squeeze().detach().cpu().numpy()
Y = x[ind, :, :, 1].squeeze().detach().cpu().numpy()
Expand Down
73 changes: 38 additions & 35 deletions exp_darcy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
from utilities3 import *
from adam import Adam
from utils.utilities3 import *
from utils.adam import Adam
from utils.params import get_args
from model_dict import get_model
import math
import os
from models.LSM_2D import LSM2d

torch.manual_seed(0)
np.random.seed(0)
Expand All @@ -15,32 +16,35 @@
################################################################
# configs
################################################################
TRAIN_PATH = '/home/wuhaixu/piececonst_r421_N1024_smooth1.mat'
TEST_PATH = '/home/wuhaixu/piececonst_r421_N1024_smooth2.mat'

ntrain = 1000
ntest = 200
in_channels = 1
out_channels = 1
r1 = 5
r2 = 5
s1 = int(((421 - 1) / r1) + 1)
s2 = int(((421 - 1) / r2) + 1)

batch_size = 20
learning_rate = 0.001
epochs = 501
step_size = 100
gamma = 0.5

num_basis = 12
num_token = 4
width = 64
patch_size = [3, 3]
padding = [11, 11]

model_save_path = './checkpoints/darcy'
model_save_name = 'darcy_lsm.pt'
args = get_args()

TRAIN_PATH = os.path.join(args.data_path, './piececonst_r421_N1024_smooth1.mat')
TEST_PATH = os.path.join(args.data_path, './piececonst_r421_N1024_smooth2.mat')

ntrain = args.ntrain
ntest = args.ntest
N = args.ntotal
in_channels = args.in_dim
out_channels = args.out_dim
r1 = args.h_down
r2 = args.w_down
s1 = int(((args.h - 1) / r1) + 1)
s2 = int(((args.w - 1) / r2) + 1)

batch_size = args.batch_size
learning_rate = args.learning_rate
epochs = args.epochs
step_size = args.step_size
gamma = args.gamma

model_save_path = args.model_save_path
model_save_name = args.model_save_name

################################################################
# models
################################################################
model = get_model(args)
print(count_params(model))

################################################################
# load data and data normalization
Expand Down Expand Up @@ -69,12 +73,6 @@
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
shuffle=False)

################################################################
# models
################################################################
model = LSM2d(in_channels, out_channels, width, patch_size, num_basis, num_token, padding).cuda()
print(count_params(model))

################################################################
# training and evaluation
################################################################
Expand Down Expand Up @@ -118,3 +116,8 @@

t2 = default_timer()
print(ep, t2 - t1, train_l2, test_l2)
if ep % step_size == 0:
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
print('save model')
torch.save(model.state_dict(), os.path.join(model_save_path, model_save_name))
73 changes: 36 additions & 37 deletions exp_elas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
from utilities3 import *
from adam import Adam
from utils.utilities3 import *
from utils.adam import Adam
from utils.params import get_args
from model_dict import get_model
import math
import os
from models.LSM_Irregular_Geo import LSM2d
from models.LSM_Irregular_Geo import IPHI

torch.manual_seed(0)
np.random.seed(0)
Expand All @@ -16,31 +16,33 @@
################################################################
# configs
################################################################
PATH_Sigma = '/home/wuhaixu/elasticity/Meshes/Random_UnitCell_sigma_10.npy'
PATH_XY = '/home/wuhaixu/elasticity/Meshes/Random_UnitCell_XY_10.npy'
PATH_rr = '/home/wuhaixu/elasticity/Meshes/Random_UnitCell_rr_10.npy'
N = 2000
ntrain = 1000
ntest = 200

in_channels = 2
out_channels = 1

batch_size = 20
learning_rate = 0.0005
epochs = 501
step_size = 100
gamma = 0.5

num_basis = 12
num_token = 4
width = 32
patch_size = [6, 6]
padding = [0, 0]
modes = 12

model_save_path = './checkpoints/elas'
model_save_name = 'elas_lsm.pt'
args = get_args()

PATH_Sigma = os.path.join(args.data_path, './Meshes/Random_UnitCell_sigma_10.npy')
PATH_XY = os.path.join(args.data_path, './Meshes/Random_UnitCell_XY_10.npy')
PATH_rr = os.path.join(args.data_path, './Meshes/Random_UnitCell_rr_10.npy')

ntrain = args.ntrain
ntest = args.ntest
N = args.ntotal
in_channels = args.in_dim
out_channels = args.out_dim

batch_size = args.batch_size
learning_rate = args.learning_rate
epochs = args.epochs
step_size = args.step_size
gamma = args.gamma

model_save_path = args.model_save_path
model_save_name = args.model_save_name

################################################################
# models
################################################################
model, model_iphi = get_model(args)
print(count_params(model), count_params(model_iphi))
params = list(model.parameters()) + list(model_iphi.parameters())

################################################################
# load data and data normalization
Expand Down Expand Up @@ -68,17 +70,9 @@
batch_size=batch_size,
shuffle=False)

################################################################
# models
################################################################
model = LSM2d(in_channels, out_channels, width, patch_size, num_basis, num_token, padding).cuda()
model_iphi = IPHI().cuda()
print(count_params(model), count_params(model_iphi))

################################################################
# training and evaluation
################################################################
params = list(model.parameters()) + list(model_iphi.parameters())
optimizer = Adam(params, lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

Expand Down Expand Up @@ -118,3 +112,8 @@

t2 = default_timer()
print(ep, t2 - t1, train_l2, test_l2)
if ep % step_size == 0:
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
print('save model')
torch.save(model.state_dict(), os.path.join(model_save_path, model_save_name))
Loading

0 comments on commit 4c310ef

Please sign in to comment.