Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snapshot for preemption #155

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion examples/atari/reproduction/dqn/train_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def main():
" If it does not exist, it will be created."
),
)
parser.add_argument(
"--exp-id",
type=str,
default=None,
help="Experiment ID. If None, commit hash or timestamp is used.",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)")
parser.add_argument(
"--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU."
Expand Down Expand Up @@ -73,9 +79,22 @@ def main():
default=5 * 10 ** 4,
help="Minimum replay buffer size before " + "performing gradient updates.",
)
parser.add_argument(
"--save-snapshot",
action="store_true",
default=False,
help="Take resumable snapshot at every checkpoint",
)
parser.add_argument(
"--load-snapshot",
action="store_true",
default=False,
help="Load snapshot if exists",
)
parser.add_argument("--eval-n-steps", type=int, default=125000)
parser.add_argument("--eval-interval", type=int, default=250000)
parser.add_argument("--n-best-episodes", type=int, default=30)
parser.add_argument("--checkpoint-freq", type=int, default=2000000)
args = parser.parse_args()

import logging
Expand All @@ -89,7 +108,9 @@ def main():
train_seed = args.seed
test_seed = 2 ** 31 - 1 - args.seed

args.outdir = experiments.prepare_output_dir(args, args.outdir)
args.outdir = experiments.prepare_output_dir(
args, args.outdir, exp_id=args.exp_id, make_backup=args.exp_id is None
)
print("Output files are saved in {}".format(args.outdir))

def make_env(test):
Expand Down Expand Up @@ -162,6 +183,17 @@ def phi(x):
phi=phi,
)

# load snapshot
step_offset, episode_offset = 0, 0
max_score = None
if args.load_snapshot:
snapshot_dirname = experiments.latest_snapshot_dir(args.outdir)
if snapshot_dirname:
print(f"load snapshot from {snapshot_dirname}")
step_offset, episode_offset, max_score = experiments.load_snapshot(
agent, snapshot_dirname
)

if args.load or args.load_pretrained:
# either load or load_pretrained must be false
assert not args.load or not args.load_pretrained
Expand Down Expand Up @@ -194,6 +226,11 @@ def phi(x):
eval_n_steps=args.eval_n_steps,
eval_n_episodes=None,
eval_interval=args.eval_interval,
step_offset=step_offset,
episode_offset=episode_offset,
max_score=max_score,
checkpoint_freq=args.checkpoint_freq,
take_resumable_snapshot=args.save_snapshot,
outdir=args.outdir,
save_best_so_far_agent=True,
eval_env=eval_env,
Expand Down
8 changes: 8 additions & 0 deletions examples_tests/atari/reproduction/test_dqn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ gpu="$1"
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu
model=$(find $outdir/atari/reproduction/dqn -name "*_finish")
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu

# snapshot without eval
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --save-snapshot --checkpoint-freq 45
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --load-snapshot

# snapshot after eval
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4600 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --save-snapshot --checkpoint-freq 4000
python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4700 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --load-snapshot
6 changes: 5 additions & 1 deletion pfrl/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA
from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA
from pfrl.experiments.train_agent import train_agent # NOQA
from pfrl.experiments.train_agent import train_agent_with_evaluation # NOQA
from pfrl.experiments.train_agent import ( # NOQA
latest_snapshot_dir,
load_snapshot,
train_agent_with_evaluation,
)
from pfrl.experiments.train_agent_async import train_agent_async # NOQA
from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA
from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA
5 changes: 4 additions & 1 deletion pfrl/experiments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,10 @@ def write_header(outdir, agent, env):
"max", # maximum value of returns of evaluation runs
"min", # minimum value of returns of evaluation runs
)
with open(os.path.join(outdir, "scores.txt"), "w") as f:
fp = os.path.join(outdir, "scores.txt")
if os.path.exists(fp) and os.stat(fp).st_size > 0:
return
with open(fp, "w") as f:
custom_columns = tuple(t[0] for t in agent.get_statistics())
env_get_stats = getattr(env, "get_statistics", lambda: [])
assert callable(env_get_stats)
Expand Down
111 changes: 108 additions & 3 deletions pfrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import csv
import logging
import os
import shutil
import time

from pfrl.experiments.evaluator import Evaluator, save_agent
from pfrl.utils.ask_yes_no import ask_yes_no
Expand All @@ -21,14 +24,96 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""):
save_agent_replay_buffer(agent, t, outdir, suffix=suffix)


def snapshot(
agent,
t,
episode_idx,
outdir,
suffix="_snapshot",
logger=None,
delete_old=True,
):
start_time = time.time()
tmp_suffix = f"{suffix}_"
tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # use until files are saved
agent.save(tmp_dirname)
if hasattr(agent, "replay_buffer"):
agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl"))
if os.path.exists(os.path.join(outdir, "scores.txt")):
shutil.copyfile(
os.path.join(outdir, "scores.txt"), os.path.join(tmp_dirname, "scores.txt")
)

history_path = os.path.join(outdir, "snapshot_history.txt")
if not os.path.exists(history_path): # write header
with open(history_path, "a") as f:
csv.writer(f, delimiter="\t").writerow(["step", "episode", "snapshot_time"])
with open(history_path, "a") as f:
csv.writer(f, delimiter="\t").writerow(
[t, episode_idx, time.time() - start_time]
)
shutil.copyfile(history_path, os.path.join(tmp_dirname, "snapshot_history.txt"))

real_dirname = os.path.join(outdir, f"{t}{suffix}")
os.rename(tmp_dirname, real_dirname)
if logger:
logger.info(f"Saved the snapshot to {real_dirname}")
if delete_old:
for old_dir in filter(
lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir)
):
if old_dir != f"{t}{suffix}":
shutil.rmtree(os.path.join(outdir, old_dir))


def load_snapshot(agent, dirname, logger=None):
agent.load(dirname)
if hasattr(agent, "replay_buffer"):
agent.replay_buffer.load(os.path.join(dirname, "replay.pkl"))
if logger:
logger.info(f"Loaded the snapshot from {dirname}")
with open(os.path.join(dirname, "snapshot_history.txt")) as f:
step, episode = map(int, f.readlines()[-1].split()[:2])
max_score = None
if os.path.exists(os.path.join(dirname, "scores.txt")):
with open(os.path.join(dirname, "scores.txt")) as f:
lines = f.readlines()
if len(lines) > 1:
max_score = float(lines[-1].split()[3]) # mean
shutil.copyfile(
os.path.join(dirname, "snapshot_history.txt"),
os.path.join(dirname, "..", "snapshot_history.txt"),
)
shutil.copyfile(
os.path.join(dirname, "scores.txt"),
os.path.join(dirname, "..", "scores.txt"),
)
return step, episode, max_score


def latest_snapshot_dir(search_dir, suffix="_snapshot"):
"""
return None if no snapshot exists
"""
candidates = list(filter(lambda s: s.endswith(suffix), os.listdir(search_dir)))
if len(candidates) == 0:
return None
return os.path.join(
search_dir, max(candidates, key=lambda name: int(name.split("_")[0]))
)


def train_agent(
agent,
env,
steps,
outdir,
checkpoint_freq=None,
take_resumable_snapshot=False,
max_episode_len=None,
step_offset=0,
episode_offset=0,
max_score=None,
evaluator=None,
successful_score=None,
step_hooks=(),
Expand All @@ -38,8 +123,12 @@ def train_agent(

logger = logger or logging.getLogger(__name__)

# restore max_score
if evaluator and max_score:
evaluator.max_score = max_score

episode_r = 0
episode_idx = 0
episode_idx = episode_offset

# o_0, r_0
obs = env.reset()
Expand Down Expand Up @@ -100,7 +189,10 @@ def train_agent(
episode_len = 0
obs = env.reset()
if checkpoint_freq and t % checkpoint_freq == 0:
save_agent(agent, t, outdir, logger, suffix="_checkpoint")
if take_resumable_snapshot:
snapshot(agent, t, episode_idx, outdir, logger=logger)
else:
save_agent(agent, t, outdir, logger, suffix="_checkpoint")

except (Exception, KeyboardInterrupt):
# Save the current model before being killed
Expand All @@ -122,9 +214,12 @@ def train_agent_with_evaluation(
eval_interval,
outdir,
checkpoint_freq=None,
take_resumable_snapshot=False,
train_max_episode_len=None,
step_offset=0,
episode_offset=0,
eval_max_episode_len=None,
max_score=None,
eval_env=None,
successful_score=None,
step_hooks=(),
Expand All @@ -144,11 +239,18 @@ def train_agent_with_evaluation(
eval_n_episodes (int): Number of episodes at each evaluation phase.
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output data.
checkpoint_freq (int): frequency at which agents are stored.
checkpoint_freq (int): frequency in step at which agents are stored.
take_resumable_snapshot (bool): If True, snapshot is saved in checkpoint.
Note that currently, snapshot does not support agent analytics (e.g.,
for DQN, average_q, average_loss, cumulative_steps, and n_updates) and
those valued in "scores.txt" might be incorrect after resuming from
snapshot.
train_max_episode_len (int): Maximum episode length during training.
step_offset (int): Time step from which training starts.
episode_offset (int): Episode index from which training starts,
eval_max_episode_len (int or None): Maximum episode length of
evaluation runs. If None, train_max_episode_len is used instead.
max_score (int): Current max socre.
eval_env: Environment used for evaluation.
successful_score (float): Finish training if the mean score is greater
than or equal to this value if not None
Expand Down Expand Up @@ -211,8 +313,11 @@ def train_agent_with_evaluation(
steps,
outdir,
checkpoint_freq=checkpoint_freq,
take_resumable_snapshot=take_resumable_snapshot,
max_episode_len=train_max_episode_len,
step_offset=step_offset,
episode_offset=episode_offset,
max_score=max_score,
evaluator=evaluator,
successful_score=successful_score,
step_hooks=step_hooks,
Expand Down