forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.py
executable file
·83 lines (66 loc) · 2.64 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
class RingBuffer(object):
def __init__(self, maxlen, shape, dtype='float32'):
self.maxlen = maxlen
self.start = 0
self.length = 0
self.data = np.zeros((maxlen,) + shape).astype(dtype)
def __len__(self):
return self.length
def __getitem__(self, idx):
if idx < 0 or idx >= self.length:
raise KeyError()
return self.data[(self.start + idx) % self.maxlen]
def get_batch(self, idxs):
return self.data[(self.start + idxs) % self.maxlen]
def append(self, v):
if self.length < self.maxlen:
# We have space, simply increase the length.
self.length += 1
elif self.length == self.maxlen:
# No space, "remove" the first item.
self.start = (self.start + 1) % self.maxlen
else:
# This should never happen.
raise RuntimeError()
self.data[(self.start + self.length - 1) % self.maxlen] = v
def array_min2d(x):
x = np.array(x)
if x.ndim >= 2:
return x
return x.reshape(-1, 1)
class Memory(object):
def __init__(self, limit, action_shape, observation_shape):
self.limit = limit
self.observations0 = RingBuffer(limit, shape=observation_shape)
self.actions = RingBuffer(limit, shape=action_shape)
self.rewards = RingBuffer(limit, shape=(1,))
self.terminals1 = RingBuffer(limit, shape=(1,))
self.observations1 = RingBuffer(limit, shape=observation_shape)
def sample(self, batch_size):
# Draw such that we always have a proceeding element.
batch_idxs = np.random.randint(self.nb_entries - 2, size=batch_size)
obs0_batch = self.observations0.get_batch(batch_idxs)
obs1_batch = self.observations1.get_batch(batch_idxs)
action_batch = self.actions.get_batch(batch_idxs)
reward_batch = self.rewards.get_batch(batch_idxs)
terminal1_batch = self.terminals1.get_batch(batch_idxs)
result = {
'obs0': array_min2d(obs0_batch),
'obs1': array_min2d(obs1_batch),
'rewards': array_min2d(reward_batch),
'actions': array_min2d(action_batch),
'terminals1': array_min2d(terminal1_batch),
}
return result
def append(self, obs0, action, reward, obs1, terminal1, training=True):
if not training:
return
self.observations0.append(obs0)
self.actions.append(action)
self.rewards.append(reward)
self.observations1.append(obs1)
self.terminals1.append(terminal1)
@property
def nb_entries(self):
return len(self.observations0)