Skip to content

Commit 6109825

Browse files
committed
added phase shifting env
1 parent 899616c commit 6109825

File tree

5 files changed

+419
-3
lines changed

5 files changed

+419
-3
lines changed

examples/example-6.2-rl-phaseshifting.ipynb

Lines changed: 277 additions & 0 deletions
Large diffs are not rendered by default.

neurolib/control/reinforcement_learning/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@
44
id="StateSwitching-v0",
55
entry_point="neurolib.control.reinforcement_learning.environments.state_switching:StateSwitchingEnv",
66
)
7+
8+
register(
9+
id="PhaseShifting-v0",
10+
entry_point="neurolib.control.reinforcement_learning.environments.phase_shifting:PhaseShiftingEnv",
11+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from neurolib.control.reinforcement_learning.environments.state_switching import StateSwitchingEnv
2+
from neurolib.control.reinforcement_learning.environments.phase_shifting import PhaseShiftingEnv
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from neurolib.utils.stimulus import ZeroInput
2+
3+
import numpy as np
4+
import scipy
5+
6+
import gymnasium as gym
7+
from gymnasium import spaces
8+
9+
from neurolib.models.wc import WCModel
10+
11+
12+
class PhaseShiftingEnv(gym.Env):
13+
14+
def __init__(
15+
self,
16+
duration=300,
17+
dt=0.1,
18+
target_shift=1 * np.pi,
19+
exc_ext_baseline=2.8,
20+
inh_ext_baseline=1.2,
21+
x_init=0.04201540010391125,
22+
y_init=0.1354067401509556,
23+
sigma_ou=0.0,
24+
c_inhexc=16,
25+
c_excinh=10,
26+
c_inhinh=1,
27+
control_strength_loss_scale=0.005,
28+
):
29+
self.exc_ext_baseline = exc_ext_baseline
30+
self.inh_ext_baseline = inh_ext_baseline
31+
32+
self.duration = duration
33+
self.dt = dt
34+
self.target_shift = target_shift
35+
self.x_init = x_init
36+
self.y_init = y_init
37+
self.control_strength_loss_scale = control_strength_loss_scale
38+
39+
assert 0 < self.target_shift < 2 * np.pi
40+
41+
self.model = WCModel()
42+
self.model.params["dt"] = self.dt
43+
self.model.params["sigma_ou"] = sigma_ou
44+
self.model.params["duration"] = self.dt # one step at a time
45+
self.model.params["exc_init"] = np.array([[x_init]])
46+
self.model.params["inh_init"] = np.array([[y_init]])
47+
self.model.params["exc_ext_baseline"] = self.exc_ext_baseline
48+
self.model.params["inh_ext_baseline"] = self.inh_ext_baseline
49+
50+
self.model.params["c_inhexc"] = c_inhexc
51+
self.model.params["c_excinh"] = c_excinh
52+
self.model.params["c_inhinh"] = c_inhinh
53+
self.params = self.model.params.copy()
54+
55+
self.n_steps = round(self.duration / self.dt)
56+
57+
self.target = self.get_target()
58+
59+
self.observation_space = spaces.Dict(
60+
{
61+
"exc": spaces.Box(0, 1, shape=(1,), dtype=float),
62+
"inh": spaces.Box(0, 1, shape=(1,), dtype=float),
63+
}
64+
)
65+
66+
self.action_space = spaces.Tuple(
67+
(
68+
spaces.Box(-5, 5, shape=(1,), dtype=float), # exc
69+
spaces.Box(-5, 5, shape=(1,), dtype=float), # inh
70+
)
71+
)
72+
73+
def get_target(self):
74+
wc = WCModel()
75+
wc.params = self.model.params.copy()
76+
wc.params["duration"] = self.duration + 100.0
77+
wc.run()
78+
79+
peaks = scipy.signal.find_peaks(wc.exc[0, :])[0]
80+
p_list = []
81+
for i in range(3, len(peaks)):
82+
p_list.append(peaks[i] - peaks[i - 1])
83+
period = np.mean(p_list) * self.dt
84+
self.period = period
85+
86+
raw = np.stack((wc.exc, wc.inh), axis=1)[0]
87+
index = np.round(self.target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
88+
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
89+
90+
return target
91+
92+
def _get_obs(self):
93+
return {"exc": self.model.exc[0], "inh": self.model.inh[0]}
94+
95+
def _get_info(self):
96+
return {"t": self.t_i * self.dt}
97+
98+
def reset(self, seed=None, options=None):
99+
super().reset(seed=seed, options=options)
100+
self.t_i = 0
101+
self.model.clearModelState()
102+
103+
self.model.params = self.params.copy()
104+
self.model.exc = np.array([[self.x_init]])
105+
self.model.inh = np.array([[self.y_init]])
106+
107+
observation = self._get_obs()
108+
info = self._get_info()
109+
return observation, info
110+
111+
def _loss(self, obs, action):
112+
control_loss = np.sqrt(
113+
(self.target[0, self.t_i] - obs["exc"].item()) ** 2 + (self.target[1, self.t_i] - obs["inh"].item()) ** 2
114+
)
115+
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
116+
return control_loss + control_strength_loss
117+
118+
def step(self, action):
119+
assert self.action_space.contains(action)
120+
exc, inh = action
121+
self.model.params["exc_ext"] = np.array([exc])
122+
self.model.params["inh_ext"] = np.array([inh])
123+
self.model.run(continue_run=True)
124+
125+
observation = self._get_obs()
126+
127+
reward = -self._loss(observation, action)
128+
129+
self.t_i += 1
130+
terminated = self.t_i >= self.n_steps
131+
info = self._get_info()
132+
133+
return observation, reward, terminated, False, info

neurolib/control/reinforcement_learning/environments/state_switching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ def reset(self, seed=None, options=None):
101101
return observation, info
102102

103103
def _loss(self, obs, action):
104-
accuracy_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
105-
# exc_ext, inh_ext = action
104+
control_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
106105
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
107-
return accuracy_loss + control_strength_loss
106+
return control_loss + control_strength_loss
108107

109108
def step(self, action):
109+
assert self.action_space.contains(action)
110110
exc, inh = action
111111
self.model.params["exc_ext"] = np.array([exc])
112112
self.model.params["inh_ext"] = np.array([inh])

0 commit comments

Comments
 (0)