Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Clearing of Replay Buffer #138

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions pfrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def stop_current_episode(self, env_id=0):
"""
raise NotImplementedError

@abstractmethod
def clear(self):
"""Clears the replay buffer.
"""
raise NotImplementedError


class AbstractEpisodicReplayBuffer(AbstractReplayBuffer):
"""Defines a common interface of episodic replay buffer.
Expand Down Expand Up @@ -145,6 +151,12 @@ def n_episodes(self):
"""
raise NotImplementedError

@abstractmethod
def clear(self):
"""Clears the replay buffer.
"""
raise NotImplementedError


def random_subseq(seq, subseq_len):
if len(seq) <= subseq_len:
Expand Down
14 changes: 11 additions & 3 deletions pfrl/replay_buffers/episodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer):
capacity: Optional[int] = None

def __init__(self, capacity=None):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = RandomAccessQueue()
self.memory = RandomAccessQueue()
# initialize data structures
self.initialize_memory()
self.capacity = capacity

def append(
Expand Down Expand Up @@ -97,3 +96,12 @@ def stop_current_episode(self, env_id=0):
for _ in range(len(discarded_episode)):
self.memory.popleft()
assert not self.current_episode[env_id]

def initialize_memory(self):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = RandomAccessQueue()
self.memory = RandomAccessQueue()

def clear(self):
self.initialize_memory()

6 changes: 6 additions & 0 deletions pfrl/replay_buffers/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def load(self, _):
"{}.load() has been ignored, as it is persistent replay buffer".format(self)
)

def clear(self):
raise NotImplementedError


class PersistentEpisodicReplayBuffer(EpisodicReplayBuffer):
"""Episodic version of :py:class:`PersistentReplayBuffer`
Expand Down Expand Up @@ -163,3 +166,6 @@ def load(self, _):
warnings.warn(
"PersistentEpisodicReplayBuffer.load() is called but it has not effect."
)

def clear(self):
raise NotImplementedError
31 changes: 21 additions & 10 deletions pfrl/replay_buffers/prioritized.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ def __init__(
self.capacity = capacity
assert num_steps > 0
self.num_steps = num_steps
self.beta0 = beta0
self.betasteps = betasteps
self.initialize_memory(capacity, num_steps, alpha, beta0, betasteps,
eps, normalize_by_max, error_min, error_max)

def sample(self, n):
assert len(self.memory) >= n
sampled, probabilities, min_prob = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities, min_prob)
for e, w in zip(sampled, weights):
e[0]["weight"] = w
return sampled

def update_errors(self, errors):
self.memory.set_last_priority(self.priority_from_errors(errors))

def initialize_memory(self, capacity, num_steps, alpha, beta0, betasteps,
eps, normalize_by_max, error_min, error_max):
self.memory = PrioritizedBuffer(capacity=capacity)
self.last_n_transitions = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps)
Expand All @@ -114,13 +132,6 @@ def __init__(
error_max=error_max,
)

def sample(self, n):
assert len(self.memory) >= n
sampled, probabilities, min_prob = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities, min_prob)
for e, w in zip(sampled, weights):
e[0]["weight"] = w
return sampled

def update_errors(self, errors):
self.memory.set_last_priority(self.priority_from_errors(errors))
def clear(self):
self.initialize_memory(self.capacity, self.num_steps, self.alpha, self.beta0, self.betasteps,
self.eps, self.normalize_by_max, self.error_min, self.error_max)
45 changes: 30 additions & 15 deletions pfrl/replay_buffers/prioritized_episodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,16 @@ def __init__(
error_min=None,
error_max=None,
):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = PrioritizedBuffer(
capacity=None, wait_priority_after_sampling=wait_priority_after_sampling
)
self.memory = RandomAccessQueue(maxlen=capacity)
self.initialize_memory(capacity)
self.capacity = capacity
self.wait_priority_after_sampling = wait_priority_after_sampling
self.beta0 = beta0
self.betasteps = betasteps

self.capacity_left = capacity
self.default_priority_func = default_priority_func
self.uniform_ratio = uniform_ratio
self.return_sample_weights = return_sample_weights
PriorityWeightError.__init__(
self,
alpha,
beta0,
betasteps,
eps,
normalize_by_max,
error_min=error_min,
error_max=error_max,
)

def sample_episodes(self, n_episodes, max_len=None):
"""Sample n unique samples from this replay buffer"""
Expand Down Expand Up @@ -75,3 +66,27 @@ def stop_current_episode(self, env_id=0):
discarded_episode = self.episodic_memory.popleft()
self.capacity_left += len(discarded_episode)
assert not self.current_episode[env_id]

def initialize_memory(self, capacity, wait_priority_after_sampling,
alpha, beta0, betasteps, eps, normalize_by_max,
error_min, error_max):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = PrioritizedBuffer(
capacity=None, wait_priority_after_sampling=wait_priority_after_sampling
)
self.memory = RandomAccessQueue(maxlen=capacity)
PriorityWeightError.__init__(
self,
alpha,
beta0,
betasteps,
eps,
normalize_by_max,
error_min=error_min,
error_max=error_max,
)

def clear(self):
self.initialize_memory(self.capacity, self.wait_priority_after_sampling,
self.alpha, self.beta0, self.betasteps, self.eps,
self.normalize_by_max, self.error_min, self.error_max)
14 changes: 10 additions & 4 deletions pfrl/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ def __init__(self, capacity: Optional[int] = None, num_steps: int = 1):
self.capacity = capacity
assert num_steps > 0
self.num_steps = num_steps
self.memory = RandomAccessQueue(maxlen=capacity)
self.last_n_transitions: collections.defaultdict = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps)
)
self.initialize_memory(capacity, num_steps)

def append(
self,
Expand Down Expand Up @@ -92,3 +89,12 @@ def load(self, filename):
if isinstance(self.memory, collections.deque):
# Load v0.2
self.memory = RandomAccessQueue(self.memory, maxlen=self.memory.maxlen)

def initialize_memory(self, capacity, num_steps):
self.memory = RandomAccessQueue(maxlen=capacity)
self.last_n_transitions: collections.defaultdict = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps)
)

def clear(self):
self.initialize_memory(self.capacity, self.num_steps)
11 changes: 11 additions & 0 deletions tests/replay_buffers_test/test_persistent_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def test_append_and_sample(self, capacity):
assert t0["next_state"] == t1["state"]
assert t0["next_action"] == t1["action"]

def test_clear(self, capacity):
rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity)
with pytest.raises(NotImplementedError):
rbuf.clear()

def test_save_and_load(self, capacity):
tempdir = tempfile.mkdtemp()

Expand Down Expand Up @@ -172,3 +177,9 @@ def test(self):

# Finally it should have 4 + 2 + 9 = 15 transitions
assert len(rbuf) == 15

def test_clear(self):
rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity=None)
with pytest.raises(NotImplementedError):
rbuf.clear()

86 changes: 86 additions & 0 deletions tests/replay_buffers_test/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ def test_append_and_sample(self):
assert s2[1] == list(correct_item)
assert s2[0] == list(correct_item2)

def test_clear(self):
capacity = self.capacity
num_steps = self.num_steps
rbuf = replay_buffers.ReplayBuffer(capacity, num_steps)

assert len(rbuf) == 0

# Add one and sample one
correct_item = collections.deque([], maxlen=num_steps)
for _ in range(num_steps):
trans1 = dict(
state=0,
action=1,
reward=2,
next_state=3,
next_action=4,
is_state_terminal=False,
)
correct_item.append(trans1)
rbuf.append(**trans1)
assert len(rbuf) == 1
rbuf.clear()
assert len(rbuf) == 0

def test_append_and_terminate(self):
capacity = self.capacity
num_steps = self.num_steps
Expand Down Expand Up @@ -248,6 +272,32 @@ def test_append_and_sample(self):
assert t0["next_state"] == t1["state"]
assert t0["next_action"] == t1["action"]

def test_clear(self):
capacity = self.capacity
rbuf = replay_buffers.EpisodicReplayBuffer(capacity)
assert len(rbuf) == 0
assert rbuf.n_episodes == 0
for n in [10, 15, 5] * 3:
transs = [
dict(
state=i,
action=100 + i,
reward=200 + i,
next_state=i + 1,
next_action=101 + i,
is_state_terminal=(i == n - 1),
)
for i in range(n)
]
for trans in transs:
rbuf.append(**trans)

assert len(rbuf) == 90
assert rbuf.n_episodes == 9
rbuf.clear()
assert len(rbuf) == 0
assert rbuf.n_episodes == 0

def test_save_and_load(self):
capacity = self.capacity

Expand Down Expand Up @@ -398,6 +448,36 @@ def test_append_and_sample(self):
s4 = rbuf.sample(2)
np.testing.assert_allclose(s4[0][0]["weight"], s4[1][0]["weight"])

def test_clear(self):
capacity = self.capacity
num_steps = self.num_steps
rbuf = replay_buffers.PrioritizedReplayBuffer(
capacity,
normalize_by_max=self.normalize_by_max,
error_max=5,
num_steps=num_steps,
)

# assert len(rbuf) == 0

# Add one and sample one
correct_item = collections.deque([], maxlen=num_steps)
for _ in range(num_steps):
trans1 = dict(
state=0,
action=1,
reward=2,
next_state=3,
next_action=4,
is_state_terminal=False,
)
correct_item.append(trans1)
rbuf.append(**trans1)
# assert len(rbuf) == 1
rbuf.update_errors([3.14])
rbuf.clear()
# assert len(rbuf) == 0

def test_normalize_by_max(self):

rbuf = replay_buffers.PrioritizedReplayBuffer(
Expand Down Expand Up @@ -701,6 +781,9 @@ def test(self):

# Finally it should have 9 + 2 + 4 = 15 transitions
assert len(rbuf) == 15
rbuf.clear()
assert len(rbuf) == 0
assert len(rbuf.last_n_transitions) == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -767,6 +850,9 @@ def test(self):

# Finally it should have 4 + 2 + 9 = 15 transitions
assert len(rbuf) == 15
rbuf.clear()
assert len(rbuf) == 0
assert len(rbuf.current_episode) == 0


class TestReplayBufferFail(unittest.TestCase):
Expand Down