-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separable-PINN in DeepXDE #1776
base: master
Are you sure you want to change the base?
Conversation
There are too many modifications. We can start with a mathematically equivalent (but slow speed) implementation by repeating the n inputs to n**2. This is similar to DeepONet deepxde/deepxde/nn/tensorflow/deeponet.py Line 17 in ad6399b
vs DeepONetCartesianProd deepxde/deepxde/nn/tensorflow/deeponet.py Line 153 in ad6399b
Then the code change would be very minimal. |
OK, I'll try that when I have more time, hopefully soon. |
This is an implementation of the SPINN model: https://jwcho5576.github.io/spinn.github.io/
The code for the network architecture (
snn.py
) is directly adapted from the original paper (https://github.com/stnamjef/SPINN)I've achieved really fast convergence with this implementation of SPINN compared to PINN (similar to the paper claim), for both forward and inverse quantification on the linear elastic plate problem.
Forward comparison :
https://github.com/lululxvi/deepxde/assets/53513604/499a961d-748c-458f-be99-56156b516ace
Inverse with PINN :
https://github.com/lululxvi/deepxde/assets/53513604/ab89554d-b82b-406a-8d73-05b3f72a3961
Inverse with SPINN :
https://github.com/lululxvi/deepxde/assets/53513604/07171442-ea03-48b4-87a6-8b5094f6809c
The implementation was more complicated than expected for the following reasons:
due to its architecture, SPINN takes an input of size n and outputs an array of size n**dim (it does the cartesian product of each coordinate) :
(n,2) --> SPINN --> n**2
This brings some difficulty with how inputs are handled in
data.pde
.Indeed, all inputs are concatenated (PDE and BCS points) and throw the net simultaneously.
So if we have
n_PDE
PDE points andn_BC
BC points we will end up with(n_PDE+n_BC)**2
points instead ofn_PDE**2+n_BC**2
I tried to find a workaround with minimal changes to
model.py
, and came up with the following:adding a
list_handler
decorator to theoutputs
function in JAX so that it can handle list inputs by applying the function to each input and then concatenates.I then modified the
pde.py
file by adding ais_SPINN
argument, if true, PDE and BC inputs are put together in a list instead of stacked. Thebcs_start
should also be modified as the outputs sizes no longer equal the inputs.I understand that this brings a lot of changes to
data.pde
, so another possibility is to create a separatedata
subclass dedicated to SPINN so that thedata.pde
class isn't overly complicated.