From ea4400d64c635215afe5123b621dd79f185ea5cc Mon Sep 17 00:00:00 2001 From: Georg Reich Date: Thu, 13 Jun 2024 14:47:16 +0100 Subject: [PATCH] fixed phase shifting env random target shift --- .../environments/phase_shifting.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/neurolib/control/reinforcement_learning/environments/phase_shifting.py b/neurolib/control/reinforcement_learning/environments/phase_shifting.py index a3b75d4b..3112a99d 100644 --- a/neurolib/control/reinforcement_learning/environments/phase_shifting.py +++ b/neurolib/control/reinforcement_learning/environments/phase_shifting.py @@ -58,7 +58,7 @@ def __init__( self.n_steps = round(self.duration / self.dt) - self.target = self.get_target() + self.init_target() self.observation_space = spaces.Dict( { @@ -75,7 +75,7 @@ def __init__( ) ) - def get_target(self): + def init_target(self): wc = WCModel() wc.params = self.model.params.copy() wc.params["duration"] = self.duration + 100.0 @@ -90,15 +90,17 @@ def get_target(self): period = np.mean(p_list) * self.dt self.period = period + self.raw_target = np.stack((wc.exc, wc.inh), axis=1)[0] + self.target_t = wc.t - raw = np.stack((wc.exc, wc.inh), axis=1)[0] + def get_target(self): if self.random_target_shift: target_shift = np.random.random() * 2 * np.pi else: target_shift = self.target_shift - index = np.round(target_shift * period / (2.0 * np.pi) / self.dt).astype(int) - target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)] - self.target_time = wc.t[index : index + target.shape[1]] + index = np.round(target_shift * self.period / (2.0 * np.pi) / self.dt).astype(int) + target = self.raw_target[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)] + self.target_time = self.target_t[index : index + target.shape[1]] self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi return target @@ -115,6 +117,7 @@ def _get_info(self): def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) + self.target = self.get_target() self.t_i = 0 self.model.clearModelState()