We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let's say I have some linear system of parametric ODEs:
def f(t, y, params): a, b, c = params dydt = torch.zeros_like(y) dydt[0] = a*y[0] - b*y[1] dydt[1] = b*y[0] - c*y[1] return dydt
How do I pass the parameters to odeint/odeint_adjoint? In scipy.integrate.odeint, this would look like this:
odeint(f, y0, t, args=(a,b,c))
The text was updated successfully, but these errors were encountered:
can someone please answer this question?
Sorry, something went wrong.
Hi! I'm not one of the developers but I think you can do it this way:
import torch import torch.nn as nn import matplotlib.pyplot as plt from torchdiffeq import odeint, odeint_adjoint class ODEfunc(nn.Module): def __init__(self, params): super(ODEfunc, self).__init__() self.params = params def forward(self, t, y): a, b, c = self.params dydt = torch.zeros_like(y) dydt[0] = a*y[0] - b*y[1] dydt[1] = b*y[0] - c*y[1] return dydt time = torch.linspace(0.0, 10.0, 100) params = torch.Tensor([1.0, 2.0, 3.0]) y0 = torch.Tensor([1.5, 0.25]) func = ODEfunc(params) result = odeint(func, y0, time) result_adjoint = odeint_adjoint(func, y0, time) plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint") plt.scatter(time, result_adjoint[:, 0], color='tab:blue') plt.plot(time, result[:, 1], color='tab:red', zorder=0) plt.scatter(time, result_adjoint[:, 1], color='tab:red') plt.xlabel('Time') plt.ylabel('Y') plt.show()
Output:
No branches or pull requests
Let's say I have some linear system of parametric ODEs:
How do I pass the parameters to odeint/odeint_adjoint? In scipy.integrate.odeint, this would look like this:
The text was updated successfully, but these errors were encountered: