-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hindsight Experience Replay Buffer #84
Open
prabhatnagarajan
wants to merge
48
commits into
pfnet:master
Choose a base branch
from
prabhatnagarajan:her
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
52cec56
Adds basic code for bitflip DQN and adds basic code for HER
prabhatnagarajan 6859259
Adds hindsight to bit flip
prabhatnagarajan 120cfa5
removes null_goals
prabhatnagarajan 3360faf
modifies total steps
prabhatnagarajan 1e57205
Updates space sampling
prabhatnagarajan cc06b25
Merge branch 'master' into her
prabhatnagarajan a60e1f5
Cleans hindsight buffer code
prabhatnagarajan eece248
Modifies experiment params
prabhatnagarajan 2111563
Merge branch 'master' into her
prabhatnagarajan d9cd9eb
Merge branch 'master' into her
prabhatnagarajan 5a2dfd8
Merges with eval_stats_collector
prabhatnagarajan e38e0d0
Applies black to pfrl
prabhatnagarajan d89c788
Updates docstring
prabhatnagarajan 080916f
Implements step function and success rate calculation
prabhatnagarajan c363dc7
Updates agent, explorer, replay start size, and phi
prabhatnagarajan 4e15a76
Applies black
prabhatnagarajan 71809a5
Updates optimizer, and target update interval
prabhatnagarajan 8c616e5
Fixes minor errors
prabhatnagarajan 9721d1a
Applies black
prabhatnagarajan 8643e0d
Addresses flakes
prabhatnagarajan 4d34f1e
Cleans up code
prabhatnagarajan f5a1bfa
Update examples/her/train_dqn_bit_flip.py
prabhatnagarajan 3812384
experiment and hyperparameter update
prabhatnagarajan 5cd21e0
Switches parse args
prabhatnagarajan 035ad63
Applies black
prabhatnagarajan e481b85
Adds HER to the Repo readme
prabhatnagarajan 573c7a2
Merge branch 'master' into her
prabhatnagarajan ed4ae2e
Applies isort
prabhatnagarajan 9841438
Make DDPG HER work for FetchReach-v1
muupan d61d1dc
Start updates earlier to match performance of baselines
muupan 1c9d308
Merge pull request #2 from muupan/her-fetch
prabhatnagarajan 18177a4
Adds Fetch DDPG to readme
prabhatnagarajan 383585f
Updates descriptions for args in bit flip
prabhatnagarajan 88380f0
Updates docs in DDPG Fetch example
prabhatnagarajan 453b04b
Minor cleanup of hindsight replay strategies
prabhatnagarajan 3c41b85
Merge branch 'her' into her_fetch_updates
prabhatnagarajan 33c2d09
Merge branch 'master' into her
prabhatnagarajan 35b07ce
Merge branch 'her' into her_fetch_updates
prabhatnagarajan 0a2efc6
Adds bit flip to examples tests
prabhatnagarajan eaa01e4
Applies black
prabhatnagarajan 977993e
Fixes merge conflicts
prabhatnagarajan 5e67f4a
Merge branch 'master' into her
prabhatnagarajan e3e99f6
Merge branch 'master' into her
prabhatnagarajan 4150555
Merge branch 'master' into her
prabhatnagarajan ef230a4
Adds HER readme and tests
prabhatnagarajan ac952f3
Applies black
prabhatnagarajan 35f085c
Adds HER to replay buffer tests
prabhatnagarajan 15d8a21
Merge branch 'master' into her
prabhatnagarajan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
import argparse | ||
|
||
import gym | ||
import gym.spaces as spaces | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from pfrl.q_functions import DiscreteActionValueHead | ||
from pfrl import agents | ||
from pfrl import experiments | ||
from pfrl import explorers | ||
from pfrl import utils | ||
from pfrl import replay_buffers | ||
|
||
from pfrl.initializers import init_chainer_default | ||
|
||
|
||
def reward_fn(dg, ag): | ||
return -1.0 if (ag != dg).any() else 0.0 | ||
|
||
|
||
class BitFlip(gym.GoalEnv): | ||
"""BitFlip environment from https://arxiv.org/pdf/1707.01495.pdf | ||
|
||
Args: | ||
n: State space is {0,1}^n | ||
""" | ||
|
||
def __init__(self, n): | ||
self.n = n | ||
self.action_space = spaces.Discrete(n) | ||
self.observation_space = spaces.Dict( | ||
dict( | ||
desired_goal=spaces.MultiBinary(n), | ||
achieved_goal=spaces.MultiBinary(n), | ||
observation=spaces.MultiBinary(n), | ||
) | ||
) | ||
self.clear_statistics() | ||
|
||
def compute_reward(self, achieved_goal, desired_goal, info): | ||
return reward_fn(desired_goal, achieved_goal) | ||
|
||
def _check_done(self): | ||
success = ( | ||
self.observation["desired_goal"] == self.observation["achieved_goal"] | ||
).all() | ||
return (self.steps >= self.n) or success, success | ||
|
||
def step(self, action): | ||
# Compute action outcome | ||
bit_new = int(not self.observation["observation"][action]) | ||
new_obs = self.observation["observation"].copy() | ||
new_obs[action] = bit_new | ||
# Set new observation | ||
dg = self.observation["desired_goal"] | ||
self.observation["desired_goal"] = dg.copy() | ||
self.observation["achieved_goal"] = new_obs | ||
self.observation["observation"] = new_obs | ||
prabhatnagarajan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
reward = self.compute_reward( | ||
self.observation["achieved_goal"], self.observation["desired_goal"], {} | ||
) | ||
self.steps += 1 | ||
done, success = self._check_done() | ||
assert success == (reward == 0) | ||
if done: | ||
result = 1 if success else 0 | ||
self.results.append(result) | ||
return self.observation, reward, done, {} | ||
|
||
def reset(self): | ||
sample_obs = self.observation_space.sample() | ||
state, goal = sample_obs["observation"], sample_obs["desired_goal"] | ||
while (state == goal).all(): | ||
sample_obs = self.observation_space.sample() | ||
state, goal = sample_obs["observation"], sample_obs["desired_goal"] | ||
self.observation = dict() | ||
self.observation["desired_goal"] = goal | ||
self.observation["achieved_goal"] = state | ||
self.observation["observation"] = state | ||
self.steps = 0 | ||
return self.observation | ||
|
||
def get_statistics(self): | ||
failures = self.results.count(0) | ||
successes = self.results.count(1) | ||
assert len(self.results) == failures + successes | ||
if not self.results: | ||
return [("success_rate", None)] | ||
success_rate = successes / float(len(self.results)) | ||
return [("success_rate", success_rate)] | ||
|
||
def clear_statistics(self): | ||
self.results = [] | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--outdir", | ||
type=str, | ||
default="results", | ||
help=( | ||
"Directory path to save output files." | ||
" If it does not exist, it will be created." | ||
), | ||
) | ||
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") | ||
parser.add_argument( | ||
"--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU." | ||
) | ||
parser.add_argument("--demo", action="store_true", default=False) | ||
parser.add_argument("--load", type=str, default=None) | ||
parser.add_argument( | ||
"--log-level", | ||
type=int, | ||
default=20, | ||
help="Logging level. 10:DEBUG, 20:INFO etc.", | ||
) | ||
parser.add_argument( | ||
"--steps", | ||
type=int, | ||
default=10 ** 7, | ||
help="Total number of timesteps to train the agent.", | ||
) | ||
parser.add_argument( | ||
"--replay-start-size", | ||
type=int, | ||
default=5 * 10 ** 2, | ||
help="Minimum replay buffer size before " + "performing gradient updates.", | ||
) | ||
parser.add_argument( | ||
"--num-bits", | ||
type=int, | ||
default=10, | ||
help="Number of bits for BitFlipping environment", | ||
) | ||
parser.add_argument("--use-hindsight", type=bool, default=True) | ||
parser.add_argument("--eval-n-episodes", type=int, default=100) | ||
parser.add_argument("--eval-interval", type=int, default=250000) | ||
parser.add_argument("--n-best-episodes", type=int, default=100) | ||
args = parser.parse_args() | ||
|
||
import logging | ||
|
||
logging.basicConfig(level=args.log_level) | ||
|
||
# Set a random seed used in PFRL. | ||
utils.set_random_seed(args.seed) | ||
|
||
# Set different random seeds for train and test envs. | ||
train_seed = args.seed | ||
test_seed = 2 ** 31 - 1 - args.seed | ||
|
||
args.outdir = experiments.prepare_output_dir(args, args.outdir) | ||
print("Output files are saved in {}".format(args.outdir)) | ||
|
||
def make_env(test): | ||
# Use different random seeds for train and test envs | ||
env_seed = test_seed if test else train_seed | ||
env = BitFlip(args.num_bits) | ||
env.seed(int(env_seed)) | ||
return env | ||
|
||
env = make_env(test=False) | ||
eval_env = make_env(test=True) | ||
|
||
n_actions = env.action_space.n | ||
q_func = nn.Sequential( | ||
init_chainer_default(nn.Linear(args.num_bits * 2, 256)), | ||
nn.ReLU(), | ||
init_chainer_default(nn.Linear(256, n_actions)), | ||
DiscreteActionValueHead(), | ||
) | ||
|
||
opt = torch.optim.Adam(q_func.parameters(), eps=1e-3) | ||
|
||
if args.use_hindsight: | ||
rbuf = replay_buffers.hindsight.HindsightReplayBuffer( | ||
reward_fn=reward_fn, | ||
replay_strategy=replay_buffers.hindsight.ReplayFutureGoal(), | ||
capacity=10 ** 6, | ||
) | ||
else: | ||
rbuf = replay_buffers.ReplayBuffer(10 ** 6) | ||
|
||
explorer = explorers.LinearDecayEpsilonGreedy( | ||
start_epsilon=0.3, | ||
end_epsilon=0.0, | ||
decay_steps=5 * 10 ** 3, | ||
random_action_func=lambda: np.random.randint(n_actions), | ||
) | ||
|
||
def phi(observation): | ||
# Feature extractor | ||
obs = np.asarray(observation["observation"], dtype=np.float32) | ||
dg = np.asarray(observation["desired_goal"], dtype=np.float32) | ||
return np.concatenate((obs, dg)) | ||
|
||
Agent = agents.DoubleDQN | ||
agent = Agent( | ||
q_func, | ||
opt, | ||
rbuf, | ||
gpu=args.gpu, | ||
gamma=0.99, | ||
explorer=explorer, | ||
replay_start_size=args.replay_start_size, | ||
target_update_interval=10 ** 3, | ||
clip_delta=True, | ||
update_interval=4, | ||
batch_accumulator="sum", | ||
phi=phi, | ||
) | ||
|
||
if args.load: | ||
agent.load(args.load) | ||
|
||
if args.demo: | ||
eval_stats = experiments.eval_performance( | ||
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None | ||
) | ||
print( | ||
"n_episodes: {} mean: {} median: {} stdev {}".format( | ||
eval_stats["episodes"], | ||
eval_stats["mean"], | ||
eval_stats["median"], | ||
eval_stats["stdev"], | ||
) | ||
) | ||
else: | ||
experiments.train_agent_with_evaluation( | ||
agent=agent, | ||
env=env, | ||
steps=args.steps, | ||
eval_n_steps=None, | ||
eval_n_episodes=args.eval_n_episodes, | ||
eval_interval=args.eval_interval, | ||
outdir=args.outdir, | ||
save_best_so_far_agent=True, | ||
eval_env=eval_env, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I should use the wrapper
ContinuingTimeLimit
actually. Exceeding the number of timesteps should not be a terminal signal.