Skip to content

Commit ea4400d

Browse files
committed
fixed phase shifting env random target shift
1 parent 1e1520b commit ea4400d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

neurolib/control/reinforcement_learning/environments/phase_shifting.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858

5959
self.n_steps = round(self.duration / self.dt)
6060

61-
self.target = self.get_target()
61+
self.init_target()
6262

6363
self.observation_space = spaces.Dict(
6464
{
@@ -75,7 +75,7 @@ def __init__(
7575
)
7676
)
7777

78-
def get_target(self):
78+
def init_target(self):
7979
wc = WCModel()
8080
wc.params = self.model.params.copy()
8181
wc.params["duration"] = self.duration + 100.0
@@ -90,15 +90,17 @@ def get_target(self):
9090

9191
period = np.mean(p_list) * self.dt
9292
self.period = period
93+
self.raw_target = np.stack((wc.exc, wc.inh), axis=1)[0]
94+
self.target_t = wc.t
9395

94-
raw = np.stack((wc.exc, wc.inh), axis=1)[0]
96+
def get_target(self):
9597
if self.random_target_shift:
9698
target_shift = np.random.random() * 2 * np.pi
9799
else:
98100
target_shift = self.target_shift
99-
index = np.round(target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
100-
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
101-
self.target_time = wc.t[index : index + target.shape[1]]
101+
index = np.round(target_shift * self.period / (2.0 * np.pi) / self.dt).astype(int)
102+
target = self.raw_target[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
103+
self.target_time = self.target_t[index : index + target.shape[1]]
102104
self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi
103105

104106
return target
@@ -115,6 +117,7 @@ def _get_info(self):
115117

116118
def reset(self, seed=None, options=None):
117119
super().reset(seed=seed, options=options)
120+
self.target = self.get_target()
118121
self.t_i = 0
119122
self.model.clearModelState()
120123

0 commit comments

Comments
 (0)