-
Notifications
You must be signed in to change notification settings - Fork 0
/
rls.py
41 lines (26 loc) · 891 Bytes
/
rls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class RLS:
def __init__(self, theta0, P0 = None, lamb = 1):
self.lamb = lamb
self.N = len(theta0)
if type(theta0) != type(None):
self.theta = theta0
else:
self.theta = np.zeros(th_dim)
if type(P0) != type(None):
self.P = P0
else:
self.P = np.eye(self.N)*1000
def sample(self, phi, y):
lamb = self.lamb
P = self.P
theta = self.theta
K = P @ phi * 1/(lamb + phi.T @ P @ phi)
P = (np.eye(self.N) - np.outer(K, phi))@P*1/lamb
theta = theta + K*(y - phi.T @ theta)
self.theta = theta
self.P = P
return theta
def get_estimate(self):
return self.theta
def get_variance(self):
return self.P