From 84f426d1c2d9345651ab89ac40c41a2a3a8ca0d7 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 21 Apr 2021 22:14:38 +0900 Subject: [PATCH 1/8] Adds buffer clearing --- pfrl/replay_buffers/replay_buffer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index 0db496dd0..976ae9b03 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -25,7 +25,7 @@ def __init__(self, capacity: Optional[int] = None, num_steps: int = 1): self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps - self.memory = RandomAccessQueue(maxlen=capacity) + self.initialize_memory(capacity) self.last_n_transitions: collections.defaultdict = collections.defaultdict( lambda: collections.deque([], maxlen=num_steps) ) @@ -92,3 +92,10 @@ def load(self, filename): if isinstance(self.memory, collections.deque): # Load v0.2 self.memory = RandomAccessQueue(self.memory, maxlen=self.memory.maxlen) + + def initialize_memory(self, capacity): + self.memory = RandomAccessQueue(maxlen=capacity) + + def clear(self): + self.initialize_memory(self.capacity) + From cd9699690cf16f13c42a1def7fef4de018199cfc Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 21 Apr 2021 22:35:28 +0900 Subject: [PATCH 2/8] Adds a test for buffer clearing --- .../replay_buffers_test/test_replay_buffer.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index bf2b2b037..9b6c704e9 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -66,6 +66,30 @@ def test_append_and_sample(self): assert s2[1] == list(correct_item) assert s2[0] == list(correct_item2) + def test_clear(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + assert len(rbuf) == 0 + + # Add one and sample one + correct_item = collections.deque([], maxlen=num_steps) + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + assert len(rbuf) == 1 + rbuf.clear() + assert len(rbuf) == 0 + def test_append_and_terminate(self): capacity = self.capacity num_steps = self.num_steps From 51dc3b72b05ca3c0e64edf59134ef958d1d2a414 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 21 Apr 2021 22:43:36 +0900 Subject: [PATCH 3/8] Applies black --- pfrl/replay_buffers/replay_buffer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index 976ae9b03..dda1a7940 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -98,4 +98,3 @@ def initialize_memory(self, capacity): def clear(self): self.initialize_memory(self.capacity) - From df9e2c0fe4d397006b434cf365fc10191940eac0 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 28 Apr 2021 19:29:18 +0900 Subject: [PATCH 4/8] Adds clear to episodic, adds tests, adds clear to abstract replay buffer --- pfrl/replay_buffer.py | 12 +++++++++ pfrl/replay_buffers/episodic.py | 11 +++++--- .../replay_buffers_test/test_replay_buffer.py | 26 +++++++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/pfrl/replay_buffer.py b/pfrl/replay_buffer.py index 7da0fd3f9..583528099 100644 --- a/pfrl/replay_buffer.py +++ b/pfrl/replay_buffer.py @@ -113,6 +113,12 @@ def stop_current_episode(self, env_id=0): """ raise NotImplementedError + @abstractmethod + def clear(self): + """Clears the replay buffer. + """ + raise NotImplementedError + class AbstractEpisodicReplayBuffer(AbstractReplayBuffer): """Defines a common interface of episodic replay buffer. @@ -145,6 +151,12 @@ def n_episodes(self): """ raise NotImplementedError + @abstractmethod + def clear(self): + """Clears the replay buffer. + """ + raise NotImplementedError + def random_subseq(seq, subseq_len): if len(seq) <= subseq_len: diff --git a/pfrl/replay_buffers/episodic.py b/pfrl/replay_buffers/episodic.py index 31e88b0e4..ef1fe7f0b 100644 --- a/pfrl/replay_buffers/episodic.py +++ b/pfrl/replay_buffers/episodic.py @@ -12,9 +12,8 @@ class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer): capacity: Optional[int] = None def __init__(self, capacity=None): - self.current_episode = collections.defaultdict(list) - self.episodic_memory = RandomAccessQueue() - self.memory = RandomAccessQueue() + # initialize data structures + self.clear() self.capacity = capacity def append( @@ -97,3 +96,9 @@ def stop_current_episode(self, env_id=0): for _ in range(len(discarded_episode)): self.memory.popleft() assert not self.current_episode[env_id] + + def clear(self): + self.current_episode = collections.defaultdict(list) + self.episodic_memory = RandomAccessQueue() + self.memory = RandomAccessQueue() + diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index 9b6c704e9..4d84bada2 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -272,6 +272,32 @@ def test_append_and_sample(self): assert t0["next_state"] == t1["state"] assert t0["next_action"] == t1["action"] + def test_clear(self): + capacity = self.capacity + rbuf = replay_buffers.EpisodicReplayBuffer(capacity) + assert len(rbuf) == 0 + assert rbuf.n_episodes == 0 + for n in [10, 15, 5] * 3: + transs = [ + dict( + state=i, + action=100 + i, + reward=200 + i, + next_state=i + 1, + next_action=101 + i, + is_state_terminal=(i == n - 1), + ) + for i in range(n) + ] + for trans in transs: + rbuf.append(**trans) + + assert len(rbuf) == 90 + assert rbuf.n_episodes == 9 + rbuf.clear() + assert len(rbuf) == 0 + assert rbuf.n_episodes == 0 + def test_save_and_load(self): capacity = self.capacity From 51549830e2e94ff8a44aaa087e2a547f40ada6fd Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 30 Apr 2021 20:08:38 +0900 Subject: [PATCH 5/8] Makes persistent buffers raise an error when clear() is called --- pfrl/replay_buffers/episodic.py | 7 +++++-- pfrl/replay_buffers/persistent.py | 6 ++++++ .../test_persistent_replay_buffer.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pfrl/replay_buffers/episodic.py b/pfrl/replay_buffers/episodic.py index ef1fe7f0b..c480430b0 100644 --- a/pfrl/replay_buffers/episodic.py +++ b/pfrl/replay_buffers/episodic.py @@ -13,7 +13,7 @@ class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer): def __init__(self, capacity=None): # initialize data structures - self.clear() + self.initialize_memory() self.capacity = capacity def append( @@ -97,8 +97,11 @@ def stop_current_episode(self, env_id=0): self.memory.popleft() assert not self.current_episode[env_id] - def clear(self): + def initialize_memory(self): self.current_episode = collections.defaultdict(list) self.episodic_memory = RandomAccessQueue() self.memory = RandomAccessQueue() + def clear(self): + self.initialize_memory() + diff --git a/pfrl/replay_buffers/persistent.py b/pfrl/replay_buffers/persistent.py index 19342df2e..fe7dfea74 100644 --- a/pfrl/replay_buffers/persistent.py +++ b/pfrl/replay_buffers/persistent.py @@ -80,6 +80,9 @@ def load(self, _): "{}.load() has been ignored, as it is persistent replay buffer".format(self) ) + def clear(self): + raise NotImplementedError + class PersistentEpisodicReplayBuffer(EpisodicReplayBuffer): """Episodic version of :py:class:`PersistentReplayBuffer` @@ -163,3 +166,6 @@ def load(self, _): warnings.warn( "PersistentEpisodicReplayBuffer.load() is called but it has not effect." ) + + def clear(self): + raise NotImplementedError diff --git a/tests/replay_buffers_test/test_persistent_replay_buffer.py b/tests/replay_buffers_test/test_persistent_replay_buffer.py index ce94ccfa4..0a5015788 100644 --- a/tests/replay_buffers_test/test_persistent_replay_buffer.py +++ b/tests/replay_buffers_test/test_persistent_replay_buffer.py @@ -51,6 +51,11 @@ def test_append_and_sample(self, capacity): assert t0["next_state"] == t1["state"] assert t0["next_action"] == t1["action"] + def test_clear(self, capacity): + rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity) + with pytest.raises(NotImplementedError): + rbuf.clear() + def test_save_and_load(self, capacity): tempdir = tempfile.mkdtemp() @@ -172,3 +177,9 @@ def test(self): # Finally it should have 4 + 2 + 9 = 15 transitions assert len(rbuf) == 15 + + def test_clear(self): + rbuf = PersistentEpisodicReplayBuffer(self.tempdir.name, capacity=None) + with pytest.raises(NotImplementedError): + rbuf.clear() + From 26efc9b0ee81fff4759a1a5a66c094690cb49678 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 30 Apr 2021 20:53:07 +0900 Subject: [PATCH 6/8] Adds tests for replay buffers with env ids --- pfrl/replay_buffers/replay_buffer.py | 13 +++++++------ tests/replay_buffers_test/test_replay_buffer.py | 6 ++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index dda1a7940..bce2aee75 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -25,10 +25,7 @@ def __init__(self, capacity: Optional[int] = None, num_steps: int = 1): self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps - self.initialize_memory(capacity) - self.last_n_transitions: collections.defaultdict = collections.defaultdict( - lambda: collections.deque([], maxlen=num_steps) - ) + self.initialize_memory(capacity, num_steps) def append( self, @@ -93,8 +90,12 @@ def load(self, filename): # Load v0.2 self.memory = RandomAccessQueue(self.memory, maxlen=self.memory.maxlen) - def initialize_memory(self, capacity): + def initialize_memory(self, capacity, num_steps): self.memory = RandomAccessQueue(maxlen=capacity) + self.last_n_transitions: collections.defaultdict = collections.defaultdict( + lambda: collections.deque([], maxlen=num_steps) + ) def clear(self): - self.initialize_memory(self.capacity) + self.initialize_memory(self.capacity, self.num_steps) + diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index 4d84bada2..ef7ec6f5b 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -751,6 +751,9 @@ def test(self): # Finally it should have 9 + 2 + 4 = 15 transitions assert len(rbuf) == 15 + rbuf.clear() + assert len(rbuf) == 0 + assert len(rbuf.last_n_transitions) == 0 @pytest.mark.parametrize( @@ -817,6 +820,9 @@ def test(self): # Finally it should have 4 + 2 + 9 = 15 transitions assert len(rbuf) == 15 + rbuf.clear() + assert len(rbuf) == 0 + assert len(rbuf.current_episode) == 0 class TestReplayBufferFail(unittest.TestCase): From 219d1a783b7984c4643e1c189ed1eedd83473e94 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Mon, 3 May 2021 15:11:05 +0900 Subject: [PATCH 7/8] Adds resetting of priority weight error for prioritized buffers --- pfrl/replay_buffers/prioritized.py | 31 +++++++++++++++++++--------- pfrl/replay_buffers/replay_buffer.py | 1 - 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/pfrl/replay_buffers/prioritized.py b/pfrl/replay_buffers/prioritized.py index 5814d2e71..3ca014469 100644 --- a/pfrl/replay_buffers/prioritized.py +++ b/pfrl/replay_buffers/prioritized.py @@ -99,6 +99,24 @@ def __init__( self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps + self.beta0 = beta0 + self.betasteps = betasteps + self.initialize_memory(capacity, num_steps, alpha, beta0, betasteps, + eps, normalize_by_max, error_min, error_max) + + def sample(self, n): + assert len(self.memory) >= n + sampled, probabilities, min_prob = self.memory.sample(n) + weights = self.weights_from_probabilities(probabilities, min_prob) + for e, w in zip(sampled, weights): + e[0]["weight"] = w + return sampled + + def update_errors(self, errors): + self.memory.set_last_priority(self.priority_from_errors(errors)) + + def initialize_memory(self, capacity, num_steps, alpha, beta0, betasteps, + eps, normalize_by_max, error_min, error_max): self.memory = PrioritizedBuffer(capacity=capacity) self.last_n_transitions = collections.defaultdict( lambda: collections.deque([], maxlen=num_steps) @@ -114,13 +132,6 @@ def __init__( error_max=error_max, ) - def sample(self, n): - assert len(self.memory) >= n - sampled, probabilities, min_prob = self.memory.sample(n) - weights = self.weights_from_probabilities(probabilities, min_prob) - for e, w in zip(sampled, weights): - e[0]["weight"] = w - return sampled - - def update_errors(self, errors): - self.memory.set_last_priority(self.priority_from_errors(errors)) + def clear(self): + self.initialize_memory(self.capacity, self.num_steps, self.alpha, self.beta0, self.betasteps, + self.eps, self.normalize_by_max, self.error_min, self.error_max) diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index bce2aee75..060d538cd 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -98,4 +98,3 @@ def initialize_memory(self, capacity, num_steps): def clear(self): self.initialize_memory(self.capacity, self.num_steps) - From 41274ebafb07d84407bd13db7f05c2f1358c3e2e Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 26 May 2021 19:09:36 +0900 Subject: [PATCH 8/8] Adds additional support for buffer clearing --- pfrl/replay_buffers/prioritized_episodic.py | 45 ++++++++++++------- .../replay_buffers_test/test_replay_buffer.py | 30 +++++++++++++ 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/pfrl/replay_buffers/prioritized_episodic.py b/pfrl/replay_buffers/prioritized_episodic.py index e31a74863..652ce18cb 100644 --- a/pfrl/replay_buffers/prioritized_episodic.py +++ b/pfrl/replay_buffers/prioritized_episodic.py @@ -22,25 +22,16 @@ def __init__( error_min=None, error_max=None, ): - self.current_episode = collections.defaultdict(list) - self.episodic_memory = PrioritizedBuffer( - capacity=None, wait_priority_after_sampling=wait_priority_after_sampling - ) - self.memory = RandomAccessQueue(maxlen=capacity) + self.initialize_memory(capacity) + self.capacity = capacity + self.wait_priority_after_sampling = wait_priority_after_sampling + self.beta0 = beta0 + self.betasteps = betasteps + self.capacity_left = capacity self.default_priority_func = default_priority_func self.uniform_ratio = uniform_ratio self.return_sample_weights = return_sample_weights - PriorityWeightError.__init__( - self, - alpha, - beta0, - betasteps, - eps, - normalize_by_max, - error_min=error_min, - error_max=error_max, - ) def sample_episodes(self, n_episodes, max_len=None): """Sample n unique samples from this replay buffer""" @@ -75,3 +66,27 @@ def stop_current_episode(self, env_id=0): discarded_episode = self.episodic_memory.popleft() self.capacity_left += len(discarded_episode) assert not self.current_episode[env_id] + + def initialize_memory(self, capacity, wait_priority_after_sampling, + alpha, beta0, betasteps, eps, normalize_by_max, + error_min, error_max): + self.current_episode = collections.defaultdict(list) + self.episodic_memory = PrioritizedBuffer( + capacity=None, wait_priority_after_sampling=wait_priority_after_sampling + ) + self.memory = RandomAccessQueue(maxlen=capacity) + PriorityWeightError.__init__( + self, + alpha, + beta0, + betasteps, + eps, + normalize_by_max, + error_min=error_min, + error_max=error_max, + ) + + def clear(self): + self.initialize_memory(self.capacity, self.wait_priority_after_sampling, + self.alpha, self.beta0, self.betasteps, self.eps, + self.normalize_by_max, self.error_min, self.error_max) diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index ef7ec6f5b..00c90c4c1 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -448,6 +448,36 @@ def test_append_and_sample(self): s4 = rbuf.sample(2) np.testing.assert_allclose(s4[0][0]["weight"], s4[1][0]["weight"]) + def test_clear(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.PrioritizedReplayBuffer( + capacity, + normalize_by_max=self.normalize_by_max, + error_max=5, + num_steps=num_steps, + ) + + # assert len(rbuf) == 0 + + # Add one and sample one + correct_item = collections.deque([], maxlen=num_steps) + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + # assert len(rbuf) == 1 + rbuf.update_errors([3.14]) + rbuf.clear() + # assert len(rbuf) == 0 + def test_normalize_by_max(self): rbuf = replay_buffers.PrioritizedReplayBuffer(