-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix external variable initialization #1775
base: master
Are you sure you want to change the base?
Conversation
Could you point out an example for using this code? |
Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py The implementation of sbinn using JAX.
For this first train, we want to use the default |
The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier). |
This one was already working well because there is no pertaining without the external variables.
The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default |
The code seems OK. But the underlying logic becomes extremely complicated now. In fact, you can simply add |
That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644) |
Please resolve the conflicts. |
I've resolved the conflict based on your improved logic. |
if len(aux) == 2: | ||
# External trainable variables in aux[1] are used for unknowns | ||
f = self.pde(inputs, outputs_pde, unknowns=aux[1]) | ||
if len(aux) == 1 and has_default_values(self.pde)[-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
-->elif
- How about the case
has_default_values(self.pde)[-1]
is False
I've faced two bugs when trying to implement :
https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8
in
model.py
it's because the external variables were not initialized on the second compile as the parameters of the net were alreadyin
pde.py
if I don't compile with external variables I still want the code to work with the default values of unknownsI think this code can be safely modified only for jax because the line after was already only for jax