Skip to content

Commit

Permalink
Add diffusion reaction PI-DeepONet examples
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Jul 10, 2023
1 parent 746fced commit 0b83d1b
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/demos/operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ PI-DeepONet
- `Advection equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet.py>`_
- `Advection equation 2D with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_aligned_pideeponet_2d.py>`_
- `Advection equation 2D with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet_2d.py>`_
- `Diffusion reaction equation with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/diff_rec_aligned_pideeponet.py>`_
- `Diffusion reaction equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/diff_rec_unaligned_pideeponet.py>`_
76 changes: 76 additions & 0 deletions examples/operator/ADR_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import matplotlib.pyplot as plt
import numpy as np


def solve_ADR(xmin, xmax, tmin, tmax, k, v, g, dg, f, u0, Nx, Nt):
"""Solve 1D
u_t = (k(x) u_x)_x - v(x) u_x + g(u) + f(x, t)
with zero boundary condition.
"""
x = np.linspace(xmin, xmax, Nx)
t = np.linspace(tmin, tmax, Nt)
h = x[1] - x[0]
dt = t[1] - t[0]
h2 = h**2

D1 = np.eye(Nx, k=1) - np.eye(Nx, k=-1)
D2 = -2 * np.eye(Nx) + np.eye(Nx, k=-1) + np.eye(Nx, k=1)
D3 = np.eye(Nx - 2)
k = k(x)
M = -np.diag(D1 @ k) @ D1 - 4 * np.diag(k) @ D2
m_bond = 8 * h2 / dt * D3 + M[1:-1, 1:-1]
v = v(x)
v_bond = 2 * h * np.diag(v[1:-1]) @ D1[1:-1, 1:-1] + 2 * h * np.diag(
v[2:] - v[: Nx - 2]
)
mv_bond = m_bond + v_bond
c = 8 * h2 / dt * D3 - M[1:-1, 1:-1] - v_bond
f = f(x[:, None], t)

u = np.zeros((Nx, Nt))
u[:, 0] = u0(x)
for i in range(Nt - 1):
gi = g(u[1:-1, i])
dgi = dg(u[1:-1, i])
h2dgi = np.diag(4 * h2 * dgi)
A = mv_bond - h2dgi
b1 = 8 * h2 * (0.5 * f[1:-1, i] + 0.5 * f[1:-1, i + 1] + gi)
b2 = (c - h2dgi) @ u[1:-1, i].T
u[1:-1, i + 1] = np.linalg.solve(A, b1 + b2)
return x, t, u


def main():
xmin, xmax = -1, 1
tmin, tmax = 0, 1
k = lambda x: x**2 - x**2 + 1
v = lambda x: np.ones_like(x)
g = lambda u: u**3
dg = lambda u: 3 * u**2
f = (
lambda x, t: np.exp(-t) * (1 + x**2 - 2 * x)
- (np.exp(-t) * (1 - x**2)) ** 3
)
u0 = lambda x: (x + 1) * (1 - x)
u_true = lambda x, t: np.exp(-t) * (1 - x**2)

# xmin, xmax = 0, 1
# tmin, tmax = 0, 1
# k = lambda x: np.ones_like(x)
# v = lambda x: np.zeros_like(x)
# g = lambda u: u ** 2
# dg = lambda u: 2 * u
# f = lambda x, t: x * (1 - x) + 2 * t - t ** 2 * (x - x ** 2) ** 2
# u0 = lambda x: np.zeros_like(x)
# u_true = lambda x, t: t * x * (1 - x)

Nx, Nt = 100, 100
x, t, u = solve_ADR(xmin, xmax, tmin, tmax, k, v, g, dg, f, u0, Nx, Nt)

print(np.max(abs(u - u_true(x[:, None], t))))
plt.plot(x, u)
plt.show()


if __name__ == "__main__":
main()
88 changes: 88 additions & 0 deletions examples/operator/diff_rec_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Backend supported: tensorflow.compat.v1"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np

from ADR_solver import solve_ADR


# PDE
def pde(x, y, v):
D = 0.01
k = 0.01
dy_t = dde.grad.jacobian(y, x, j=1)
dy_xx = dde.grad.hessian(y, x, j=0)
return dy_t - D * dy_xx + k * y**2 - v


geom = dde.geometry.Interval(0, 1)
timedomain = dde.geometry.TimeDomain(0, 1)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc = dde.icbc.DirichletBC(geomtime, lambda _: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(geomtime, lambda _: 0, lambda _, on_initial: on_initial)

pde = dde.data.TimePDE(
geomtime,
pde,
[bc, ic],
num_domain=200,
num_boundary=40,
num_initial=20,
num_test=500,
)

# Function space
func_space = dde.data.GRF(length_scale=0.2)

# Data
eval_pts = np.linspace(0, 1, num=50)[:, None]
data = dde.data.PDEOperatorCartesianProd(
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=100, batch_size=50
)

# Net
net = dde.nn.DeepONetCartesianProd(
[50, 128, 128, 128],
[2, 128, 128, 128],
"tanh",
"Glorot normal",
)

model = dde.Model(data, net)
model.compile("adam", lr=0.0005)
losshistory, train_state = model.train(epochs=20000)
dde.utils.plot_loss_history(losshistory)

func_feats = func_space.random(1)
xs = np.linspace(0, 1, num=100)[:, None]
v = func_space.eval_batch(func_feats, xs)[0]
x, t, u_true = solve_ADR(
0,
1,
0,
1,
lambda x: 0.01 * np.ones_like(x),
lambda x: np.zeros_like(x),
lambda u: 0.01 * u**2,
lambda u: 0.02 * u,
lambda x, t: np.tile(v[:, None], (1, len(t))),
lambda x: np.zeros_like(x),
100,
100,
)
u_true = u_true.T
plt.figure()
plt.imshow(u_true)
plt.colorbar()

v_branch = func_space.eval_batch(func_feats, np.linspace(0, 1, num=50)[:, None])
xv, tv = np.meshgrid(x, t)
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
u_pred = model.predict((v_branch, x_trunk))
u_pred = u_pred.reshape((100, 100))
print(dde.metrics.l2_relative_error(u_true, u_pred))
plt.figure()
plt.imshow(u_pred)
plt.colorbar()
plt.show()
88 changes: 88 additions & 0 deletions examples/operator/diff_rec_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Backend supported: tensorflow.compat.v1"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np

from ADR_solver import solve_ADR


# PDE
def pde(x, y, v):
D = 0.01
k = 0.01
dy_t = dde.grad.jacobian(y, x, j=1)
dy_xx = dde.grad.hessian(y, x, j=0)
return dy_t - D * dy_xx + k * y**2 - v


geom = dde.geometry.Interval(0, 1)
timedomain = dde.geometry.TimeDomain(0, 1)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc = dde.icbc.DirichletBC(geomtime, lambda _: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(geomtime, lambda _: 0, lambda _, on_initial: on_initial)

pde = dde.data.TimePDE(
geomtime,
pde,
[bc, ic],
num_domain=200,
num_boundary=40,
num_initial=20,
num_test=500,
)

# Function space
func_space = dde.data.GRF(length_scale=0.2)

# Data
eval_pts = np.linspace(0, 1, num=50)[:, None]
data = dde.data.PDEOperator(
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=1000
)

# Net
net = dde.nn.DeepONet(
[50, 128, 128, 128],
[2, 128, 128, 128],
"tanh",
"Glorot normal",
)

model = dde.Model(data, net)
model.compile("adam", lr=0.0005)
losshistory, train_state = model.train(epochs=50000)
dde.utils.plot_loss_history(losshistory)

func_feats = func_space.random(1)
xs = np.linspace(0, 1, num=100)[:, None]
v = func_space.eval_batch(func_feats, xs)[0]
x, t, u_true = solve_ADR(
0,
1,
0,
1,
lambda x: 0.01 * np.ones_like(x),
lambda x: np.zeros_like(x),
lambda u: 0.01 * u**2,
lambda u: 0.02 * u,
lambda x, t: np.tile(v[:, None], (1, len(t))),
lambda x: np.zeros_like(x),
100,
100,
)
u_true = u_true.T
plt.figure()
plt.imshow(u_true)
plt.colorbar()

v_branch = func_space.eval_batch(func_feats, np.linspace(0, 1, num=50)[:, None])[0]
xv, tv = np.meshgrid(x, t)
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
u_pred = model.predict((np.tile(v_branch, (100 * 100, 1)), x_trunk))
u_pred = u_pred.reshape((100, 100))
print(dde.metrics.l2_relative_error(u_true, u_pred))
plt.figure()
plt.imshow(u_pred)
plt.colorbar()
plt.show()

0 comments on commit 0b83d1b

Please sign in to comment.