|
198 | 198 | "from ray.rllib.algorithms.dqn import DQNConfig\n",
|
199 | 199 | "from ray import tune, train\n",
|
200 | 200 | "import os\n",
|
| 201 | + "import ray\n", |
201 | 202 | "\n",
|
202 | 203 | "# This is passed to each environment (SimpleRLPlayer) during training.\n",
|
203 | 204 | "# 'player_config' is passed as a kwarg to the super().__init__() of SimpleRLPlayer's Gen9EnvSinglePlayer superclass.\n",
|
|
208 | 209 | " 'start_challenging': True,\n",
|
209 | 210 | " },\n",
|
210 | 211 | " 'opponent_class': MaxBasePowerPlayer,\n",
|
211 |
| - " 'opponent_username': 'tr_MaxBasePower',\n", |
| 212 | + " 'opponent_username': 'tr_MaxBP',\n", |
212 | 213 | " 'opponent_config': {\n",
|
213 | 214 | " 'battle_format': \"gen9randombattle\",\n",
|
214 | 215 | " },\n",
|
|
221 | 222 | " 'start_challenging': True,\n",
|
222 | 223 | " },\n",
|
223 | 224 | " 'opponent_class': MaxBasePowerPlayer,\n",
|
224 |
| - " 'opponent_username': 'ev_MaxBasePower',\n", |
| 225 | + " 'opponent_username': 'ev_MaxBP',\n", |
225 | 226 | " 'opponent_config': {\n",
|
226 | 227 | " 'battle_format': \"gen9randombattle\",\n",
|
227 | 228 | " },\n",
|
228 | 229 | "}\n",
|
229 | 230 | "\n",
|
230 |
| - "# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters \n", |
| 231 | + "# Guide to RLLib parameters: https://docs.ray.io/en/latest/rllib/rllib-training.html#common-parameters\n", |
| 232 | + "\n", |
231 | 233 | "config = DQNConfig()\n",
|
232 | 234 | "config = config.environment(env = SimpleRLPlayer, env_config = train_env_config)\n",
|
233 | 235 | "# Set the framework to use. \"tf2\" for tensorflow, \"torch\" for PyTorch. Dev container is set up for Tensorflow 2.13.\n",
|
234 | 236 | "config = config.framework(framework=\"tf2\")\n",
|
235 | 237 | "config = config.resources(\n",
|
236 |
| - " num_cpus_for_main_process=2,\n", |
237 |
| - " num_gpus=1,\n", |
| 238 | + " num_cpus_for_main_process=4,\n", |
| 239 | + " num_gpus=0,\n", |
| 240 | + ")\n", |
| 241 | + "config = config.learners(\n", |
| 242 | + " num_learners=0,\n", |
| 243 | + " # num_gpus_per_learner=0\n", |
238 | 244 | ")\n",
|
239 | 245 | "config = config.env_runners(\n",
|
| 246 | + " # Number of cpus assigned to each env_runner. Does not improve sampling speed very much on its own. \n", |
240 | 247 | " num_cpus_per_env_runner=1,\n",
|
241 | 248 | " # Number of workers to run environments. 0 forces rollouts onto the local worker.\n",
|
242 |
| - " num_env_runners=20,\n", |
243 |
| - " num_envs_per_env_runner=1,\n", |
| 249 | + " num_env_runners=4,\n", |
| 250 | + " # Number of environments on each env_runner worker, higher drastically improves sampling speed.\n", |
| 251 | + " num_envs_per_env_runner=4,\n", |
244 | 252 | " # Don't cut off episodes before they finish when batching.\n",
|
245 | 253 | " # As a result, the batch size hyperparameter acts as a minimum and batches may vary in size.\n",
|
246 | 254 | " batch_mode=\"complete_episodes\",\n",
|
247 | 255 | " # Validation creates environments and does not close them, causes problems.\n",
|
248 | 256 | " # validate_env_runners_after_construction=False,\n",
|
249 |
| - " # rollout_fragment_length=300,\n", |
250 |
| - " rollout_fragment_length=\"auto\",\n", |
| 257 | + " rollout_fragment_length=50,\n", |
| 258 | + " # rollout_fragment_length=\"auto\",\n", |
251 | 259 | " explore=True,\n",
|
252 | 260 | " exploration_config = {\n",
|
253 | 261 | " \"type\": \"EpsilonGreedy\",\n",
|
|
271 | 279 | " \"capacity\": 100000,\n",
|
272 | 280 | " },\n",
|
273 | 281 | " num_steps_sampled_before_learning_starts=1000,\n",
|
274 |
| - " # v_min=-48, # minimum reward\n", |
275 |
| - " # v_max=48, # maximum reward\n", |
| 282 | + " v_min=-48, # minimum reward\n", |
| 283 | + " v_max=48, # maximum reward\n", |
276 | 284 | " # n_step=1,\n",
|
277 | 285 | " double_q=False,\n",
|
278 | 286 | " # double_q=tune.grid_search([True, False]),\n",
|
|
281 | 289 | " # noisy=tune.grid_search([True, False]),\n",
|
282 | 290 | " dueling=False,\n",
|
283 | 291 | " # dueling=tune.grid_search([True, False]),\n",
|
284 |
| - " train_batch_size=300\n", |
| 292 | + " train_batch_size=1200,\n", |
285 | 293 | ")\n",
|
286 | 294 | "config = config.evaluation(\n",
|
287 | 295 | " evaluation_interval=1,\n",
|
288 |
| - " evaluation_num_env_runners=2,\n", |
| 296 | + " evaluation_num_env_runners=4,\n", |
289 | 297 | " # evaluation_parallel_to_training=True,\n",
|
290 | 298 | " evaluation_duration=30,\n",
|
291 | 299 | " evaluation_config={\n",
|
|
296 | 304 | " },\n",
|
297 | 305 | ")\n",
|
298 | 306 | "# These settings allows runs to continue after a worker fails for whatever reason.\n",
|
299 |
| - "# config = config.fault_tolerance(recreate_failed_env_runners=True)\n", |
| 307 | + "config = config.fault_tolerance(recreate_failed_env_runners=True)" |
| 308 | + ] |
| 309 | + }, |
| 310 | + { |
| 311 | + "cell_type": "code", |
| 312 | + "execution_count": null, |
| 313 | + "metadata": {}, |
| 314 | + "outputs": [], |
| 315 | + "source": [ |
| 316 | + "## Set stopping criteria for the trials\n", |
| 317 | + "from ray.tune.stopper import CombinedStopper, MaximumIterationStopper, TrialPlateauStopper\n", |
300 | 318 | "\n",
|
301 |
| - "# This sets the stopping criteria for the run.\n", |
302 |
| - "stop = {\n", |
303 |
| - " # \"evaluation/env_runners/episode_reward_mean\": 30,\n", |
304 |
| - " \"training_iteration\": 100,\n", |
305 |
| - "}" |
| 319 | + "stopper = CombinedStopper(\n", |
| 320 | + " MaximumIterationStopper(max_iter=120),\n", |
| 321 | + " TrialPlateauStopper(metric=\"evaluation/env_runners/episode_reward_mean\"),\n", |
| 322 | + ")" |
306 | 323 | ]
|
307 | 324 | },
|
308 | 325 | {
|
|
329 | 346 | " # scheduler= NoneProvided, # When using concurrent trials, this ends or changes poorly performing trials early.\n",
|
330 | 347 | " ),\n",
|
331 | 348 | " run_config=train.RunConfig(\n",
|
332 |
| - " name=\"DQN_SimpleRL_vs_MaxBP\",\n", |
| 349 | + " name=\"DQN_SimpleRL_v_MaxBP_1\",\n", |
333 | 350 | " storage_path=os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results')),\n",
|
334 |
| - " stop=stop,\n", |
| 351 | + " stop=stopper,\n", |
335 | 352 | " checkpoint_config=train.CheckpointConfig(\n",
|
336 |
| - " checkpoint_frequency=10,\n", |
| 353 | + " checkpoint_frequency=1,\n", |
337 | 354 | " # checkpoint_score_attribute is the metric to use to determine which checkpoints to keep.\n",
|
338 | 355 | " checkpoint_score_attribute=\"evaluation/env_runners/episode_reward_mean\",\n",
|
339 | 356 | " # Only the best num_to_keep checkpoints are saved, using checkpoint_score_attribute as the metric to compare.\n",
|
340 | 357 | " num_to_keep=1,\n",
|
341 | 358 | " # checkpoint_score_order determines whether a higher (\"max\") or lower (\"min\") checkpoint_score_attribute is better.\n",
|
342 | 359 | " checkpoint_score_order=\"max\",\n",
|
343 |
| - " checkpoint_at_end=True\n", |
| 360 | + " # checkpoint_at_end=True\n", |
344 | 361 | " ),\n",
|
345 | 362 | " ),\n",
|
346 | 363 | "\n",
|
|
376 | 393 | "# If manually loading a checkpoint from a path, you can skip all above cells after SimpleRLPlayer class creation.\n",
|
377 | 394 | "# The test_checkpoint path should end with the checkpoint_XXXXXX directory, where X's are the checkpoint number with leading 0s.\n",
|
378 | 395 | "\n",
|
379 |
| - "# test_checkpoint = \"../results/DQN_SimpleRL_vs_MaxBP/DQN_SimpleRLPlayer_4345e_00024_24_lr=0.0002,weight_decay=0.0194_2024-07-17_13-00-24/checkpoint_000000\"" |
| 396 | + "# test_checkpoint = \"../results/DQN_SimpleRL_v_MaxBP_1/DQN_SimpleRLPlayer_9fc91_00010_10_train_batch_size=900_2024-07-26_13-50-33/checkpoint_000019\"" |
380 | 397 | ]
|
381 | 398 | },
|
382 | 399 | {
|
|
525 | 542 | "name": "python",
|
526 | 543 | "nbconvert_exporter": "python",
|
527 | 544 | "pygments_lexer": "ipython3",
|
528 |
| - "version": "3.11.0" |
| 545 | + "version": "3.11.0rc1" |
529 | 546 | }
|
530 | 547 | },
|
531 | 548 | "nbformat": 4,
|
|
0 commit comments