From 8f3fb8e193f6f44c46ab3ede19cdd1af3b06ec89 Mon Sep 17 00:00:00 2001 From: Mishari Date: Tue, 27 Oct 2020 13:28:14 -0600 Subject: [PATCH] Add PER --- examples/torch/dqn_atari.py | 16 +- src/garage/envs/gym_env.py | 1 - src/garage/replay_buffer/__init__.py | 3 +- src/garage/replay_buffer/path_buffer.py | 14 +- src/garage/replay_buffer/per_replay_buffer.py | 145 ++++++++++++++++++ src/garage/replay_buffer/replay_buffer.py | 3 + src/garage/tf/algos/ddpg.py | 2 +- src/garage/tf/algos/dqn.py | 2 +- src/garage/tf/algos/td3.py | 2 +- src/garage/torch/algos/ddpg.py | 6 +- src/garage/torch/algos/dqn.py | 32 +++- src/garage/torch/algos/pearl.py | 4 +- src/garage/torch/algos/sac.py | 2 +- src/garage/torch/algos/td3.py | 2 +- .../replay_buffer/test_her_replay_buffer.py | 4 +- .../garage/replay_buffer/test_path_buffer.py | 10 +- .../replay_buffer/test_per_replay_buffer.py | 120 +++++++++++++++ tests/garage/torch/algos/test_dqn.py | 4 +- 18 files changed, 337 insertions(+), 35 deletions(-) create mode 100644 src/garage/replay_buffer/per_replay_buffer.py create mode 100644 tests/garage/replay_buffer/test_per_replay_buffer.py diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index bc2707d5b0..cf4e740c91 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -23,7 +23,7 @@ from garage.envs.wrappers.stack_frames import StackFrames from garage.experiment.deterministic import set_seed from garage.np.exploration_policies import EpsilonGreedyPolicy -from garage.replay_buffer import PathBuffer +from garage.replay_buffer import PERReplayBuffer from garage.sampler import FragmentWorker, LocalSampler from garage.torch import set_gpu_mode from garage.torch.algos import DQN @@ -40,6 +40,9 @@ n_train_steps=125, target_update_freq=2, buffer_batch_size=32, + double_q=True, + per_beta_init=0.4, + per_alpha=0.6, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.1, @@ -104,7 +107,7 @@ def main(env=None, # pylint: disable=unused-argument -@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30) +@wrap_experiment(snapshot_mode='none') def dqn_atari(ctxt=None, env=None, seed=24, @@ -150,8 +153,12 @@ def dqn_atari(ctxt=None, steps_per_epoch = hyperparams['steps_per_epoch'] sampler_batch_size = hyperparams['sampler_batch_size'] num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size - replay_buffer = PathBuffer( - capacity_in_transitions=hyperparams['buffer_size']) + + replay_buffer = PERReplayBuffer(hyperparams['buffer_size'], + num_timesteps, + env.spec, + alpha=hyperparams['per_alpha'], + beta_init=hyperparams['per_beta_init']) qf = DiscreteCNNQFunction( env_spec=env.spec, @@ -179,6 +186,7 @@ def dqn_atari(ctxt=None, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=hyperparams['lr'], + double_q=hyperparams['double_q'], clip_gradient=hyperparams['clip_gradient'], discount=hyperparams['discount'], min_buffer_size=hyperparams['min_buffer_size'], diff --git a/src/garage/envs/gym_env.py b/src/garage/envs/gym_env.py index 321fe0ecaa..4d5cf8f40a 100644 --- a/src/garage/envs/gym_env.py +++ b/src/garage/envs/gym_env.py @@ -13,7 +13,6 @@ # entry points don't close their viewer windows. KNOWN_GYM_NOT_CLOSE_VIEWER = [ # Please keep alphabetized - 'gym.envs.atari', 'gym.envs.box2d', 'gym.envs.classic_control' ] diff --git a/src/garage/replay_buffer/__init__.py b/src/garage/replay_buffer/__init__.py index c69c5a3780..6d0e8daa11 100644 --- a/src/garage/replay_buffer/__init__.py +++ b/src/garage/replay_buffer/__init__.py @@ -4,6 +4,7 @@ """ from garage.replay_buffer.her_replay_buffer import HERReplayBuffer from garage.replay_buffer.path_buffer import PathBuffer +from garage.replay_buffer.per_replay_buffer import PERReplayBuffer from garage.replay_buffer.replay_buffer import ReplayBuffer -__all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer'] +__all__ = ['PERReplayBuffer', 'ReplayBuffer', 'HERReplayBuffer', 'PathBuffer'] diff --git a/src/garage/replay_buffer/path_buffer.py b/src/garage/replay_buffer/path_buffer.py index 0f7d43e3f5..3d0adfa3c2 100644 --- a/src/garage/replay_buffer/path_buffer.py +++ b/src/garage/replay_buffer/path_buffer.py @@ -119,10 +119,15 @@ def sample_transitions(self, batch_size): Returns: dict: A dict of arrays of shape (batch_size, flat_dim). + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. """ idx = np.random.randint(self._transitions_stored, size=batch_size) - return {key: buf_arr[idx] for key, buf_arr in self._buffer.items()} + w = np.ones(batch_size) + data = {key: buf_arr[idx] for key, buf_arr in self._buffer.items()} + return data, w, idx def sample_timesteps(self, batch_size): """Sample a batch of timesteps from the buffer. @@ -132,9 +137,12 @@ def sample_timesteps(self, batch_size): Returns: TimeStepBatch: The batch of timesteps. + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. """ - samples = self.sample_transitions(batch_size) + samples, w, idx = self.sample_transitions(batch_size) step_types = np.array([ StepType.TERMINAL if terminal else StepType.MID for terminal in samples['terminals'].reshape(-1) @@ -147,7 +155,7 @@ def sample_timesteps(self, batch_size): next_observations=samples['next_observations'], step_types=step_types, env_infos={}, - agent_infos={}) + agent_infos={}), w, idx def _next_path_segments(self, n_indices): """Compute where the next path should be stored. diff --git a/src/garage/replay_buffer/per_replay_buffer.py b/src/garage/replay_buffer/per_replay_buffer.py new file mode 100644 index 0000000000..c51bd57ac7 --- /dev/null +++ b/src/garage/replay_buffer/per_replay_buffer.py @@ -0,0 +1,145 @@ +"""Prioritized Experience Replay.""" + +import numpy as np + +from garage import StepType, TimeStepBatch +from garage.replay_buffer.path_buffer import PathBuffer + + +class PERReplayBuffer(PathBuffer): + """Replay buffer for PER (Prioritized Experience Replay). + + PER assigns priorities to transitions in the buffer. Typically + these priority of each transition is proportional to the corresponding + loss computed at each update step. The priorities are then used to create + a probability distribution when sampling such that higher priority + transitions are sampled more frequently. For more see + https://arxiv.org/abs/1511.05952. + + Args: + capacity_in_transitions (int): total size of transitions in the buffer. + env_spec (EnvSpec): Environment specification. + total_timesteps (int): Total timesteps the experiment will run for. + This is used to calculate the beta parameter when sampling. + alpha (float): hyperparameter that controls the degree of + prioritization. Typically between [0, 1], where 0 corresponds to + no prioritization (uniform sampling). + beta_init (float): Initial value of beta exponent in importance + sampling. Beta is linearly annealed from beta_init to 1 + over total_timesteps. + """ + + def __init__(self, + capacity_in_transitions, + total_timesteps, + env_spec, + alpha=0.6, + beta_init=0.5): + self._alpha = alpha + self._beta_init = beta_init + self._total_timesteps = total_timesteps + self._curr_timestep = 0 + self._priorities = np.zeros((capacity_in_transitions, ), np.float32) + self._rng = np.random.default_rng() + super().__init__(capacity_in_transitions, env_spec) + + def sample_timesteps(self, batch_size): + """Sample a batch of timesteps from the buffer. + + Args: + batch_size (int): Number of timesteps to sample. + + Returns: + TimeStepBatch: The batch of timesteps. + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. + + """ + samples, w, idx = self.sample_transitions(batch_size) + step_types = np.array([ + StepType.TERMINAL if terminal else StepType.MID + for terminal in samples['terminals'].reshape(-1) + ], + dtype=StepType) + return TimeStepBatch(env_spec=self._env_spec, + observations=samples['observations'], + actions=samples['actions'], + rewards=samples['rewards'], + next_observations=samples['next_observations'], + step_types=step_types, + env_infos={}, + agent_infos={}), w, idx + + def sample_transitions(self, batch_size): + """Sample a batch of transitions from the buffer. + + Args: + batch_size (int): Number of transitions to sample. + + Returns: + dict: A dict of arrays of shape (batch_size, flat_dim). + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. + + """ + priorities = self._priorities + if self._transitions_stored < self._capacity: + priorities = self._priorities[:self._transitions_stored] + probs = priorities**self._alpha + probs /= probs.sum() + idx = self._rng.choice(self._transitions_stored, + size=batch_size, + p=probs) + + beta = self._beta_init + self._curr_timestep * ( + 1.0 - self._beta_init) / self._total_timesteps + beta = min(1.0, beta) + transitions = { + key: buf_arr[idx] + for key, buf_arr in self._buffer.items() + } + + w = (self._transitions_stored * probs[idx])**(-beta) + w /= w.max() + w = np.array(w) + + return transitions, w, idx + + def update_priorities(self, indices, priorities): + """Update priorities of timesteps. + + Args: + indices (np.ndarray): Array of indices corresponding to the + timesteps/priorities to update. + priorities (list[float]): new priorities to set. + + """ + for idx, priority in zip(indices, priorities): + self._priorities[int(idx)] = priority + + def add_path(self, path): + """Add a path to the buffer. + + This differs from the underlying buffer's add_path method + in that the priorities for the new samples are set to + the maximum of all priorities in the buffer. + + Args: + path (dict): A dict of array of shape (path_len, flat_dim). + + """ + path_len = len(path['observations']) + self._curr_timestep += path_len + + # find the indices where the path will be stored + first_seg, second_seg = self._next_path_segments(path_len) + + # set priorities for new timesteps = max(self._priorities) + # or 1 if buffer is empty + max_priority = self._priorities.max() or 1. + self._priorities[first_seg.start:first_seg.stop] = max_priority + if second_seg != range(0, 0): + self._priorities[second_seg.start:second_seg.stop] = max_priority + super().add_path(path) diff --git a/src/garage/replay_buffer/replay_buffer.py b/src/garage/replay_buffer/replay_buffer.py index d470755248..e3e4953265 100644 --- a/src/garage/replay_buffer/replay_buffer.py +++ b/src/garage/replay_buffer/replay_buffer.py @@ -55,6 +55,9 @@ def sample(self, batch_size): Args: batch_size(int): The number of transitions to be sampled. + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. """ raise NotImplementedError diff --git a/src/garage/tf/algos/ddpg.py b/src/garage/tf/algos/ddpg.py index 0da0ee42ba..a058a34f32 100644 --- a/src/garage/tf/algos/ddpg.py +++ b/src/garage/tf/algos/ddpg.py @@ -350,7 +350,7 @@ def _optimize_policy(self): float: Q value predicted by the q network. """ - timesteps = self._replay_buffer.sample_timesteps( + timesteps, _, _ = self._replay_buffer.sample_timesteps( self._buffer_batch_size) observations = timesteps.observations diff --git a/src/garage/tf/algos/dqn.py b/src/garage/tf/algos/dqn.py index 6c5ca95409..dd968fd8f1 100644 --- a/src/garage/tf/algos/dqn.py +++ b/src/garage/tf/algos/dqn.py @@ -258,7 +258,7 @@ def _optimize_policy(self): numpy.float64: Loss of policy. """ - timesteps = self._replay_buffer.sample_timesteps( + timesteps, _, _ = self._replay_buffer.sample_timesteps( self._buffer_batch_size) observations = timesteps.observations diff --git a/src/garage/tf/algos/td3.py b/src/garage/tf/algos/td3.py index 95fef35f06..f1ca997cde 100644 --- a/src/garage/tf/algos/td3.py +++ b/src/garage/tf/algos/td3.py @@ -371,7 +371,7 @@ def _optimize_policy(self, itr): float: Q value predicted by the q network. """ - timesteps = self._replay_buffer.sample_timesteps( + timesteps, _, _ = self._replay_buffer.sample_timesteps( self._buffer_batch_size) observations = timesteps.observations diff --git a/src/garage/torch/algos/ddpg.py b/src/garage/torch/algos/ddpg.py index 1c68ede465..e9166b445f 100644 --- a/src/garage/torch/algos/ddpg.py +++ b/src/garage/torch/algos/ddpg.py @@ -6,9 +6,7 @@ import numpy as np import torch -from garage import (_Default, - log_performance, - make_optimizer, +from garage import (_Default, log_performance, make_optimizer, obtain_evaluation_episodes) from garage.np.algos import RLAlgorithm from garage.sampler import FragmentWorker, LocalSampler @@ -188,7 +186,7 @@ def train_once(self, itr, episodes): for _ in range(self._n_train_steps): if (self.replay_buffer.n_transitions_stored >= self._min_buffer_size): - samples = self.replay_buffer.sample_transitions( + samples, _, _ = self.replay_buffer.sample_transitions( self._buffer_batch_size) samples['rewards'] *= self._reward_scale qf_loss, y, q, policy_loss = torch_to_np( diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index 5419c55218..d6ce5d6f4a 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -10,8 +10,9 @@ from garage import _Default, log_performance, make_optimizer from garage._functions import obtain_evaluation_episodes from garage.np.algos import RLAlgorithm +from garage.replay_buffer import PERReplayBuffer from garage.sampler import FragmentWorker -from garage.torch import global_device, np_to_torch +from garage.torch import global_device, np_to_torch, torch_to_np class DQN(RLAlgorithm): @@ -122,6 +123,9 @@ def __init__( self._qf_optimizer = make_optimizer(qf_optimizer, module=self._qf, lr=qf_lr) + + self._prioritized_replay = isinstance(self.replay_buffer, + PERReplayBuffer) self._eval_env = eval_env def train(self, trainer): @@ -192,10 +196,12 @@ def _train_once(self, itr, episodes): for _ in range(self._n_train_steps): if (self.replay_buffer.n_transitions_stored >= self._min_buffer_size): - timesteps = self.replay_buffer.sample_timesteps( - self._buffer_batch_size) - qf_loss, y, q = tuple(v.cpu().numpy() - for v in self._optimize_qf(timesteps)) + timesteps, weights, indices = ( + self.replay_buffer.sample_timesteps( + self._buffer_batch_size)) + qf_loss, y, q = tuple( + v.cpu().numpy() + for v in self._optimize_qf(timesteps, weights, indices)) self._episode_qf_losses.append(qf_loss) self._epoch_ys.append(y) @@ -228,11 +234,15 @@ def _log_eval_results(self, epoch): tabular.record('QFunction/AverageAbsY', np.mean(np.abs(self._epoch_ys))) - def _optimize_qf(self, timesteps): + def _optimize_qf(self, timesteps, weights=None, indices=None): """Perform algorithm optimizing. Args: timesteps (TimeStepBatch): Processed batch data. + weights (np.ndarray[float]): Weights used by PER when updating + the network. + indices (list[int or float]): Indices of the sampled + timesteps in the replay buffer. Returns: qval_loss: Loss of Q-value predicted by the Q-network. @@ -274,7 +284,15 @@ def _optimize_qf(self, timesteps): # optimize qf qvals = self._qf(inputs) selected_qs = torch.sum(qvals * actions, axis=1) - qval_loss = F.smooth_l1_loss(selected_qs, y_target) + qval_loss = F.smooth_l1_loss(selected_qs, y_target, reduction='none') + + if self._prioritized_replay: + qval_loss *= np_to_torch(weights) + priorities = qval_loss + 1e-5 # offset to avoid 0 priorities + priorities = torch_to_np(priorities.data.cpu()) + self.replay_buffer.update_priorities(indices, priorities) + + qval_loss = qval_loss.mean() self._qf_optimizer.zero_grad() qval_loss.backward() diff --git a/src/garage/torch/algos/pearl.py b/src/garage/torch/algos/pearl.py index 1be8392935..acc000fa5c 100644 --- a/src/garage/torch/algos/pearl.py +++ b/src/garage/torch/algos/pearl.py @@ -478,7 +478,7 @@ def _sample_data(self, indices): # transitions sampled randomly from replay buffer initialized = False for idx in indices: - batch = self._replay_buffers[idx].sample_transitions( + batch, _, _ = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] @@ -522,7 +522,7 @@ def _sample_context(self, indices): initialized = False for idx in indices: - batch = self._context_replay_buffers[idx].sample_transitions( + batch, _, _ = self._context_replay_buffers[idx].sample_transitions( self._embedding_batch_size) o = batch['observations'] a = batch['actions'] diff --git a/src/garage/torch/algos/sac.py b/src/garage/torch/algos/sac.py index 3ac0c7e814..066e42d30a 100644 --- a/src/garage/torch/algos/sac.py +++ b/src/garage/torch/algos/sac.py @@ -234,7 +234,7 @@ def train_once(self, itr=None, paths=None): del itr del paths if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: - samples = self.replay_buffer.sample_transitions( + samples, _, _ = self.replay_buffer.sample_transitions( self._buffer_batch_size) samples = dict_np_to_torch(samples) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples) diff --git a/src/garage/torch/algos/td3.py b/src/garage/torch/algos/td3.py index d1bdc7550e..7f1db668ca 100644 --- a/src/garage/torch/algos/td3.py +++ b/src/garage/torch/algos/td3.py @@ -236,7 +236,7 @@ def _train_once(self, itr): if (self._replay_buffer.n_transitions_stored >= self._min_buffer_size): # Sample from buffer - samples = self._replay_buffer.sample_transitions( + samples, _, _ = self._replay_buffer.sample_transitions( self._buffer_batch_size) samples = dict_np_to_torch(samples) diff --git a/tests/garage/replay_buffer/test_her_replay_buffer.py b/tests/garage/replay_buffer/test_her_replay_buffer.py index fbc2686249..e68e86c3e9 100644 --- a/tests/garage/replay_buffer/test_her_replay_buffer.py +++ b/tests/garage/replay_buffer/test_her_replay_buffer.py @@ -81,8 +81,8 @@ def test_pickleable(self): for k in replay_buffer_pickled._buffer: assert replay_buffer_pickled._buffer[ k].shape == self.replay_buffer._buffer[k].shape - sample = self.replay_buffer.sample_transitions(1) - sample2 = replay_buffer_pickled.sample_transitions(1) + sample, _, _ = self.replay_buffer.sample_transitions(1) + sample2, _, _ = replay_buffer_pickled.sample_transitions(1) for k in sample.keys(): assert sample[k].shape == sample2[k].shape assert len(sample) == len(sample2) diff --git a/tests/garage/replay_buffer/test_path_buffer.py b/tests/garage/replay_buffer/test_path_buffer.py index e44e8bd5af..a89c6f619e 100644 --- a/tests/garage/replay_buffer/test_path_buffer.py +++ b/tests/garage/replay_buffer/test_path_buffer.py @@ -66,7 +66,7 @@ def test_add_path_dtype(self): 'actions': np.array([[env.action_space.sample()]]) }) - sample = replay_buffer.sample_transitions(1) + sample, _, _ = replay_buffer.sample_transitions(1) sample_obs = sample['observations'] sample_action = sample['actions'] @@ -77,7 +77,7 @@ def test_episode_batch_to_timestep_batch(self, eps_data): t = EpisodeBatch(**eps_data) replay_buffer = PathBuffer(capacity_in_transitions=100) replay_buffer.add_episode_batch(t) - timesteps = replay_buffer.sample_timesteps(10) + timesteps, _, _ = replay_buffer.sample_timesteps(10) assert len(timesteps.rewards) == 10 def test_eviction_policy(self): @@ -85,7 +85,8 @@ def test_eviction_policy(self): replay_buffer = PathBuffer(capacity_in_transitions=3) replay_buffer.add_path(dict(obs=obs)) - sampled_obs = replay_buffer.sample_transitions(3)['obs'] + samples_data, _, _ = replay_buffer.sample_transitions(3) + sampled_obs = samples_data['obs'] assert (sampled_obs == np.array([[1], [1], [1]])).all() sampled_path_obs = replay_buffer.sample_path()['obs'] @@ -106,7 +107,8 @@ def test_eviction_policy(self): assert replay_buffer.add_path(dict(obs=obs4)) # Can still sample from old path - new_sampled_obs = replay_buffer.sample_transitions(1000)['obs'] + new_samples_data, _, _ = replay_buffer.sample_transitions(1000) + new_sampled_obs = new_samples_data['obs'] assert set(new_sampled_obs.flatten()) == {1, 2, 3} # Can't sample complete old path diff --git a/tests/garage/replay_buffer/test_per_replay_buffer.py b/tests/garage/replay_buffer/test_per_replay_buffer.py new file mode 100644 index 0000000000..1f47f12529 --- /dev/null +++ b/tests/garage/replay_buffer/test_per_replay_buffer.py @@ -0,0 +1,120 @@ +import akro +import numpy as np +import pytest + +from garage import EnvSpec, EpisodeBatch, StepType +from garage.replay_buffer import PERReplayBuffer + +from tests.fixtures.envs.dummy import DummyDiscreteEnv + + +@pytest.fixture +def setup(): + obs_space = akro.Box(low=1, high=np.inf, shape=(1, ), dtype=np.float32) + act_space = akro.Discrete(1) + env_spec = EnvSpec(obs_space, act_space) + buffer = PERReplayBuffer(100, 100, env_spec) + return buffer, DummyDiscreteEnv() + + +def test_add_path(setup): + buff, env = setup + obs = env.reset() + buff.add_path({'observations': np.array([obs for _ in range(5)])}) + + # initial priorities for inserted timesteps should be 1 + assert (buff._priorities[:5] == 1.).all() + assert (buff._priorities[5:] == 0.).all() + + # test case where buffer is full and paths are split + # into two segments + num_obs = buff._capacity - buff._transitions_stored + buff.add_path( + {'observations': np.array([obs for _ in range(num_obs - 1)])}) + + # artificially set the priority of a transition to be high . + # the next path added to the buffer should wrap around the buffer + # and contain one timestep at the end and 5 at the beginning, all + # of which should have priority == max(buff._priorities). + buff._priorities[-1] = 100. + buff.add_path({'observations': np.array([obs for _ in range(6)])}) + + assert buff._priorities[-1] == 100. + assert (buff._priorities[:5] == 100.).all() + + +def test_update_priorities(setup): + buff, env = setup + obs = env.reset() + buff.add_path({'observations': np.array([obs for _ in range(5)])}) + + assert (buff._priorities[:5] == 1.).all() + assert (buff._priorities[5:] == 0.).all() + + indices = list(range(2, 10)) + new_priorities = [0.5 for _ in range(2, 10)] + buff.update_priorities(indices, new_priorities) + + assert (buff._priorities[2:10] == 0.5).all() + assert (buff._priorities[:2] != 0.5).all() + assert (buff._priorities[10:] != 0.5).all() + + +@pytest.mark.parametrize('alpha, beta_init', [(0.5, 0.5), (0.4, 0.6), + (0.1, 0.9)]) +def test_sample_transitions(setup, alpha, beta_init): + buff, env = setup + obs = env.reset() + buff.add_path({ + 'observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + }) + + buff._beta_init = beta_init + buff._alpha = alpha + transitions, weights, indices = buff.sample_transitions(50) + obses = transitions['observations'] + + # verify the indices returned correspond to the correct samples + for o, i in zip(obses, indices): + assert (o == i).all() + + # verify the weights are correct + probs = buff._priorities**buff._alpha + probs /= probs.sum() + + beta = buff._beta_init + 50 * (1.0 - buff._beta_init) / 100 + beta = min(1.0, beta) + w = (50 * probs[indices])**(-beta) + w /= w.max() + w = np.array(w) + + assert (w == weights).all() + + +def test_sample_timesteps(setup): + buff, env = setup + obs = env.reset() + buff.add_path({ + 'observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'next_observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'actions': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'terminals': + np.array([[False] for _ in range(50)]), + 'rewards': + np.array([[1] for _ in range(50)]) + }) + + timesteps, weights, indices = buff.sample_timesteps(50) + + assert len(weights) == 50 + assert len(indices) == 50 + + obses, actions = timesteps.observations, timesteps.actions + + for a, o, i in zip(actions, obses, indices): + assert (o == i).all() + assert (a == i).all() diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index 03423633f7..f66768171a 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -86,7 +86,7 @@ def test_dqn_loss(setup): paths = trainer.obtain_episodes(0, batch_size=batch_size) buff.add_episode_batch(paths) - timesteps = buff.sample_timesteps(algo._buffer_batch_size) + timesteps, _, _ = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) observations = np_to_torch(timesteps.observations) @@ -130,7 +130,7 @@ def test_double_dqn_loss(setup): paths = trainer.obtain_episodes(0, batch_size=batch_size) buff.add_episode_batch(paths) - timesteps = buff.sample_timesteps(algo._buffer_batch_size) + timesteps, _, _ = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) observations = np_to_torch(timesteps.observations)