-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
633 lines (536 loc) · 21.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
import glob
import math # MODIFIED:
import os
import platform
import random
import re
import sys # MODIFIED:
from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
import cloudpickle
import gymnasium as gym
import numpy as np
import stable_baselines3 as sb3
import torch as th
from gymnasium import spaces
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import (
GymEnv,
Schedule,
TensorDict,
TrainFreq,
TrainFrequencyUnit,
)
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
# MODIFIED: Fix the overflow warning
max_float_value = sys.float_info.max
var = np.var(y_true - y_pred)
var_y = np.clip(var_y, -max_float_value, max_float_value)
var = np.clip(var, -max_float_value, max_float_value)
if math.isclose(var_y, 0):
return np.nan
else:
return 1 - var / var_y
# Instead of this line, we use the above lines to avoid overflow warning
# return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
def update_learning_rate(
optimizer: th.optim.Optimizer, learning_rate: float
) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return: Linear schedule function.
"""
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return (
start + (1 - progress_remaining) * (end - start) / end_fraction
)
return func
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val: constant value
:return: Constant schedule function.
"""
def func(_):
return val
return func
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(
os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")
):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if (
log_name == "_".join(file_name.split("_")[:-1])
and ext.isdigit()
and int(ext) > max_run_id
):
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0,
tensorboard_log: Optional[str] = None,
tb_log_name: str = "",
reset_num_timesteps: bool = True,
) -> Logger:
"""
Configure the logger's outputs.
:param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
:param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
It allows to continue a previous learning curve (``reset_num_timesteps=False``)
or start from t=0 (``reset_num_timesteps=True``, the default).
:return: The logger object
"""
save_path, format_strings = None, ["stdout"]
if tensorboard_log is not None and SummaryWriter is None:
raise ImportError(
"Trying to log data to tensorboard but tensorboard is not installed."
)
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(
tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}"
)
if verbose >= 1:
format_strings = ["stdout", "tensorboard"]
else:
format_strings = ["tensorboard"]
elif verbose == 0:
format_strings = [""]
return configure(save_path, format_strings=format_strings)
def check_for_correct_spaces(
env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space
) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(
f"Observation spaces do not match: {observation_space} != {env.observation_space}"
)
if action_space != env.action_space:
raise ValueError(
f"Action spaces do not match: {action_space} != {env.action_space}"
)
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(
space2, spaces.Dict
), "spaces must be of the same type"
assert (
space1.spaces.keys() == space2.spaces.keys()
), "spaces must have the same keys"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, "spaces must have the same shape"
def is_vectorized_box_observation(
observation: np.ndarray, observation_space: spaces.Box
) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(
", ".join(map(str, observation_space.shape))
)
)
def is_vectorized_discrete_observation(
observation: Union[int, np.ndarray], observation_space: spaces.Discrete
) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if (
isinstance(observation, int) or observation.shape == ()
): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use () or (n_env,) for the observation shape."
)
def is_vectorized_multidiscrete_observation(
observation: np.ndarray, observation_space: spaces.MultiDiscrete
) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(
observation_space.nvec
):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
def is_vectorized_multibinary_observation(
observation: np.ndarray, observation_space: spaces.MultiBinary
) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif (
len(observation.shape) == len(observation_space.shape) + 1
and observation.shape[1:] == observation_space.shape
):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
def is_vectorized_dict_observation(
observation: np.ndarray, observation_space: spaces.Dict
) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
# We first assume that all observations are not vectorized
all_non_vectorized = True
for key, subspace in observation_space.spaces.items():
# This fails when the observation is not vectorized
# or when it has the wrong shape
if observation[key].shape != subspace.shape:
all_non_vectorized = False
break
if all_non_vectorized:
return False
all_vectorized = True
# Now we check that all observation are vectorized and have the correct shape
for key, subspace in observation_space.spaces.items():
if observation[key].shape[1:] != subspace.shape:
all_vectorized = False
break
if all_vectorized:
return True
else:
# Retrieve error message
error_msg = ""
try:
is_vectorized_observation(
observation[key], observation_space.spaces[key]
)
except ValueError as e:
error_msg = f"{e}"
raise ValueError(
f"There seems to be a mix of vectorized and non-vectorized observations. "
f"Unexpected observation shape {observation[key].shape} for key {key} "
f"of type {observation_space.spaces[key]}. {error_msg}"
)
def is_vectorized_observation(
observation: Union[int, np.ndarray], observation_space: spaces.Space
) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
is_vec_obs_func_dict = {
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space) # type: ignore[operator]
else:
# for-else happens if no break is called
raise ValueError(
f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}."
)
def safe_mean(arr: Union[np.ndarray, list, deque]) -> float:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
def get_parameters_by_name(
model: th.nn.Module, included_names: Iterable[str]
) -> List[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [
param
for name, param in model.state_dict().items()
if any([key in name for key in included_names])
]
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.
:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
def polyak_update(
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(
target_param.data, param.data, alpha=tau, out=target_param.data
)
def obs_as_tensor(
obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device
) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {
key: th.as_tensor(_obs, device=device)
for (key, _obs) in obs.items()
}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
:param print_info: Whether to print or not those infos
:return: Dictionary summing up the version for each relevant package
and a formatted string.
"""
env_info = {
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(
r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"
),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym
env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:
pass
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str