Skip to content

Commit e84c2f8

Browse files
committed
adjust number of envs per env_runner
1 parent b92ad89 commit e84c2f8

File tree

3 files changed

+44
-27
lines changed

3 files changed

+44
-27
lines changed

.devcontainer/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"image": "tensorflow/tensorflow:2.15.0-gpu-jupyter",
66

77
"runArgs": ["--gpus=all",
8-
"--shm-size=24gb"
8+
"--shm-size=50gb"
99
],
1010

1111
// Features to add to the dev container. More info: https://containers.dev/features.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
### Requirements Overview
44
This notebook uses RLLib, an open-source scalable reinforcement learning library in the Ray framework.
55
RLLib currently supports Python 3.9 - 3.12.
6-
RLLib supports both PyTorch and Tensorflow, so either may be used. Preferably, the library used will be CUDA-enabled to utilize the GPU, but it is optional. Nvidia GPU support only.
6+
RLLib supports both PyTorch and Tensorflow, so either may be used. This setup will assume GPU will be used, but it is not necessary for most algorithms. Training with GPU was found to be slightly slower than only using CPU for DQN. GPU use is most likely only useful for large models that take longer for inference or backprop.
77

88
### Tensorflow GPU Support
99
A dev container is provided that will set up a Linux Tensorflow 2.15.0-gpu-jupyter Docker container with everything set up for Tensorflow GPU support, which also starts its own local pokemon showdown server when started. The showdown server is port forwarded to be visible on the host, at http://localhost:8000.

notebooks/basic_rl.ipynb

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@
198198
"from ray.rllib.algorithms.dqn import DQNConfig\n",
199199
"from ray import tune, train\n",
200200
"import os\n",
201+
"import ray\n",
201202
"\n",
202203
"# This is passed to each environment (SimpleRLPlayer) during training.\n",
203204
"# 'player_config' is passed as a kwarg to the super().__init__() of SimpleRLPlayer's Gen9EnvSinglePlayer superclass.\n",
@@ -208,7 +209,7 @@
208209
" 'start_challenging': True,\n",
209210
" },\n",
210211
" 'opponent_class': MaxBasePowerPlayer,\n",
211-
" 'opponent_username': 'tr_MaxBasePower',\n",
212+
" 'opponent_username': 'tr_MaxBP',\n",
212213
" 'opponent_config': {\n",
213214
" 'battle_format': \"gen9randombattle\",\n",
214215
" },\n",
@@ -221,33 +222,40 @@
221222
" 'start_challenging': True,\n",
222223
" },\n",
223224
" 'opponent_class': MaxBasePowerPlayer,\n",
224-
" 'opponent_username': 'ev_MaxBasePower',\n",
225+
" 'opponent_username': 'ev_MaxBP',\n",
225226
" 'opponent_config': {\n",
226227
" 'battle_format': \"gen9randombattle\",\n",
227228
" },\n",
228229
"}\n",
229230
"\n",
230-
"# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters \n",
231+
"# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters\n",
232+
"\n",
231233
"config = DQNConfig()\n",
232234
"config = config.environment(env = SimpleRLPlayer, env_config = train_env_config)\n",
233235
"# Set the framework to use. \"tf2\" for tensorflow, \"torch\" for PyTorch. Dev container is set up for Tensorflow 2.13.\n",
234236
"config = config.framework(framework=\"tf2\")\n",
235237
"config = config.resources(\n",
236-
" num_cpus_for_main_process=2,\n",
237-
" num_gpus=1,\n",
238+
" num_cpus_for_main_process=4,\n",
239+
" num_gpus=0,\n",
240+
")\n",
241+
"config = config.learners(\n",
242+
" num_learners=0,\n",
243+
" # num_gpus_per_learner=0\n",
238244
")\n",
239245
"config = config.env_runners(\n",
246+
" # Number of cpus assigned to each env_runner. Does not improve sampling speed very much on its own. \n",
240247
" num_cpus_per_env_runner=1,\n",
241248
" # Number of workers to run environments. 0 forces rollouts onto the local worker.\n",
242-
" num_env_runners=20,\n",
243-
" num_envs_per_env_runner=1,\n",
249+
" num_env_runners=4,\n",
250+
" # Number of environments on each env_runner worker, higher drastically improves sampling speed.\n",
251+
" num_envs_per_env_runner=4,\n",
244252
" # Don't cut off episodes before they finish when batching.\n",
245253
" # As a result, the batch size hyperparameter acts as a minimum and batches may vary in size.\n",
246254
" batch_mode=\"complete_episodes\",\n",
247255
" # Validation creates environments and does not close them, causes problems.\n",
248256
" # validate_env_runners_after_construction=False,\n",
249-
" # rollout_fragment_length=300,\n",
250-
" rollout_fragment_length=\"auto\",\n",
257+
" rollout_fragment_length=50,\n",
258+
" # rollout_fragment_length=\"auto\",\n",
251259
" explore=True,\n",
252260
" exploration_config = {\n",
253261
" \"type\": \"EpsilonGreedy\",\n",
@@ -271,8 +279,8 @@
271279
" \"capacity\": 100000,\n",
272280
" },\n",
273281
" num_steps_sampled_before_learning_starts=1000,\n",
274-
" # v_min=-48, # minimum reward\n",
275-
" # v_max=48, # maximum reward\n",
282+
" v_min=-48, # minimum reward\n",
283+
" v_max=48, # maximum reward\n",
276284
" # n_step=1,\n",
277285
" double_q=False,\n",
278286
" # double_q=tune.grid_search([True, False]),\n",
@@ -281,11 +289,11 @@
281289
" # noisy=tune.grid_search([True, False]),\n",
282290
" dueling=False,\n",
283291
" # dueling=tune.grid_search([True, False]),\n",
284-
" train_batch_size=300\n",
292+
" train_batch_size=1200,\n",
285293
")\n",
286294
"config = config.evaluation(\n",
287295
" evaluation_interval=1,\n",
288-
" evaluation_num_env_runners=2,\n",
296+
" evaluation_num_env_runners=4,\n",
289297
" # evaluation_parallel_to_training=True,\n",
290298
" evaluation_duration=30,\n",
291299
" evaluation_config={\n",
@@ -296,13 +304,22 @@
296304
" },\n",
297305
")\n",
298306
"# These settings allows runs to continue after a worker fails for whatever reason.\n",
299-
"# config = config.fault_tolerance(recreate_failed_env_runners=True)\n",
307+
"config = config.fault_tolerance(recreate_failed_env_runners=True)"
308+
]
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": null,
313+
"metadata": {},
314+
"outputs": [],
315+
"source": [
316+
"## Set stopping criteria for the trials\n",
317+
"from ray.tune.stopper import CombinedStopper, MaximumIterationStopper, TrialPlateauStopper\n",
300318
"\n",
301-
"# This sets the stopping criteria for the run.\n",
302-
"stop = {\n",
303-
" # \"evaluation/env_runners/episode_reward_mean\": 30,\n",
304-
" \"training_iteration\": 100,\n",
305-
"}"
319+
"stopper = CombinedStopper(\n",
320+
" MaximumIterationStopper(max_iter=120),\n",
321+
" TrialPlateauStopper(metric=\"evaluation/env_runners/episode_reward_mean\"),\n",
322+
")"
306323
]
307324
},
308325
{
@@ -329,18 +346,18 @@
329346
" # scheduler= NoneProvided, # When using concurrent trials, this ends or changes poorly performing trials early.\n",
330347
" ),\n",
331348
" run_config=train.RunConfig(\n",
332-
" name=\"DQN_SimpleRL_vs_MaxBP\",\n",
349+
" name=\"DQN_SimpleRL_v_MaxBP_1\",\n",
333350
" storage_path=os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results')),\n",
334-
" stop=stop,\n",
351+
" stop=stopper,\n",
335352
" checkpoint_config=train.CheckpointConfig(\n",
336-
" checkpoint_frequency=10,\n",
353+
" checkpoint_frequency=1,\n",
337354
" # checkpoint_score_attribute is the metric to use to determine which checkpoints to keep.\n",
338355
" checkpoint_score_attribute=\"evaluation/env_runners/episode_reward_mean\",\n",
339356
" # Only the best num_to_keep checkpoints are saved, using checkpoint_score_attribute as the metric to compare.\n",
340357
" num_to_keep=1,\n",
341358
" # checkpoint_score_order determines whether a higher (\"max\") or lower (\"min\") checkpoint_score_attribute is better.\n",
342359
" checkpoint_score_order=\"max\",\n",
343-
" checkpoint_at_end=True\n",
360+
" # checkpoint_at_end=True\n",
344361
" ),\n",
345362
" ),\n",
346363
"\n",
@@ -376,7 +393,7 @@
376393
"# If manually loading a checkpoint from a path, you can skip all above cells after SimpleRLPlayer class creation.\n",
377394
"# The test_checkpoint path should end with the checkpoint_XXXXXX directory, where X's are the checkpoint number with leading 0s.\n",
378395
"\n",
379-
"# test_checkpoint = \"../results/DQN_SimpleRL_vs_MaxBP/DQN_SimpleRLPlayer_4345e_00024_24_lr=0.0002,weight_decay=0.0194_2024-07-17_13-00-24/checkpoint_000000\""
396+
"# test_checkpoint = \"../results/DQN_SimpleRL_v_MaxBP_1/DQN_SimpleRLPlayer_9fc91_00010_10_train_batch_size=900_2024-07-26_13-50-33/checkpoint_000019\""
380397
]
381398
},
382399
{
@@ -525,7 +542,7 @@
525542
"name": "python",
526543
"nbconvert_exporter": "python",
527544
"pygments_lexer": "ipython3",
528-
"version": "3.11.0"
545+
"version": "3.11.0rc1"
529546
}
530547
},
531548
"nbformat": 4,

0 commit comments

Comments
 (0)