-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrainer.py
383 lines (329 loc) · 19.4 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import numpy as np
import os
import pickle
import time
import torch
from collections import deque
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from buffer import Buffer
from model import ActorCriticModel
from utils import batched_index_select, create_env, polynomial_decay, process_episode_info
from worker import Worker
class PPOTrainer:
def __init__(self, config:dict, run_id:str="run", device:torch.device=torch.device("cpu")) -> None:
"""Initializes all needed training components.
Arguments:
config {dict} -- Configuration and hyperparameters of the environment, trainer and model.
run_id {str, optional} -- A tag used to save Tensorboard Summaries and the trained model. Defaults to "run".
device {torch.device, optional} -- Determines the training device. Defaults to cpu.
"""
# Set members
self.config = config
self.device = device
self.run_id = run_id
self.num_workers = config["n_workers"]
self.lr_schedule = config["learning_rate_schedule"]
self.beta_schedule = config["beta_schedule"]
self.cr_schedule = config["clip_range_schedule"]
self.memory_length = config["transformer"]["memory_length"]
self.num_blocks = config["transformer"]["num_blocks"]
self.embed_dim = config["transformer"]["embed_dim"]
# Setup Tensorboard Summary Writer
if not os.path.exists("./summaries"):
os.makedirs("./summaries")
timestamp = time.strftime("/%Y%m%d-%H%M%S" + "/")
self.writer = SummaryWriter("./summaries/" + run_id + timestamp)
# Init dummy environment to retrieve action space, observation space and max episode length
print("Step 1: Init dummy environment")
dummy_env = create_env(self.config["environment"])
observation_space = dummy_env.observation_space
self.action_space_shape = (dummy_env.action_space.n,)
self.max_episode_length = dummy_env.max_episode_steps
dummy_env.close()
# Init buffer
print("Step 2: Init buffer")
self.buffer = Buffer(self.config, observation_space, self.action_space_shape, self.max_episode_length, self.device)
# Init model
print("Step 3: Init model and optimizer")
self.model = ActorCriticModel(self.config, observation_space, self.action_space_shape, self.max_episode_length).to(self.device)
self.model.train()
self.optimizer = optim.AdamW(self.model.parameters(), lr=self.lr_schedule["initial"])
# Init workers
print("Step 4: Init environment workers")
self.workers = [Worker(self.config["environment"]) for w in range(self.num_workers)]
self.worker_ids = range(self.num_workers)
self.worker_current_episode_step = torch.zeros((self.num_workers, ), dtype=torch.long)
# Reset workers (i.e. environments)
print("Step 5: Reset workers")
for worker in self.workers:
worker.child.send(("reset", None))
# Grab initial observations and store them in their respective placeholder location
self.obs = np.zeros((self.num_workers,) + observation_space.shape, dtype=np.float32)
for w, worker in enumerate(self.workers):
self.obs[w] = worker.child.recv()
# Setup placeholders for each worker's current episodic memory
self.memory = torch.zeros((self.num_workers, self.max_episode_length, self.num_blocks, self.embed_dim), dtype=torch.float32)
# Generate episodic memory mask used in attention
self.memory_mask = torch.tril(torch.ones((self.memory_length, self.memory_length)), diagonal=-1)
""" e.g. memory mask tensor looks like this if memory_length = 6
0, 0, 0, 0, 0, 0
1, 0, 0, 0, 0, 0
1, 1, 0, 0, 0, 0
1, 1, 1, 0, 0, 0
1, 1, 1, 1, 0, 0
1, 1, 1, 1, 1, 0
"""
# Setup memory window indices to support a sliding window over the episodic memory
repetitions = torch.repeat_interleave(torch.arange(0, self.memory_length).unsqueeze(0), self.memory_length - 1, dim = 0).long()
self.memory_indices = torch.stack([torch.arange(i, i + self.memory_length) for i in range(self.max_episode_length - self.memory_length + 1)]).long()
self.memory_indices = torch.cat((repetitions, self.memory_indices))
""" e.g. the memory window indices tensor looks like this if memory_length = 4 and max_episode_length = 7:
0, 1, 2, 3
0, 1, 2, 3
0, 1, 2, 3
0, 1, 2, 3
1, 2, 3, 4
2, 3, 4, 5
3, 4, 5, 6
"""
def run_training(self) -> None:
"""Runs the entire training logic from sampling data to optimizing the model. Only the final model is saved."""
print("Step 6: Starting training using " + str(self.device))
# Store episode results for monitoring statistics
episode_infos = deque(maxlen=100)
for update in range(self.config["updates"]):
# Decay hyperparameters polynomially based on the provided config
learning_rate = polynomial_decay(self.lr_schedule["initial"], self.lr_schedule["final"], self.lr_schedule["max_decay_steps"], self.lr_schedule["power"], update)
beta = polynomial_decay(self.beta_schedule["initial"], self.beta_schedule["final"], self.beta_schedule["max_decay_steps"], self.beta_schedule["power"], update)
clip_range = polynomial_decay(self.cr_schedule["initial"], self.cr_schedule["final"], self.cr_schedule["max_decay_steps"], self.cr_schedule["power"], update)
# Sample training data
sampled_episode_info = self._sample_training_data()
# Prepare the sampled data inside the buffer (splits data into sequences)
self.buffer.prepare_batch_dict()
# Train epochs
training_stats, grad_info = self._train_epochs(learning_rate, clip_range, beta)
training_stats = np.mean(training_stats, axis=0)
# Store recent episode infos
episode_infos.extend(sampled_episode_info)
episode_result = process_episode_info(episode_infos)
# Print training statistics
if "success" in episode_result:
result = "{:4} reward={:.2f} std={:.2f} length={:.1f} std={:.2f} success={:.2f} pi_loss={:3f} v_loss={:3f} entropy={:.3f} loss={:3f} value={:.3f} advantage={:.3f}".format(
update, episode_result["reward_mean"], episode_result["reward_std"], episode_result["length_mean"], episode_result["length_std"], episode_result["success"],
training_stats[0], training_stats[1], training_stats[3], training_stats[2], torch.mean(self.buffer.values), torch.mean(self.buffer.advantages))
else:
result = "{:4} reward={:.2f} std={:.2f} length={:.1f} std={:.2f} pi_loss={:3f} v_loss={:3f} entropy={:.3f} loss={:3f} value={:.3f} advantage={:.3f}".format(
update, episode_result["reward_mean"], episode_result["reward_std"], episode_result["length_mean"], episode_result["length_std"],
training_stats[0], training_stats[1], training_stats[3], training_stats[2], torch.mean(self.buffer.values), torch.mean(self.buffer.advantages))
print(result)
# Write training statistics to tensorboard
self._write_gradient_summary(update, grad_info)
self._write_training_summary(update, training_stats, episode_result)
# Save the trained model at the end of the training
self._save_model()
def _sample_training_data(self) -> list:
"""Runs all n workers for n steps to sample training data.
Returns:
{list} -- list of results of completed episodes.
"""
episode_infos = []
# Init episodic memory buffer using each workers' current episodic memory
self.buffer.memories = [self.memory[w] for w in range(self.num_workers)]
for w in range(self.num_workers):
self.buffer.memory_index[w] = w
# Sample actions from the model and collect experiences for optimization
for t in range(self.config["worker_steps"]):
# Gradients can be omitted for sampling training data
with torch.no_grad():
# Store the initial observations inside the buffer
self.buffer.obs[:, t] = torch.tensor(self.obs)
# Store mask and memory indices inside the buffer
self.buffer.memory_mask[:, t] = self.memory_mask[torch.clip(self.worker_current_episode_step, 0, self.memory_length - 1)]
self.buffer.memory_indices[:, t] = self.memory_indices[self.worker_current_episode_step]
# Retrieve the memory window from the entire episodic memory
sliced_memory = batched_index_select(self.memory, 1, self.buffer.memory_indices[:,t])
# Forward the model to retrieve the policy, the states' value and the new memory item
policy, value, memory = self.model(torch.tensor(self.obs), sliced_memory, self.buffer.memory_mask[:, t],
self.buffer.memory_indices[:,t])
# Add new memory item to the episodic memory
self.memory[self.worker_ids, self.worker_current_episode_step] = memory
# Sample actions from each individual policy branch
actions = []
log_probs = []
for action_branch in policy:
action = action_branch.sample()
actions.append(action)
log_probs.append(action_branch.log_prob(action))
# Write actions, log_probs and values to buffer
self.buffer.actions[:, t] = torch.stack(actions, dim=1)
self.buffer.log_probs[:, t] = torch.stack(log_probs, dim=1)
self.buffer.values[:, t] = value
# Send actions to the environments
for w, worker in enumerate(self.workers):
worker.child.send(("step", self.buffer.actions[w, t].cpu().numpy()))
# Retrieve step results from the environments
for w, worker in enumerate(self.workers):
obs, self.buffer.rewards[w, t], self.buffer.dones[w, t], info = worker.child.recv()
if info: # i.e. done
# Reset the worker's current timestep
self.worker_current_episode_step[w] = 0
# Store the information of the completed episode (e.g. total reward, episode length)
episode_infos.append(info)
# Reset the agent (potential interface for providing reset parameters)
worker.child.send(("reset", None))
# Get data from reset
obs = worker.child.recv()
# Break the reference to the worker's memory
mem_index = self.buffer.memory_index[w, t]
self.buffer.memories[mem_index] = self.buffer.memories[mem_index].clone()
# Reset episodic memory
self.memory[w] = torch.zeros((self.max_episode_length, self.num_blocks, self.embed_dim), dtype=torch.float32)
if t < self.config["worker_steps"] - 1:
# Store memory inside the buffer
self.buffer.memories.append(self.memory[w])
# Store the reference of to the current episodic memory inside the buffer
self.buffer.memory_index[w, t + 1:] = len(self.buffer.memories) - 1
else:
# Increment worker timestep
self.worker_current_episode_step[w] +=1
# Store latest observations
self.obs[w] = obs
# Compute the last value of the current observation and memory window to compute GAE
last_value = self.get_last_value()
# Compute advantages
self.buffer.calc_advantages(last_value, self.config["gamma"], self.config["lamda"])
return episode_infos
def get_last_value(self):
"""Returns:
{torch.tensor} -- Last value of the current observation and memory window to compute GAE"""
start = torch.clip(self.worker_current_episode_step - self.memory_length, 0)
end = torch.clip(self.worker_current_episode_step, self.memory_length)
indices = torch.stack([torch.arange(start[b],end[b]) for b in range(self.num_workers)]).long()
sliced_memory = batched_index_select(self.memory, 1, indices) # Retrieve the memory window from the entire episode
_, last_value, _ = self.model(torch.tensor(self.obs),
sliced_memory, self.memory_mask[torch.clip(self.worker_current_episode_step, 0, self.memory_length - 1)],
self.buffer.memory_indices[:,-1])
return last_value
def _train_epochs(self, learning_rate:float, clip_range:float, beta:float) -> list:
"""Trains several PPO epochs over one batch of data while dividing the batch into mini batches.
Arguments:
learning_rate {float} -- The current learning rate
clip_range {float} -- The current clip range
beta {float} -- The current entropy bonus coefficient
Returns:
{tuple} -- Training and gradient statistics of one training epoch"""
train_info, grad_info = [], {}
for _ in range(self.config["epochs"]):
mini_batch_generator = self.buffer.mini_batch_generator()
for mini_batch in mini_batch_generator:
train_info.append(self._train_mini_batch(mini_batch, learning_rate, clip_range, beta))
for key, value in self.model.get_grad_norm().items():
grad_info.setdefault(key, []).append(value)
return train_info, grad_info
def _train_mini_batch(self, samples:dict, learning_rate:float, clip_range:float, beta:float) -> list:
"""Uses one mini batch to optimize the model.
Arguments:
mini_batch {dict} -- The to be used mini batch data to optimize the model
learning_rate {float} -- Current learning rate
clip_range {float} -- Current clip range
beta {float} -- Current entropy bonus coefficient
Returns:
{list} -- list of trainig statistics (e.g. loss)
"""
# Select episodic memory windows
memory = batched_index_select(samples["memories"], 1, samples["memory_indices"])
# Forward model
policy, value, _ = self.model(samples["obs"], memory, samples["memory_mask"], samples["memory_indices"])
# Retrieve and process log_probs from each policy branch
log_probs, entropies = [], []
for i, policy_branch in enumerate(policy):
log_probs.append(policy_branch.log_prob(samples["actions"][:, i]))
entropies.append(policy_branch.entropy())
log_probs = torch.stack(log_probs, dim=1)
entropies = torch.stack(entropies, dim=1).sum(1).reshape(-1)
# Compute policy surrogates to establish the policy loss
normalized_advantage = (samples["advantages"] - samples["advantages"].mean()) / (samples["advantages"].std() + 1e-8)
normalized_advantage = normalized_advantage.unsqueeze(1).repeat(1, len(self.action_space_shape)) # Repeat is necessary for multi-discrete action spaces
log_ratio = log_probs - samples["log_probs"]
ratio = torch.exp(log_ratio)
surr1 = ratio * normalized_advantage
surr2 = torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range) * normalized_advantage
policy_loss = torch.min(surr1, surr2)
policy_loss = policy_loss.mean()
# Value function loss
sampled_return = samples["values"] + samples["advantages"]
clipped_value = samples["values"] + (value - samples["values"]).clamp(min=-clip_range, max=clip_range)
vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
vf_loss = vf_loss.mean()
# Entropy Bonus
entropy_bonus = entropies.mean()
# Complete loss
loss = -(policy_loss - self.config["value_loss_coefficient"] * vf_loss + beta * entropy_bonus)
# Compute gradients
for pg in self.optimizer.param_groups:
pg["lr"] = learning_rate
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config["max_grad_norm"])
self.optimizer.step()
# Monitor additional training stats
approx_kl = (ratio - 1.0) - log_ratio # http://joschu.net/blog/kl-approx.html
clip_fraction = (abs((ratio - 1.0)) > clip_range).float().mean()
return [policy_loss.cpu().data.numpy(),
vf_loss.cpu().data.numpy(),
loss.cpu().data.numpy(),
entropy_bonus.cpu().data.numpy(),
approx_kl.mean().cpu().data.numpy(),
clip_fraction.cpu().data.numpy()]
def _write_training_summary(self, update, training_stats, episode_result) -> None:
"""Writes to an event file based on the run-id argument.
Arguments:
update {int} -- Current PPO Update
training_stats {list} -- Statistics of the training algorithm
episode_result {dict} -- Statistics of completed episodes
"""
if episode_result:
for key in episode_result:
if "std" not in key:
self.writer.add_scalar("episode/" + key, episode_result[key], update)
self.writer.add_scalar("losses/loss", training_stats[2], update)
self.writer.add_scalar("losses/policy_loss", training_stats[0], update)
self.writer.add_scalar("losses/value_loss", training_stats[1], update)
self.writer.add_scalar("losses/entropy", training_stats[3], update)
self.writer.add_scalar("training/value_mean", torch.mean(self.buffer.values), update)
self.writer.add_scalar("training/advantage_mean", torch.mean(self.buffer.advantages), update)
self.writer.add_scalar("other/clip_fraction", training_stats[4], update)
self.writer.add_scalar("other/kl", training_stats[5], update)
def _write_gradient_summary(self, update, grad_info):
"""Adds gradient statistics to the tensorboard event file.
Arguments:
update {int} -- Current PPO Update
grad_info {dict} -- Gradient statistics
"""
for key, value in grad_info.items():
self.writer.add_scalar("gradients/" + key, np.mean(value), update)
def _save_model(self) -> None:
"""Saves the model and the used training config to the models directory. The filename is based on the run id."""
if not os.path.exists("./models"):
os.makedirs("./models")
self.model.cpu()
pickle.dump((self.model.state_dict(), self.config), open("./models/" + self.run_id + ".nn", "wb"))
print("Model saved to " + "./models/" + self.run_id + ".nn")
def close(self) -> None:
"""Terminates the trainer and all related processes."""
try:
self.dummy_env.close()
except:
pass
try:
self.writer.close()
except:
pass
try:
for worker in self.workers:
worker.child.send(("close", None))
except:
pass
time.sleep(1.0)
exit(0)