-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolver.py
174 lines (126 loc) · 5.1 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import time
from scipy.integrate import solve_ivp
from my_typing import *
try:
import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
is_jax_installed = True
except ImportError as e:
print('Jax not installed. Falling back to numpy')
print(e)
jax = None
jnp = np
is_jax_installed = False
from RK import RK
def calc_time(f: Callable):
def wrapper(*args, **kwargs) :
s_time = time.time()
ret = f(*args, **kwargs)
el_time = time.time() - s_time
return ret, el_time
return wrapper
class SolverAbstr:
'''
Abstract class. All methods below are expected to return the ODE solution at time t1, given
an initial condition u0 at time t0. F is expected to be slower but more accurate than G.
'''
def run_F(self, t0: float, t1: float, u0: np.ndarray):
raise NotImplementedError('run_F not implemented')
def run_F_timed(self, t0: float, t1: float, u0: np.ndarray, *args, **kwargs):
s_time = time.time()
ret = self.run_F(t0, t1, u0, *args, **kwargs)
el_time = time.time() - s_time
return ret, el_time
def run_F_full(self, t0: float, t1: float, u0: np.ndarray):
raise NotImplementedError('run_F_full not implemented')
@calc_time
def run_F_full_timed(self, t0: float, t1: float, u0: np.ndarray, *args, **kwargs) -> np.ndarray:
return self.run_F_full(t0, t1, u0, *args, **kwargs)
def run_G(self, t0: float, t1: float, u0: np.ndarray):
raise NotImplementedError('run_G not implemented')
@calc_time
def run_G_timed(self, t0, t1, u0, *args, **kwargs):
return self.run_G(t0, t1, u0, *args, **kwargs)
def run_G_full(self, t0, t1, u0):
raise NotImplementedError('run_G_full not implemented')
@calc_time
def run_G_full_timed(self, t0, t1, u0, *args, **kwargs):
return self.run_G_full(t0, t1, u0, *args, **kwargs)
class SolverRK(SolverAbstr):
def __init__(self, f, Ng, Nf, F, G, thresh=1e7, use_jax=True, **kwargs):
self.f = f
self.Ng = int(Ng)
self.Nf = int(Nf)
self.F = F
self.G = G
self.thresh = thresh
self.RK_F = RK(f, F, use_jax)
self.RK_G = RK(f, G, use_jax)
def _run_RK_paged(self, t0, t1, u0, steps, solver):
f = self.f
thresh = self.thresh
if steps > thresh:
steps = steps - 1
iters = [thresh]*int(steps/thresh) + [steps%thresh]* (steps%thresh != 0)
step = (t1 - t0)/(steps)
for temp_steps in iters:
t1 = t0 + step*temp_steps
u0 = solver.run_get_last(t0, t1, steps, u0)
t0 = t1
else:
u0 = solver.run_get_last(t0, t1, steps, u0)
return u0
def run_F(self, t0, t1, u0):
return self._run_RK_paged(t0, t1, u0, self.Nf, self.RK_F)
def run_G(self, t0, t1, u0):
return self._run_RK_paged(t0, t1, u0, self.Ng, self.RK_G)
def run_F_full(self, t0, t1, u0):
return self.RK_F.run(t0, t1, self.Nf, u0)
def run_G_full(self, t0, t1, u0):
return self.RK_G.run(t0, t1, self.Ng, u0)
class SolverScipy(SolverAbstr):
def __init__(self, f, Ng, Nf, G, F= 'RK45', use_jax=True, verbose=True, **kwargs):
'''
Note: Nf is interpreted as a soft constaint only. The algorithm may do more steps.
For the coarse solver, it will use my own RK implementation.'''
self.f = f
self.Ng = Ng
self.Nf = Nf
self.F = self._map_solver(F)
self.G = G
self.kwargs = kwargs
self.rk_solver = SolverRK(f, Ng, Nf, F, G, use_jax=use_jax)
self.verbose = verbose
def _map_solver(self, solver):
map_dict = {'RK2': 'RK23', 'RK4': 'RK45', 'RK8': 'DOP853'}
if solver in map_dict:
return map_dict[solver]
else:
return solver
def run_F(self, t0, t1, u0):
return self._run_F_steps(t0, t1, u0, (t1, )).reshape(-1)
def run_G(self, t0, t1, u0):
return self.rk_solver.run_G(t0, t1, u0)
def _run_F_steps(self, t0, t1, u0, t_steps):
f = self.f
res = solve_ivp(f, [t0, t1], u0, method=self.F,
t_eval=t_steps, max_step=(t1-t0)/self.Nf,
**self.kwargs)
if res.success is not True:
raise ValueError(f'F solver did not converge. Message: {res.message}')
if res.nfev > self.Nf*1.5 and self.verbose:
print(f'Warning: F solver did {res.nfev/self.Nf:0.1f}x more steps than expected')
return res.y
def run_F_full(self, t0, t1, u0, t_steps=None):
if t_steps is None:
t_steps = np.linspace(t0, t1, num=100)
elif isinstance(t_steps, int):
t_steps = np.linspace(t0, t1, num=t_steps)
elif isinstance(t_steps, (list, tuple, np.ndarray)):
pass
else:
raise ValueError(f'Unknown input value for t_steps {t_steps}.')
return self._run_F_steps(t0, t1, u0, t_steps)