From 78f5ef76fb117e2ba24af1056de95f5f37a5bd70 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 30 Aug 2021 10:14:18 +0000 Subject: [PATCH 01/16] Save agent and replay_buffer as snapshot --- pfrl/experiments/__init__.py | 2 +- pfrl/experiments/train_agent.py | 52 +++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index e4e79e49d..a14a6f9b5 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -7,7 +7,7 @@ from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA from pfrl.experiments.train_agent import train_agent # NOQA -from pfrl.experiments.train_agent import train_agent_with_evaluation # NOQA +from pfrl.experiments.train_agent import train_agent_with_evaluation, load_snapshot, latest_snapshot_dir # NOQA from pfrl.experiments.train_agent_async import train_agent_async # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 210c7ed24..1e10f7647 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -1,5 +1,8 @@ import logging import os +import shutil + +import numpy as np from pfrl.experiments.evaluator import Evaluator, save_agent from pfrl.utils.ask_yes_no import ask_yes_no @@ -21,6 +24,43 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): save_agent_replay_buffer(agent, t, outdir, suffix=suffix) +def snapshot(agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delete_old=True): + tmp_suffix = f"{suffix}_" + dirname = os.path.join(outdir, f"{t}{suffix}") + tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # temporary filename until files are saved + agent.save(tmp_dirname) + if hasattr(agent, "replay_buffer"): + agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl")) + if evaluator: + np.save(os.path.join(tmp_dirname, "max_score"), evaluator.max_score) + os.rename(tmp_dirname, dirname) + if logger: + logger.info(f"Saved the snapshot to {dirname}") + if delete_old: + for old_dir in filter(lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir)): + if old_dir != f"{t}{suffix}": + shutil.rmtree(os.path.join(outdir, old_dir)) + + +def load_snapshot(agent, dirname, logger=None): + agent.load(dirname) + if hasattr(agent, "replay_buffer"): + agent.replay_buffer.load(os.path.join(dirname, "replay.pkl")) + if logger: + logger.info(f"Loaded the snapshot from {dirname}") + + +def latest_snapshot_dir(search_dir, suffix="_snapshot"): + """ + returns (dirname, steps) + (None, 0) if no snapshot exists + """ + candidates = list(filter(lambda s: s.endswith(suffix), os.listdir(search_dir))) + if len(candidates) == 0: + return 0, None + return max([(int(name.split("_")[0]), os.path.join(search_dir, name)) for name in candidates]) + + def train_agent( agent, env, @@ -29,6 +69,7 @@ def train_agent( checkpoint_freq=None, max_episode_len=None, step_offset=0, + max_score=None, evaluator=None, successful_score=None, step_hooks=(), @@ -38,6 +79,10 @@ def train_agent( logger = logger or logging.getLogger(__name__) + # restore max_score + if evaluator and max_score: + evaluator.max_score = max_score + episode_r = 0 episode_idx = 0 @@ -100,7 +145,7 @@ def train_agent( episode_len = 0 obs = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: - save_agent(agent, t, outdir, logger, suffix="_checkpoint") + snapshot(agent, evaluator, t, outdir, logger=logger) except (Exception, KeyboardInterrupt): # Save the current model before being killed @@ -125,6 +170,7 @@ def train_agent_with_evaluation( train_max_episode_len=None, step_offset=0, eval_max_episode_len=None, + max_score=None, eval_env=None, successful_score=None, step_hooks=(), @@ -144,11 +190,12 @@ def train_agent_with_evaluation( eval_n_episodes (int): Number of episodes at each evaluation phase. eval_interval (int): Interval of evaluation. outdir (str): Path to the directory to output data. - checkpoint_freq (int): frequency at which agents are stored. + checkpoint_freq (int): frequency in step at which agents are stored. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. eval_max_episode_len (int or None): Maximum episode length of evaluation runs. If None, train_max_episode_len is used instead. + max_score (int): Current max socre. eval_env: Environment used for evaluation. successful_score (float): Finish training if the mean score is greater than or equal to this value if not None @@ -213,6 +260,7 @@ def train_agent_with_evaluation( checkpoint_freq=checkpoint_freq, max_episode_len=train_max_episode_len, step_offset=step_offset, + max_score=max_score, evaluator=evaluator, successful_score=successful_score, step_hooks=step_hooks, From 5070ddcf41bca4ad4d08d6a1ef18f9da6cd68608 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 30 Aug 2021 10:24:13 +0000 Subject: [PATCH 02/16] Add load_snapshot option to train_dqn_gym.py --- examples/gym/train_dqn_gym.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index acbf8cc70..c72c90b5a 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -60,6 +60,7 @@ def main(): parser.add_argument("--update-interval", type=int, default=1) parser.add_argument("--eval-n-runs", type=int, default=100) parser.add_argument("--eval-interval", type=int, default=10 ** 4) + parser.add_argument("--checkpoint-freq", type=int, default=10 ** 4) parser.add_argument("--n-hidden-channels", type=int, default=100) parser.add_argument("--n-hidden-layers", type=int, default=2) parser.add_argument("--gamma", type=float, default=0.99) @@ -68,6 +69,7 @@ def main(): parser.add_argument("--render-eval", action="store_true") parser.add_argument("--monitor", action="store_true") parser.add_argument("--reward-scale-factor", type=float, default=1e-3) + parser.add_argument('--load_snapshot', action='store_true') parser.add_argument( "--actor-learner", action="store_true", @@ -187,6 +189,16 @@ def make_env(idx=0, test=False): soft_update_tau=args.soft_update_tau, ) + # load snapshot + step_offset = 0 + max_score = None + if args.load_snapshot: + step_offset, snapshot_dirname = experiments.latest_snapshot_dir(args.outdir) + if snapshot_dirname: + experiments.load_snapshot(agent, snapshot_dirname) + if os.path.exists(os.path.join(snapshot_dirname, "max_score.npy")): + max_score = np.load(os.path.join(snapshot_dirname, "max_score.npy")) + if args.load: agent.load(args.load) @@ -225,6 +237,9 @@ def make_env(idx=0, test=False): eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, outdir=args.outdir, + step_offset=step_offset, + max_score=max_score, + checkpoint_freq=args.checkpoint_freq, eval_env=eval_env, train_max_episode_len=timestep_limit, eval_during_episode=True, @@ -258,6 +273,9 @@ def make_env(idx=0, test=False): eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, outdir=args.outdir, + step_offset=step_offset, + max_score=max_score, + checkpoint_freq=args.checkpoint_freq, stop_event=learner.stop_event, exception_event=exception_event, ) From 27bf4979eec3619ef143a0b9803be480bc2d2528 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 30 Aug 2021 10:25:02 +0000 Subject: [PATCH 03/16] Temporarily fix output directory --- examples/gym/train_dqn_gym.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index c72c90b5a..b69c298ef 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -90,6 +90,9 @@ def main(): utils.set_random_seed(args.seed) args.outdir = experiments.prepare_output_dir(args, args.outdir, argv=sys.argv) + # TODO: fix output directory for now!! + args.outdir = "results/tmp" + os.makedirs(args.outdir, exist_ok=True) print("Output files are saved in {}".format(args.outdir)) # Set different random seeds for different subprocesses. From 7cf0dd9be3cf02a6552fcb17130a4c98f68353b9 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 30 Aug 2021 10:36:03 +0000 Subject: [PATCH 04/16] Format --- examples/gym/train_dqn_gym.py | 2 +- pfrl/experiments/__init__.py | 6 +++++- pfrl/experiments/train_agent.py | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index b69c298ef..60ed30117 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -69,7 +69,7 @@ def main(): parser.add_argument("--render-eval", action="store_true") parser.add_argument("--monitor", action="store_true") parser.add_argument("--reward-scale-factor", type=float, default=1e-3) - parser.add_argument('--load_snapshot', action='store_true') + parser.add_argument("--load_snapshot", action="store_true") parser.add_argument( "--actor-learner", action="store_true", diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index a14a6f9b5..04bd7d3ab 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -7,7 +7,11 @@ from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA from pfrl.experiments.train_agent import train_agent # NOQA -from pfrl.experiments.train_agent import train_agent_with_evaluation, load_snapshot, latest_snapshot_dir # NOQA +from pfrl.experiments.train_agent import ( + train_agent_with_evaluation, + load_snapshot, + latest_snapshot_dir, +) # NOQA from pfrl.experiments.train_agent_async import train_agent_async # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 1e10f7647..4fae7d56a 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -24,10 +24,14 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): save_agent_replay_buffer(agent, t, outdir, suffix=suffix) -def snapshot(agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delete_old=True): +def snapshot( + agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delete_old=True +): tmp_suffix = f"{suffix}_" dirname = os.path.join(outdir, f"{t}{suffix}") - tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # temporary filename until files are saved + tmp_dirname = os.path.join( + outdir, f"{t}{tmp_suffix}" + ) # temporary filename until files are saved agent.save(tmp_dirname) if hasattr(agent, "replay_buffer"): agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl")) @@ -37,7 +41,9 @@ def snapshot(agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delet if logger: logger.info(f"Saved the snapshot to {dirname}") if delete_old: - for old_dir in filter(lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir)): + for old_dir in filter( + lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir) + ): if old_dir != f"{t}{suffix}": shutil.rmtree(os.path.join(outdir, old_dir)) @@ -58,7 +64,12 @@ def latest_snapshot_dir(search_dir, suffix="_snapshot"): candidates = list(filter(lambda s: s.endswith(suffix), os.listdir(search_dir))) if len(candidates) == 0: return 0, None - return max([(int(name.split("_")[0]), os.path.join(search_dir, name)) for name in candidates]) + return max( + [ + (int(name.split("_")[0]), os.path.join(search_dir, name)) + for name in candidates + ] + ) def train_agent( From e6bdd04e4899a1518d4b11c0946bd7d2a7d23f92 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 9 Sep 2021 10:27:15 +0000 Subject: [PATCH 05/16] Revert examples/gym/train_dqn_gym.py --- examples/gym/train_dqn_gym.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index 60ed30117..acbf8cc70 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -60,7 +60,6 @@ def main(): parser.add_argument("--update-interval", type=int, default=1) parser.add_argument("--eval-n-runs", type=int, default=100) parser.add_argument("--eval-interval", type=int, default=10 ** 4) - parser.add_argument("--checkpoint-freq", type=int, default=10 ** 4) parser.add_argument("--n-hidden-channels", type=int, default=100) parser.add_argument("--n-hidden-layers", type=int, default=2) parser.add_argument("--gamma", type=float, default=0.99) @@ -69,7 +68,6 @@ def main(): parser.add_argument("--render-eval", action="store_true") parser.add_argument("--monitor", action="store_true") parser.add_argument("--reward-scale-factor", type=float, default=1e-3) - parser.add_argument("--load_snapshot", action="store_true") parser.add_argument( "--actor-learner", action="store_true", @@ -90,9 +88,6 @@ def main(): utils.set_random_seed(args.seed) args.outdir = experiments.prepare_output_dir(args, args.outdir, argv=sys.argv) - # TODO: fix output directory for now!! - args.outdir = "results/tmp" - os.makedirs(args.outdir, exist_ok=True) print("Output files are saved in {}".format(args.outdir)) # Set different random seeds for different subprocesses. @@ -192,16 +187,6 @@ def make_env(idx=0, test=False): soft_update_tau=args.soft_update_tau, ) - # load snapshot - step_offset = 0 - max_score = None - if args.load_snapshot: - step_offset, snapshot_dirname = experiments.latest_snapshot_dir(args.outdir) - if snapshot_dirname: - experiments.load_snapshot(agent, snapshot_dirname) - if os.path.exists(os.path.join(snapshot_dirname, "max_score.npy")): - max_score = np.load(os.path.join(snapshot_dirname, "max_score.npy")) - if args.load: agent.load(args.load) @@ -240,9 +225,6 @@ def make_env(idx=0, test=False): eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, outdir=args.outdir, - step_offset=step_offset, - max_score=max_score, - checkpoint_freq=args.checkpoint_freq, eval_env=eval_env, train_max_episode_len=timestep_limit, eval_during_episode=True, @@ -276,9 +258,6 @@ def make_env(idx=0, test=False): eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval, outdir=args.outdir, - step_offset=step_offset, - max_score=max_score, - checkpoint_freq=args.checkpoint_freq, stop_event=learner.stop_event, exception_event=exception_event, ) From 62000832fc72f66c7f0ab7c16dbe459a5069457d Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 9 Sep 2021 10:29:45 +0000 Subject: [PATCH 06/16] Snapshot only when take_resumable_snapshot is True --- pfrl/experiments/train_agent.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 4fae7d56a..e2ab18a3a 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -78,6 +78,7 @@ def train_agent( steps, outdir, checkpoint_freq=None, + take_resumable_snapshot=False, max_episode_len=None, step_offset=0, max_score=None, @@ -156,7 +157,10 @@ def train_agent( episode_len = 0 obs = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: - snapshot(agent, evaluator, t, outdir, logger=logger) + if take_resumable_snapshot: + snapshot(agent, evaluator, t, outdir, logger=logger) + else: + save_agent(agent, t, outdir, logger, suffix="_checkpoint") except (Exception, KeyboardInterrupt): # Save the current model before being killed @@ -178,6 +182,7 @@ def train_agent_with_evaluation( eval_interval, outdir, checkpoint_freq=None, + take_resumable_snapshot=False, train_max_episode_len=None, step_offset=0, eval_max_episode_len=None, @@ -269,6 +274,7 @@ def train_agent_with_evaluation( steps, outdir, checkpoint_freq=checkpoint_freq, + take_resumable_snapshot=take_resumable_snapshot, max_episode_len=train_max_episode_len, step_offset=step_offset, max_score=max_score, From 2a4b12cf7a9f940d41e9520d2b02f3f2c7fecf6f Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 9 Sep 2021 10:39:45 +0000 Subject: [PATCH 07/16] Refactor --- pfrl/experiments/train_agent.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index e2ab18a3a..a8918f7ac 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -28,18 +28,16 @@ def snapshot( agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delete_old=True ): tmp_suffix = f"{suffix}_" - dirname = os.path.join(outdir, f"{t}{suffix}") - tmp_dirname = os.path.join( - outdir, f"{t}{tmp_suffix}" - ) # temporary filename until files are saved + tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # use until files are saved agent.save(tmp_dirname) if hasattr(agent, "replay_buffer"): agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl")) if evaluator: np.save(os.path.join(tmp_dirname, "max_score"), evaluator.max_score) - os.rename(tmp_dirname, dirname) + real_dirname = os.path.join(outdir, f"{t}{suffix}") + os.rename(tmp_dirname, real_dirname) if logger: - logger.info(f"Saved the snapshot to {dirname}") + logger.info(f"Saved the snapshot to {real_dirname}") if delete_old: for old_dir in filter( lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir) From ab9abb0c0e9478f94acd808fd1253a0d4c4cdfe9 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Sep 2021 09:34:11 +0000 Subject: [PATCH 08/16] Do not write header in scores.txt if exists --- pfrl/experiments/evaluator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index 75691784c..12664a211 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -384,7 +384,10 @@ def write_header(outdir, agent, env): "max", # maximum value of returns of evaluation runs "min", # minimum value of returns of evaluation runs ) - with open(os.path.join(outdir, "scores.txt"), "w") as f: + fp = os.path.join(outdir, "scores.txt") + if os.path.exists(fp) and os.stat(fp).st_size > 0: + return + with open(fp, "w") as f: custom_columns = tuple(t[0] for t in agent.get_statistics()) env_get_stats = getattr(env, "get_statistics", lambda: []) assert callable(env_get_stats) From aa488944ffb5a4ccb59693204e939f0de9b132d1 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Sep 2021 09:44:26 +0000 Subject: [PATCH 09/16] Add episode_offset to train_agent_with_evaluation --- pfrl/experiments/train_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index a8918f7ac..ef9f76e22 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -79,6 +79,7 @@ def train_agent( take_resumable_snapshot=False, max_episode_len=None, step_offset=0, + episode_offset=0, max_score=None, evaluator=None, successful_score=None, @@ -94,7 +95,7 @@ def train_agent( evaluator.max_score = max_score episode_r = 0 - episode_idx = 0 + episode_idx = episode_offset # o_0, r_0 obs = env.reset() @@ -183,6 +184,7 @@ def train_agent_with_evaluation( take_resumable_snapshot=False, train_max_episode_len=None, step_offset=0, + episode_offset=0, eval_max_episode_len=None, max_score=None, eval_env=None, @@ -207,6 +209,7 @@ def train_agent_with_evaluation( checkpoint_freq (int): frequency in step at which agents are stored. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. + episode_offset (int): Episode index from which training starts, eval_max_episode_len (int or None): Maximum episode length of evaluation runs. If None, train_max_episode_len is used instead. max_score (int): Current max socre. @@ -275,6 +278,7 @@ def train_agent_with_evaluation( take_resumable_snapshot=take_resumable_snapshot, max_episode_len=train_max_episode_len, step_offset=step_offset, + episode_offset=episode_offset, max_score=max_score, evaluator=evaluator, successful_score=successful_score, From 8e9fcaf3d4b59bf20cfa0db0e2c8b290c4618e9f Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Sep 2021 09:46:04 +0000 Subject: [PATCH 10/16] Add comment of take_resumable_snapshot --- pfrl/experiments/train_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index ef9f76e22..660b2ddb4 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -207,6 +207,7 @@ def train_agent_with_evaluation( eval_interval (int): Interval of evaluation. outdir (str): Path to the directory to output data. checkpoint_freq (int): frequency in step at which agents are stored. + take_resumable_snapshot (bool): If True, snapshot is saved in checkpoint. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. episode_offset (int): Episode index from which training starts, From e72a4244a8d8e708d5d852ca2e0a14274015eaf8 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Sep 2021 09:49:27 +0000 Subject: [PATCH 11/16] Save scores.txt and snapshot_history in snapshot --- pfrl/experiments/train_agent.py | 59 +++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 660b2ddb4..c96e6c08a 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -1,8 +1,8 @@ import logging import os import shutil - -import numpy as np +import csv +import time from pfrl.experiments.evaluator import Evaluator, save_agent from pfrl.utils.ask_yes_no import ask_yes_no @@ -25,15 +25,35 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): def snapshot( - agent, evaluator, t, outdir, suffix="_snapshot", logger=None, delete_old=True + agent, + t, + episode_idx, + outdir, + suffix="_snapshot", + logger=None, + delete_old=True, ): + start_time = time.time() tmp_suffix = f"{suffix}_" tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # use until files are saved agent.save(tmp_dirname) if hasattr(agent, "replay_buffer"): agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl")) - if evaluator: - np.save(os.path.join(tmp_dirname, "max_score"), evaluator.max_score) + if os.path.exists(os.path.join(outdir, "scores.txt")): + shutil.copyfile( + os.path.join(outdir, "scores.txt"), os.path.join(tmp_dirname, "scores.txt") + ) + + history_path = os.path.join(outdir, "snapshot_history.txt") + if not os.path.exists(history_path): # write header + with open(history_path, "a") as f: + csv.writer(f, delimiter="\t").writerow(["step", "episode", "snapshot_time"]) + with open(history_path, "a") as f: + csv.writer(f, delimiter="\t").writerow( + [t, episode_idx, time.time() - start_time] + ) + shutil.copyfile(history_path, os.path.join(tmp_dirname, "snapshot_history.txt")) + real_dirname = os.path.join(outdir, f"{t}{suffix}") os.rename(tmp_dirname, real_dirname) if logger: @@ -52,21 +72,32 @@ def load_snapshot(agent, dirname, logger=None): agent.replay_buffer.load(os.path.join(dirname, "replay.pkl")) if logger: logger.info(f"Loaded the snapshot from {dirname}") + with open(os.path.join(dirname, "snapshot_history.txt")) as f: + step, episode = map(int, f.readlines()[-1].split()[:2]) + max_score = None + if os.path.exists(os.path.join(dirname, "scores.txt")): + with open(os.path.join(dirname, "scores.txt")) as f: + max_score = float(f.readlines()[-1].split()[3]) # mean + shutil.copyfile( + os.path.join(dirname, "snapshot_history.txt"), + os.path.join(dirname, "..", "snapshot_history.txt"), + ) + shutil.copyfile( + os.path.join(dirname, "scores.txt"), + os.path.join(dirname, "..", "scores.txt"), + ) + return step, episode, max_score def latest_snapshot_dir(search_dir, suffix="_snapshot"): """ - returns (dirname, steps) - (None, 0) if no snapshot exists + return None if no snapshot exists """ candidates = list(filter(lambda s: s.endswith(suffix), os.listdir(search_dir))) if len(candidates) == 0: - return 0, None - return max( - [ - (int(name.split("_")[0]), os.path.join(search_dir, name)) - for name in candidates - ] + return None + return os.path.join( + search_dir, max(candidates, key=lambda name: int(name.split("_")[0])) ) @@ -157,7 +188,7 @@ def train_agent( obs = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: if take_resumable_snapshot: - snapshot(agent, evaluator, t, outdir, logger=logger) + snapshot(agent, t, episode_idx, outdir, logger=logger) else: save_agent(agent, t, outdir, logger, suffix="_checkpoint") From c975ffce94d9ac2afeb27cb5884154b7efbd0c65 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Sep 2021 09:53:05 +0000 Subject: [PATCH 12/16] Add save & load snapshot option in atari example --- examples/atari/reproduction/dqn/train_dqn.py | 39 +++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/examples/atari/reproduction/dqn/train_dqn.py b/examples/atari/reproduction/dqn/train_dqn.py index 72c210ad5..6402201bb 100644 --- a/examples/atari/reproduction/dqn/train_dqn.py +++ b/examples/atari/reproduction/dqn/train_dqn.py @@ -31,6 +31,12 @@ def main(): " If it does not exist, it will be created." ), ) + parser.add_argument( + "--exp-id", + type=str, + default=None, + help="Experiment ID. If None, commit hash or timestamp is used.", + ) parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") parser.add_argument( "--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU." @@ -73,9 +79,22 @@ def main(): default=5 * 10 ** 4, help="Minimum replay buffer size before " + "performing gradient updates.", ) + parser.add_argument( + "--save-snapshot", + action="store_true", + default=False, + help="Take resumable snapshot at every checkpoint", + ) + parser.add_argument( + "--load-snapshot", + action="store_true", + default=False, + help="Load snapshot if exists", + ) parser.add_argument("--eval-n-steps", type=int, default=125000) parser.add_argument("--eval-interval", type=int, default=250000) parser.add_argument("--n-best-episodes", type=int, default=30) + parser.add_argument("--checkpoint-freq", type=int, default=2000000) args = parser.parse_args() import logging @@ -89,7 +108,9 @@ def main(): train_seed = args.seed test_seed = 2 ** 31 - 1 - args.seed - args.outdir = experiments.prepare_output_dir(args, args.outdir) + args.outdir = experiments.prepare_output_dir( + args, args.outdir, exp_id=args.exp_id, make_backup=args.exp_id is None + ) print("Output files are saved in {}".format(args.outdir)) def make_env(test): @@ -162,6 +183,17 @@ def phi(x): phi=phi, ) + # load snapshot + step_offset, episode_offset = 0, 0 + max_score = None + if args.load_snapshot: + snapshot_dirname = experiments.latest_snapshot_dir(args.outdir) + if snapshot_dirname: + print(f"load snapshot from {snapshot_dirname}") + step_offset, episode_offset, max_score = experiments.load_snapshot( + agent, snapshot_dirname + ) + if args.load or args.load_pretrained: # either load or load_pretrained must be false assert not args.load or not args.load_pretrained @@ -194,6 +226,11 @@ def phi(x): eval_n_steps=args.eval_n_steps, eval_n_episodes=None, eval_interval=args.eval_interval, + step_offset=step_offset, + episode_offset=episode_offset, + max_score=max_score, + checkpoint_freq=args.checkpoint_freq, + take_resumable_snapshot=args.save_snapshot, outdir=args.outdir, save_best_so_far_agent=True, eval_env=eval_env, From c4ed3df97dcc59a7c95cbea196179d02de511c64 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Sep 2021 10:32:15 +0000 Subject: [PATCH 13/16] Fix bug when scores.txt has only header --- pfrl/experiments/train_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index c96e6c08a..d0a82aad5 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -77,7 +77,9 @@ def load_snapshot(agent, dirname, logger=None): max_score = None if os.path.exists(os.path.join(dirname, "scores.txt")): with open(os.path.join(dirname, "scores.txt")) as f: - max_score = float(f.readlines()[-1].split()[3]) # mean + lines = f.readlines() + if len(lines) > 1: + max_score = float(lines[-1].split()[3]) # mean shutil.copyfile( os.path.join(dirname, "snapshot_history.txt"), os.path.join(dirname, "..", "snapshot_history.txt"), From ca5c450063a826b5c8d10b6b2a924f0185372c65 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Sep 2021 10:33:26 +0000 Subject: [PATCH 14/16] Add snapshot test --- examples_tests/atari/reproduction/test_dqn.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples_tests/atari/reproduction/test_dqn.sh b/examples_tests/atari/reproduction/test_dqn.sh index faa0e9368..3e99633ac 100644 --- a/examples_tests/atari/reproduction/test_dqn.sh +++ b/examples_tests/atari/reproduction/test_dqn.sh @@ -10,3 +10,11 @@ gpu="$1" python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu model=$(find $outdir/atari/reproduction/dqn -name "*_finish") python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu + +# snapshot without eval +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --save-snapshot --checkpoint-freq 45 +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --load-snapshot + +# snapshot after eval +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4600 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --save-snapshot --checkpoint-freq 4000 +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4700 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --load-snapshot From dde7ebf388ecfef023fb5abff0e05ec573c5da85 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Sep 2021 10:40:07 +0000 Subject: [PATCH 15/16] Add warnings of agent analytics in docstring --- pfrl/experiments/train_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index d0a82aad5..cc587080f 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -241,6 +241,10 @@ def train_agent_with_evaluation( outdir (str): Path to the directory to output data. checkpoint_freq (int): frequency in step at which agents are stored. take_resumable_snapshot (bool): If True, snapshot is saved in checkpoint. + Note that currently, snapshot does not support agent analytics (e.g., + for DQN, average_q, average_loss, cumulative_steps, and n_updates) and + those valued in "scores.txt" might be incorrect after resuming from + snapshot. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. episode_offset (int): Episode index from which training starts, From 2d4e67dc4cdd2427ee29833444065e5606110a5b Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Sep 2021 12:16:26 +0000 Subject: [PATCH 16/16] Apply isort --- pfrl/experiments/__init__.py | 8 ++++---- pfrl/experiments/train_agent.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index 04bd7d3ab..20f307aaf 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -7,11 +7,11 @@ from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA from pfrl.experiments.train_agent import train_agent # NOQA -from pfrl.experiments.train_agent import ( - train_agent_with_evaluation, - load_snapshot, +from pfrl.experiments.train_agent import ( # NOQA latest_snapshot_dir, -) # NOQA + load_snapshot, + train_agent_with_evaluation, +) from pfrl.experiments.train_agent_async import train_agent_async # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index cc587080f..e89031413 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -1,7 +1,7 @@ +import csv import logging import os import shutil -import csv import time from pfrl.experiments.evaluator import Evaluator, save_agent