Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separable-PINN in DeepXDE #1776

Draft
wants to merge 35 commits into
base: master
Choose a base branch
from
Draft

Conversation

bonneted
Copy link
Contributor

This is an implementation of the SPINN model: https://jwcho5576.github.io/spinn.github.io/

The code for the network architecture (snn.py) is directly adapted from the original paper (https://github.com/stnamjef/SPINN)

I've achieved really fast convergence with this implementation of SPINN compared to PINN (similar to the paper claim), for both forward and inverse quantification on the linear elastic plate problem.
Forward comparison :
https://github.com/lululxvi/deepxde/assets/53513604/499a961d-748c-458f-be99-56156b516ace

Inverse with PINN :
https://github.com/lululxvi/deepxde/assets/53513604/ab89554d-b82b-406a-8d73-05b3f72a3961

Inverse with SPINN :
https://github.com/lululxvi/deepxde/assets/53513604/07171442-ea03-48b4-87a6-8b5094f6809c

The implementation was more complicated than expected for the following reasons:
due to its architecture, SPINN takes an input of size n and outputs an array of size n**dim (it does the cartesian product of each coordinate) :
(n,2) --> SPINN --> n**2

This brings some difficulty with how inputs are handled in data.pde.
Indeed, all inputs are concatenated (PDE and BCS points) and throw the net simultaneously.
So if we have n_PDE PDE points and n_BC BC points we will end up with (n_PDE+n_BC)**2 points instead of n_PDE**2+n_BC**2

I tried to find a workaround with minimal changes to model.py, and came up with the following:
adding a list_handler decorator to the outputs function in JAX so that it can handle list inputs by applying the function to each input and then concatenates.

I then modified the pde.py file by adding a is_SPINN argument, if true, PDE and BC inputs are put together in a list instead of stacked. The bcs_start should also be modified as the outputs sizes no longer equal the inputs.

I understand that this brings a lot of changes to data.pde, so another possibility is to create a separate data subclass dedicated to SPINN so that the data.pde class isn't overly complicated.

@lululxvi
Copy link
Owner

lululxvi commented Jun 17, 2024

There are too many modifications. We can start with a mathematically equivalent (but slow speed) implementation by repeating the n inputs to n**2. This is similar to DeepONet

class DeepONet(NN):

vs DeepONetCartesianProd

class DeepONetCartesianProd(NN):

Then the code change would be very minimal.

@bonneted
Copy link
Contributor Author

bonneted commented Jul 2, 2024

OK, I'll try that when I have more time, hopefully soon.

@bonneted bonneted marked this pull request as draft July 2, 2024 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants