From 153b2422af9fa07d43d7539b89c4603d9c085da9 Mon Sep 17 00:00:00 2001 From: Zymrael Date: Sun, 28 Mar 2021 19:37:34 +0900 Subject: [PATCH 1/3] add classic dp5 tableau and test --- tests/DETEST/run.py | 4 +-- torchdiffeq/_impl/dopri5.py | 54 ++++++++++++++++++++++++++----------- torchdiffeq/_impl/odeint.py | 3 ++- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/tests/DETEST/run.py b/tests/DETEST/run.py index 8f3e9ca9c..e95748acf 100644 --- a/tests/DETEST/run.py +++ b/tests/DETEST/run.py @@ -22,8 +22,8 @@ def __call__(self, t, y): def main(): sol = dict() - for method in ['dopri5', 'adams']: - for tol in [1e-3, 1e-6, 1e-9]: + for method in ['dopri5', 'classic_dopri5']: + for tol in [1e-3, 1e-4, 1e-4, 1e-5, 1e-6]: print('======= {} | tol={:e} ======='.format(method, tol)) nfes = [] times = [] diff --git a/torchdiffeq/_impl/dopri5.py b/torchdiffeq/_impl/dopri5.py index 1a925ef1d..bcc610632 100644 --- a/torchdiffeq/_impl/dopri5.py +++ b/torchdiffeq/_impl/dopri5.py @@ -1,18 +1,39 @@ import torch from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver +alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64) +beta=[ + torch.tensor([1 / 5], dtype=torch.float64), + torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), + torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), + torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), + torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), + torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), +] +c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64) + + +_CLASSIC_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( + alpha=alpha, beta=beta, c_sol=c_sol, + c_error=torch.tensor([ + 35 / 384 - 5179 / 57600, + 0, + 500 / 1113 - 7571 / 16695, + 125 / 192 - 393 / 640, + -2187 / 6784 - -92097 / 339200, + 11 / 84 - 187 / 2100, + -1 / 40, + ], dtype=torch.float64), +) + +DPS_C_MID = torch.tensor([ + 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, + 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 +], dtype=torch.float64) + _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64), - beta=[ - torch.tensor([1 / 5], dtype=torch.float64), - torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), - torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), - torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), - torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), - torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), - ], - c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64), + alpha=alpha, beta=beta, c_sol=c_sol, c_error=torch.tensor([ 35 / 384 - 1951 / 21600, 0, @@ -24,13 +45,16 @@ ], dtype=torch.float64), ) -DPS_C_MID = torch.tensor([ - 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, - 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 -], dtype=torch.float64) - class Dopri5Solver(RKAdaptiveStepsizeODESolver): order = 5 tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU mid = DPS_C_MID + + +class ClassicDopri5Solver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _CLASSIC_DORMAND_PRINCE_SHAMPINE_TABLEAU + mid = DPS_C_MID + + diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219ad..2b47eaf48 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -1,6 +1,6 @@ import torch from torch.autograd.functional import vjp -from .dopri5 import Dopri5Solver +from .dopri5 import ClassicDopri5Solver, Dopri5Solver from .bosh3 import Bosh3Solver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 @@ -13,6 +13,7 @@ SOLVERS = { 'dopri8': Dopri8Solver, 'dopri5': Dopri5Solver, + 'classic_dopri5': ClassicDopri5Solver, 'bosh3': Bosh3Solver, 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver, From 6de13ec00dd086977091405d94b4dfba93173252 Mon Sep 17 00:00:00 2001 From: Zymrael Date: Sun, 28 Mar 2021 19:38:33 +0900 Subject: [PATCH 2/3] fix tol in DETEST --- .../.ipynb_checkpoints/detest-checkpoint.py | 334 ++++++++++++++++++ .../.ipynb_checkpoints/run-checkpoint.py | 59 ++++ tests/DETEST/run.py | 2 +- 3 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py create mode 100644 tests/DETEST/.ipynb_checkpoints/run-checkpoint.py diff --git a/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py b/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py new file mode 100644 index 000000000..328d5384a --- /dev/null +++ b/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py @@ -0,0 +1,334 @@ +import math +import torch + + +#################################### +# Problem Class A. Single equations. +#################################### +def A1(): + diffeq = lambda t, y: -y + init = lambda: (torch.tensor(0.), torch.tensor(1.)) + solution = lambda t: torch.exp(-t) + return diffeq, init, solution + + +def A2(): + diffeq = lambda t, y: -y**3 / 2 + init = lambda: (torch.tensor(0.), torch.tensor(1.)) + solution = lambda t: 1 / torch.sqrt(t + 1) + return diffeq, init, solution + + +def A3(): + diffeq = lambda t, y: y * torch.cos(t) + init = lambda: (torch.tensor(0.), torch.tensor(1.)) + solution = lambda t: torch.exp(torch.sin(t)) + return diffeq, init, solution + + +def A4(): + diffeq = lambda t, y: y / 4 * (1 - y / 20) + init = lambda: (torch.tensor(0.), torch.tensor(1.)) + solution = lambda t: 20 / (1 + 19 * torch.exp(-t / 4)) + return diffeq, init, solution + + +def A5(): + diffeq = lambda t, y: (y - t) / (y + t) + init = lambda: (torch.tensor(0.), torch.tensor(4.)) + return diffeq, init, None + + +################################# +# Problem Class B. Small systems. +################################# +def B1(): + + def diffeq(t, y): + dy0 = 2 * (y[0] - y[0] * y[1]) + dy1 = -(y[1] - y[0] * y[1]) + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([1., 3.]) + + return diffeq, init, None + + +def B2(): + + A = torch.tensor([[-1., 1., 0.], [1., -2., 1.], [0., 1., -1.]]) + + def diffeq(t, y): + dy = torch.mv(A, y) + return dy + + def init(): + return torch.tensor(0.), torch.tensor([2., 0., 1.]) + + return diffeq, init, None + + +def B3(): + + def diffeq(t, y): + dy0 = -y[0] + dy1 = y[0] - y[1] * y[1] + dy2 = y[1] * y[1] + return torch.stack([dy0, dy1, dy2]) + + def init(): + return torch.tensor(0.), torch.tensor([1., 0., 0.]) + + return diffeq, init, None + + +def B4(): + + def diffeq(t, y): + a = torch.sqrt(y[0] * y[0] + y[1] * y[1]) + dy0 = -y[1] - y[0] * y[2] / a + dy1 = y[0] - y[1] * y[2] / a + dy2 = y[0] / a + return torch.stack([dy0, dy1, dy2]) + + def init(): + return torch.tensor(0.), torch.tensor([3., 0., 0.]) + + return diffeq, init, None + + +def B5(): + + def diffeq(t, y): + dy0 = y[1] * y[2] + dy1 = -y[0] * y[2] + dy2 = -0.51 * y[0] * y[1] + return torch.stack([dy0, dy1, dy2]) + + def init(): + return torch.tensor(0.), torch.tensor([0., 1., 1.]) + + return diffeq, init, None + + +#################################### +# Problem Class C. Moderate systems. +#################################### +def C1(): + + A = torch.zeros(10, 10) + A.view(-1)[:-1:11] = -1 + A.view(-1)[10::11] = 1 + + def diffeq(t, y): + return torch.mv(A, y) + + def init(): + y0 = torch.zeros(10) + y0[0] = 1 + return torch.tensor(0.), y0 + + return diffeq, init, None + + +def C2(): + + A = torch.zeros(10, 10) + A.view(-1)[:-1:11] = torch.linspace(-1, -9, 9) + A.view(-1)[10::11] = torch.linspace(1, 9, 9) + + def diffeq(t, y): + return torch.mv(A, y) + + def init(): + y0 = torch.zeros(10) + y0[0] = 1 + return torch.tensor(0.), y0 + + return diffeq, init, None + + +def C3(): + n = 10 + A = torch.zeros(n, n) + A.view(-1)[::n + 1] = -2 + A.view(-1)[n::n + 1] = 1 + A.view(-1)[1::n + 1] = 1 + + def diffeq(t, y): + return torch.mv(A, y) + + def init(): + y0 = torch.zeros(n) + y0[0] = 1 + return torch.tensor(0.), y0 + + return diffeq, init, None + + +def C4(): + n = 51 + A = torch.zeros(n, n) + A.view(-1)[::n + 1] = -2 + A.view(-1)[n::n + 1] = 1 + A.view(-1)[1::n + 1] = 1 + + def diffeq(t, y): + return torch.mv(A, y) + + def init(): + y0 = torch.zeros(n) + y0[0] = 1 + return torch.tensor(0.), y0 + + return diffeq, init, None + + +def C5(): + + k2 = torch.tensor(2.95912208286) + m0 = torch.tensor(1.00000597682) + m = torch.tensor([ + 0.000954786104043, + 0.000285583733151, + 0.0000437273164546, + 0.0000517759138449, + 0.00000277777777778, + ]).view(1, 5) + + def diffeq(t, y): + # y is 2 x 3 x 5 + # y[0] contains y, y[0] contains y' + # second axis indexes space (x,y,z). + # third axis indexes 5 bodies. + + dy = y[1, :, :] + y = y[0] + r = torch.sqrt(torch.sum(y**2, 0)).view(1, 5) + d = torch.sqrt(torch.sum((y[:, :, None] - y[:, None, :])**2, 0)) + F = m.view(1, 1, 5) * ((y[:, None, :] - y[:, :, None]) / (d * d * d).view(1, 5, 5) + y.view(3, 1, 5) / + (r * r * r).view(1, 1, 5)) + F.view(3, 5 * 5)[:, ::6] = 0 + ddy = k2 * (-(m0 + m) * y / (r * r * r)) + F.sum(2) + return torch.stack([dy, ddy], 0) + + def init(): + y0 = torch.tensor([ + 3.42947415189, 3.35386959711, 1.35494901715, 6.64145542550, 5.97156957878, 2.18231499728, 11.2630437207, + 14.6952576794, 6.27960525067, -30.1552268759, 165699966404, 1.43785752721, -21.1238353380, 28.4465098142, + 15.388265967 + ]).view(5, 3).transpose(0, 1) + + dy0 = torch.tensor([ + -.557160570446, .505696783289, .230578543901, -.415570776342, .365682722812, .169143213293, -.325325669158, + .189706021964, .0877265322780, -.0240476254170, -.287659532608, -.117219543175, -.176860753121, + -.216393453025, -.0148647893090 + ]).view(5, 3).transpose(0, 1) + + return torch.tensor(0.), torch.stack([y0, dy0], 0) + + return diffeq, init, None + + +################################### +# Problem Class D. Orbit equations. +################################### +def _DTemplate(eps): + + def diffeq(t, y): + r = (y[0]**2 + y[1]**2)**(3 / 2) + dy0 = y[2] + dy1 = y[3] + dy2 = -y[0] / r + dy3 = -y[1] / r + return torch.stack([dy0, dy1, dy2, dy3]) + + def init(): + return torch.tensor(0.), torch.tensor([1 - eps, 0, 0, math.sqrt((1 + eps) / (1 - eps))]) + + return diffeq, init, None + + +D1 = lambda: _DTemplate(0.1) +D2 = lambda: _DTemplate(0.3) +D3 = lambda: _DTemplate(0.5) +D4 = lambda: _DTemplate(0.7) +D5 = lambda: _DTemplate(0.9) + + +########################################## +# Problem Class E. Higher order equations. +########################################## +def E1(): + + def diffeq(t, y): + dy0 = y[1] + dy1 = -(y[1] / (t + 1) + (1 - 0.25 / (t + 1)**2) * y[0]) + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([.671396707141803, .0954005144474744]) + + return diffeq, init, None + + +def E2(): + + def diffeq(t, y): + dy0 = y[1] + dy1 = (1 - y[0]**2) * y[1] - y[0] + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([2., 0.]) + + return diffeq, init, None + + +def E3(): + + def diffeq(t, y): + dy0 = y[1] + dy1 = y[0]**3 / 6 - y[0] + 2 * torch.sin(2.78535 * t) + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([0., 0.]) + + return diffeq, init, None + + +def E4(): + + def diffeq(t, y): + dy0 = y[1] + dy1 = .32 - .4 * y[1]**2 + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([30., 0.]) + + return diffeq, init, None + + +def E5(): + + def diffeq(t, y): + dy0 = y[1] + dy1 = torch.sqrt(1 + y[1]**2) / (25 - t) + return torch.stack([dy0, dy1]) + + def init(): + return torch.tensor(0.), torch.tensor([0., 0.]) + + return diffeq, init, None + + +################### +# Helper functions. +################### +def _to_tensor(x): + if not torch.is_tensor(x): + x = torch.tensor(x) + return x diff --git a/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py b/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py new file mode 100644 index 000000000..0842f53aa --- /dev/null +++ b/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py @@ -0,0 +1,59 @@ +import time +import numpy as np +from scipy.stats.mstats import gmean +import torch +from torchdiffeq import odeint +import detest + +torch.set_default_tensor_type(torch.DoubleTensor) + + +class NFEDiffEq: + + def __init__(self, diffeq): + self.diffeq = diffeq + self.nfe = 0 + + def __call__(self, t, y): + self.nfe += 1 + return self.diffeq(t, y) + + +def main(): + + sol = dict() + for method in ['dopri5', 'classic_dopri5']: + for tol in [1e-3, 1e-4, 1e-5, 1e-6]: + print('======= {} | tol={:e} ======='.format(method, tol)) + nfes = [] + times = [] + errs = [] + for c in ['A', 'B', 'C', 'D', 'E']: + for i in ['1', '2', '3', '4', '5']: + diffeq, init, _ = getattr(detest, c + i)() + t0, y0 = init() + diffeq = NFEDiffEq(diffeq) + + if not c + i in sol: + sol[c + i] = odeint( + diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=1e-12, rtol=1e-12, method='dopri5' + )[1] + diffeq.nfe = 0 + + start_time = time.time() + est = odeint(diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=tol, rtol=tol, method=method) + time_spent = time.time() - start_time + + error = torch.sqrt(torch.mean((sol[c + i] - est[1])**2)) + + errs.append(error.item()) + nfes.append(diffeq.nfe) + times.append(time_spent) + + print('{}: NFE {} | Time {} | Err {:e}'.format(c + i, diffeq.nfe, time_spent, error.item())) + + print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format(np.sum(nfes), np.sum(times), gmean(errs))) + + +if __name__ == '__main__': + main() diff --git a/tests/DETEST/run.py b/tests/DETEST/run.py index e95748acf..0842f53aa 100644 --- a/tests/DETEST/run.py +++ b/tests/DETEST/run.py @@ -23,7 +23,7 @@ def main(): sol = dict() for method in ['dopri5', 'classic_dopri5']: - for tol in [1e-3, 1e-4, 1e-4, 1e-5, 1e-6]: + for tol in [1e-3, 1e-4, 1e-5, 1e-6]: print('======= {} | tol={:e} ======='.format(method, tol)) nfes = [] times = [] From 4bd6a7a57fc895fb56ad59c67924ec08baaf221e Mon Sep 17 00:00:00 2001 From: Zymrael Date: Sun, 28 Mar 2021 19:42:27 +0900 Subject: [PATCH 3/3] rm cache --- .../.ipynb_checkpoints/detest-checkpoint.py | 334 ------------------ .../.ipynb_checkpoints/run-checkpoint.py | 59 ---- 2 files changed, 393 deletions(-) delete mode 100644 tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py delete mode 100644 tests/DETEST/.ipynb_checkpoints/run-checkpoint.py diff --git a/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py b/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py deleted file mode 100644 index 328d5384a..000000000 --- a/tests/DETEST/.ipynb_checkpoints/detest-checkpoint.py +++ /dev/null @@ -1,334 +0,0 @@ -import math -import torch - - -#################################### -# Problem Class A. Single equations. -#################################### -def A1(): - diffeq = lambda t, y: -y - init = lambda: (torch.tensor(0.), torch.tensor(1.)) - solution = lambda t: torch.exp(-t) - return diffeq, init, solution - - -def A2(): - diffeq = lambda t, y: -y**3 / 2 - init = lambda: (torch.tensor(0.), torch.tensor(1.)) - solution = lambda t: 1 / torch.sqrt(t + 1) - return diffeq, init, solution - - -def A3(): - diffeq = lambda t, y: y * torch.cos(t) - init = lambda: (torch.tensor(0.), torch.tensor(1.)) - solution = lambda t: torch.exp(torch.sin(t)) - return diffeq, init, solution - - -def A4(): - diffeq = lambda t, y: y / 4 * (1 - y / 20) - init = lambda: (torch.tensor(0.), torch.tensor(1.)) - solution = lambda t: 20 / (1 + 19 * torch.exp(-t / 4)) - return diffeq, init, solution - - -def A5(): - diffeq = lambda t, y: (y - t) / (y + t) - init = lambda: (torch.tensor(0.), torch.tensor(4.)) - return diffeq, init, None - - -################################# -# Problem Class B. Small systems. -################################# -def B1(): - - def diffeq(t, y): - dy0 = 2 * (y[0] - y[0] * y[1]) - dy1 = -(y[1] - y[0] * y[1]) - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([1., 3.]) - - return diffeq, init, None - - -def B2(): - - A = torch.tensor([[-1., 1., 0.], [1., -2., 1.], [0., 1., -1.]]) - - def diffeq(t, y): - dy = torch.mv(A, y) - return dy - - def init(): - return torch.tensor(0.), torch.tensor([2., 0., 1.]) - - return diffeq, init, None - - -def B3(): - - def diffeq(t, y): - dy0 = -y[0] - dy1 = y[0] - y[1] * y[1] - dy2 = y[1] * y[1] - return torch.stack([dy0, dy1, dy2]) - - def init(): - return torch.tensor(0.), torch.tensor([1., 0., 0.]) - - return diffeq, init, None - - -def B4(): - - def diffeq(t, y): - a = torch.sqrt(y[0] * y[0] + y[1] * y[1]) - dy0 = -y[1] - y[0] * y[2] / a - dy1 = y[0] - y[1] * y[2] / a - dy2 = y[0] / a - return torch.stack([dy0, dy1, dy2]) - - def init(): - return torch.tensor(0.), torch.tensor([3., 0., 0.]) - - return diffeq, init, None - - -def B5(): - - def diffeq(t, y): - dy0 = y[1] * y[2] - dy1 = -y[0] * y[2] - dy2 = -0.51 * y[0] * y[1] - return torch.stack([dy0, dy1, dy2]) - - def init(): - return torch.tensor(0.), torch.tensor([0., 1., 1.]) - - return diffeq, init, None - - -#################################### -# Problem Class C. Moderate systems. -#################################### -def C1(): - - A = torch.zeros(10, 10) - A.view(-1)[:-1:11] = -1 - A.view(-1)[10::11] = 1 - - def diffeq(t, y): - return torch.mv(A, y) - - def init(): - y0 = torch.zeros(10) - y0[0] = 1 - return torch.tensor(0.), y0 - - return diffeq, init, None - - -def C2(): - - A = torch.zeros(10, 10) - A.view(-1)[:-1:11] = torch.linspace(-1, -9, 9) - A.view(-1)[10::11] = torch.linspace(1, 9, 9) - - def diffeq(t, y): - return torch.mv(A, y) - - def init(): - y0 = torch.zeros(10) - y0[0] = 1 - return torch.tensor(0.), y0 - - return diffeq, init, None - - -def C3(): - n = 10 - A = torch.zeros(n, n) - A.view(-1)[::n + 1] = -2 - A.view(-1)[n::n + 1] = 1 - A.view(-1)[1::n + 1] = 1 - - def diffeq(t, y): - return torch.mv(A, y) - - def init(): - y0 = torch.zeros(n) - y0[0] = 1 - return torch.tensor(0.), y0 - - return diffeq, init, None - - -def C4(): - n = 51 - A = torch.zeros(n, n) - A.view(-1)[::n + 1] = -2 - A.view(-1)[n::n + 1] = 1 - A.view(-1)[1::n + 1] = 1 - - def diffeq(t, y): - return torch.mv(A, y) - - def init(): - y0 = torch.zeros(n) - y0[0] = 1 - return torch.tensor(0.), y0 - - return diffeq, init, None - - -def C5(): - - k2 = torch.tensor(2.95912208286) - m0 = torch.tensor(1.00000597682) - m = torch.tensor([ - 0.000954786104043, - 0.000285583733151, - 0.0000437273164546, - 0.0000517759138449, - 0.00000277777777778, - ]).view(1, 5) - - def diffeq(t, y): - # y is 2 x 3 x 5 - # y[0] contains y, y[0] contains y' - # second axis indexes space (x,y,z). - # third axis indexes 5 bodies. - - dy = y[1, :, :] - y = y[0] - r = torch.sqrt(torch.sum(y**2, 0)).view(1, 5) - d = torch.sqrt(torch.sum((y[:, :, None] - y[:, None, :])**2, 0)) - F = m.view(1, 1, 5) * ((y[:, None, :] - y[:, :, None]) / (d * d * d).view(1, 5, 5) + y.view(3, 1, 5) / - (r * r * r).view(1, 1, 5)) - F.view(3, 5 * 5)[:, ::6] = 0 - ddy = k2 * (-(m0 + m) * y / (r * r * r)) + F.sum(2) - return torch.stack([dy, ddy], 0) - - def init(): - y0 = torch.tensor([ - 3.42947415189, 3.35386959711, 1.35494901715, 6.64145542550, 5.97156957878, 2.18231499728, 11.2630437207, - 14.6952576794, 6.27960525067, -30.1552268759, 165699966404, 1.43785752721, -21.1238353380, 28.4465098142, - 15.388265967 - ]).view(5, 3).transpose(0, 1) - - dy0 = torch.tensor([ - -.557160570446, .505696783289, .230578543901, -.415570776342, .365682722812, .169143213293, -.325325669158, - .189706021964, .0877265322780, -.0240476254170, -.287659532608, -.117219543175, -.176860753121, - -.216393453025, -.0148647893090 - ]).view(5, 3).transpose(0, 1) - - return torch.tensor(0.), torch.stack([y0, dy0], 0) - - return diffeq, init, None - - -################################### -# Problem Class D. Orbit equations. -################################### -def _DTemplate(eps): - - def diffeq(t, y): - r = (y[0]**2 + y[1]**2)**(3 / 2) - dy0 = y[2] - dy1 = y[3] - dy2 = -y[0] / r - dy3 = -y[1] / r - return torch.stack([dy0, dy1, dy2, dy3]) - - def init(): - return torch.tensor(0.), torch.tensor([1 - eps, 0, 0, math.sqrt((1 + eps) / (1 - eps))]) - - return diffeq, init, None - - -D1 = lambda: _DTemplate(0.1) -D2 = lambda: _DTemplate(0.3) -D3 = lambda: _DTemplate(0.5) -D4 = lambda: _DTemplate(0.7) -D5 = lambda: _DTemplate(0.9) - - -########################################## -# Problem Class E. Higher order equations. -########################################## -def E1(): - - def diffeq(t, y): - dy0 = y[1] - dy1 = -(y[1] / (t + 1) + (1 - 0.25 / (t + 1)**2) * y[0]) - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([.671396707141803, .0954005144474744]) - - return diffeq, init, None - - -def E2(): - - def diffeq(t, y): - dy0 = y[1] - dy1 = (1 - y[0]**2) * y[1] - y[0] - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([2., 0.]) - - return diffeq, init, None - - -def E3(): - - def diffeq(t, y): - dy0 = y[1] - dy1 = y[0]**3 / 6 - y[0] + 2 * torch.sin(2.78535 * t) - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([0., 0.]) - - return diffeq, init, None - - -def E4(): - - def diffeq(t, y): - dy0 = y[1] - dy1 = .32 - .4 * y[1]**2 - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([30., 0.]) - - return diffeq, init, None - - -def E5(): - - def diffeq(t, y): - dy0 = y[1] - dy1 = torch.sqrt(1 + y[1]**2) / (25 - t) - return torch.stack([dy0, dy1]) - - def init(): - return torch.tensor(0.), torch.tensor([0., 0.]) - - return diffeq, init, None - - -################### -# Helper functions. -################### -def _to_tensor(x): - if not torch.is_tensor(x): - x = torch.tensor(x) - return x diff --git a/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py b/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py deleted file mode 100644 index 0842f53aa..000000000 --- a/tests/DETEST/.ipynb_checkpoints/run-checkpoint.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -import numpy as np -from scipy.stats.mstats import gmean -import torch -from torchdiffeq import odeint -import detest - -torch.set_default_tensor_type(torch.DoubleTensor) - - -class NFEDiffEq: - - def __init__(self, diffeq): - self.diffeq = diffeq - self.nfe = 0 - - def __call__(self, t, y): - self.nfe += 1 - return self.diffeq(t, y) - - -def main(): - - sol = dict() - for method in ['dopri5', 'classic_dopri5']: - for tol in [1e-3, 1e-4, 1e-5, 1e-6]: - print('======= {} | tol={:e} ======='.format(method, tol)) - nfes = [] - times = [] - errs = [] - for c in ['A', 'B', 'C', 'D', 'E']: - for i in ['1', '2', '3', '4', '5']: - diffeq, init, _ = getattr(detest, c + i)() - t0, y0 = init() - diffeq = NFEDiffEq(diffeq) - - if not c + i in sol: - sol[c + i] = odeint( - diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=1e-12, rtol=1e-12, method='dopri5' - )[1] - diffeq.nfe = 0 - - start_time = time.time() - est = odeint(diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=tol, rtol=tol, method=method) - time_spent = time.time() - start_time - - error = torch.sqrt(torch.mean((sol[c + i] - est[1])**2)) - - errs.append(error.item()) - nfes.append(diffeq.nfe) - times.append(time_spent) - - print('{}: NFE {} | Time {} | Err {:e}'.format(c + i, diffeq.nfe, time_spent, error.item())) - - print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format(np.sum(nfes), np.sum(times), gmean(errs))) - - -if __name__ == '__main__': - main()