Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hindsight Experience Replay Buffer #84

Open
wants to merge 48 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
52cec56
Adds basic code for bitflip DQN and adds basic code for HER
prabhatnagarajan Jul 10, 2020
6859259
Adds hindsight to bit flip
prabhatnagarajan Jul 11, 2020
120cfa5
removes null_goals
prabhatnagarajan Jul 11, 2020
3360faf
modifies total steps
prabhatnagarajan Jul 11, 2020
1e57205
Updates space sampling
prabhatnagarajan Jul 11, 2020
cc06b25
Merge branch 'master' into her
prabhatnagarajan Jul 17, 2020
a60e1f5
Cleans hindsight buffer code
prabhatnagarajan Jul 17, 2020
eece248
Modifies experiment params
prabhatnagarajan Jul 19, 2020
2111563
Merge branch 'master' into her
prabhatnagarajan Jul 31, 2020
d9cd9eb
Merge branch 'master' into her
prabhatnagarajan Oct 28, 2020
5a2dfd8
Merges with eval_stats_collector
prabhatnagarajan Oct 28, 2020
e38e0d0
Applies black to pfrl
prabhatnagarajan Oct 28, 2020
d89c788
Updates docstring
prabhatnagarajan Oct 28, 2020
080916f
Implements step function and success rate calculation
prabhatnagarajan Oct 28, 2020
c363dc7
Updates agent, explorer, replay start size, and phi
prabhatnagarajan Oct 28, 2020
4e15a76
Applies black
prabhatnagarajan Oct 28, 2020
71809a5
Updates optimizer, and target update interval
prabhatnagarajan Oct 28, 2020
8c616e5
Fixes minor errors
prabhatnagarajan Oct 28, 2020
9721d1a
Applies black
prabhatnagarajan Oct 29, 2020
8643e0d
Addresses flakes
prabhatnagarajan Oct 29, 2020
4d34f1e
Cleans up code
prabhatnagarajan Oct 29, 2020
f5a1bfa
Update examples/her/train_dqn_bit_flip.py
prabhatnagarajan Oct 29, 2020
3812384
experiment and hyperparameter update
prabhatnagarajan Oct 30, 2020
5cd21e0
Switches parse args
prabhatnagarajan Oct 30, 2020
035ad63
Applies black
prabhatnagarajan Oct 30, 2020
e481b85
Adds HER to the Repo readme
prabhatnagarajan Nov 5, 2020
573c7a2
Merge branch 'master' into her
prabhatnagarajan Nov 5, 2020
ed4ae2e
Applies isort
prabhatnagarajan Nov 5, 2020
9841438
Make DDPG HER work for FetchReach-v1
muupan Nov 6, 2020
d61d1dc
Start updates earlier to match performance of baselines
muupan Nov 9, 2020
1c9d308
Merge pull request #2 from muupan/her-fetch
prabhatnagarajan Nov 10, 2020
18177a4
Adds Fetch DDPG to readme
prabhatnagarajan Nov 10, 2020
383585f
Updates descriptions for args in bit flip
prabhatnagarajan Nov 10, 2020
88380f0
Updates docs in DDPG Fetch example
prabhatnagarajan Nov 10, 2020
453b04b
Minor cleanup of hindsight replay strategies
prabhatnagarajan Nov 10, 2020
3c41b85
Merge branch 'her' into her_fetch_updates
prabhatnagarajan Nov 10, 2020
33c2d09
Merge branch 'master' into her
prabhatnagarajan Nov 12, 2020
35b07ce
Merge branch 'her' into her_fetch_updates
prabhatnagarajan Nov 12, 2020
0a2efc6
Adds bit flip to examples tests
prabhatnagarajan Nov 12, 2020
eaa01e4
Applies black
prabhatnagarajan Nov 12, 2020
977993e
Fixes merge conflicts
prabhatnagarajan Nov 12, 2020
5e67f4a
Merge branch 'master' into her
prabhatnagarajan Jan 7, 2021
e3e99f6
Merge branch 'master' into her
prabhatnagarajan Jan 29, 2021
4150555
Merge branch 'master' into her
prabhatnagarajan Mar 12, 2021
ef230a4
Adds HER readme and tests
prabhatnagarajan Mar 22, 2021
ac952f3
Applies black
prabhatnagarajan Mar 30, 2021
35f085c
Adds HER to replay buffer tests
prabhatnagarajan Mar 30, 2021
15d8a21
Merge branch 'master' into her
prabhatnagarajan Jul 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions examples/her/train_dqn_bit_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import argparse

import gym
import gym.spaces as spaces
import numpy as np
import torch
import torch.nn as nn

from pfrl.q_functions import DiscreteActionValueHead
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers

from pfrl.initializers import init_chainer_default


def reward_fn(dg, ag):
return -1.0 if (ag != dg).any() else 0.0


class BitFlip(gym.GoalEnv):
"""BitFlip environment from https://arxiv.org/pdf/1707.01495.pdf

Args:
n: State space is {0,1}^n
"""

def __init__(self, n):
self.n = n
self.action_space = spaces.Discrete(n)
self.observation_space = spaces.Dict(
dict(
desired_goal=spaces.MultiBinary(n),
achieved_goal=spaces.MultiBinary(n),
observation=spaces.MultiBinary(n),
)
)
self.clear_statistics()

def compute_reward(self, achieved_goal, desired_goal, info):
return reward_fn(desired_goal, achieved_goal)

def _check_done(self):
success = (
self.observation["desired_goal"] == self.observation["achieved_goal"]
).all()
return (self.steps >= self.n) or success, success
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should use the wrapper ContinuingTimeLimit actually. Exceeding the number of timesteps should not be a terminal signal.


def step(self, action):
# Compute action outcome
bit_new = int(not self.observation["observation"][action])
new_obs = self.observation["observation"].copy()
new_obs[action] = bit_new
# Set new observation
dg = self.observation["desired_goal"]
self.observation["desired_goal"] = dg.copy()
self.observation["achieved_goal"] = new_obs
self.observation["observation"] = new_obs
prabhatnagarajan marked this conversation as resolved.
Show resolved Hide resolved

reward = self.compute_reward(
self.observation["achieved_goal"], self.observation["desired_goal"], {}
)
self.steps += 1
done, success = self._check_done()
assert success == (reward == 0)
if done:
result = 1 if success else 0
self.results.append(result)
return self.observation, reward, done, {}

def reset(self):
sample_obs = self.observation_space.sample()
state, goal = sample_obs["observation"], sample_obs["desired_goal"]
while (state == goal).all():
sample_obs = self.observation_space.sample()
state, goal = sample_obs["observation"], sample_obs["desired_goal"]
self.observation = dict()
self.observation["desired_goal"] = goal
self.observation["achieved_goal"] = state
self.observation["observation"] = state
self.steps = 0
return self.observation

def get_statistics(self):
failures = self.results.count(0)
successes = self.results.count(1)
assert len(self.results) == failures + successes
if not self.results:
return [("success_rate", None)]
success_rate = successes / float(len(self.results))
return [("success_rate", success_rate)]

def clear_statistics(self):
self.results = []


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--outdir",
type=str,
default="results",
help=(
"Directory path to save output files."
" If it does not exist, it will be created."
),
)
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)")
parser.add_argument(
"--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU."
)
parser.add_argument("--demo", action="store_true", default=False)
parser.add_argument("--load", type=str, default=None)
parser.add_argument(
"--log-level",
type=int,
default=20,
help="Logging level. 10:DEBUG, 20:INFO etc.",
)
parser.add_argument(
"--steps",
type=int,
default=10 ** 7,
help="Total number of timesteps to train the agent.",
)
parser.add_argument(
"--replay-start-size",
type=int,
default=5 * 10 ** 2,
help="Minimum replay buffer size before " + "performing gradient updates.",
)
parser.add_argument(
"--num-bits",
type=int,
default=10,
help="Number of bits for BitFlipping environment",
)
parser.add_argument("--use-hindsight", type=bool, default=True)
parser.add_argument("--eval-n-episodes", type=int, default=100)
parser.add_argument("--eval-interval", type=int, default=250000)
parser.add_argument("--n-best-episodes", type=int, default=100)
args = parser.parse_args()

import logging

logging.basicConfig(level=args.log_level)

# Set a random seed used in PFRL.
utils.set_random_seed(args.seed)

# Set different random seeds for train and test envs.
train_seed = args.seed
test_seed = 2 ** 31 - 1 - args.seed

args.outdir = experiments.prepare_output_dir(args, args.outdir)
print("Output files are saved in {}".format(args.outdir))

def make_env(test):
# Use different random seeds for train and test envs
env_seed = test_seed if test else train_seed
env = BitFlip(args.num_bits)
env.seed(int(env_seed))
return env

env = make_env(test=False)
eval_env = make_env(test=True)

n_actions = env.action_space.n
q_func = nn.Sequential(
init_chainer_default(nn.Linear(args.num_bits * 2, 256)),
nn.ReLU(),
init_chainer_default(nn.Linear(256, n_actions)),
DiscreteActionValueHead(),
)

opt = torch.optim.Adam(q_func.parameters(), eps=1e-3)

if args.use_hindsight:
rbuf = replay_buffers.hindsight.HindsightReplayBuffer(
reward_fn=reward_fn,
replay_strategy=replay_buffers.hindsight.ReplayFutureGoal(),
capacity=10 ** 6,
)
else:
rbuf = replay_buffers.ReplayBuffer(10 ** 6)

explorer = explorers.LinearDecayEpsilonGreedy(
start_epsilon=0.3,
end_epsilon=0.0,
decay_steps=5 * 10 ** 3,
random_action_func=lambda: np.random.randint(n_actions),
)

def phi(observation):
# Feature extractor
obs = np.asarray(observation["observation"], dtype=np.float32)
dg = np.asarray(observation["desired_goal"], dtype=np.float32)
return np.concatenate((obs, dg))

Agent = agents.DoubleDQN
agent = Agent(
q_func,
opt,
rbuf,
gpu=args.gpu,
gamma=0.99,
explorer=explorer,
replay_start_size=args.replay_start_size,
target_update_interval=10 ** 3,
clip_delta=True,
update_interval=4,
batch_accumulator="sum",
phi=phi,
)

if args.load:
agent.load(args.load)

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None
)
print(
"n_episodes: {} mean: {} median: {} stdev {}".format(
eval_stats["episodes"],
eval_stats["mean"],
eval_stats["median"],
eval_stats["stdev"],
)
)
else:
experiments.train_agent_with_evaluation(
agent=agent,
env=env,
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_episodes,
eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=True,
eval_env=eval_env,
)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions pfrl/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pfrl.replay_buffers.episodic import EpisodicReplayBuffer # NOQA
from pfrl.replay_buffers.hindsight import HindsightReplayStrategy # NOQA
from pfrl.replay_buffers.hindsight import HindsightReplayBuffer # NOQA
from pfrl.replay_buffers.persistent import PersistentEpisodicReplayBuffer # NOQA
from pfrl.replay_buffers.persistent import PersistentReplayBuffer # NOQA
from pfrl.replay_buffers.prioritized import PrioritizedReplayBuffer # NOQA
Expand All @@ -7,3 +9,5 @@
PrioritizedEpisodicReplayBuffer,
)
from pfrl.replay_buffers.replay_buffer import ReplayBuffer # NOQA
from pfrl.replay_buffers.hindsight import ReplayFinalGoal # NOQA
from pfrl.replay_buffers.hindsight import ReplayFutureGoal # NOQA
Loading