diff --git a/README.md b/README.md
index 24680f61..9e5c3ccb 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,6 @@
- Benchmark of the algorithms is conducted in many RL environment
-
## :arrow_down: Installation
```
@@ -30,7 +29,7 @@
## :rocket: QuickStart
-
+
@@ -42,24 +41,33 @@
+## :mag: How to
+
+- [How to use](./docs/How_to_use.md)
+- [How to customize config](./config/README.md)
+- [How to customize agent](./core/agent/README.md)
+- [How to customize environment](./core/env/README.md)
+- [How to customize network](./core/network/README.md)
+- [How to customize buffer](./core/buffer/README.md)
+
+
+
## :page_facing_up: Documentation
-- [Implementation List](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/Implementation_list.md)
-- [Benchmark](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/Benchmark.md)
-- [Distributed Architecture](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/Distributed_Architecture.md)
-- [Reference](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/Reference.md)
+- [Distributed Architecture](./docs/Distributed_Architecture.md)
+- [Role of Managers](./manager/README.md)
+- [Implementation List](./docs/Implementation_list.md)
+- [Naming Convention](./docs/Naming_convention.md)
+- [Benchmark](https://www.notion.so/rlnote/Benchmark-c7642d152cad4980bc03fe804fe9e88a)
+- [Reference](./docs/Reference.md)
-- [How to use](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/How_to_use.md)
-- [How to add RL algorithm](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/How_to_add_rl_algorithm.md)
-- [How to add environment](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/How_to_add_environment.md)
-- [How to add network](https://github.kakaocorp.com/leonard-q/RL_Algorithms/blob/master/docs/How_to_add_network.md)
## :busts_in_silhouette: Contributors
-:mailbox: Contact: [Leonard.Q](leonard.q@kakaoenterprise.com), [Ramanuzan.Lee](ramanuzan.lee@kakaoenterprise.com), [Royce.Choi](royce.choi@kakaoenterprise.com)
+:mailbox: Contact: atech.rl@kakaocorp.com
-
+
## :copyright: License
diff --git a/async_distributed_train.py b/async_distributed_train.py
index f5cb8a8b..abbae91e 100644
--- a/async_distributed_train.py
+++ b/async_distributed_train.py
@@ -37,7 +37,7 @@
path_queue = mp.Queue(1)
record_period = config.train.record_period if config.train.record_period else config.train.run_step//10
- test_manager_config = (Env(**config.env), config.train.test_iteration, config.train.record, record_period)
+ eval_manager_config = (Env(**config.env), config.train.eval_iteration, config.train.record, record_period)
log_id = config.train.id if config.train.id else config.agent.name
log_manager_config = (config.env.name, log_id, config.train.experiment)
agent_config['device'] = "cpu"
@@ -45,7 +45,7 @@
args=(Agent, agent_config,
result_queue, manage_sync_queue, path_queue,
config.train.run_step, config.train.print_period,
- MetricManager, TestManager, test_manager_config,
+ MetricManager, EvalManager, eval_manager_config,
LogManager, log_manager_config, config_manager))
distributed_manager_config = (Env, config.env, Agent, agent_config, config.train.num_workers, 'async')
interact = mp.Process(target=interact_process,
diff --git a/config/README.md b/config/README.md
new file mode 100644
index 00000000..9865f733
--- /dev/null
+++ b/config/README.md
@@ -0,0 +1,43 @@
+# How to customize config
+
+## Config file management rules
+- The config file provided by default is mainly managed in the form of config/\[agent\]/\[env\].py.
+- For a specific environment group that shares parameters, manage it in the form of config/\[agent\]/\[env_group\], and specify the environment name with --env.name in the run command.
+
+reference: [dqn/cartpole.py](./dqn/cartpole.py), [dqn/atari.py](./dqn/atari.py)
+
+## Config setting
+- The config file is managed with a total of four dictionary variables: agent, env, optim, and train.
+
+ ### agent
+ - The agent dictionary manages input parameters used by the agent class.
+ - name: The key of the agent class you want to use.
+ - others: You can check it in the agent class.
+
+ ### env
+ - The env dictionary manages input parameters used by the env class.
+ - name: The key of the env class you want to use.
+ - others: You can check it in the env class.
+
+ ### optim
+ - The optim dictionary manages input parameters used by the optimizer class. Since the optimizer of pytorch is used as it is, any optimizer supported by pytorch can be used.
+ - name: The key of the optimizer class you want to use.
+ - others: You can check it in the optimizer class supported by pytorch.
+
+ ### train
+ - The optim dictionary manages parameters used in the main script.
+ - training: It means whether to learn. Set to False in the eval.py script and True otherwise.
+ - load_path: It means the path to load the model. If you want to load the model or in the eval.py script, you need to set it. If not, set it None.
+ - run_step: It determines the total number of interactions to proceed.
+ - print_period: It means the cycle(unit=step) to print the progress.
+ - save_period: It means the cycle(unit=step) to save the model.
+ - eval_iteration: It means how many episodes will be run in total to get the evaluation score.
+ - record: It means whether to record the simulation as the evaluation proceeds. If you set it True, simulation is saved as a gif file in save_path. If you set it True and env is recordable, simulation is saved as a gif file in save_path. (Note that this does not work for non-recordable environments.)
+ - record_period: It means the cycle(unit=step) to record.
+ - distributed_batch_size: In distributed script, uses distributed_batch_size instead of agent.batch_size.
+ - update_period: It means the cycle(unit=step) in which actors pass transition data to learner.
+ - num_workers: Total number of distributed actors which interact with env.
+
+ __distributed_batch_size, update_period and num_workers are only used in distributed scripts.__
+
+reference: [ppo/atari.py](./ppo/atari.py)
\ No newline at end of file
diff --git a/config/ape_x/atari.py b/config/ape_x/atari.py
index 262f84f4..975d6540 100644
--- a/config/ape_x/atari.py
+++ b/config/ape_x/atari.py
@@ -43,7 +43,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/ape_x/cartpole.py b/config/ape_x/cartpole.py
index e874fcef..5f14d6ef 100644
--- a/config/ape_x/cartpole.py
+++ b/config/ape_x/cartpole.py
@@ -36,7 +36,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : 16,
diff --git a/config/ape_x/pong_mlagent.py b/config/ape_x/pong_mlagent.py
index e4310661..e437183a 100644
--- a/config/ape_x/pong_mlagent.py
+++ b/config/ape_x/pong_mlagent.py
@@ -35,7 +35,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : 16,
diff --git a/config/ape_x/procgen.py b/config/ape_x/procgen.py
index 624249bc..011d324e 100644
--- a/config/ape_x/procgen.py
+++ b/config/ape_x/procgen.py
@@ -40,7 +40,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/c51/atari.py b/config/c51/atari.py
index 51f22798..d65a55e4 100644
--- a/config/c51/atari.py
+++ b/config/c51/atari.py
@@ -41,7 +41,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/c51/cartpole.py b/config/c51/cartpole.py
index 71565f76..54a50f8e 100644
--- a/config/c51/cartpole.py
+++ b/config/c51/cartpole.py
@@ -34,7 +34,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/c51/pong_mlagent.py b/config/c51/pong_mlagent.py
index d95edfe8..8430cf4d 100644
--- a/config/c51/pong_mlagent.py
+++ b/config/c51/pong_mlagent.py
@@ -33,7 +33,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/c51/procgen.py b/config/c51/procgen.py
index 5e5474b1..d6218c3b 100644
--- a/config/c51/procgen.py
+++ b/config/c51/procgen.py
@@ -38,7 +38,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/ddpg/cartpole.py b/config/ddpg/cartpole.py
index 065c960a..ca212ed6 100644
--- a/config/ddpg/cartpole.py
+++ b/config/ddpg/cartpole.py
@@ -34,7 +34,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period": 1,
"num_workers": 8,
diff --git a/config/ddpg/hopper_mlagent.py b/config/ddpg/hopper_mlagent.py
index e4aad883..e6ae4517 100644
--- a/config/ddpg/hopper_mlagent.py
+++ b/config/ddpg/hopper_mlagent.py
@@ -33,7 +33,7 @@
"run_step" : 300000,
"print_period" : 5000,
"save_period" : 10000,
- "test_iteration" : 10,
+ "eval_iteration" : 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 1,
diff --git a/config/ddpg/pendulum.py b/config/ddpg/pendulum.py
index 13cf4cd3..f921e76d 100644
--- a/config/ddpg/pendulum.py
+++ b/config/ddpg/pendulum.py
@@ -33,7 +33,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 128,
"update_period" : 1,
diff --git a/config/double/atari.py b/config/double/atari.py
index ce72ef8e..25b8e0bb 100644
--- a/config/double/atari.py
+++ b/config/double/atari.py
@@ -37,7 +37,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/double/cartpole.py b/config/double/cartpole.py
index ebe883fa..83bfa5aa 100644
--- a/config/double/cartpole.py
+++ b/config/double/cartpole.py
@@ -30,7 +30,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/double/pong_mlagent.py b/config/double/pong_mlagent.py
index bfaa1b8c..e0b59985 100644
--- a/config/double/pong_mlagent.py
+++ b/config/double/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 2000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/double/procgen.py b/config/double/procgen.py
index ada1c0e6..8a7694cb 100644
--- a/config/double/procgen.py
+++ b/config/double/procgen.py
@@ -34,7 +34,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/dqn/atari.py b/config/dqn/atari.py
index 9005ba6e..990e0f8a 100644
--- a/config/dqn/atari.py
+++ b/config/dqn/atari.py
@@ -37,7 +37,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/dqn/cartpole.py b/config/dqn/cartpole.py
index f260243e..a60f8da9 100644
--- a/config/dqn/cartpole.py
+++ b/config/dqn/cartpole.py
@@ -29,7 +29,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/dqn/mario.py b/config/dqn/mario.py
index 3153ac05..b905d196 100644
--- a/config/dqn/mario.py
+++ b/config/dqn/mario.py
@@ -34,7 +34,7 @@
"run_step" : 100000000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record" : True,
"record_period" : 200000,
# distributed setting
diff --git a/config/dqn/pong_mlagent.py b/config/dqn/pong_mlagent.py
index 5d752283..5eb320e1 100644
--- a/config/dqn/pong_mlagent.py
+++ b/config/dqn/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/dqn/procgen.py b/config/dqn/procgen.py
index 1387d9d2..41cf733f 100644
--- a/config/dqn/procgen.py
+++ b/config/dqn/procgen.py
@@ -34,7 +34,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/dueling/atari.py b/config/dueling/atari.py
index 882ae9e5..b9245968 100644
--- a/config/dueling/atari.py
+++ b/config/dueling/atari.py
@@ -37,7 +37,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/dueling/cartpole.py b/config/dueling/cartpole.py
index b0bcc06a..a619f889 100644
--- a/config/dueling/cartpole.py
+++ b/config/dueling/cartpole.py
@@ -30,7 +30,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/dueling/pong_mlagent.py b/config/dueling/pong_mlagent.py
index 6bdddb6f..976a016b 100644
--- a/config/dueling/pong_mlagent.py
+++ b/config/dueling/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/dueling/procgen.py b/config/dueling/procgen.py
index 33d8c189..ef6beb02 100644
--- a/config/dueling/procgen.py
+++ b/config/dueling/procgen.py
@@ -34,7 +34,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/icm_ppo/atari.py b/config/icm_ppo/atari.py
index 9d05b85e..7684d440 100644
--- a/config/icm_ppo/atari.py
+++ b/config/icm_ppo/atari.py
@@ -45,7 +45,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/icm_ppo/cartpole.py b/config/icm_ppo/cartpole.py
index 6bbe5410..886cde41 100644
--- a/config/icm_ppo/cartpole.py
+++ b/config/icm_ppo/cartpole.py
@@ -38,7 +38,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/icm_ppo/mario.py b/config/icm_ppo/mario.py
index 291a270a..b7296614 100644
--- a/config/icm_ppo/mario.py
+++ b/config/icm_ppo/mario.py
@@ -49,7 +49,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 500000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record": True,
"record_period": 500000,
# distributed setting
diff --git a/config/icm_ppo/procgen.py b/config/icm_ppo/procgen.py
index 2514ab75..6114b5e6 100644
--- a/config/icm_ppo/procgen.py
+++ b/config/icm_ppo/procgen.py
@@ -41,7 +41,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/iqn/atari.py b/config/iqn/atari.py
index f0c4f450..aa98f8a3 100644
--- a/config/iqn/atari.py
+++ b/config/iqn/atari.py
@@ -43,7 +43,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/iqn/cartpole.py b/config/iqn/cartpole.py
index ec540290..96d3194a 100644
--- a/config/iqn/cartpole.py
+++ b/config/iqn/cartpole.py
@@ -36,7 +36,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/iqn/mario.py b/config/iqn/mario.py
index 2772e1ce..d3b16eef 100644
--- a/config/iqn/mario.py
+++ b/config/iqn/mario.py
@@ -37,7 +37,7 @@
"run_step" : 100000000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/iqn/pong_mlagent.py b/config/iqn/pong_mlagent.py
index ecf5b68d..2e083251 100644
--- a/config/iqn/pong_mlagent.py
+++ b/config/iqn/pong_mlagent.py
@@ -35,7 +35,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/iqn/procgen.py b/config/iqn/procgen.py
index 125e0acb..517240eb 100644
--- a/config/iqn/procgen.py
+++ b/config/iqn/procgen.py
@@ -40,7 +40,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/mpo/atari.py b/config/mpo/atari.py
index 354dd88f..b363cb52 100644
--- a/config/mpo/atari.py
+++ b/config/mpo/atari.py
@@ -51,7 +51,7 @@
"run_step" : 30000000,
"print_period" : 1000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/mpo/cartpole.py b/config/mpo/cartpole.py
index fbb3beae..fa102f51 100644
--- a/config/mpo/cartpole.py
+++ b/config/mpo/cartpole.py
@@ -43,7 +43,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 128,
diff --git a/config/mpo/hopper_mlagent.py b/config/mpo/hopper_mlagent.py
index d1c2efcc..11e31681 100644
--- a/config/mpo/hopper_mlagent.py
+++ b/config/mpo/hopper_mlagent.py
@@ -42,7 +42,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration" : 10,
+ "eval_iteration" : 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 128,
diff --git a/config/mpo/mountaincar.py b/config/mpo/mountaincar.py
index 409ac4c1..0b41c9db 100644
--- a/config/mpo/mountaincar.py
+++ b/config/mpo/mountaincar.py
@@ -42,7 +42,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 128,
diff --git a/config/mpo/pendulum.py b/config/mpo/pendulum.py
index 83891ecb..17bc4986 100644
--- a/config/mpo/pendulum.py
+++ b/config/mpo/pendulum.py
@@ -42,7 +42,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 128,
diff --git a/config/mpo/pong_mlagent.py b/config/mpo/pong_mlagent.py
index 2adb60f5..1f197fc1 100644
--- a/config/mpo/pong_mlagent.py
+++ b/config/mpo/pong_mlagent.py
@@ -42,7 +42,7 @@
"run_step" : 200000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : 128,
diff --git a/config/mpo/procgen.py b/config/mpo/procgen.py
index e69038f6..d0fba2d4 100644
--- a/config/mpo/procgen.py
+++ b/config/mpo/procgen.py
@@ -48,7 +48,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/multistep/atari.py b/config/multistep/atari.py
index e5f72da2..49aa113a 100644
--- a/config/multistep/atari.py
+++ b/config/multistep/atari.py
@@ -38,7 +38,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/multistep/cartpole.py b/config/multistep/cartpole.py
index 93367efb..dc6fcfa7 100644
--- a/config/multistep/cartpole.py
+++ b/config/multistep/cartpole.py
@@ -30,7 +30,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 8,
"num_workers" : 8,
diff --git a/config/multistep/pong_mlagent.py b/config/multistep/pong_mlagent.py
index d64849f2..facec30e 100644
--- a/config/multistep/pong_mlagent.py
+++ b/config/multistep/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/multistep/procgen.py b/config/multistep/procgen.py
index 6e95c856..67e373fb 100644
--- a/config/multistep/procgen.py
+++ b/config/multistep/procgen.py
@@ -35,7 +35,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/noisy/atari.py b/config/noisy/atari.py
index c1a15525..02cb39be 100644
--- a/config/noisy/atari.py
+++ b/config/noisy/atari.py
@@ -37,7 +37,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/noisy/cartpole.py b/config/noisy/cartpole.py
index a2e323dc..636ca36f 100644
--- a/config/noisy/cartpole.py
+++ b/config/noisy/cartpole.py
@@ -30,7 +30,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : 32,
"num_workers" : 8,
diff --git a/config/noisy/mario.py b/config/noisy/mario.py
index 2385460d..d5dd5e1a 100644
--- a/config/noisy/mario.py
+++ b/config/noisy/mario.py
@@ -34,7 +34,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/noisy/pong_mlagent.py b/config/noisy/pong_mlagent.py
index e44e32ff..4b943d1e 100644
--- a/config/noisy/pong_mlagent.py
+++ b/config/noisy/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/noisy/procgen.py b/config/noisy/procgen.py
index eb03fe2d..9b2b993b 100644
--- a/config/noisy/procgen.py
+++ b/config/noisy/procgen.py
@@ -34,7 +34,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/per/atari.py b/config/per/atari.py
index 53024475..4825c674 100644
--- a/config/per/atari.py
+++ b/config/per/atari.py
@@ -41,7 +41,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/per/cartpole.py b/config/per/cartpole.py
index d2e438da..04f0264e 100644
--- a/config/per/cartpole.py
+++ b/config/per/cartpole.py
@@ -34,7 +34,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"update_period" : agent["learn_period"],
"num_workers" : 8,
diff --git a/config/per/pong_mlagent.py b/config/per/pong_mlagent.py
index 4e4db21c..808defb3 100644
--- a/config/per/pong_mlagent.py
+++ b/config/per/pong_mlagent.py
@@ -33,7 +33,7 @@
"run_step" : 200000,
"print_period" : 2000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : agent["learn_period"],
"num_workers" : 16,
diff --git a/config/per/procgen.py b/config/per/procgen.py
index ec167ae0..60e361f8 100644
--- a/config/per/procgen.py
+++ b/config/per/procgen.py
@@ -38,7 +38,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/ppo/atari.py b/config/ppo/atari.py
index 021f507b..0d4aadc8 100644
--- a/config/ppo/atari.py
+++ b/config/ppo/atari.py
@@ -38,7 +38,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/ppo/cartpole.py b/config/ppo/cartpole.py
index be370cef..7a04c911 100644
--- a/config/ppo/cartpole.py
+++ b/config/ppo/cartpole.py
@@ -31,7 +31,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/ppo/hopper_mlagent.py b/config/ppo/hopper_mlagent.py
index 7c0acde7..30bf99c2 100644
--- a/config/ppo/hopper_mlagent.py
+++ b/config/ppo/hopper_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 300000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/ppo/mario.py b/config/ppo/mario.py
index 28a28464..77a9c570 100644
--- a/config/ppo/mario.py
+++ b/config/ppo/mario.py
@@ -39,7 +39,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record" : True,
"record_period" : 250000,
# distributed setting
diff --git a/config/ppo/pendulum.py b/config/ppo/pendulum.py
index ce1029ec..8739ef25 100644
--- a/config/ppo/pendulum.py
+++ b/config/ppo/pendulum.py
@@ -30,7 +30,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/ppo/pong_mlagent.py b/config/ppo/pong_mlagent.py
index 0b215d34..9d8d608a 100644
--- a/config/ppo/pong_mlagent.py
+++ b/config/ppo/pong_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 200000,
"print_period" : 1000,
"save_period" : 50000,
- "test_iteration": 5,
+ "eval_iteration": 5,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/ppo/procgen.py b/config/ppo/procgen.py
index 152e76d8..9597b629 100644
--- a/config/ppo/procgen.py
+++ b/config/ppo/procgen.py
@@ -35,7 +35,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/ppo/worm_mlagent.py b/config/ppo/worm_mlagent.py
index 47668056..a287e198 100644
--- a/config/ppo/worm_mlagent.py
+++ b/config/ppo/worm_mlagent.py
@@ -29,7 +29,7 @@
"run_step" : 300000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/qrdqn/atari.py b/config/qrdqn/atari.py
index e4bb46e9..20887056 100644
--- a/config/qrdqn/atari.py
+++ b/config/qrdqn/atari.py
@@ -39,7 +39,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/qrdqn/cartpole.py b/config/qrdqn/cartpole.py
index d90c7a56..0f318a4a 100644
--- a/config/qrdqn/cartpole.py
+++ b/config/qrdqn/cartpole.py
@@ -32,5 +32,5 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
}
diff --git a/config/qrdqn/pong_mlagent.py b/config/qrdqn/pong_mlagent.py
index 1f863fd6..71ec88ec 100644
--- a/config/qrdqn/pong_mlagent.py
+++ b/config/qrdqn/pong_mlagent.py
@@ -31,7 +31,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/qrdqn/procgen.py b/config/qrdqn/procgen.py
index 53b80b7f..32636f1f 100644
--- a/config/qrdqn/procgen.py
+++ b/config/qrdqn/procgen.py
@@ -39,7 +39,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/r2d2/atari.py b/config/r2d2/atari.py
index 96c79f4e..849a5ec2 100644
--- a/config/r2d2/atari.py
+++ b/config/r2d2/atari.py
@@ -46,7 +46,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/r2d2/cartpole.py b/config/r2d2/cartpole.py
index 242f9f5c..c7bde4b6 100644
--- a/config/r2d2/cartpole.py
+++ b/config/r2d2/cartpole.py
@@ -40,7 +40,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : 16,
diff --git a/config/r2d2/pong_mlagent.py b/config/r2d2/pong_mlagent.py
index 828527b0..fa242f20 100644
--- a/config/r2d2/pong_mlagent.py
+++ b/config/r2d2/pong_mlagent.py
@@ -38,7 +38,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : 16,
diff --git a/config/r2d2/procgen.py b/config/r2d2/procgen.py
index 54657eae..d732b4b6 100644
--- a/config/r2d2/procgen.py
+++ b/config/r2d2/procgen.py
@@ -41,7 +41,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow/atari.py b/config/rainbow/atari.py
index 08b7f744..43a8865f 100644
--- a/config/rainbow/atari.py
+++ b/config/rainbow/atari.py
@@ -48,7 +48,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow/cartpole.py b/config/rainbow/cartpole.py
index b1449ae0..468cf1e8 100644
--- a/config/rainbow/cartpole.py
+++ b/config/rainbow/cartpole.py
@@ -42,7 +42,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 8,
diff --git a/config/rainbow/mario.py b/config/rainbow/mario.py
index 7a7e8e32..20f2493d 100644
--- a/config/rainbow/mario.py
+++ b/config/rainbow/mario.py
@@ -47,7 +47,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow/pong_mlagent.py b/config/rainbow/pong_mlagent.py
index f7f17126..a10dc433 100644
--- a/config/rainbow/pong_mlagent.py
+++ b/config/rainbow/pong_mlagent.py
@@ -40,7 +40,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/rainbow/procgen.py b/config/rainbow/procgen.py
index 88bbc766..ff11607f 100644
--- a/config/rainbow/procgen.py
+++ b/config/rainbow/procgen.py
@@ -45,7 +45,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow_iqn/atari.py b/config/rainbow_iqn/atari.py
index d156c516..a800686a 100644
--- a/config/rainbow_iqn/atari.py
+++ b/config/rainbow_iqn/atari.py
@@ -50,7 +50,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow_iqn/cartpole.py b/config/rainbow_iqn/cartpole.py
index 01442bd6..cb8813ac 100644
--- a/config/rainbow_iqn/cartpole.py
+++ b/config/rainbow_iqn/cartpole.py
@@ -44,7 +44,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 8,
diff --git a/config/rainbow_iqn/mario.py b/config/rainbow_iqn/mario.py
index 64a5c999..88e38546 100644
--- a/config/rainbow_iqn/mario.py
+++ b/config/rainbow_iqn/mario.py
@@ -49,7 +49,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/rainbow_iqn/pong_mlagent.py b/config/rainbow_iqn/pong_mlagent.py
index 2302890a..8223cc4d 100644
--- a/config/rainbow_iqn/pong_mlagent.py
+++ b/config/rainbow_iqn/pong_mlagent.py
@@ -42,7 +42,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period" : 8,
"num_workers" : 16,
diff --git a/config/rainbow_iqn/procgen.py b/config/rainbow_iqn/procgen.py
index 69c35c25..5fb4ccbb 100644
--- a/config/rainbow_iqn/procgen.py
+++ b/config/rainbow_iqn/procgen.py
@@ -47,7 +47,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/reinforce/cartpole.py b/config/reinforce/cartpole.py
index 145d3392..6f563e17 100644
--- a/config/reinforce/cartpole.py
+++ b/config/reinforce/cartpole.py
@@ -24,5 +24,5 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
}
\ No newline at end of file
diff --git a/config/rnd_ppo/atari.py b/config/rnd_ppo/atari.py
index cd572718..8ae07e35 100644
--- a/config/rnd_ppo/atari.py
+++ b/config/rnd_ppo/atari.py
@@ -47,7 +47,7 @@
"run_step" : 100000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record" : True,
"record_period" : 1000000,
# distributed setting
diff --git a/config/rnd_ppo/cartpole.py b/config/rnd_ppo/cartpole.py
index 2007eba9..2e52a19e 100644
--- a/config/rnd_ppo/cartpole.py
+++ b/config/rnd_ppo/cartpole.py
@@ -39,7 +39,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : agent["n_step"],
diff --git a/config/rnd_ppo/mario.py b/config/rnd_ppo/mario.py
index 32223737..0347262a 100644
--- a/config/rnd_ppo/mario.py
+++ b/config/rnd_ppo/mario.py
@@ -47,7 +47,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record" : True,
"record_period" : 250000,
# distributed setting
diff --git a/config/rnd_ppo/pong_mlagent.py b/config/rnd_ppo/pong_mlagent.py
index 609d2933..aca807ab 100644
--- a/config/rnd_ppo/pong_mlagent.py
+++ b/config/rnd_ppo/pong_mlagent.py
@@ -38,7 +38,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 512,
"update_period" : agent["n_step"],
diff --git a/config/rnd_ppo/procgen.py b/config/rnd_ppo/procgen.py
index acb67e90..246b73a6 100644
--- a/config/rnd_ppo/procgen.py
+++ b/config/rnd_ppo/procgen.py
@@ -43,7 +43,7 @@
"run_step" : 100000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 1,
+ "eval_iteration": 1,
"record" : True,
"record_period" : 1000000,
# distributed setting
diff --git a/config/sac/cartpole.py b/config/sac/cartpole.py
index 1b3d2d8a..424b4ba9 100644
--- a/config/sac/cartpole.py
+++ b/config/sac/cartpole.py
@@ -34,7 +34,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period": 32,
"num_workers": 8,
diff --git a/config/sac/hopper_mlagent.py b/config/sac/hopper_mlagent.py
index 5a84b3a8..a1074451 100644
--- a/config/sac/hopper_mlagent.py
+++ b/config/sac/hopper_mlagent.py
@@ -33,7 +33,7 @@
"run_step" : 1000000,
"print_period" : 10000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period": 128,
"num_workers": 16,
diff --git a/config/sac/pendulum.py b/config/sac/pendulum.py
index edc34c7d..c615e6ca 100644
--- a/config/sac/pendulum.py
+++ b/config/sac/pendulum.py
@@ -33,7 +33,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"update_period": 32,
"num_workers": 8,
diff --git a/config/vmpo/atari.py b/config/vmpo/atari.py
index 47666667..6dc5a024 100644
--- a/config/vmpo/atari.py
+++ b/config/vmpo/atari.py
@@ -47,7 +47,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/config/vmpo/cartpole.py b/config/vmpo/cartpole.py
index 9f5d350f..04182c55 100644
--- a/config/vmpo/cartpole.py
+++ b/config/vmpo/cartpole.py
@@ -39,7 +39,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/vmpo/hopper_mlagent.py b/config/vmpo/hopper_mlagent.py
index adae57c7..0d63a15d 100644
--- a/config/vmpo/hopper_mlagent.py
+++ b/config/vmpo/hopper_mlagent.py
@@ -38,7 +38,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/vmpo/mountaincar.py b/config/vmpo/mountaincar.py
index de77447b..68c95b67 100644
--- a/config/vmpo/mountaincar.py
+++ b/config/vmpo/mountaincar.py
@@ -38,7 +38,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/vmpo/pendulum.py b/config/vmpo/pendulum.py
index 7e518d6f..489f68ce 100644
--- a/config/vmpo/pendulum.py
+++ b/config/vmpo/pendulum.py
@@ -38,7 +38,7 @@
"run_step" : 100000,
"print_period" : 1000,
"save_period" : 10000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/vmpo/pong_mlagent.py b/config/vmpo/pong_mlagent.py
index ca5dcd72..884a6d8b 100644
--- a/config/vmpo/pong_mlagent.py
+++ b/config/vmpo/pong_mlagent.py
@@ -38,7 +38,7 @@
"run_step" : 200000,
"print_period" : 5000,
"save_period" : 50000,
- "test_iteration": 10,
+ "eval_iteration": 10,
# distributed setting
"distributed_batch_size" : 256,
"update_period" : agent["n_step"],
diff --git a/config/vmpo/procgen.py b/config/vmpo/procgen.py
index 7d587f5d..e5859058 100644
--- a/config/vmpo/procgen.py
+++ b/config/vmpo/procgen.py
@@ -44,7 +44,7 @@
"run_step" : 30000000,
"print_period" : 10000,
"save_period" : 100000,
- "test_iteration": 5,
+ "eval_iteration": 5,
"record" : True,
"record_period" : 300000,
# distributed setting
diff --git a/core/agent/README.md b/core/agent/README.md
new file mode 100644
index 00000000..9c56d6c5
--- /dev/null
+++ b/core/agent/README.md
@@ -0,0 +1,28 @@
+# How to customize agent
+
+## 1. Inherit BaseAgent class.
+- If you want to add a new agent without inheriting the provided agents, you must inherit the base agent.
+
+reference: [dqn.py](./dqn.py), [reinforce.py](./reinforce.py), [sac.py](./sac.py), ...
+
+## 2. Implement abstract methods.
+- Abstract methods(__act__, __learn__, __process__, __save__, __load__) should be implemented. Implement these methods by referring to the comments.
+- When implementing a __process__, it is easy to manage events using __time_t__, __delta_t__, __event_period__, and __event_stamp__.
+ - __time_t__ means the timestamp that the agent interacted with.
+ - __delta_t__ means the difference between the new timestamp received when the __process__ is executed and the previous __time_t__.
+ - __event_period__ means the period in which the event should be executed.
+ - __event_stamp__ is added to __delta_t__ each time the __process__ is executed, and if it is greater than or equal to __event_period__, a specific event is fired.
+
+
+reference: [dqn.py](./dqn.py), ...
+
+## 3. If necessary, implement another method.
+- __sync_in__, __sync_out__ methods are implemented base on __self.network__. if don't use self.network(e.g. self.actor) in agent class, should override this method.
+
+reference: [ddpg.py](./ddpg.py), [sac.py](./sac.py), ...
+
+- Override __set_distributed__ if you need additional work on the workers initialization.
+- Override __interact_callback__ if you need additional work after interact(agent.act and env.step).
+
+reference: [ape_x.py](./ape_x.py), ...
+
diff --git a/core/agent/__init__.py b/core/agent/__init__.py
index b224c723..0ac2ba65 100644
--- a/core/agent/__init__.py
+++ b/core/agent/__init__.py
@@ -7,21 +7,21 @@
file_list = os.listdir(working_path)
module_list = [file.replace(".py", "") for file in file_list
if file.endswith(".py") and file.replace(".py","") not in ["__init__", "base", "utils"]]
-class_dict = {}
+agent_dict = {}
+naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
for module_name in module_list:
module_path = f"{__name__}.{module_name}"
module = __import__(module_path, fromlist=[None])
for class_name, _class in inspect.getmembers(module, inspect.isclass):
if module_path in str(_class):
- naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
- class_dict[naming_rule(class_name)] = _class
+ agent_dict[naming_rule(class_name)] = _class
-class_dict = OrderedDict(sorted(class_dict.items()))
-with open(os.path.join(working_path, "_class_dict.txt"), 'w') as f:
- f.write('### Class Dictionary ###\n')
+agent_dict = OrderedDict(sorted(agent_dict.items()))
+with open(os.path.join(working_path, "_agent_dict.txt"), 'w') as f:
+ f.write('### Agent Dictionary ###\n')
f.write('format: (key, class)\n')
f.write('------------------------\n')
- for item in class_dict.items():
+ for item in agent_dict.items():
f.write(str(item) + '\n')
class Agent:
@@ -31,7 +31,7 @@ def __new__(self, name, *args, **kwargs):
print("### name variable must be string! ###")
raise Exception
name = name.lower()
- if not name in class_dict.keys():
- print(f"### can use only follows {[opt for opt in class_dict.keys()]}")
+ if not name in agent_dict.keys():
+ print(f"### can use only follows {[opt for opt in agent_dict.keys()]}")
raise Exception
- return class_dict[name](*args, **kwargs)
+ return agent_dict[name](*args, **kwargs)
diff --git a/core/agent/_class_dict.txt b/core/agent/_agent_dict.txt
similarity index 93%
rename from core/agent/_class_dict.txt
rename to core/agent/_agent_dict.txt
index 1590d350..c1942bb4 100644
--- a/core/agent/_class_dict.txt
+++ b/core/agent/_agent_dict.txt
@@ -1,4 +1,4 @@
-### Class Dictionary ###
+### Agent Dictionary ###
format: (key, class)
------------------------
('ape_x', )
@@ -12,7 +12,6 @@ format: (key, class)
('mpo', )
('multistep', )
('noisy', )
-('ou_noise', )
('per', )
('ppo', )
('qrdqn', )
diff --git a/core/agent/ape_x.py b/core/agent/ape_x.py
index 4cdaab7c..d256d6e2 100644
--- a/core/agent/ape_x.py
+++ b/core/agent/ape_x.py
@@ -1,12 +1,8 @@
from collections import deque
import torch
torch.backends.cudnn.benchmark = True
-import torch.nn.functional as F
import numpy as np
-import copy
-from core.network import Network
-from core.optimizer import Optimizer
from core.buffer import PERBuffer
from .dqn import DQN
@@ -76,7 +72,6 @@ def learn(self):
max_Q = torch.max(q).item()
next_q = self.network(next_state)
max_a = torch.argmax(next_q, axis=1)
- max_eye = torch.eye(self.action_size).to(self.device)
max_one_hot_action = eye[max_a.view(-1).long()]
next_target_q = self.target_network(next_state)
diff --git a/core/agent/ddpg.py b/core/agent/ddpg.py
index cf9675b8..037f0882 100644
--- a/core/agent/ddpg.py
+++ b/core/agent/ddpg.py
@@ -1,33 +1,13 @@
import torch
torch.backends.cudnn.benchmark = True
import torch.nn.functional as F
-from torch.distributions import Normal
import os
-import numpy as np
from core.network import Network
from core.optimizer import Optimizer
from core.buffer import ReplayBuffer
from .base import BaseAgent
-
-# OU noise class
-class OU_noise:
- def __init__(self, action_size, mu, theta, sigma):
- self.action_size = action_size
-
- self.mu = mu
- self.theta = theta
- self.sigma = sigma
-
- self.reset()
-
- def reset(self):
- self.X = np.ones((1, self.action_size), dtype=np.float32) * self.mu
-
- def sample(self):
- dx = self.theta * (self.mu - self.X) + self.sigma * np.random.randn(len(self.X))
- self.X = self.X + dx
- return self.X
+from .utils import OU_Noise
class DDPG(BaseAgent):
def __init__(self,
@@ -62,7 +42,7 @@ def __init__(self,
self.actor_optimizer = Optimizer(optim_config.actor, self.actor.parameters(), lr=optim_config.actor_lr)
self.critic_optimizer = Optimizer(optim_config.critic, self.critic.parameters(), lr=optim_config.critic_lr)
- self.OU = OU_noise(action_size, mu, theta, sigma)
+ self.OU = OU_Noise(action_size, mu, theta, sigma)
self.gamma = gamma
self.tau = tau
diff --git a/core/agent/dqn.py b/core/agent/dqn.py
index d2452b8b..c173e403 100644
--- a/core/agent/dqn.py
+++ b/core/agent/dqn.py
@@ -3,7 +3,6 @@
import torch.nn.functional as F
import numpy as np
import os
-from collections import OrderedDict
from core.network import Network
from core.optimizer import Optimizer
diff --git a/core/agent/mpo.py b/core/agent/mpo.py
index 5e7b0170..9cb39083 100644
--- a/core/agent/mpo.py
+++ b/core/agent/mpo.py
@@ -1,9 +1,9 @@
from collections import deque
-import os, copy
+import os
import numpy as np
import torch
import torch.nn.functional as F
-from torch.distributions import Normal, Categorical
+from torch.distributions import Normal
from .base import BaseAgent
from core.network import Network
diff --git a/core/agent/noisy.py b/core/agent/noisy.py
index 09083921..62dae65d 100644
--- a/core/agent/noisy.py
+++ b/core/agent/noisy.py
@@ -2,11 +2,9 @@
torch.backends.cudnn.benchmark = True
import torch.nn.functional as F
import numpy as np
-import copy
from core.network import Network
from core.optimizer import Optimizer
-from core.buffer import ReplayBuffer
from .dqn import DQN
class Noisy(DQN):
diff --git a/core/agent/per.py b/core/agent/per.py
index 8e9f9329..f0121186 100644
--- a/core/agent/per.py
+++ b/core/agent/per.py
@@ -1,6 +1,5 @@
import torch
torch.backends.cudnn.benchmark = True
-import torch.nn.functional as F
from .dqn import DQN
from core.buffer import PERBuffer
diff --git a/core/agent/r2d2.py b/core/agent/r2d2.py
index 79e2976a..57afb38c 100644
--- a/core/agent/r2d2.py
+++ b/core/agent/r2d2.py
@@ -2,13 +2,8 @@
from itertools import islice
import torch
torch.backends.cudnn.benchmark = True
-import torch.nn.functional as F
import numpy as np
-import copy
-from core.network import Network
-from core.optimizer import Optimizer
-from core.buffer import PERBuffer
from .ape_x import ApeX
class R2D2(ApeX):
diff --git a/core/agent/rainbow_iqn.py b/core/agent/rainbow_iqn.py
index 3f2b1a83..b7fd801e 100644
--- a/core/agent/rainbow_iqn.py
+++ b/core/agent/rainbow_iqn.py
@@ -3,8 +3,6 @@
torch.backends.cudnn.benchmark = True
import torch.nn.functional as F
import numpy as np
-import copy
-import time
from core.network import Network
from core.optimizer import Optimizer
diff --git a/core/agent/reinforce.py b/core/agent/reinforce.py
index a10ff249..1b2eb139 100644
--- a/core/agent/reinforce.py
+++ b/core/agent/reinforce.py
@@ -1,6 +1,6 @@
import torch
torch.backends.cudnn.benchmark = True
-from torch.distributions import Normal, Categorical
+from torch.distributions import Normal
import numpy as np
import os
diff --git a/core/agent/rnd_ppo.py b/core/agent/rnd_ppo.py
index a930f849..416764b7 100644
--- a/core/agent/rnd_ppo.py
+++ b/core/agent/rnd_ppo.py
@@ -8,8 +8,6 @@
from .ppo import PPO
from core.network import Network
-import torch.optim as optim
-
class RND_PPO(PPO):
def __init__(self,
state_size,
diff --git a/core/agent/utils.py b/core/agent/utils.py
index c733f1bf..3bbd371e 100644
--- a/core/agent/utils.py
+++ b/core/agent/utils.py
@@ -1,6 +1,24 @@
-import torch
import numpy as np
+# OU noise class
+class OU_Noise:
+ def __init__(self, action_size, mu, theta, sigma):
+ self.action_size = action_size
+
+ self.mu = mu
+ self.theta = theta
+ self.sigma = sigma
+
+ self.reset()
+
+ def reset(self):
+ self.X = np.ones((1, self.action_size), dtype=np.float32) * self.mu
+
+ def sample(self):
+ dx = self.theta * (self.mu - self.X) + self.sigma * np.random.randn(len(self.X))
+ self.X = self.X + dx
+ return self.X
+
# d: dictionary of pytorch tensors which you want to check if they are INF of NaN
# vs: dictionary of pytorch tensors, min and max of which you want to check, when any value of d is INF or NaN
def check_explode(d, vs={}):
diff --git a/core/agent/vmpo.py b/core/agent/vmpo.py
index badbd74d..6b9b1d33 100644
--- a/core/agent/vmpo.py
+++ b/core/agent/vmpo.py
@@ -1,10 +1,9 @@
import torch
import torch.nn.functional as F
-from torch.distributions import Normal, Categorical
+from torch.distributions import Normal
import numpy as np
from .reinforce import REINFORCE
-from core.network import Network
from core.optimizer import Optimizer
class VMPO(REINFORCE):
diff --git a/core/buffer/README.md b/core/buffer/README.md
new file mode 100644
index 00000000..438de328
--- /dev/null
+++ b/core/buffer/README.md
@@ -0,0 +1,12 @@
+# How to customize new buffer
+
+## 1. Inherit BaseBuffer class.
+- If you want to add a new buffer without inheriting the provided buffers, you must inherit the base buffer.
+
+reference: [replay_buffer.py](./replay_buffer.py), [rollout_buffer.py](./rollout_buffer.py), ...
+
+## 2. Implement abstract methods.
+- Abstract methods(__store__, __sample__) should be implemented. Implement these methods by referring to the comments.
+- When implementing __store__, it is recommended to check transition data dimension using __check_dim__. to use the __check_dim__, run __super().\_\_init\_\_()__ in the __\_\_init\_\___.
+
+reference: [replay_buffer.py](./replay_buffer.py), [rollout_buffer.py](./rollout_buffer.py), ...
diff --git a/core/buffer/base.py b/core/buffer/base.py
index dc3eb418..7d8ea304 100644
--- a/core/buffer/base.py
+++ b/core/buffer/base.py
@@ -1,6 +1,17 @@
from abc import *
class BaseBuffer(ABC):
+ def __init__(self):
+ self.first_store = True
+
+ def check_dim(self, transition):
+ print("########################################")
+ print("You should check dimension of transition")
+ for key, val in transition.items():
+ print(f"{key}: {val.shape}")
+ print("########################################")
+ self.first_store = False
+
@abstractmethod
def store(self, transitions):
"""
@@ -22,10 +33,3 @@ def sample(self, batch_size):
transitions = [{}]
return transitions
- def check_dim(self, transition):
- print("########################################")
- print("You should check dimension of transition")
- for key, val in transition.items():
- print(f"{key}: {val.shape}")
- print("########################################")
- self.first_store = False
\ No newline at end of file
diff --git a/core/buffer/replay_buffer.py b/core/buffer/replay_buffer.py
index c4a63d81..19957ac6 100644
--- a/core/buffer/replay_buffer.py
+++ b/core/buffer/replay_buffer.py
@@ -6,12 +6,11 @@
class ReplayBuffer(BaseBuffer):
def __init__(self, buffer_size):
+ super(ReplayBuffer, self).__init__()
self.buffer = np.zeros(buffer_size, dtype=dict) # define replay buffer
self.buffer_index = 0
self.buffer_size = buffer_size
self.buffer_counter = 0
-
- self.first_store = True
def store(self, transitions):
if self.first_store:
diff --git a/core/buffer/rollout_buffer.py b/core/buffer/rollout_buffer.py
index e58cec9d..925c27b2 100644
--- a/core/buffer/rollout_buffer.py
+++ b/core/buffer/rollout_buffer.py
@@ -4,8 +4,8 @@
class RolloutBuffer(BaseBuffer):
def __init__(self):
+ super(RolloutBuffer, self).__init__()
self.buffer = list()
- self.first_store = True
def store(self, transitions):
if self.first_store:
diff --git a/core/env/README.md b/core/env/README.md
new file mode 100644
index 00000000..39fdc26a
--- /dev/null
+++ b/core/env/README.md
@@ -0,0 +1,23 @@
+# How to customize environment
+
+## 1. Inherit BaseEnv class.
+- If you want to implement a new environment without inheriting the provided environments, you must inherit the base environment.
+
+reference: [base.py](./base.py), [gym_env.py](./gym_env.py), [atari.py](./atari.py), ...
+
+## 2. Implement abstract methods.
+- Abstract methods(__reset__, __step__, __close__) should be implemented. Implement these methods by referring to the comments.
+- When you implement a __step__ method, you should expand dimension of state, reward, and done from (d) to (1,d) using __expand_dim__.
+
+reference: [gym_env.py](./gym_env.py)
+
+## 3. If necessary, implement another method.
+- __recordable__ indicates the environment can be recorded as a gif. If a newly implemented environment has no visual state, set __recordable__ to return False.
+- If you want to record evaluation episode as a gif file, make sure __recordable__ returns True and set train.record in config file to True.
+
+reference: [atari.py](./atari.py), [procgen.py](./procgen.py), [nes.py](./nes.py)
+
+## When adding open source environment.
+- If you want to add an open source environment, we refer to some provided environments; atari, procgen, and nes environments.
+
+reference: [atari.py](./atari.py), [procgen.py](./procgen.py), [nes.py](./nes.py)
diff --git a/core/env/__init__.py b/core/env/__init__.py
index c71f6469..c8bd3f5d 100644
--- a/core/env/__init__.py
+++ b/core/env/__init__.py
@@ -9,21 +9,21 @@
file_list = os.listdir(working_path)
module_list = [file.replace(".py", "") for file in file_list
if file.endswith(".py") and file.replace(".py","") not in ["__init__", "base", "utils"]]
-class_dict = {}
+env_dict = {}
+naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
for module_name in module_list:
module_path = f"{__name__}.{module_name}"
module = __import__(module_path, fromlist=[None])
for class_name, _class in inspect.getmembers(module, inspect.isclass):
- if module_path in str(_class):
- naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
- class_dict[naming_rule(class_name)] = _class
+ if module_path in str(_class) and '_' != class_name[0]:
+ env_dict[naming_rule(class_name)] = _class
-class_dict = OrderedDict(sorted(class_dict.items()))
-with open(os.path.join(working_path, "_class_dict.txt"), 'w') as f:
- f.write('### Class Dictionary ###\n')
+env_dict = OrderedDict(sorted(env_dict.items()))
+with open(os.path.join(working_path, "_env_dict.txt"), 'w') as f:
+ f.write('### Env Dictionary ###\n')
f.write('format: (key, class)\n')
f.write('------------------------\n')
- for item in class_dict.items():
+ for item in env_dict.items():
f.write(str(item) + '\n')
class Env:
@@ -33,7 +33,7 @@ def __new__(self, name, *args, **kwargs):
print("### name variable must be string! ###")
raise Exception
name = name.lower()
- if not name in class_dict.keys():
- print(f"### can use only follows {[opt for opt in class_dict.keys()]}")
+ if not name in env_dict.keys():
+ print(f"### can use only follows {[opt for opt in env_dict.keys()]}")
raise Exception
- return class_dict[name](*args, **kwargs)
\ No newline at end of file
+ return env_dict[name](*args, **kwargs)
\ No newline at end of file
diff --git a/core/env/_class_dict.txt b/core/env/_env_dict.txt
similarity index 88%
rename from core/env/_class_dict.txt
rename to core/env/_env_dict.txt
index f335b25f..2d181c96 100644
--- a/core/env/_class_dict.txt
+++ b/core/env/_env_dict.txt
@@ -1,10 +1,9 @@
-### Class Dictionary ###
+### Env Dictionary ###
format: (key, class)
------------------------
('alien', )
('assault', )
('asterix', )
-('atari', )
('bigfish', )
('bossfight', )
('breakout', )
@@ -17,7 +16,6 @@ format: (key, class)
('dodgeball', )
('enduro', )
('fruitbot', )
-('gym', )
('heist', )
('hopper_mlagent', )
('jumper', )
@@ -25,17 +23,14 @@ format: (key, class)
('mario', )
('maze', )
('miner', )
-('mlagent', )
('montezuma_revenge', )
('mountain_car', )
-('nes', )
('ninja', )
('pendulum', )
('plunder', )
('pong', )
('pong_mlagent', )
('private_eye', )
-('procgen', )
('qbert', )
('seaquest', )
('spaceinvaders', )
diff --git a/core/env/atari.py b/core/env/atari.py
index 73f46a7b..e0c5a29d 100644
--- a/core/env/atari.py
+++ b/core/env/atari.py
@@ -6,7 +6,7 @@
COMMON_VERSION = 'Deterministic-v4'
-class Atari(BaseEnv):
+class _Atari(BaseEnv):
def __init__(self,
name,
render=False,
@@ -106,50 +106,50 @@ def recordable(self):
def get_frame(self):
return self.env.ale.getScreenRGB2()
-class Breakout(Atari):
+class Breakout(_Atari):
def __init__(self, **kwargs):
super(Breakout, self).__init__(f"Breakout{COMMON_VERSION}", **kwargs)
-class Pong(Atari):
+class Pong(_Atari):
def __init__(self, **kwargs):
super(Pong, self).__init__(f"Pong{COMMON_VERSION}", **kwargs)
-class Asterix(Atari):
+class Asterix(_Atari):
def __init__(self, **kwargs):
super(Asterix, self).__init__(f"Asterix{COMMON_VERSION}", **kwargs)
-class Assault(Atari):
+class Assault(_Atari):
def __init__(self, **kwargs):
super(Assault, self).__init__(f"Assualt{COMMON_VERSION}", **kwargs)
-class Seaquest(Atari):
+class Seaquest(_Atari):
def __init__(self, **kwargs):
super(Seaquest, self).__init__(f"Seaquest{COMMON_VERSION}", **kwargs)
-class Spaceinvaders(Atari):
+class Spaceinvaders(_Atari):
def __init__(self, **kwargs):
super(Spaceinvaders, self).__init__(f"SpaceInvaders{COMMON_VERSION}", **kwargs)
-class Alien(Atari):
+class Alien(_Atari):
def __init__(self, **kwargs):
super(Alien, self).__init__(f"Alien{COMMON_VERSION}", **kwargs)
-class CrazyClimber(Atari):
+class CrazyClimber(_Atari):
def __init__(self, **kwargs):
super(CrazyClimber, self).__init__(f"CrazyClimber{COMMON_VERSION}", **kwargs)
-class Enduro(Atari):
+class Enduro(_Atari):
def __init__(self, **kwargs):
super(Enduro, self).__init__(f"Enduro{COMMON_VERSION}", **kwargs)
-class Qbert(Atari):
+class Qbert(_Atari):
def __init__(self, **kwargs):
super(Qbert, self).__init__(f"Qbert{COMMON_VERSION}", **kwargs)
-class PrivateEye(Atari):
+class PrivateEye(_Atari):
def __init__(self, **kwargs):
super(PrivateEye, self).__init__(f"PrivateEye{COMMON_VERSION}", **kwargs)
-class MontezumaRevenge(Atari):
+class MontezumaRevenge(_Atari):
def __init__(self, **kwargs):
super(MontezumaRevenge, self).__init__(f"MontezumaRevenge{COMMON_VERSION}", **kwargs)
diff --git a/core/env/gym_env.py b/core/env/gym_env.py
index 03b4f7fa..e10b040b 100644
--- a/core/env/gym_env.py
+++ b/core/env/gym_env.py
@@ -2,7 +2,7 @@
import numpy as np
from .base import BaseEnv
-class Gym(BaseEnv):
+class _Gym(BaseEnv):
def __init__(self,
name,
mode,
@@ -42,7 +42,7 @@ def step(self, action):
def close(self):
self.env.close()
-class Cartpole(Gym):
+class Cartpole(_Gym):
def __init__(self,
mode='discrete',
**kwargs):
@@ -65,13 +65,13 @@ def step(self, action):
next_state, reward, done = map(lambda x: np.expand_dims(x, 0), [next_state, [reward], [done]]) # for (1, ?)
return (next_state, reward, done)
-class Pendulum(Gym):
+class Pendulum(_Gym):
def __init__(self,
**kwargs
):
super(Pendulum, self).__init__('Pendulum-v0', 'continuous', **kwargs)
-class MountainCar(Gym):
+class MountainCar(_Gym):
def __init__(self,
**kwargs
):
diff --git a/core/env/ml_agent.py b/core/env/ml_agent.py
index 7231e388..6db9316f 100644
--- a/core/env/ml_agent.py
+++ b/core/env/ml_agent.py
@@ -13,7 +13,7 @@ def match_build():
return {"Windows": "Windows",
"Darwin" : "Mac"}[os]
-class MLAgent(BaseEnv):
+class _MLAgent(BaseEnv):
def __init__(self, env_name, train_mode=True, id=None):
env_path = f"./core/env/mlagents/{env_name}/{match_build()}/{env_name}"
id = np.random.randint(65534) if id is None else id
@@ -69,7 +69,7 @@ def step(self, action):
def close(self):
self.env.close()
-class HopperMLAgent(MLAgent):
+class HopperMLAgent(_MLAgent):
def __init__(self, **kwargs):
env_name = "Hopper"
super(HopperMLAgent, self).__init__(env_name, **kwargs)
@@ -77,7 +77,7 @@ def __init__(self, **kwargs):
self.state_size = 19*4
self.action_size = 3
-class PongMLAgent(MLAgent):
+class PongMLAgent(_MLAgent):
def __init__(self, **kwargs):
env_name = "Pong"
super(PongMLAgent, self).__init__(env_name, **kwargs)
@@ -85,7 +85,7 @@ def __init__(self, **kwargs):
self.state_size = 8*1
self.action_size = 3
-class WormMLAgent(MLAgent):
+class WormMLAgent(_MLAgent):
def __init__(self, **kwargs):
env_name = "Worm"
super(WormMLAgent, self).__init__(env_name, **kwargs)
diff --git a/core/env/nes.py b/core/env/nes.py
index 3cdb24e2..d1149bbe 100644
--- a/core/env/nes.py
+++ b/core/env/nes.py
@@ -5,11 +5,11 @@
from gym_super_mario_bros.actions import RIGHT_ONLY
import numpy as np
-from .atari import Atari
+from .atari import _Atari
-class Nes(Atari):
+class _Nes(_Atari):
def __init__(self, name, **kwargs):
- super(Nes, self).__init__(name=name, life_key='life', **kwargs)
+ super(_Nes, self).__init__(name=name, life_key='life', **kwargs)
self.env = JoypadSpace(self.env, RIGHT_ONLY)
print(f"action size changed: {self.action_size} -> {self.env.action_space.n}")
self.action_size = self.env.action_space.n
@@ -17,7 +17,7 @@ def __init__(self, name, **kwargs):
def get_frame(self):
return np.copy(self.env.screen)
-class Mario(Nes):
+class Mario(_Nes):
def __init__(self, **kwargs):
reward_scale = 15.
super(Mario, self).__init__('SuperMarioBros-v0', reward_scale=reward_scale, **kwargs)
diff --git a/core/env/procgen.py b/core/env/procgen.py
index 39f0456f..acc9546b 100644
--- a/core/env/procgen.py
+++ b/core/env/procgen.py
@@ -5,7 +5,7 @@
from .utils import ImgProcessor
from .base import BaseEnv
-class Procgen(BaseEnv):
+class _Procgen(BaseEnv):
def __init__(self,
name,
render=False,
@@ -81,66 +81,66 @@ def get_frame(self):
raw_image = self.env.render(mode="rgb_array")
return cv2.resize(raw_image, dsize=(256, 256))
-class Coinrun(Procgen):
+class Coinrun(_Procgen):
def __init__(self, **kwargs):
super(Coinrun, self).__init__('coinrun', **kwargs)
-class Bigfish(Procgen):
+class Bigfish(_Procgen):
def __init__(self, **kwargs):
super(Bigfish, self).__init__('bigfish', **kwargs)
-class Bossfight(Procgen):
+class Bossfight(_Procgen):
def __init__(self, **kwargs):
super(Bossfight, self).__init__('bossfight', **kwargs)
-class Caveflyer(Procgen):
+class Caveflyer(_Procgen):
def __init__(self, **kwargs):
super(Caveflyer, self).__init__('caveflyer', **kwargs)
-class Chaser(Procgen):
+class Chaser(_Procgen):
def __init__(self, **kwargs):
super(Chaser, self).__init__('chaser', **kwargs)
-class Climber(Procgen):
+class Climber(_Procgen):
def __init__(self, **kwargs):
super(Climber, self).__init__('climber', **kwargs)
-class Dodgeball(Procgen):
+class Dodgeball(_Procgen):
def __init__(self, **kwargs):
super(Dodgeball, self).__init__('dodgeball', **kwargs)
-class Fruitbot(Procgen):
+class Fruitbot(_Procgen):
def __init__(self, **kwargs):
super(Fruitbot, self).__init__('fruitbot', **kwargs)
-class Heist(Procgen):
+class Heist(_Procgen):
def __init__(self, **kwargs):
super(Heist, self).__init__('heist', **kwargs)
-class Jumper(Procgen):
+class Jumper(_Procgen):
def __init__(self, **kwargs):
super(Jumper, self).__init__('jumper', **kwargs)
-class Leaper(Procgen):
+class Leaper(_Procgen):
def __init__(self, **kwargs):
super(Leaper, self).__init__('leaper', **kwargs)
-class Maze(Procgen):
+class Maze(_Procgen):
def __init__(self, **kwargs):
super(Maze, self).__init__('maze', **kwargs)
-class Miner(Procgen):
+class Miner(_Procgen):
def __init__(self, **kwargs):
super(Miner, self).__init__('miner', **kwargs)
-class Ninja(Procgen):
+class Ninja(_Procgen):
def __init__(self, **kwargs):
super(Ninja, self).__init__('ninja', **kwargs)
-class Plunder(Procgen):
+class Plunder(_Procgen):
def __init__(self, **kwargs):
super(Plunder, self).__init__('plunder', **kwargs)
-class Starpilot(Procgen):
+class Starpilot(_Procgen):
def __init__(self, **kwargs):
super(Starpilot, self).__init__('starpilot', **kwargs)
\ No newline at end of file
diff --git a/core/network/README.md b/core/network/README.md
new file mode 100644
index 00000000..a61ea1f3
--- /dev/null
+++ b/core/network/README.md
@@ -0,0 +1,18 @@
+# How to customize network
+
+## 1. Inherit BaseNetwork class.
+- If you want to add a new network with using head in [head.py](./head.py), you must inherit the base network.
+
+reference: [dqn.py](./dqn.py), [policy_value.py](./policy_value.py), ...
+
+- If not, inherit torch.nn.Module.
+
+reference: [icm.py](./icm.py), [rnd.py](./rnd.py)
+
+## 2. If you inherit BaseNetwork, override all methods.
+- __\_\_init\_\___, __forward__ methods should be overrided.
+- When override __\_\_init\_\___, should consider head class. D_head_out means dimension of the embedded feature through the head.
+- When override __foward__, should pass head network using super().forward(x).
+
+reference: [dqn.py](./dqn.py), [policy_value.py](./policy_value.py), ...
+
diff --git a/core/network/__init__.py b/core/network/__init__.py
index 32fd1878..29fc4625 100644
--- a/core/network/__init__.py
+++ b/core/network/__init__.py
@@ -4,22 +4,22 @@
working_path = os.path.dirname(os.path.realpath(__file__))
file_list = os.listdir(working_path)
module_list = [file.replace(".py", "") for file in file_list
- if file.endswith(".py") and file.replace(".py","") not in ["__init__", "base", "utils"]]
-class_dict = {}
+ if file.endswith(".py") and file.replace(".py","") not in ["__init__", "base", "head", "utils"]]
+network_dict = {}
+naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
for module_name in module_list:
module_path = f"{__name__}.{module_name}"
module = __import__(module_path, fromlist=[None])
for class_name, _class in inspect.getmembers(module, inspect.isclass):
if module_path in str(_class):
- naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
- class_dict[naming_rule(class_name)] = _class
+ network_dict[naming_rule(class_name)] = _class
-class_dict = OrderedDict(sorted(class_dict.items()))
-with open(os.path.join(working_path, "_class_dict.txt"), 'w') as f:
- f.write('### Class Dictionary ###\n')
+network_dict = OrderedDict(sorted(network_dict.items()))
+with open(os.path.join(working_path, "_network_dict.txt"), 'w') as f:
+ f.write('### Network Dictionary ###\n')
f.write('format: (key, class)\n')
f.write('------------------------\n')
- for item in class_dict.items():
+ for item in network_dict.items():
f.write(str(item) + '\n')
class Network:
@@ -29,7 +29,7 @@ def __new__(self, name, *args, **kwargs):
print("### name variable must be string! ###")
raise Exception
name = name.lower()
- if not name in class_dict.keys():
- print(f"### can use only follows {[opt for opt in class_dict.keys()]}")
+ if not name in network_dict.keys():
+ print(f"### can use only follows {[opt for opt in network_dict.keys()]}")
raise Exception
- return class_dict[name](*args, **kwargs)
+ return network_dict[name](*args, **kwargs)
diff --git a/core/network/_head_dict.txt b/core/network/_head_dict.txt
new file mode 100644
index 00000000..56a0eb15
--- /dev/null
+++ b/core/network/_head_dict.txt
@@ -0,0 +1,7 @@
+### Head Dictionary ###
+format: (key, class)
+------------------------
+('cnn', )
+('cnn_lstm', )
+('mlp', )
+('mlp_lstm', )
diff --git a/core/network/_class_dict.txt b/core/network/_network_dict.txt
similarity index 88%
rename from core/network/_class_dict.txt
rename to core/network/_network_dict.txt
index 56e471f3..7d333987 100644
--- a/core/network/_class_dict.txt
+++ b/core/network/_network_dict.txt
@@ -1,4 +1,4 @@
-### Class Dictionary ###
+### Network Dictionary ###
format: (key, class)
------------------------
('continuous_policy', )
@@ -18,8 +18,6 @@ format: (key, class)
('r2d2', )
('rainbow', )
('rainbow_iqn', )
-('reward_forward_filter', )
('rnd_cnn', )
('rnd_mlp', )
-('running_mean_std', )
('sac_critic', )
diff --git a/core/network/base.py b/core/network/base.py
index 3fd6181a..88060371 100644
--- a/core/network/base.py
+++ b/core/network/base.py
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
-from .utils import head_dict
+from .head import head_dict
class BaseNetwork(torch.nn.Module):
def __init__(self, D_in, D_hidden, head):
diff --git a/core/network/head.py b/core/network/head.py
new file mode 100644
index 00000000..ce2f758c
--- /dev/null
+++ b/core/network/head.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn.functional as F
+
+class CNN(torch.nn.Module):
+ def __init__(self, D_in, D_hidden=512):
+ super(CNN, self).__init__()
+
+ self.conv1 = torch.nn.Conv2d(in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4)
+ dim1 = ((D_in[1] - 8)//4 + 1, (D_in[2] - 8)//4 + 1)
+ self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
+ dim2 = ((dim1[0] - 4)//2 + 1, (dim1[1] - 4)//2 + 1)
+ self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
+ dim3 = ((dim2[0] - 3)//1 + 1, (dim2[1] - 3)//1 + 1)
+
+ self.D_head_out = 64*dim3[0]*dim3[1]
+
+ def forward(self, x):
+ x = (x-(255.0/2))/(255.0/2)
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ x = F.relu(self.conv3(x))
+ x = x.view(x.size(0), -1)
+ return x
+
+class MLP(torch.nn.Module):
+ def __init__(self, D_in, D_hidden=512):
+ super(MLP, self).__init__()
+
+ self.l = torch.nn.Linear(D_in, D_hidden)
+ self.D_head_out = D_hidden
+
+ def forward(self, x):
+ x = F.relu(self.l(x))
+ return x
+
+class CNN_LSTM(torch.nn.Module):
+ def __init__(self, D_in, D_hidden=512):
+ super(CNN_LSTM, self).__init__()
+
+ self.conv1 = torch.nn.Conv2d(in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4)
+ dim1 = ((D_in[1] - 8)//4 + 1, (D_in[2] - 8)//4 + 1)
+ self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
+ dim2 = ((dim1[0] - 4)//2 + 1, (dim1[1] - 4)//2 + 1)
+ self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
+ dim3 = ((dim2[0] - 3)//1 + 1, (dim2[1] - 3)//1 + 1)
+
+ self.D_conv_out = 64*dim3[0]*dim3[1]
+
+ self.lstm = torch.nn.LSTM(input_size=self.D_conv_out, hidden_size=D_hidden, batch_first=True)
+
+ self.D_head_out = D_hidden
+
+ def forward(self, x, hidden_in=None):
+ x = (x-(255.0/2))/(255.0/2)
+
+ seq_len = x.size(1)
+
+ if hidden_in is None:
+ hidden_in = (torch.zeros(1, x.size(0), self.D_head_out).to(x.device),
+ torch.zeros(1, x.size(0), self.D_head_out).to(x.device))
+
+ x = x.reshape(-1, *x.shape[2:])
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ x = F.relu(self.conv3(x))
+ x = x.view(-1, seq_len, self.D_conv_out)
+ x, hidden_out = self.lstm(x, hidden_in)
+
+ return x, hidden_in, hidden_out
+
+class MLP_LSTM(torch.nn.Module):
+ def __init__(self, D_in, D_hidden=512):
+ super(MLP_LSTM, self).__init__()
+
+ self.l = torch.nn.Linear(D_in, D_hidden)
+ self.lstm = torch.nn.LSTM(input_size=D_hidden, hidden_size=D_hidden, batch_first=True)
+ self.D_head_out = D_hidden
+
+ def forward(self, x, hidden_in=None):
+ if hidden_in is None:
+ hidden_in = (torch.zeros(1, x.size(0), self.D_head_out).to(x.device),
+ torch.zeros(1, x.size(0), self.D_head_out).to(x.device))
+
+ x = F.relu(self.l(x))
+ x, hidden_out = self.lstm(x, hidden_in)
+
+ return x, hidden_in, hidden_out
+
+import os, sys, inspect, re
+from collections import OrderedDict
+
+working_path = os.path.dirname(os.path.realpath(__file__))
+head_dict = {}
+naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
+for class_name, _class in inspect.getmembers(sys.modules[__name__], inspect.isclass):
+ if __name__ in str(_class):
+ head_dict[naming_rule(class_name)] = _class
+
+head_dict = OrderedDict(sorted(head_dict.items()))
+with open(os.path.join(working_path, "_head_dict.txt"), 'w') as f:
+ f.write('### Head Dictionary ###\n')
+ f.write('format: (key, class)\n')
+ f.write('------------------------\n')
+ for item in head_dict.items():
+ f.write(str(item) + '\n')
\ No newline at end of file
diff --git a/core/network/rnd.py b/core/network/rnd.py
index c1733420..b819241e 100644
--- a/core/network/rnd.py
+++ b/core/network/rnd.py
@@ -1,47 +1,7 @@
import torch
import torch.nn.functional as F
-# codes from https://github.com/openai/random-network-distillation
-class RewardForwardFilter(torch.nn.Module):
- def __init__(self, gamma, num_workers):
- super(RewardForwardFilter, self).__init__()
- self.rewems = torch.nn.Parameter(torch.zeros(num_workers), requires_grad=False)
- self.gamma = gamma
-
- def update(self, rews):
- self.rewems.data = self.rewems * self.gamma + rews
- return self.rewems
-
-# codes modified from https://github.com/openai/random-network-distillation
-class RunningMeanStd(torch.nn.Module):
- def __init__(self, shape, epsilon=1e-4):
- super(RunningMeanStd, self).__init__()
-
- self.mean = torch.nn.Parameter(torch.zeros(shape), requires_grad=False)
- self.var = torch.nn.Parameter(torch.zeros(shape), requires_grad=False)
- self.count = torch.nn.Parameter(torch.tensor(epsilon), requires_grad=False)
-
- def update(self, x):
- batch_mean, batch_std, batch_count = x.mean(axis=0), x.std(axis=0), x.shape[0]
- batch_var = torch.square(batch_std)
- self.update_from_moments(batch_mean, batch_var, batch_count)
-
- def update_from_moments(self, batch_mean, batch_var, batch_count):
- delta = batch_mean - self.mean
- tot_count = self.count + batch_count
-
- new_mean = self.mean + delta * batch_count / tot_count
- m_a = self.var * self.count
- m_b = batch_var * (batch_count)
- M2 = m_a + m_b + torch.square(delta) * self.count * batch_count / (self.count + batch_count)
- new_var = M2 / (self.count + batch_count)
-
- new_count = batch_count + self.count
-
- self.mean.data = new_mean
- self.var.data = new_var
- self.count.data = new_count
-
+from .utils import RewardForwardFilter, RunningMeanStd
# normalize observation
# assumed state shape: (batch_size, dim_state)
def normalize_obs(obs, m, v):
diff --git a/core/network/utils.py b/core/network/utils.py
index d4ae9a6e..83bee5cf 100644
--- a/core/network/utils.py
+++ b/core/network/utils.py
@@ -1,94 +1,44 @@
import torch
import torch.nn.functional as F
-class CNN(torch.nn.Module):
- def __init__(self, D_in, D_hidden=512):
- super(CNN, self).__init__()
-
- self.conv1 = torch.nn.Conv2d(in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4)
- dim1 = ((D_in[1] - 8)//4 + 1, (D_in[2] - 8)//4 + 1)
- self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
- dim2 = ((dim1[0] - 4)//2 + 1, (dim1[1] - 4)//2 + 1)
- self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
- dim3 = ((dim2[0] - 3)//1 + 1, (dim2[1] - 3)//1 + 1)
-
- self.D_head_out = 64*dim3[0]*dim3[1]
-
- def forward(self, x):
- x = (x-(255.0/2))/(255.0/2)
- x = F.relu(self.conv1(x))
- x = F.relu(self.conv2(x))
- x = F.relu(self.conv3(x))
- x = x.view(x.size(0), -1)
- return x
-
-class MLP(torch.nn.Module):
- def __init__(self, D_in, D_hidden=512):
- super(MLP, self).__init__()
-
- self.l = torch.nn.Linear(D_in, D_hidden)
- self.D_head_out = D_hidden
-
- def forward(self, x):
- x = F.relu(self.l(x))
- return x
-
-class CNN_LSTM(torch.nn.Module):
- def __init__(self, D_in, D_hidden=512):
- super(CNN_LSTM, self).__init__()
-
- self.conv1 = torch.nn.Conv2d(in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4)
- dim1 = ((D_in[1] - 8)//4 + 1, (D_in[2] - 8)//4 + 1)
- self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
- dim2 = ((dim1[0] - 4)//2 + 1, (dim1[1] - 4)//2 + 1)
- self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
- dim3 = ((dim2[0] - 3)//1 + 1, (dim2[1] - 3)//1 + 1)
-
- self.D_conv_out = 64*dim3[0]*dim3[1]
-
- self.lstm = torch.nn.LSTM(input_size=self.D_conv_out, hidden_size=D_hidden, batch_first=True)
-
- self.D_head_out = D_hidden
-
- def forward(self, x, hidden_in=None):
- x = (x-(255.0/2))/(255.0/2)
-
- seq_len = x.size(1)
-
- if hidden_in is None:
- hidden_in = (torch.zeros(1, x.size(0), self.D_head_out).to(x.device),
- torch.zeros(1, x.size(0), self.D_head_out).to(x.device))
-
- x = x.reshape(-1, *x.shape[2:])
- x = F.relu(self.conv1(x))
- x = F.relu(self.conv2(x))
- x = F.relu(self.conv3(x))
- x = x.view(-1, seq_len, self.D_conv_out)
- x, hidden_out = self.lstm(x, hidden_in)
-
- return x, hidden_in, hidden_out
+# codes from https://github.com/openai/random-network-distillation
+class RewardForwardFilter(torch.nn.Module):
+ def __init__(self, gamma, num_workers):
+ super(RewardForwardFilter, self).__init__()
+ self.rewems = torch.nn.Parameter(torch.zeros(num_workers), requires_grad=False)
+ self.gamma = gamma
+
+ def update(self, rews):
+ self.rewems.data = self.rewems * self.gamma + rews
+ return self.rewems
-class MLP_LSTM(torch.nn.Module):
- def __init__(self, D_in, D_hidden=512):
- super(MLP_LSTM, self).__init__()
-
- self.l = torch.nn.Linear(D_in, D_hidden)
- self.lstm = torch.nn.LSTM(input_size=D_hidden, hidden_size=D_hidden, batch_first=True)
- self.D_head_out = D_hidden
-
- def forward(self, x, hidden_in=None):
- if hidden_in is None:
- hidden_in = (torch.zeros(1, x.size(0), self.D_head_out).to(x.device),
- torch.zeros(1, x.size(0), self.D_head_out).to(x.device))
-
- x = F.relu(self.l(x))
- x, hidden_out = self.lstm(x, hidden_in)
-
- return x, hidden_in, hidden_out
-
-import sys, inspect, re
-
-head_dict = {}
-for class_name, _class in inspect.getmembers(sys.modules[__name__], inspect.isclass):
- naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
- head_dict[naming_rule(class_name)] = _class
+# codes modified from https://github.com/openai/random-network-distillation
+class RunningMeanStd(torch.nn.Module):
+ def __init__(self, shape, epsilon=1e-4):
+ super(RunningMeanStd, self).__init__()
+
+ self.mean = torch.nn.Parameter(torch.zeros(shape), requires_grad=False)
+ self.var = torch.nn.Parameter(torch.zeros(shape), requires_grad=False)
+ self.count = torch.nn.Parameter(torch.tensor(epsilon), requires_grad=False)
+
+ def update(self, x):
+ batch_mean, batch_std, batch_count = x.mean(axis=0), x.std(axis=0), x.shape[0]
+ batch_var = torch.square(batch_std)
+ self.update_from_moments(batch_mean, batch_var, batch_count)
+
+ def update_from_moments(self, batch_mean, batch_var, batch_count):
+ delta = batch_mean - self.mean
+ tot_count = self.count + batch_count
+
+ new_mean = self.mean + delta * batch_count / tot_count
+ m_a = self.var * self.count
+ m_b = batch_var * (batch_count)
+ M2 = m_a + m_b + torch.square(delta) * self.count * batch_count / (self.count + batch_count)
+ new_var = M2 / (self.count + batch_count)
+
+ new_count = batch_count + self.count
+
+ self.mean.data = new_mean
+ self.var.data = new_var
+ self.count.data = new_count
+
\ No newline at end of file
diff --git a/core/optimizer/__init__.py b/core/optimizer/__init__.py
index 4d36d57c..8ec6e6b5 100644
--- a/core/optimizer/__init__.py
+++ b/core/optimizer/__init__.py
@@ -3,18 +3,18 @@
from torch.optim import *
-class_dict = {}
+optimizer_dict = {}
+naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
for class_name, _class in inspect.getmembers(sys.modules[__name__], inspect.isclass):
- naming_rule = lambda x: re.sub('([a-z])([A-Z])', r'\1_\2', x).lower()
- class_dict[naming_rule(class_name)] = _class
+ optimizer_dict[naming_rule(class_name)] = _class
working_path = os.path.dirname(os.path.realpath(__file__))
-class_dict = OrderedDict(sorted(class_dict.items()))
-with open(os.path.join(working_path, "_class_dict.txt"), 'w') as f:
- f.write('### Class Dictionary ###\n')
+optimizer_dict = OrderedDict(sorted(optimizer_dict.items()))
+with open(os.path.join(working_path, "_optimizer_dict.txt"), 'w') as f:
+ f.write('### Optimizer Dictionary ###\n')
f.write('format: (key, class)\n')
f.write('------------------------\n')
- for item in class_dict.items():
+ for item in optimizer_dict.items():
f.write(str(item) + '\n')
class Optimizer:
@@ -24,7 +24,7 @@ def __new__(self, name, *args, **kwargs):
print("### name variable must be string! ###")
raise Exception
name = name.lower()
- if not name in class_dict.keys():
- print(f"### can use only follows {[opt for opt in class_dict.keys()]}")
+ if not name in optimizer_dict.keys():
+ print(f"### can use only follows {[opt for opt in optimizer_dict.keys()]}")
raise Exception
- return class_dict[name](*args, **kwargs)
+ return optimizer_dict[name](*args, **kwargs)
diff --git a/core/optimizer/_class_dict.txt b/core/optimizer/_optimizer_dict.txt
similarity index 95%
rename from core/optimizer/_class_dict.txt
rename to core/optimizer/_optimizer_dict.txt
index bc49c1d3..8deba314 100644
--- a/core/optimizer/_class_dict.txt
+++ b/core/optimizer/_optimizer_dict.txt
@@ -1,4 +1,4 @@
-### Class Dictionary ###
+### Optimizer Dictionary ###
format: (key, class)
------------------------
('adadelta', )
diff --git a/docs/Benchmark.md b/docs/Benchmark.md
deleted file mode 100644
index b8366f92..00000000
--- a/docs/Benchmark.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# Benchmark
-
diff --git a/docs/Distributed_Architecture.md b/docs/Distributed_Architecture.md
index bc32c35b..70463ec6 100644
--- a/docs/Distributed_Architecture.md
+++ b/docs/Distributed_Architecture.md
@@ -1,2 +1,35 @@
# Distributed Architecture
+In addition to single actor train, it supports distributed reinforcement learning (synchronous and asynchronous both). To implement distributed reinforcement learning, we use __ray__(In particular, to allow actors to interact in parallel) and __multiprocessing__. See the flowchart and timeline for each script(single, sync and async). Flowchart shows the flow of data between processes between components. Timeline shows the work progress and data communication between processes.
+
+## Single actor train
+
+In a single actor train script, there is __main process__ and __manage process__. In the __main process__, a single agent interacts with env to collect transition data and trains network from it. In the __manage process__, evaluates with the latest network to get a score, and records this score and the results of training in the main process.
+
+### Flow chart
+
+
+### Timeline
+
+
+## Sync distributed train
+
+Sync distributed train script also has __main process__ and __manage process__. In the __main process__, multiple actors interact in parallel at the same time to collect transition data and learner trains model from it. In the __manage process__, evaluates with the latest model to get a score, and records this score and the results of training in the main process.
+
+### Flow chart
+
+
+### Timeline
+
+
+## Async distributed train
+
+Async distributed train script has __interact process__, __main process__ and __manage process__. In the __interact process__, multiple actors interact in parallel to collect transition data. Unlike the sync distributed train script, each actor interacts asynchronously. More specifically, in the async distributed train script, when actors interact, data is transferred only for actors that have completed within a specific time. In the __main process__, the learner trains the model through the transition data. In the __manage process__, evaluates with the latest model to get a score, and records this score and the results of training in the main process.
+
+### Flow chart
+
+
+### Timeline
+
+
+reference: [manager/distributed_manager.py](../manager/distributed_manager.py), [process](../process.py)
diff --git a/docs/How_to_add_environment.md b/docs/How_to_add_environment.md
deleted file mode 100644
index 0d7472d4..00000000
--- a/docs/How_to_add_environment.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# How to add new environment
-
diff --git a/docs/How_to_add_network.md b/docs/How_to_add_network.md
deleted file mode 100644
index 17c8f59a..00000000
--- a/docs/How_to_add_network.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# How to add new network
-
diff --git a/docs/How_to_add_rl_algorithm.md b/docs/How_to_add_rl_algorithm.md
deleted file mode 100644
index 942438ba..00000000
--- a/docs/How_to_add_rl_algorithm.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# How to add new RL algorithm
-
diff --git a/docs/How_to_use.md b/docs/How_to_use.md
index 420be308..6921f858 100644
--- a/docs/How_to_use.md
+++ b/docs/How_to_use.md
@@ -1,2 +1,127 @@
# How to use
+## How to Check Implemented List
+- In order to use the various agents(algorithms), environments, and networks provided by JORLDY, you need to know the name that calls the algorithm. JORLDY lists the names of the provided agent, env and network in **_agent_dict.txt**, **_env_dict.txt** and **_network_dict.txt**, respectively.
+- **_class_dict.txt** file shows *(key, class)*. You can call the desired element by writing this key to the config file.
+- **Note**: If you implement a new environment, agent, or network according to the our documentation and run **main.py**, **_class_dict.txt** will be updated automatically.
+
+### Agents
+- A list of implemented agents can be found in [_agent_dict.txt](../core/agent/_agent_dict.txt).
+
+- Example: You can check the key of the Ape-X agent in [_agent_dict.txt](../core/agent/_agent_dict.txt): *('ape_x', )*. If you want to use the Ape-X agent, write agent.name as *ape_x* in config file.
+
+```python
+agent = {
+ "name": "ape_x",
+ "network": "dueling",
+ ...
+}
+```
+
+### Environments
+- Provided environments list can be found in [_env_dict.txt](../core/env/_env_dict.txt).
+
+- Example: You can check the key of the Procgen starpilot environment in [_env_dict.txt](../core/env/_env_dict.txt): *('starpilot', )*. If you want to use the starpilot environment, it should be defined in the command using the key of starpilot environment. ex) python main.py --config config.dqn.procgen --env.name starpilot.
+
+### Networks
+- A list of implemented networks can be found in [_network_dict.txt](../core/network/_network_dict.txt).
+- If the network you want to use requires a head element, you should also include head in the config. A list of implemented heads can be found in [_head_dict.txt](../core/network/_head_dict.txt).
+- **Note**: To use head in your customized network, you should inherit the [BaseNetwork class](../core/network/base.py). We refer to [How to customize network](../core/network/README.md).
+
+- Example 1: You can check the key of the PPO discrete policy network in [_network_dict.txt](../core/network/_network_dict.txt): *('discrete_policy_value', )*. If you want to use the PPO discrete policy network, write agent.network as *discrete_policy_value* in config file.
+```python
+agent = {
+ "name":"ppo",
+ "network":"discrete_policy_value",
+ ...
+}
+```
+
+- Example 2: Use head case (image state); add the key "head" and set the value "cnn" to the agent dictionary in config file.
+
+```python
+agent = {
+ "name":"ppo",
+ "network":"discrete_policy_value",
+ "head": "cnn",
+ ...
+}
+```
+## Run Command Example
+- Default command line consists of script part and config part. When you type __*config path*__, you should omit '.py' in the name of the config file. If you do not type __*config path*__, It runs with the default config in the script.
+ ```
+ python [script name].py --config [config path]
+ ```
+ - Example:
+ ```
+ python single_train.py --config config.dqn.cartpole
+ ```
+
+- If you want to load environment in the atari (or procgen), use the atari (or procgen) config path and define environment by using the parser env.name.
+ ```
+ python [script name].py --config [config path] --env.name [env name]
+ ```
+ - Example:
+ ```
+ python single_train.py --config config.dqn.atari --env.name assault
+ ```
+- All parameters in the config file can be changed by using the parser without modifying the config file.
+ ```
+ python [script name].py --config [config path] --[optional parameter key] [optional parameter value]
+ ```
+ - Example:
+ ```
+ python single_train.py --config config.dqn.cartpole --agent.batch_size 64
+ ```
+ ```
+ python sync_distributed_train.py --config config.ppo.cartpole --train.num_worker 8
+ ```
+
+- Executable script list: **single_train.py**, **sync_distributed_train.py**, **async_distributed_train.py**.
+
+## Inference
+
+### Saved Files
+
+- The files are saved in the path **logs/[env name]/[Algorithm]/[Datetime]/**
+ - Ex) logs/breakout/rainbow/20211014152800
+- The saved files are as follows
+ - **gif files**: gif of test episode
+ - **ckpt**: saved Pytorch checkpoint file
+ - **config.py**: configuration of the running
+ - **events.out.tfevents...**: saved TensorBoard event file
+
+### How to add data in Tensorboard
+
+- The TensorBoard data can be added by modifying the **core/agent** codes
+- For example, Noisy algorithm adds mean value of the sigma to the Tensorboard. To do this, add sigma to the result dictionary inside the process function of the agent as follows.
+
+```python
+result = {
+"loss" : loss.item(),
+"max_Q": max_Q,
+"sig_w1": sig1_mean.item(),
+"sig_w2": sig2_mean.item(),
+}
+```
+
+- If you check the TensorBoard after performing the above process, you can see that the sigma values are added as follows.
+
+
+
+### How to load trained model
+
+- If you want to load the trained model, you should set path of the saved model in the train part in config.
+ - If the saved model is not loaded, set "load path" as "None"
+- Example
+ - env: space invaders
+ - algorithm: rainbow (if you want to test the model without training, set "training" as False)
+
+```python
+train = {
+ "training" : False,
+ "load_path" : "./logs/spaceinvaders/rainbow/20211015110908/",
+ ...
+}
+```
+
diff --git a/docs/Implementation_list.md b/docs/Implementation_list.md
index bfb392e3..38e42c22 100644
--- a/docs/Implementation_list.md
+++ b/docs/Implementation_list.md
@@ -29,6 +29,11 @@
- [Rainbow [DQN, IQN]](https://arxiv.org/abs/1710.02298)
+**Distributed**
+
+- [APE-X](https://arxiv.org/pdf/1803.00933.pdf)
+- [R2D2](https://openreview.net/pdf?id=r1lyTjAqYX)
+
**Policy Optimization, Actor-Critic**
- [REINFORCE [Discrete, Continuous]](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
diff --git a/docs/Naming_convention.md b/docs/Naming_convention.md
new file mode 100644
index 00000000..774b212a
--- /dev/null
+++ b/docs/Naming_convention.md
@@ -0,0 +1,23 @@
+# Class naming convention
+- We basically follow the PascalCase naming conventions in which the first letter of each word in a compound word is capitalized.
+ - Example: DiscretePolicy, MountainCar, ...
+- Write all acronyms in uppercase.
+ - Example: DQN, CNN, SAC, PPO, PongMLAgent
+- Acronyms and words (including abbreviations) are separated by '_'.
+ - Example: ICM_PPO, SAC_Critic, ...
+- Exeptional case:
+ - CartPole -> Cartpole
+ - The exception rule is applied to speed up debugging.
+ - SuperMarioBros -> Mario
+ - The exception rule is applied because the class name is too long.
+
+# Class calling convention
+- If there are consecutive lowercase and uppercase letters, add '_' between them.
+- Change all uppercase letters to lowercase
+- Example:
+ - DiscretePolicyValue -> discrete_policy_value
+ - SAC_Critic -> sac_critic
+ - PongMLAgent -> pong_mlagent
+
+# Function, Variable naming convention
+- We follow the Snake case in which each space is replaced by an underscore (_) character, and the first letter of each word written in lowercase.
\ No newline at end of file
diff --git a/main.py b/main.py
index d790a8e8..9d5b92d4 100644
--- a/main.py
+++ b/main.py
@@ -24,7 +24,7 @@
agent.load(config.train.load_path)
record_period = config.train.record_period if config.train.record_period else config.train.run_step//10
- test_manager = TestManager(Env(**config.env), config.train.test_iteration,
+ eval_manager = EvalManager(Env(**config.env), config.train.eval_iteration,
config.train.record, record_period)
metric_manager = MetricManager()
log_id = config.train.id if config.train.id else config.agent.name
@@ -53,7 +53,7 @@
state = env.reset()
if step % config.train.print_period == 0:
- score, frames = test_manager.test(agent, step)
+ score, frames = eval_manager.evaluate(agent, step)
metric_manager.append({"score": score})
statistics = metric_manager.get_statistics()
print(f"{episode} Episode / Step : {step} / {statistics}")
diff --git a/manager/README.md b/manager/README.md
new file mode 100644
index 00000000..5c62ac9d
--- /dev/null
+++ b/manager/README.md
@@ -0,0 +1,20 @@
+# Role of managers
+
+The manager is responsible for the non-learning aspects. The roles of each manager are as follows.
+
+### config_manager
+- It processes the config file and the optional parameter of run_command, and dumps the config to the storage path.
+
+### distributed_manager
+- It manages actors in distributed scripts. Let actors interact for update_period and sync actors network.
+
+### eval_manager
+- A simulation is performed for a certain number of episodes to obtain an evaluated score.
+- It saves the frames for recording.
+
+### log_manager
+- It records the learning progress in TensorBoard.
+- It receives the frames for recording and creates it as a gif.
+
+### metric_manager
+- It manages metrics and calculates and provides statistics.
diff --git a/manager/test_manager.py b/manager/eval_manager.py
similarity index 96%
rename from manager/test_manager.py
rename to manager/eval_manager.py
index c07406c0..1c0644ef 100644
--- a/manager/test_manager.py
+++ b/manager/eval_manager.py
@@ -1,6 +1,6 @@
import numpy as np
-class TestManager:
+class EvalManager:
def __init__(self, env, iteration=10, record=None, record_period=None):
self.env = env
self.iteration = iteration if iteration else 10
@@ -10,7 +10,7 @@ def __init__(self, env, iteration=10, record=None, record_period=None):
self.record_stamp = 0
self.time_t = 0
- def test(self, agent, step):
+ def evaluate(self, agent, step):
scores = []
frames = []
self.record_stamp += step - self.time_t
diff --git a/manager/time_manager.py b/manager/time_manager.py
deleted file mode 100644
index c43b4732..00000000
--- a/manager/time_manager.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import time
-from collections import deque
-
-class TimeManager:
- def __init__(self, n_mean = 20):
- self.n_mean = n_mean
- self.reset()
-
- def reset(self):
- self.timedict = dict()
-
- def start(self, keyword):
- if keyword not in self.timedict:
- self.timedict[keyword] = {
- 'start_timestamp': time.time(),
- 'deque': deque(maxlen=self.n_mean),
- 'mean': -1,
- 'last_time': -1,
- }
- else:
- self.timedict[keyword]['start_timestamp'] = time.time()
-
- def end(self, keyword):
- if keyword in self.timedict:
- time_current = time.time() - self.timedict[keyword]['start_timestamp']
- self.timedict[keyword]['last_time'] = time_current
- self.timedict[keyword]['deque'].append(time_current)
- self.timedict[keyword]['start_timestamp'] = -1
- self.timedict[keyword]['mean'] = sum(self.timedict[keyword]['deque']) / len(self.timedict[keyword]['deque'])
-
- return self.timedict[keyword]['last_time'], self.timedict[keyword]['mean']
-
- def get_statistics(self):
- return {k: self.timedict[k]['mean'] for k in self.timedict}
\ No newline at end of file
diff --git a/process.py b/process.py
index 3f6b0803..776a2ef6 100644
--- a/process.py
+++ b/process.py
@@ -26,10 +26,10 @@ def interact_process(DistributedManager, distributed_manager_config,
def manage_process(Agent, agent_config,
result_queue, sync_queue, path_queue,
run_step, print_period, MetricManager,
- TestManager, test_manager_config,
+ EvalManager, eval_manager_config,
LogManager, log_manager_config, config_manager):
agent = Agent(**agent_config)
- test_manager = TestManager(*test_manager_config)
+ eval_manager = EvalManager(*eval_manager_config)
metric_manager = MetricManager()
log_manager = LogManager(*log_manager_config)
path_queue.put(log_manager.path)
@@ -47,7 +47,7 @@ def manage_process(Agent, agent_config,
step = _step
if print_stamp >= print_period or step >= run_step:
agent.sync_in(**sync_queue.get())
- score, frames = test_manager.test(agent, step)
+ score, frames = eval_manager.evaluate(agent, step)
metric_manager.append({"score": score})
statistics = metric_manager.get_statistics()
print(f"Step : {step} / {statistics}")
diff --git a/resrc/async_distributed_train_flowchart.png b/resrc/async_distributed_train_flowchart.png
new file mode 100644
index 00000000..5b5c5807
Binary files /dev/null and b/resrc/async_distributed_train_flowchart.png differ
diff --git a/resrc/async_distributed_train_timeline.png b/resrc/async_distributed_train_timeline.png
new file mode 100644
index 00000000..226d7c50
Binary files /dev/null and b/resrc/async_distributed_train_timeline.png differ
diff --git a/img/breakout_result.gif b/resrc/breakout_result.gif
similarity index 100%
rename from img/breakout_result.gif
rename to resrc/breakout_result.gif
diff --git a/img/breakout_score.png b/resrc/breakout_score.png
similarity index 100%
rename from img/breakout_score.png
rename to resrc/breakout_score.png
diff --git a/img/contributors.png b/resrc/contributors.png
similarity index 100%
rename from img/contributors.png
rename to resrc/contributors.png
diff --git a/img/hopper_mlagent_score.png b/resrc/hopper_mlagent_score.png
similarity index 100%
rename from img/hopper_mlagent_score.png
rename to resrc/hopper_mlagent_score.png
diff --git a/img/hopper_result.gif b/resrc/hopper_result.gif
similarity index 100%
rename from img/hopper_result.gif
rename to resrc/hopper_result.gif
diff --git a/resrc/noisy_tensorboard.png b/resrc/noisy_tensorboard.png
new file mode 100644
index 00000000..14167984
Binary files /dev/null and b/resrc/noisy_tensorboard.png differ
diff --git a/img/pong_mlagent_score.png b/resrc/pong_mlagent_score.png
similarity index 100%
rename from img/pong_mlagent_score.png
rename to resrc/pong_mlagent_score.png
diff --git a/img/pong_result.gif b/resrc/pong_result.gif
similarity index 100%
rename from img/pong_result.gif
rename to resrc/pong_result.gif
diff --git a/img/quickstart.png b/resrc/quickstart.png
similarity index 100%
rename from img/quickstart.png
rename to resrc/quickstart.png
diff --git a/resrc/single_actor_train_flowchart.png b/resrc/single_actor_train_flowchart.png
new file mode 100644
index 00000000..52db1c58
Binary files /dev/null and b/resrc/single_actor_train_flowchart.png differ
diff --git a/resrc/single_actor_train_timeline.png b/resrc/single_actor_train_timeline.png
new file mode 100644
index 00000000..8136293c
Binary files /dev/null and b/resrc/single_actor_train_timeline.png differ
diff --git a/resrc/sync_distributed_train_flowchart.png b/resrc/sync_distributed_train_flowchart.png
new file mode 100644
index 00000000..1e24a6fa
Binary files /dev/null and b/resrc/sync_distributed_train_flowchart.png differ
diff --git a/resrc/sync_distributed_train_timeline.png b/resrc/sync_distributed_train_timeline.png
new file mode 100644
index 00000000..52f89596
Binary files /dev/null and b/resrc/sync_distributed_train_timeline.png differ
diff --git a/single_train.py b/single_train.py
index f0860db6..81b72d14 100644
--- a/single_train.py
+++ b/single_train.py
@@ -32,7 +32,7 @@
path_queue = mp.Queue(1)
record_period = config.train.record_period if config.train.record_period else config.train.run_step//10
- test_manager_config = (Env(**config.env), config.train.test_iteration, config.train.record, record_period)
+ eval_manager_config = (Env(**config.env), config.train.eval_iteration, config.train.record, record_period)
log_id = config.train.id if config.train.id else config.agent.name
log_manager_config = (config.env.name, log_id, config.train.experiment)
agent_config['device'] = "cpu"
@@ -40,7 +40,7 @@
args=(Agent, agent_config,
result_queue, manage_sync_queue, path_queue,
config.train.run_step, config.train.print_period,
- MetricManager, TestManager, test_manager_config,
+ MetricManager, EvalManager, eval_manager_config,
LogManager, log_manager_config, config_manager))
manage.start()
try:
diff --git a/sync_distributed_train.py b/sync_distributed_train.py
index e778a882..f8ec09e8 100644
--- a/sync_distributed_train.py
+++ b/sync_distributed_train.py
@@ -35,7 +35,7 @@
path_queue = mp.Queue(1)
record_period = config.train.record_period if config.train.record_period else config.train.run_step//10
- test_manager_config = (Env(**config.env), config.train.test_iteration, config.train.record, record_period)
+ eval_manager_config = (Env(**config.env), config.train.eval_iteration, config.train.record, record_period)
log_id = config.train.id if config.train.id else config.agent.name
log_manager_config = (config.env.name, log_id, config.train.experiment)
agent_config['device'] = "cpu"
@@ -43,7 +43,7 @@
args=(Agent, agent_config,
result_queue, manage_sync_queue, path_queue,
config.train.run_step, config.train.print_period,
- MetricManager, TestManager, test_manager_config,
+ MetricManager, EvalManager, eval_manager_config,
LogManager, log_manager_config, config_manager))
distributed_manager = DistributedManager(Env, config.env, Agent, agent_config, config.train.num_workers, 'sync')