@@ -24,13 +24,13 @@ def __init__(
24
24
self .action_high = action_high
25
25
self .action_dim = action_dim
26
26
self .net = nn .Sequential (
27
- nn .Linear (obs_dim , 256 ),
27
+ nn .Linear (obs_dim , 64 ),
28
28
nn .ReLU (),
29
- nn .Linear (256 , 256 ),
29
+ nn .Linear (64 , 64 ),
30
30
nn .ReLU (),
31
- nn .Linear (256 , 256 ),
31
+ nn .Linear (64 , 64 ),
32
32
nn .ReLU (),
33
- nn .Linear (256 , action_dim ),
33
+ nn .Linear (64 , action_dim ),
34
34
nn .Tanh (),
35
35
)
36
36
@@ -46,8 +46,7 @@ def select_action(
46
46
) -> np .ndarray :
47
47
# random exploration
48
48
if np .random .uniform () < epsilon :
49
- mu = np .random .uniform (- self .action_high ,
50
- self .action_high , self .action_dim )
49
+ mu = np .random .uniform (- self .action_high , self .action_high , self .action_dim )
51
50
52
51
else :
53
52
@@ -61,7 +60,7 @@ def select_action(
61
60
noise = noise_rate * self .action_high * np .random .randn (* mu .shape )
62
61
mu += noise
63
62
mu = np .clip (mu , - self .action_high , self .action_high )
64
- return mu
63
+ return mu . copy ()
65
64
66
65
67
66
class Critic (nn .Module ):
@@ -75,13 +74,13 @@ def __init__(
75
74
self .action_high = action_high
76
75
# critic should give scores for all agents' actions
77
76
self .net = nn .Sequential (
78
- nn .Linear (sum (obs_dims ) + sum (action_dims ), 256 ),
77
+ nn .Linear (sum (obs_dims ) + sum (action_dims ), 64 ),
79
78
nn .ReLU (),
80
- nn .Linear (256 , 256 ),
79
+ nn .Linear (64 , 64 ),
81
80
nn .ReLU (),
82
- nn .Linear (256 , 256 ),
81
+ nn .Linear (64 , 64 ),
83
82
nn .ReLU (),
84
- nn .Linear (256 , 1 ),
83
+ nn .Linear (64 , 1 ),
85
84
)
86
85
87
86
def forward (self , obs : Tensor , actions : Tensor ) -> Tensor :
@@ -153,27 +152,21 @@ def __init__(
153
152
for i in range (n_agents ):
154
153
self .actors .append (Actor (action_high , obs_dims [i ], action_dims [i ]))
155
154
self .critics .append (Critic (action_high , obs_dims , action_dims ))
156
- self .target_actors .append (
157
- Actor (action_high , obs_dims [i ], action_dims [i ]))
158
- self .target_critics .append (
159
- Critic (action_high , obs_dims , action_dims ))
155
+ self .target_actors .append (Actor (action_high , obs_dims [i ], action_dims [i ]))
156
+ self .target_critics .append (Critic (action_high , obs_dims , action_dims ))
160
157
# load_state_dict
161
158
self .target_actors [i ].load_state_dict (self .actors [i ].state_dict ())
162
- self .target_critics [i ].load_state_dict (
163
- self .critics [i ].state_dict ())
159
+ self .target_critics [i ].load_state_dict (self .critics [i ].state_dict ())
164
160
# optimizers
165
- self .optimizer_a .append (optim .Adam (
166
- self .actors [i ].parameters (), lr = lr_a ))
167
- self .optimizer_c .append (optim .Adam (
168
- self .critics [i ].parameters (), lr = lr_c ))
161
+ self .optimizer_a .append (optim .Adam (self .actors [i ].parameters (), lr = lr_a ))
162
+ self .optimizer_c .append (optim .Adam (self .critics [i ].parameters (), lr = lr_c ))
169
163
170
164
self .actors [i ] = self .actors [i ].to (self .device )
171
165
self .critics [i ] = self .critics [i ].to (self .device )
172
166
self .target_actors [i ] = self .target_actors [i ].to (self .device )
173
167
self .target_critics [i ] = self .target_critics [i ].to (self .device )
174
168
175
- self .buffer = MemoryBuffer (
176
- mem_capacity , obs_dims , action_dims , self .n_agents )
169
+ self .buffer = MemoryBuffer (mem_capacity , obs_dims , action_dims , self .n_agents )
177
170
self .writer = SummaryWriter (log_dir = log_dir )
178
171
179
172
def learn (self ):
@@ -260,40 +253,39 @@ def _update_policy(self, transitions: dict):
260
253
261
254
# comput td target and use the square of td residual as the loss
262
255
q_value = self .critics [i ].forward (o , mu )
263
- critic_loss = t .mean ((q_target - q_value ) * (q_target - q_value ))
256
+ critic_loss = t .mean ((q_target - q_value ) * (q_target - q_value ))
264
257
265
258
# actor loss, Actor's goal is to make Critic's scoring higher
266
259
mu [i ] = self .actors [i ].forward (o [i ])
267
260
actor_loss = - self .critics [i ].forward (o , mu ).mean ()
268
261
269
262
# then perform gradient descent
270
263
self .optimizer_a [i ].zero_grad ()
271
- self .optimizer_c [i ].zero_grad ()
272
- critic_loss .backward ()
273
264
actor_loss .backward ()
274
265
self .optimizer_a [i ].step ()
266
+ self .optimizer_c [i ].zero_grad ()
267
+ critic_loss .backward ()
275
268
self .optimizer_c [i ].step ()
276
269
277
270
actor_losses .append (actor_loss .item ())
278
271
critic_losses .append (critic_loss .item ())
279
272
280
- # then soft update the target network
281
- self ._soft_update_target ()
273
+ # then soft update the target network
274
+ self ._soft_update_target (i )
282
275
283
276
return actor_losses , critic_losses
284
277
285
- def _soft_update_target (self ) -> None :
286
- for i in range ( self . n_agents ):
287
- for target_param , param in zip (
278
+ def _soft_update_target (self , i ) -> None :
279
+
280
+ for target_param , param in zip (
288
281
self .target_actors [i ].parameters (), self .actors [i ].parameters ()
289
282
):
290
283
target_param .data .copy_ (
291
284
(1 - self .tau ) * target_param .data + self .tau * param .data
292
285
)
293
286
294
- for target_param , param in zip (
295
- self .target_critics [i ].parameters (
296
- ), self .critics [i ].parameters ()
287
+ for target_param , param in zip (
288
+ self .target_critics [i ].parameters (), self .critics [i ].parameters ()
297
289
):
298
290
target_param .data .copy_ (
299
291
(1 - self .tau ) * target_param .data + self .tau * param .data
0 commit comments