Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix external variable initialization #1775

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

bonneted
Copy link
Contributor

I've faced two bugs when trying to implement :
https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8

  • in model.py it's because the external variables were not initialized on the second compile as the parameters of the net were already

  • in pde.py if I don't compile with external variables I still want the code to work with the default values of unknowns
    I think this code can be safely modified only for jax because the line after was already only for jax

deepxde/data/pde.py Outdated Show resolved Hide resolved
@lululxvi
Copy link
Owner

Could you point out an example for using this code?

@bonneted
Copy link
Contributor Author

Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py

The implementation of sbinn using JAX.
We first train the model without the external variables :

    def ODE(t, y, unknowns=[var.value for var in var_list_]):
    ...

    model.compile("adam", lr=1e-3, loss_weights=[0, 0, 0, 0, 0, 0, 1e-2])
    model.train(epochs=firsttrain, display_every=1000)
    model.compile(
        "adam",
        lr=1e-3,
        loss_weights=[1, 1, 1e-2, 1, 1, 1, 1e-2],
        external_trainable_variables=var_list_,
    )
    variablefilename = "variables.csv"
    variable = dde.callbacks.VariableValue(
        var_list_, period=callbackperiod, filename=variablefilename
    )
    losshistory, train_state = model.train(
        epochs=maxepochs, display_every=1000, callbacks=[variable]
    )

For this first train, we want to use the default unknowns argument for the ODE

@lululxvi
Copy link
Owner

The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier).

@bonneted
Copy link
Contributor Author

This one was already working well because there is no pertaining without the external variables.
The model is only compiled with the external trainable variables :

model.compile(
    "adam", lr=0.001, external_trainable_variables=external_trainable_variables
)
losshistory, train_state = model.train(iterations=20000, callbacks=[variable])

The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default unknowns argument.

@lululxvi
Copy link
Owner

The code seems OK. But the underlying logic becomes extremely complicated now.

In fact, you can simply add external_trainable_variables in the first compile. As the PDE loss weight is 0, those variables won't get updated any way.

@bonneted
Copy link
Contributor Author

bonneted commented Jul 2, 2024

That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644)
Moreover, it would mean that putting default unknowns values for the PDE is useless and misleading as they can never be used.

@lululxvi
Copy link
Owner

lululxvi commented Jul 3, 2024

Please resolve the conflicts.

@bonneted
Copy link
Contributor Author

bonneted commented Jul 8, 2024

I've resolved the conflict based on your improved logic.
In the JAX backend conditional I added the possibility that there are no external trainable variables but a default value available.

if len(aux) == 2:
# External trainable variables in aux[1] are used for unknowns
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
if len(aux) == 1 and has_default_values(self.pde)[-1]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • if --> elif
  • How about the case has_default_values(self.pde)[-1] is False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants