-
Notifications
You must be signed in to change notification settings - Fork 1
/
memory.py
33 lines (26 loc) · 1.17 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
from random import sample, randint, random
class ReplayMemory:
def __init__(self, resolution, capacity = 10000, stack_size = 4):
stack_size = stack_size
state_shape = (capacity, stack_size, resolution[0], resolution[1])
self.s1 = np.zeros(state_shape, dtype=np.float32)
self.s2 = np.zeros(state_shape, dtype=np.float32)
self.a = np.zeros(capacity, dtype=np.int32)
self.r = np.zeros(capacity, dtype=np.float32)
self.isterminal = np.zeros(capacity, dtype=np.float32)
self.capacity = capacity
self.size = 0
self.pos = 0
def add_transition(self, s1, action, s2, isterminal, reward):
self.s1[self.pos, :, :, :] = s1
self.a[self.pos] = action
if not isterminal:
self.s2[self.pos, :, :, :] = s2
self.isterminal[self.pos] = isterminal
self.r[self.pos] = reward
self.pos = (self.pos + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def get_sample(self, sample_size):
i = sample(range(0, self.size), sample_size)
return self.s1[i], self.a[i], self.s2[i], self.isterminal[i], self.r[i]