@@ -58,7 +58,7 @@ def __init__(
58
58
59
59
self .n_steps = round (self .duration / self .dt )
60
60
61
- self .target = self . get_target ()
61
+ self .init_target ()
62
62
63
63
self .observation_space = spaces .Dict (
64
64
{
@@ -75,7 +75,7 @@ def __init__(
75
75
)
76
76
)
77
77
78
- def get_target (self ):
78
+ def init_target (self ):
79
79
wc = WCModel ()
80
80
wc .params = self .model .params .copy ()
81
81
wc .params ["duration" ] = self .duration + 100.0
@@ -90,15 +90,17 @@ def get_target(self):
90
90
91
91
period = np .mean (p_list ) * self .dt
92
92
self .period = period
93
+ self .raw_target = np .stack ((wc .exc , wc .inh ), axis = 1 )[0 ]
94
+ self .target_t = wc .t
93
95
94
- raw = np . stack (( wc . exc , wc . inh ), axis = 1 )[ 0 ]
96
+ def get_target ( self ):
95
97
if self .random_target_shift :
96
98
target_shift = np .random .random () * 2 * np .pi
97
99
else :
98
100
target_shift = self .target_shift
99
- index = np .round (target_shift * period / (2.0 * np .pi ) / self .dt ).astype (int )
100
- target = raw [:, index : index + np .round (1 + self .duration / self .dt , 1 ).astype (int )]
101
- self .target_time = wc . t [index : index + target .shape [1 ]]
101
+ index = np .round (target_shift * self . period / (2.0 * np .pi ) / self .dt ).astype (int )
102
+ target = self . raw_target [:, index : index + np .round (1 + self .duration / self .dt , 1 ).astype (int )]
103
+ self .target_time = self . target_t [index : index + target .shape [1 ]]
102
104
self .target_phase = (self .target_time % self .period ) / self .period * 2 * np .pi
103
105
104
106
return target
@@ -115,6 +117,7 @@ def _get_info(self):
115
117
116
118
def reset (self , seed = None , options = None ):
117
119
super ().reset (seed = seed , options = options )
120
+ self .target = self .get_target ()
118
121
self .t_i = 0
119
122
self .model .clearModelState ()
120
123
0 commit comments