Skip to content

Commit

Permalink
Backend JAX: Transform bug fix (#1717)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonneted authored Apr 20, 2024
1 parent d4bc99f commit 862b67f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
4 changes: 2 additions & 2 deletions deepxde/nn/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def transform_handling_flat(x):
if isinstance(x, (list, tuple)):
return transform(x)
if x.ndim == 1:
return transform(x.reshape(1, -1)).squeeze()
return transform(x.reshape(1, -1)).reshape(-1)
return transform(x)

self._input_transform = transform_handling_flat
Expand All @@ -36,7 +36,7 @@ def transform_handling_flat(inputs, outputs):
if isinstance(inputs, (list, tuple)):
return transform(inputs, outputs)
if inputs.ndim == 1:
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).squeeze()
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).reshape(-1)
return transform(inputs, outputs)

self._output_transform = transform_handling_flat
25 changes: 11 additions & 14 deletions examples/pinn_forward/Helmholtz_Dirichlet_2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle"""
import deepxde as dde
import numpy as np

Expand All @@ -12,24 +12,21 @@
parameters = [1e-3, 3, 150, "sin"]

# Define sine function
if dde.backend.backend_name == "pytorch":
sin = dde.backend.pytorch.sin
elif dde.backend.backend_name == "paddle":
sin = dde.backend.paddle.sin
else:
from deepxde.backend import tf

sin = tf.sin
sin = dde.backend.sin

learning_rate, num_dense_layers, num_dense_nodes, activation = parameters


def pde(x, y):
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
dy_yy = dde.grad.hessian(y, x, i=1, j=1)

f = k0 ** 2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2])
return -dy_xx - dy_yy - k0 ** 2 * y - f
if dde.backend.backend_name == "jax":
y = y[0]
dy_xx = dy_xx[0]
dy_yy = dy_yy[0]

f = k0**2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2])
return -dy_xx - dy_yy - k0**2 * y - f


def func(x):
Expand Down Expand Up @@ -65,10 +62,10 @@ def boundary(_, on_boundary):
geom,
pde,
bc,
num_domain=nx_train ** 2,
num_domain=nx_train**2,
num_boundary=4 * nx_train,
solution=func,
num_test=nx_test ** 2,
num_test=nx_test**2,
)

net = dde.nn.FNN(
Expand Down

0 comments on commit 862b67f

Please sign in to comment.