diff --git a/.github/unittest/helpers/coverage_run_parallel.py b/.github/unittest/helpers/coverage_run_parallel.py index 8c6251cf82b..9b97b848f53 100644 --- a/.github/unittest/helpers/coverage_run_parallel.py +++ b/.github/unittest/helpers/coverage_run_parallel.py @@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None: argv: Arguments passed to this script, which need to be converted to config file entries """ assert not config_path.exists(), "Temporary coverage config exists already" - cmdline = " ".join(shlex.quote(arg) for arg in argv[1:]) - with open(str(config_path), "wt", encoding="utf-8") as fh: + cmdline = shlex.join(argv[1:]) + with open(str(config_path), "w", encoding="utf-8") as fh: fh.write( f"""# .coveragerc to control coverage.py [run] diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 7b710a751fe..25ac1c49e43 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,7 +35,7 @@ jobs: echo '::endgroup::' echo '::group::Install lint tools' - pip install --progress-bar=off pre-commit + pip install --progress-bar=off pre-commit autoflake echo '::endgroup::' echo '::group::Lint Python source and configs' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 532445125aa..37adaef7979 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,3 +35,17 @@ repos: hooks: - id: pydocstyle files: ^torchrl/ + + - repo: https://github.com/asottile/pyupgrade + rev: v3.9.0 + hooks: + - id: pyupgrade + args: [--py38-plus] + + - repo: local + hooks: + - id: autoflake + name: autoflake + entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports + language: system + types: [python] diff --git a/build_tools/setup_helpers/__init__.py b/build_tools/setup_helpers/__init__.py index 6c424ebba14..52c1db79251 100644 --- a/build_tools/setup_helpers/__init__.py +++ b/build_tools/setup_helpers/__init__.py @@ -3,4 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .extension import CMakeBuild, get_ext_modules # noqa +from .extension import CMakeBuild, get_ext_modules + +__all__ = ["CMakeBuild", "get_ext_modules"] diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 6e950caa237..82df53d3af0 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -14,7 +14,6 @@ from setuptools import Extension from setuptools.command.build_ext import build_ext - _THIS_DIR = Path(__file__).parent.resolve() _ROOT_DIR = _THIS_DIR.parent.parent.resolve() _TORCHRL_DIR = _ROOT_DIR / "torchrl" @@ -130,7 +129,7 @@ def build_extension(self, ext): # using -j in the build_ext call, not supported by pip or PyPA-build. if hasattr(self, "parallel") and self.parallel: # CMake 3.12+ only. - build_args += ["-j{}".format(self.parallel)] + build_args += [f"-j{self.parallel}"] if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index b53ac84585d..71bc2b2219f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1220,7 +1220,7 @@ Recorders are transforms that register data as they come in, for logging purpose Helpers ------- -.. currentmodule:: torchrl.envs.utils +.. currentmodule:: torchrl.envs .. autosummary:: :toctree: generated/ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 3d88536e3d9..f2741809bd3 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -111,6 +111,7 @@ auto-completion to make their choice. :template: rl_template_noinherit.rst LossModule + add_random_module DQN --- diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py index 3de34d55166..b5be188fbd9 100644 --- a/examples/rlhf/models/actor_critic.py +++ b/examples/rlhf/models/actor_critic.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator from torchrl.modules.tensordict_module.common import VmapModule diff --git a/setup.cfg b/setup.cfg index 985c68e5af9..0649a97497f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,3 +45,7 @@ ignore-decorators = test_* ; test/*.py ; .circleci/* + +[autoflake] +per-file-ignores = + torchrl/trainers/helpers/envs.py * diff --git a/setup.py b/setup.py index aebbf09037a..252b67e3187 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_version(): version_txt = os.path.join(cwd, "version.txt") - with open(version_txt, "r") as f: + with open(version_txt) as f: version = f.readline().strip() if os.getenv("TORCHRL_BUILD_VERSION"): version = os.getenv("TORCHRL_BUILD_VERSION") @@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace: def write_version_file(version): version_path = os.path.join(cwd, "torchrl", "version.py") with open(version_path, "w") as f: - f.write("__version__ = '{}'\n".format(version)) - f.write("git_version = {}\n".format(repr(sha))) + f.write(f"__version__ = '{version}'\n") + f.write(f"git_version = {repr(sha)}\n") def _get_pytorch_version(is_nightly, is_local): @@ -185,7 +185,7 @@ def _main(argv): version = get_version() write_version_file(version) TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION") - logging.info("Building wheel {}-{}".format(package_name, version)) + logging.info(f"Building wheel {package_name}-{version}") logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}") is_local = TORCHRL_BUILD_VERSION is None diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 3279d6e0a2b..4d12a75ea0f 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -13,7 +13,7 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 from copy import deepcopy diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 41e05dc1326..d07ee6621af 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -13,7 +13,7 @@ @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 from copy import deepcopy diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 2e1a20ad7a2..fc388399878 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -15,16 +15,13 @@ import hydra import numpy as np - import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -39,7 +36,7 @@ @hydra.main(config_path="", config_name="offline_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name) logger = None diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index e992bdb5939..5d25a34ba10 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -21,12 +21,10 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -42,7 +40,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="online_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # Create logger exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name) logger = None diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index d45ce3745fe..2f7441ee4eb 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -16,16 +16,12 @@ import hydra import numpy as np - import torch import torch.cuda import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( log_metrics, @@ -41,7 +37,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.optim.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index d84613e6876..619f2395fb1 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -15,19 +15,15 @@ import warnings import hydra - import numpy as np - import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( log_metrics, @@ -43,7 +39,7 @@ @hydra.main(version_base="1.1", config_path=".", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index bcb7ee6ef54..5b6d308aba2 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -15,16 +15,13 @@ import warnings import hydra - import numpy as np import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger @@ -41,7 +38,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.optim.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 9e8446ed82f..f565aafeafc 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -19,11 +19,9 @@ from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder - from utils import ( dump_video, log_metrics, @@ -37,7 +35,7 @@ @hydra.main(config_path="", config_name="dt_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 1404cb7ebc0..baab8bbb9a6 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -20,7 +20,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder - from utils import ( dump_video, log_metrics, @@ -34,7 +33,7 @@ @hydra.main(config_path="", config_name="odt_config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 9ff50902887..1c97163b95a 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -38,7 +38,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 786e5d2ebb0..c2bffd91869 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -18,7 +18,6 @@ import tqdm from tensordict.nn import CudaGraphModule, TensorDictSequential from torchrl._utils import timeit - from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -32,7 +31,7 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.device if device in ("", None): diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 4fde452fba9..87be7fd603a 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -11,7 +11,6 @@ import torch.nn import torch.optim import tqdm - from tensordict.nn import CudaGraphModule, TensorDictSequential from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector @@ -27,7 +26,7 @@ @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.device if device in ("", None): diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a197796e978..a39c8904916 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -11,6 +11,7 @@ import torch import torch.cuda import tqdm + from dreamer_utils import ( _default_device, dump_video, @@ -27,7 +28,6 @@ from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import RSSMRollout - from torchrl.objectives.dreamer import ( DreamerActorLoss, DreamerModelLoss, @@ -37,7 +37,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 # cfg = correct_for_frame_skip(cfg) device = _default_device(cfg.networks.device) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index bdb8843aaf6..c7fa393a2bd 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -17,16 +17,13 @@ import numpy as np import torch import tqdm - from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers @@ -34,12 +31,11 @@ from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() device = cfg.gail.device diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py index ce09292cc47..328a0864f31 100644 --- a/sota-implementations/gail/gail_utils.py +++ b/sota-implementations/gail/gail_utils.py @@ -6,11 +6,9 @@ import torch.nn as nn import torch.optim - from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import DoubleToFloat - from torchrl.modules import SafeModule @@ -45,7 +43,7 @@ def make_gail_discriminator(cfg, train_env, device="cpu"): # Define Discriminator Network class Discriminator(nn.Module): def __init__(self, state_dim, action_dim): - super(Discriminator, self).__init__() + super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, 1) diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index dcf908c2cd2..5364c82c7b2 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -14,7 +14,7 @@ @hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 4d90e9053bd..527821820ca 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -16,7 +16,7 @@ @hydra.main( config_path="", config_name="config_multi_node_submitit", version_base="1.1" ) -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index cda63ac0919..b7af2adbc38 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -14,7 +14,7 @@ @hydra.main(config_path="", config_name="config_single_node", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import time diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index aa4cea04024..43a8dcafa6e 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -21,14 +21,11 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -40,12 +37,11 @@ make_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="discrete_iql") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index eaf791438cc..6585534ff68 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -18,14 +18,11 @@ import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -36,12 +33,11 @@ make_offline_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="offline_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 5b90f00c467..eaa37f29176 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -20,14 +20,11 @@ import torch import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -39,12 +36,11 @@ make_replay_buffer, ) - torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="online_config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.backend).set() # Create logger diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 2692c1c24b5..56673ac9590 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -31,7 +30,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="iql") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index f04ccb19071..eb6700d766f 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -36,7 +35,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="maddpg_iddpg") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index 924ea12272a..0d80896fc9b 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn @@ -32,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="mappo_ippo") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index a832a29e6dd..4fed4fea5f5 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger @@ -31,7 +30,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="qmix_vdn") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 31106bdd2a0..cc30011f4a0 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -8,7 +8,6 @@ import hydra import torch - from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn @@ -33,7 +32,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path="", config_name="sac") -def train(cfg: "DictConfig"): # noqa: F821 +def train(cfg: DictConfig): # noqa: F821 # Device cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 8ecb675535b..25b6f63e893 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -12,12 +12,11 @@ import warnings import hydra - from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_atari", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import torch.optim import tqdm diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 27ae7e57848..a17d0b90339 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -12,12 +12,11 @@ import warnings import hydra - from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 import torch.optim import tqdm diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index 3dec888145c..58072701663 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -42,7 +42,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 cfg = correct_for_frame_skip(cfg) diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index b67b02c42f9..0528e2b809e 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -5,7 +5,7 @@ from __future__ import annotations from copy import copy -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Sequence import torch from omegaconf import OmegaConf @@ -20,7 +20,6 @@ from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase - from torchrl.data import ( LazyMemmapStorage, MultiStep, @@ -105,7 +104,7 @@ } -def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 +def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames @@ -172,7 +171,7 @@ def make_trainer( policy_exploration: TensorDictModuleWrapper | TensorDictModule | None, replay_buffer: ReplayBuffer | None, logger: Logger | None, - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 ) -> Trainer: """Creates a Trainer instance given its constituents. @@ -377,7 +376,7 @@ def make_trainer( def make_redq_model( proof_environment: EnvBase, - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 device: DEVICE_TYPING = "cpu", in_keys: Sequence[str] | None = None, actor_net_kwargs=None, @@ -555,7 +554,7 @@ def make_redq_model( def transformed_env_constructor( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 video_tag: str = "", logger: Logger | None = None, stats: dict | None = None, @@ -568,7 +567,7 @@ def transformed_env_constructor( state_dim_gsde: int | None = None, batch_dims: int | None = 0, obs_norm_state_dict: dict | None = None, -) -> Union[Callable, EnvCreator]: +) -> Callable | EnvCreator: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -688,7 +687,7 @@ def get_norm_state_dict(env): def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: Union[str, Tuple[str, ...]] = None, + key: str | tuple[str, ...] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. @@ -729,8 +728,8 @@ def initialize_observation_norm_transforms( def parallel_env_constructor( - cfg: "DictConfig", **kwargs # noqa: F821 -) -> Union[ParallelEnv, EnvCreator]: + cfg: DictConfig, **kwargs # noqa: F821 +) -> ParallelEnv | EnvCreator: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -916,9 +915,7 @@ def make_env_transforms( return env -def make_redq_loss( - model, cfg -) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: +def make_redq_loss(model, cfg) -> tuple[REDQLoss_deprecated, TargetNetUpdater | None]: """Builds the REDQ loss module.""" loss_kwargs = {} loss_kwargs.update({"loss_function": cfg.loss.loss_function}) @@ -950,7 +947,7 @@ def make_redq_loss( def make_target_updater( - cfg: "DictConfig", loss_module: LossModule # noqa: F821 + cfg: DictConfig, loss_module: LossModule # noqa: F821 ) -> TargetNetUpdater | None: """Builds a target network weight update object.""" if cfg.loss.type == "double": @@ -976,8 +973,8 @@ def make_target_updater( def make_collector_offpolicy( make_env: Callable[[], EnvBase], actor_model_explore: TensorDictModuleWrapper | ProbabilisticTensorDictSequential, - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Dict | None = None, + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Returns a data collector for off-policy sota-implementations. @@ -1037,7 +1034,7 @@ def make_collector_offpolicy( def make_replay_buffer( - device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 + device: DEVICE_TYPING, cfg: DictConfig # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = torch.device(device) diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index e159824f9cd..7fd6284037e 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -15,18 +15,15 @@ import warnings import hydra - import numpy as np import torch import torch.cuda import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -43,7 +40,7 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 3a741735a1c..f7b10e8cdf9 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -20,11 +20,8 @@ import torch.cuda import tqdm from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit - from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -37,12 +34,11 @@ make_td3_agent, ) - torch.set_float32_matmul_precision("high") @hydra.main(version_base="1.1", config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index ac65f2875cf..6c628904908 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -19,13 +19,10 @@ import tqdm from tensordict import TensorDict from tensordict.nn import CudaGraphModule - from torchrl._utils import compile_with_warmup, timeit - from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger - from utils import ( dump_video, log_metrics, @@ -38,7 +35,7 @@ @hydra.main(config_path="", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 +def main(cfg: DictConfig): # noqa: F821 set_gym_backend(cfg.env.library).set() # Create logger diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 5c4b9930089..05fdada16d2 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -7,7 +7,6 @@ import contextlib import logging import os - import os.path import sys import time @@ -15,16 +14,13 @@ import warnings from functools import wraps -# Get relative file path -# this returns relative path from current file. - import pytest import torch import torch.cuda - from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase from tensordict.nn import TensorDictModuleBase from torch import nn, vmap + from torchrl._utils import ( implement_for, logger as torchrl_logger, @@ -32,7 +28,6 @@ seed_generator, ) from torchrl.data.utils import CloudpickleWrapper - from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool @@ -46,6 +41,9 @@ from torchrl.modules import MLP from torchrl.objectives.value.advantages import _vmap_func +# Get relative file path +# this returns relative path from current file. + # Specified for test_utils.py __version__ = "0.3" @@ -671,7 +669,7 @@ def _lstm( if hidden1_in is None and hidden0_in is None: shape = (batch, steps) if not squeeze1 else (batch,) - hidden0_in, hidden1_in = [ + hidden0_in, hidden1_in = ( torch.zeros( *shape, self.lstm.num_layers, @@ -680,7 +678,7 @@ def _lstm( dtype=input.dtype, ) for _ in range(2) - ] + ) elif hidden1_in is None or hidden0_in is None: raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" diff --git a/test/conftest.py b/test/conftest.py index f2648a18041..ba49735b98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import os import sys diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d407a2ac241..5f03e773591 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -6,10 +6,8 @@ import random import string -from typing import Dict, List, Optional import numpy as np - import torch import torch.nn as nn from tensordict import tensorclass, TensorDict, TensorDictBase @@ -133,7 +131,7 @@ def __init__( def maxstep(self): return 100 - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self.seed = seed self.counter = seed % 17 # make counter a small number @@ -218,10 +216,10 @@ def __new__( return super().__new__(*args, **kwargs) def __init__(self, device="cpu"): - super(MockSerialEnv, self).__init__(device=device) + super().__init__(device=device) self.is_closed = False - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number @@ -259,7 +257,7 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: device=self.device, ) - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: + def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: return self.step(tensordict) @@ -338,12 +336,12 @@ def __new__( return super().__new__(cls, *args, **kwargs) def __init__(self, device="cpu", batch_size=None): - super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size) + super().__init__(device=device, batch_size=batch_size) self.counter = 0 rand_step = MockSerialEnv.rand_step - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number @@ -422,9 +420,7 @@ class MockBatchedUnLockedEnv(MockBatchedLockedEnv): """ def __init__(self, device="cpu", batch_size=None): - super(MockBatchedUnLockedEnv, self).__init__( - batch_size=batch_size, device=device - ) + super().__init__(batch_size=batch_size, device=device) @classmethod def __new__(cls, *args, **kwargs): @@ -510,7 +506,7 @@ def _step( device=tensordict.device, ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... @@ -1113,7 +1109,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1206,7 +1202,7 @@ def __init__( self, n_agents: int, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, max_steps: int = 5, start_val: int = 0, **kwargs, @@ -1287,7 +1283,7 @@ def __init__( torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1600,7 +1596,7 @@ def __init__( elif start_val.numel() <= 1: self.start_val = start_val.expand_as(self.count) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: @@ -1816,7 +1812,7 @@ def _step( return td - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) @@ -2047,7 +2043,7 @@ def _step( assert td.batch_size == self.batch_size return td - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): torch.manual_seed(seed) @@ -2084,7 +2080,7 @@ def _step( data.update(self._saved_full_reward_spec.zero()) return data - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): return seed @@ -2210,7 +2206,7 @@ def _step( reward = self.full_reward_spec.zero() return observation.update(done).update(reward) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self.manual_seed = seed return seed @@ -2280,7 +2276,7 @@ def _step( ), ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... @@ -2328,7 +2324,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: def get_random_string(self): return get_random_string(self.min_size, self.max_size) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): random.seed(seed) torch.manual_seed(0) return seed @@ -2356,7 +2352,7 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: .update(self.full_reward_spec.zero()) ) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): ... diff --git a/test/opengl_rendering.py b/test/opengl_rendering.py index 0e2f86294c1..1ec4c248841 100644 --- a/test/opengl_rendering.py +++ b/test/opengl_rendering.py @@ -23,7 +23,6 @@ # OpenGL context is available here. """ -from __future__ import print_function # pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index a803707408c..d6133ed1a64 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -9,8 +9,6 @@ import pytest -from torchrl.envs.libs.gym import gym_backend - def test_dm_control(): import dm_control # noqa: F401 diff --git a/test/test_collector.py b/test/test_collector.py index f3935d1086b..0ee85798f7d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -10,9 +10,7 @@ import gc import os import subprocess - import sys -from typing import Optional from unittest.mock import patch import numpy as np @@ -32,8 +30,8 @@ TensorDictModuleBase, TensorDictSequential, ) - from torch import nn + from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -1820,7 +1818,7 @@ def _step( def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: return self.full_done_specs.zeros().update(self.observation_spec.zeros()) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): return seed @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") diff --git a/test/test_cost.py b/test/test_cost.py index 33f7cfb0e80..3fd1fad62da 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -100,7 +100,7 @@ TD3BCLoss, TD3Loss, ) -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import add_random_module, LossModule from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss @@ -5859,7 +5859,6 @@ def test_crossq_tensordict_keys(self, td_est): actor = self._create_mock_actor() qvalue = self._create_mock_qvalue() - value = None loss_fn = CrossQLoss( actor_network=actor, @@ -16163,6 +16162,15 @@ def _composite_log_prob(self): yield setter.unset() + def test_add_random_module(self): + class MyMod(nn.Module): + ... + + add_random_module(MyMod) + import torchrl.objectives.utils + + assert MyMod in torchrl.objectives.utils.RANDOM_MODULE_LIST + def test_standardization(self): t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) std_t0 = _standardize(t, exclude_dims=(1, 3)) diff --git a/test/test_exploration.py b/test/test_exploration.py index 10a4938c6fd..847cde926d1 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -379,9 +379,8 @@ def test_nested( ) action_spec = env.action_spec - d_act = action_spec.shape[-1] + action_spec.shape[-1] - net = nn.LazyLinear(d_act).to(device) policy = TensorDictModule( CountingEnvCountModule(action_spec=action_spec), in_keys=[("data", "states") if nested_obs_action else "observation"], diff --git a/test/test_helpers.py b/test/test_helpers.py index cf1160f1bb2..3ba2326254d 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import argparse import dataclasses @@ -225,6 +226,7 @@ def test_timeit(): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) def test_transformed_env_constructor_with_state_dict(from_pixels): + config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( diff --git a/test/test_libs.py b/test/test_libs.py index 57358b732a4..97f4c0f9fe7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -195,7 +195,7 @@ def get_gym_pixel_wrapper(): PixelObservationWrapper = gym_backend( "wrappers.pixel_observation" ).PixelObservationWrapper - except Exception as err: + except Exception: from torchrl.envs.libs.utils import ( GymPixelObservationWrapper as PixelObservationWrapper, ) diff --git a/test/test_loggers.py b/test/test_loggers.py index 6b659fe7245..87003250d8d 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -154,7 +154,7 @@ def test_log_scalar(self, steps, tmpdir): step=steps[i] if steps else None, ) - with open(os.path.join(tmpdir, exp_name, "scalars", "foo.csv"), "r") as file: + with open(os.path.join(tmpdir, exp_name, "scalars", "foo.csv")) as file: for i, row in enumerate(file.readlines()): step = steps[i] if steps else i assert row == f"{step},{values[i].item()}\n" @@ -239,7 +239,7 @@ def test_log_config(self, tmpdir, config): logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) logger.log_hparams(cfg=config) - with open(os.path.join(tmpdir, exp_name, "texts", "hparams0.txt"), "r") as file: + with open(os.path.join(tmpdir, exp_name, "texts", "hparams0.txt")) as file: txt = "\n".join([f"{k}: {val}" for k, val in sorted(config.items())]) text = "".join(file.readlines()) assert text == txt diff --git a/test/test_modules.py b/test/test_modules.py index f661fa6199d..63dda533a90 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -939,7 +939,7 @@ def test_multiagent_mlp_tdparams( else: return mlp = nn.Sequential(mlp) - mlp_device = mlp.to(device) + mlp.to(device) param_set = set(mlp.parameters()) for p in mlp[0].params.values(True, True): assert p in param_set diff --git a/test/test_rb.py b/test/test_rb.py index 70f254a47aa..81b4ab3759f 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1626,7 +1626,6 @@ def test_extend(self, rbtype, storage, size, prefetch): rb.extend(data) length = len(rb) for d in data[-length:]: - found_similar = False for b in rb._storage: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -1657,7 +1656,6 @@ def test_sample(self, rbtype, storage, size, prefetch): new_data = new_data[0] for d in new_data: - found_similar = False for b in data: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -2930,7 +2928,6 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): index = rb.extend(data) rb.update_priority(index, data["priority"]) found_traj_0 = False - found_traj_4_truncated_left = False found_traj_4_truncated_right = False for i, s in enumerate(rb): t = s["traj"].unique().tolist() @@ -2942,7 +2939,7 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): if s["step_count"][0] > 10: found_traj_4_truncated_right = True if s["step_count"][0] == 0: - found_traj_4_truncated_left = True + pass if i == 1000: break assert not rb._sampler.span[0] diff --git a/test/test_specs.py b/test/test_specs.py index f523dde54f7..22eb23e82ea 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -10,6 +10,7 @@ import numpy as np import pytest import torch + import torchrl.data.tensor_specs from scipy.stats import chisquare from tensordict import ( @@ -21,7 +22,6 @@ ) from tensordict.utils import _unravel_key_to_tuple, set_capture_non_tensor_stack from torchrl._utils import _make_ordinal_device - from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, Binary, @@ -3573,8 +3573,8 @@ def test_valid_indexing(spec_class): assert spec_3d[1:, range(3)].shape == torch.Size([4, 3, 4]) assert spec_3d[[[[[0, 1]]]], [[0]]].shape == torch.Size([1, 1, 1, 2, 4]) assert spec_3d[0, [[[[0, 1]]]]].shape == torch.Size([1, 1, 1, 2, 4]) - assert spec_3d[0, ((((0, 1))))].shape == torch.Size([2, 4]) - assert spec_3d[((((0, 1)))), [0, 2]].shape == torch.Size([2, 4]) + assert spec_3d[0, ((0, 1))].shape == torch.Size([2, 4]) + assert spec_3d[((0, 1)), [0, 2]].shape == torch.Size([2, 4]) assert spec_4d[2:, [[[0, 1]]], :3].shape == torch.Size([3, 1, 1, 2, 3, 6]) assert spec_5d[2:, [[[0, 1]]], [[0, 1]], :3].shape == torch.Size([3, 1, 1, 2, 3, 7]) assert spec_5d[2:, [[[0, 1]]], 0, :3].shape == torch.Size([3, 1, 1, 2, 3, 7]) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 7ba45fb8587..9400f111ece 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1639,8 +1639,6 @@ def test_batched_actor_exceptions(self): with pytest.raises(ValueError, match="Only a single init_key can be passed"): MultiStepActorWrapper(actor_base, n_steps=time_steps, init_key=["init_key"]) - n_obs = 1 - n_action = 1 batch = 2 # The second env has frequent resets, the first none diff --git a/test/test_transforms.py b/test/test_transforms.py index 07b103ef996..d1c2947f7af 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -442,8 +442,6 @@ def test_transform_inverse(self): class TestClipTransform(TransformBase): @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): - device = "cpu" - batch = [20] torch.manual_seed(0) rb = rbclass(storage=LazyTensorStorage(20)) @@ -1271,7 +1269,7 @@ def test_catframes_reset(self, device): buffer = getattr(cat_frames, f"_cat_buffers_{key1}") tdc = td.clone() - passed_back_td = cat_frames._reset(tdc, tdc) + cat_frames._reset(tdc, tdc) # assert tdc is passed_back_td # assert (buffer == 0).all() @@ -4787,7 +4785,7 @@ def make_env(): ) return env - env = SerialEnv(2, make_env) + SerialEnv(2, make_env).check_env_specs() def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): @@ -8544,7 +8542,6 @@ def test_transform_model(self): keys = [key1, key2] dim = -2 d = 4 - N = 3 batch_size = (5,) extra_d = (3,) * (-dim - 1) device = "cpu" @@ -8570,7 +8567,6 @@ def test_transform_rb(self, rbclass): keys = [key1, key2] dim = -2 d = 4 - N = 3 batch_size = (5,) extra_d = (3,) * (-dim - 1) device = "cpu" @@ -8594,7 +8590,6 @@ def test_transform_rb(self, rbclass): def test_tmp_reset(self, device): key1 = "first key" key2 = "second key" - N = 4 keys = [key1, key2] key1_tensor = torch.randn(1, 1, 3, 3, device=device) key2_tensor = torch.randn(1, 1, 3, 3, device=device) @@ -8606,7 +8601,7 @@ def test_tmp_reset(self, device): buffer = getattr(t, f"_maxpool_buffer_{key1}") tdc = td.clone() - passed_back_td = t._reset(tdc, tdc.empty()) + t._reset(tdc, tdc.empty()) # assert tdc is passed_back_td assert (buffer != 0).any() @@ -11771,7 +11766,7 @@ def test_transform_env(self): def test_transform_model(self): t = Compose(DeviceCastTransform("cpu:1", "cpu:0")) - m = nn.Sequential(t) + nn.Sequential(t) assert t(TensorDict(device="cpu:0")).device == torch.device("cpu:1") @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) @@ -11892,7 +11887,7 @@ def test_transform_env(self): # check error with pytest.raises(ValueError, match="Only tailing dims with negative"): - t = PermuteTransform((-1, -10)) + PermuteTransform((-1, -10)) def test_transform_model(self): batch = [2] @@ -12217,7 +12212,7 @@ def test_transform_env(self): RuntimeError, match="BurnInTransform can only be appended to a ReplayBuffer.", ): - rollout = env.rollout(3) + env.rollout(3) @pytest.mark.parametrize("module", ["gru", "lstm"]) @pytest.mark.parametrize("batch_size", [2, 4]) @@ -12545,7 +12540,7 @@ def test_trans_serial_env_check(self): with pytest.raises( RuntimeError, match="The environment passed to SerialEnv has empty specs" ): - env = TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) + TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): with pytest.raises( diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 42819ef1d8c..e09cec16b01 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -46,12 +46,6 @@ ) -import torchrl.collectors -import torchrl.data -import torchrl.envs -import torchrl.modules -import torchrl.objectives -import torchrl.trainers from torchrl._utils import ( auto_unwrap_transformed_env, compile_with_warmup, @@ -107,3 +101,11 @@ def _inv(self): ComposeTransform.inv = _inv + +__all__ = [ + "auto_unwrap_transformed_env", + "compile_with_warmup", + "implement_for", + "set_auto_unwrap_transformed_env", + "timeit", +] diff --git a/torchrl/_extension.py b/torchrl/_extension.py index 61eedb46418..d84d73cca4a 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import importlib.util import warnings diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 1d541750ec2..3258cb953eb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -5,12 +5,9 @@ from __future__ import annotations import collections - import functools import inspect - import logging - import math import os import pickle @@ -24,13 +21,12 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union +from typing import Any, Callable, cast, TypeVar import numpy as np import torch from packaging.version import parse from tensordict import unravel_key - from tensordict.utils import NestedKey from torch import multiprocessing as mp, Tensor @@ -345,7 +341,7 @@ class implement_for: def __init__( self, - module_name: Union[str, Callable], + module_name: str | Callable, from_version: str = None, to_version: str = None, *, @@ -419,7 +415,7 @@ def module_set(self): setattr(cls, self.fn.__name__, self.fn) @classmethod - def import_module(cls, module_name: Union[Callable, str]) -> str: + def import_module(cls, module_name: Callable | str) -> str: """Imports module and returns its version.""" if not callable(module_name): module = cls._cache_modules.get(module_name, None) @@ -515,7 +511,7 @@ def unsupported(*args, **kwargs): return unsupported @classmethod - def reset(cls, setters_dict: Dict[str, implement_for] = None): + def reset(cls, setters_dict: dict[str, implement_for] = None): """Resets the setters in setter_dict. ``setter_dict`` is a copy of implementations. We just need to iterate through its @@ -880,7 +876,7 @@ def set_mode(self, type: Any | None) -> None: def _standardize( input: Tensor, - exclude_dims: Tuple[int] = (), + exclude_dims: tuple[int] = (), mean: Tensor | None = None, std: Tensor | None = None, eps: float | None = None, diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index d69d8c9e50c..2d40522bb07 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -12,3 +12,12 @@ MultiSyncDataCollector, SyncDataCollector, ) + +__all__ = [ + "RandomPolicy", + "aSyncDataCollector", + "DataCollectorBase", + "MultiaSyncDataCollector", + "MultiSyncDataCollector", + "SyncDataCollector", +] diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2184ae9e19c..67dff40e9de 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -7,9 +7,7 @@ import _pickle import abc import collections - import contextlib - import functools import os import queue @@ -23,10 +21,9 @@ from multiprocessing.managers import SyncManager from queue import Empty from textwrap import indent -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Iterator, Sequence import numpy as np - import torch import torch.nn as nn from tensordict import ( @@ -124,8 +121,6 @@ class _InterruptorManager(SyncManager): between processes. """ - pass - _InterruptorManager.register("_Interruptor", _Interruptor) @@ -162,7 +157,7 @@ def _get_policy_and_device( policy_device: Any = NO_DEFAULT, env_maker: Any | None = None, env_maker_kwargs: dict | None = None, - ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: """Util method to get a policy and its device given the collector __init__ inputs. We want to copy the policy and then move the data there, not call policy.to(device). @@ -245,7 +240,7 @@ def map_weight( return policy, get_original_weights def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: """Updates the policy weights if the policy of the data collector and the trained policy live on different devices. @@ -513,15 +508,11 @@ class SyncDataCollector(DataCollectorBase): def __init__( self, - create_env_fn: Union[ - EnvBase, "EnvCreator", Sequence[Callable[[], EnvBase]] # noqa: F821 - ], # noqa: F821 - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ] = None, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, frames_per_batch: int, total_frames: int = -1, @@ -543,8 +534,8 @@ def __init__( use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, trust_policy: bool = None, - compile_policy: bool | Dict[str, Any] | None = None, - cudagraph_policy: bool | Dict[str, Any] | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, **kwargs, ): @@ -990,7 +981,7 @@ def next(self): # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -1617,25 +1608,21 @@ class _MultiDataCollector(DataCollectorBase): def __init__( self, create_env_fn: Sequence[Callable[[], EnvBase]], - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ] = None, + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, frames_per_batch: int, - total_frames: Optional[int] = -1, + total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Optional[Sequence[dict]] = None, + create_env_kwargs: Sequence[dict] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, @@ -1648,8 +1635,8 @@ def __init__( replay_buffer: ReplayBuffer | None = None, replay_buffer_chunk: bool = True, trust_policy: bool = None, - compile_policy: bool | Dict[str, Any] | None = None, - cudagraph_policy: bool | Dict[str, Any] | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, ): self.closed = True @@ -2086,7 +2073,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: self.reset() return seed - def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: """Resets the environments to a new initial state. Args: @@ -2282,7 +2269,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -2646,7 +2633,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: Optional[TensorDictBase] = None + self, policy_weights: TensorDictBase | None = None ) -> None: super().update_policy_weights_(policy_weights) @@ -2654,7 +2641,7 @@ def update_policy_weights_( def frames_per_batch_worker(self): return self.requested_frames_per_batch - def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers if self.replay_buffer is not None: @@ -2745,7 +2732,7 @@ def _shutdown_main(self) -> None: del self.out_tensordicts return super()._shutdown_main() - def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: super().reset(reset_idx) if self.queue_out.full(): time.sleep(_TIMEOUT) # wait until queue is empty @@ -2900,25 +2887,20 @@ class aSyncDataCollector(MultiaSyncDataCollector): def __init__( self, create_env_fn: Callable[[], EnvBase], - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ], + policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]), *, frames_per_batch: int, - total_frames: Optional[int] = -1, + total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Optional[Sequence[dict]] = None, + create_env_kwargs: Sequence[dict] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - split_trajs: Optional[bool] = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, @@ -2977,15 +2959,15 @@ def _main_async_collector( pipe_parent: connection.Connection, pipe_child: connection.Connection, queue_out: queues.Queue, - create_env_fn: Union[EnvBase, "EnvCreator", Callable[[], EnvBase]], # noqa: F821 - create_env_kwargs: Dict[str, Any], + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int, frames_per_batch: int, reset_at_each_iter: bool, - storing_device: Optional[Union[torch.device, str, int]], - env_device: Optional[Union[torch.device, str, int]], - policy_device: Optional[Union[torch.device, str, int]], + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, idx: int = 0, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, diff --git a/torchrl/collectors/distributed/__init__.py b/torchrl/collectors/distributed/__init__.py index 97932619a60..c28122c6c6a 100644 --- a/torchrl/collectors/distributed/__init__.py +++ b/torchrl/collectors/distributed/__init__.py @@ -8,3 +8,12 @@ from .rpc import RPCDataCollector from .sync import DistributedSyncDataCollector from .utils import submitit_delayed_launcher + +__all__ = [ + "DEFAULT_SLURM_CONF", + "DistributedDataCollector", + "RayCollector", + "RPCDataCollector", + "DistributedSyncDataCollector", + "submitit_delayed_launcher", +] diff --git a/torchrl/collectors/distributed/default_configs.py b/torchrl/collectors/distributed/default_configs.py index edcaf6d91e4..8da69010242 100644 --- a/torchrl/collectors/distributed/default_configs.py +++ b/torchrl/collectors/distributed/default_configs.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os TCP_PORT = os.environ.get("TCP_PORT", "10003") diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 3491c48138c..5ec55e23a16 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -11,7 +11,7 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import Callable, List, OrderedDict, Type +from typing import Callable, OrderedDict import torch.cuda from tensordict import TensorDict @@ -410,17 +410,17 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - collector_class: Type = SyncDataCollector, + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa + collector_class: type = SyncDataCollector, collector_kwargs: dict = None, num_workers_per_collector: int = 1, sync: bool = False, @@ -527,19 +527,19 @@ def __init__( self._make_container() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter @@ -899,7 +899,7 @@ def update_policy_weights_(self, worker_rank=None) -> None: def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 - self._store.set(f"NODE_{rank}_in", f"seeding_{seed}".encode("utf-8")) + self._store.set(f"NODE_{rank}_in", f"seeding_{seed}".encode()) status = self._store.get(f"NODE_{rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'seeded' but got status {status}.") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 715e41f50fd..1716609026b 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -6,7 +6,7 @@ from __future__ import annotations import warnings -from typing import Callable, Dict, Iterator, List, OrderedDict, Union +from typing import Callable, Iterator, OrderedDict import torch import torch.nn as nn @@ -24,7 +24,6 @@ from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator - RAY_ERR = None try: import ray @@ -289,15 +288,15 @@ class RayCollector(DataCollectorBase): def __init__( self, - create_env_fn: Union[Callable, EnvBase, List[Callable], List[EnvBase]], + create_env_fn: Callable | EnvBase | list[Callable] | list[EnvBase], policy: Callable[[TensorDict], TensorDict], *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj=-1, init_random_frames=-1, reset_at_each_iter=False, @@ -305,11 +304,11 @@ def __init__( split_trajs=False, exploration_type=DEFAULT_EXPLORATION_TYPE, collector_class: Callable[[TensorDict], TensorDict] = SyncDataCollector, - collector_kwargs: Union[Dict, List[Dict]] = None, + collector_kwargs: dict | list[dict] = None, num_workers_per_collector: int = 1, sync: bool = False, - ray_init_config: Dict = None, - remote_configs: Union[Dict, List[Dict]] = None, + ray_init_config: dict = None, + remote_configs: dict | list[dict] = None, num_collectors: int = None, update_after_each_batch=False, max_weight_update_interval=-1, @@ -483,19 +482,19 @@ def num_workers(self): return self.num_collectors @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter @@ -713,13 +712,13 @@ def update_policy_weights_(self, worker_rank=None) -> None: ) self._batches_since_weight_update[worker_rank - 1] = 0 - def set_seed(self, seed: int, static_seed: bool = False) -> List[int]: + def set_seed(self, seed: int, static_seed: bool = False) -> list[int]: """Calls parent method for each remote collector iteratively and returns final seed.""" for collector in self.remote_collectors(): seed = ray.get(object_refs=collector.set_seed.remote(seed, static_seed)) return seed - def state_dict(self) -> List[OrderedDict]: + def state_dict(self) -> list[OrderedDict]: """Calls parent method for each remote collector and returns a list of results.""" futures = [ collector.state_dict.remote() for collector in self.remote_collectors() @@ -727,9 +726,7 @@ def state_dict(self) -> List[OrderedDict]: results = ray.get(object_refs=futures) return results - def load_state_dict( - self, state_dict: Union[OrderedDict, List[OrderedDict]] - ) -> None: + def load_state_dict(self, state_dict: OrderedDict | list[OrderedDict]) -> None: """Calls parent method for each remote collector.""" if isinstance(state_dict, OrderedDict): state_dicts = [state_dict] diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 94421ce8ca3..ee73cfdf4e7 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -12,10 +12,9 @@ import time import warnings from copy import copy, deepcopy -from typing import Callable, List, OrderedDict +from typing import Callable, OrderedDict from torchrl._utils import logger as torchrl_logger - from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -266,16 +265,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -386,19 +385,19 @@ def __init__( self._init() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 6aa66dfbdd2..0a2215e0abe 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -11,14 +11,13 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import Callable, List, OrderedDict +from typing import Callable, OrderedDict import torch.cuda from tensordict import TensorDict from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE - from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -276,16 +275,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, - device: torch.device | List[torch.device] = None, - storing_device: torch.device | List[torch.device] = None, - env_device: torch.device | List[torch.device] = None, - policy_device: torch.device | List[torch.device] = None, + device: torch.device | list[torch.device] = None, + storing_device: torch.device | list[torch.device] = None, + env_device: torch.device | list[torch.device] = None, + policy_device: torch.device | list[torch.device] = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable | None = None, split_trajs: bool = False, - exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -384,19 +383,19 @@ def __init__( self._make_container() @property - def device(self) -> List[torch.device]: + def device(self) -> list[torch.device]: return self._device @property - def storing_device(self) -> List[torch.device]: + def storing_device(self) -> list[torch.device]: return self._storing_device @property - def env_device(self) -> List[torch.device]: + def env_device(self) -> list[torch.device]: return self._env_device @property - def policy_device(self) -> List[torch.device]: + def policy_device(self) -> list[torch.device]: return self._policy_device @device.setter diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 2dd6fcf6c93..bc72bda6a4a 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -1,3 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + import subprocess import time diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 33353b93b5f..9fc7b77a6b8 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -101,3 +101,98 @@ UnboundedDiscreteTensorSpec, ) from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec + +__all__ = [ + "BinaryToDecimal", + "HashToInt", + "MCTSForest", + "QueryModule", + "RandomProjectionHash", + "SipHash", + "TensorDictMap", + "TensorMap", + "Tree", + "MultiStep", + "Flat2TED", + "FlatStorageCheckpointer", + "H5Combine", + "H5Split", + "H5StorageCheckpointer", + "ImmutableDatasetWriter", + "LazyMemmapStorage", + "LazyStackStorage", + "LazyTensorStorage", + "ListStorage", + "ListStorageCheckpointer", + "Nested2TED", + "NestedStorageCheckpointer", + "PrioritizedReplayBuffer", + "PrioritizedSampler", + "PrioritizedSliceSampler", + "RandomSampler", + "RemoteTensorDictReplayBuffer", + "ReplayBuffer", + "ReplayBufferEnsemble", + "RoundRobinWriter", + "SamplerEnsemble", + "SamplerWithoutReplacement", + "SliceSampler", + "SliceSamplerWithoutReplacement", + "Storage", + "StorageCheckpointerBase", + "StorageEnsemble", + "StorageEnsembleCheckpointer", + "TED2Flat", + "TED2Nested", + "TensorDictMaxValueWriter", + "TensorDictPrioritizedReplayBuffer", + "TensorDictReplayBuffer", + "TensorDictRoundRobinWriter", + "TensorStorage", + "TensorStorageCheckpointer", + "Writer", + "WriterEnsemble", + "AdaptiveKLController", + "ConstantKLController", + "create_infinite_iterator", + "get_dataloader", + "PairwiseDataset", + "PromptData", + "PromptTensorDictTokenizer", + "RewardData", + "RolloutFromModel", + "TensorDictTokenizer", + "TokenizedDatasetLoader", + "Binary", + "BinaryDiscreteTensorSpec", + "Bounded", + "BoundedContinuous", + "BoundedTensorSpec", + "Categorical", + "Choice", + "Composite", + "CompositeSpec", + "DEVICE_TYPING", + "DiscreteTensorSpec", + "LazyStackedCompositeSpec", + "LazyStackedTensorSpec", + "MultiCategorical", + "MultiDiscreteTensorSpec", + "MultiOneHot", + "MultiOneHotDiscreteTensorSpec", + "NonTensor", + "NonTensorSpec", + "OneHot", + "OneHotDiscreteTensorSpec", + "Stacked", + "StackedComposite", + "TensorSpec", + "Unbounded", + "UnboundedContinuous", + "UnboundedContinuousTensorSpec", + "UnboundedDiscrete", + "UnboundedDiscreteTensorSpec", + "check_no_exclusive_keys", + "consolidate_spec", + "contains_lazy_spec", +] diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index d099a3a1be5..7bec24cb17b 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -2,13 +2,3 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .atari_dqn import AtariDQNExperienceReplay -from .common import BaseDatasetExperienceReplay -from .d4rl import D4RLExperienceReplay -from .gen_dgrl import GenDGRLExperienceReplay -from .minari_data import MinariExperienceReplay -from .openml import OpenMLExperienceReplay -from .openx import OpenXExperienceReplay -from .roboset import RobosetExperienceReplay -from .vd4rl import VD4RLExperienceReplay diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 2a1ec90feb6..775751d73af 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -22,7 +22,6 @@ from torch import multiprocessing as mp from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay - from torchrl.data.replay_buffers.samplers import ( SamplerWithoutReplacement, SliceSampler, @@ -403,7 +402,7 @@ def __init__( download: bool | str = True, sampler=None, writer=None, - transform: "Transform" | None = None, # noqa: F821 + transform: Transform | None = None, # noqa: F821 num_procs: int = 0, num_slices: int | None = None, slice_len: int | None = None, @@ -493,7 +492,7 @@ def _is_downloaded(self): if os.path.exists(self.dataset_path / "meta.json"): return True if os.path.exists(self.dataset_path / "processed.json"): - with open(self.dataset_path / "processed.json", "r") as jsonfile: + with open(self.dataset_path / "processed.json") as jsonfile: return json.load(jsonfile).get("processed", False) == self._max_runs return False @@ -514,7 +513,7 @@ def _download_and_preproc(self): command, shell=True, capture_output=True ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) files = [ - file.decode("utf-8").replace("$", "\$") # noqa: W605 + file.decode("utf-8").replace("$", r"\$") # noqa: W605 for file in output.stdout.splitlines() if file.endswith(b".gz") ] @@ -819,7 +818,7 @@ def _load_split(self, path): def _proc_td(self, td, index): td_data = td.get("data") - obs_ = td_data.get(("observation"))[index + 1] + obs_ = td_data.get("observation")[index + 1] done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() if done.ndim and done.any(): obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 4ed9dcabada..5ad01ebfb59 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -10,22 +10,17 @@ import tempfile import urllib import warnings - from pathlib import Path from typing import Callable import numpy as np - import torch - from tensordict import make_tensordict, PersistentTensorDict, TensorDict from torchrl._utils import logger as torchrl_logger - from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS - from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -145,7 +140,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, from_env: bool = False, use_truncated_as_done: bool = True, @@ -459,7 +454,7 @@ def _download_dataset_from_url(dataset_url, dataset_path): torchrl_logger.info(f"Downloading dataset: {dataset_url} to {dataset_filepath}") urllib.request.urlretrieve(dataset_url, dataset_filepath) if not os.path.exists(dataset_filepath): - raise IOError("Failed to download dataset from %s" % dataset_url) + raise OSError("Failed to download dataset from %s" % dataset_url) return dataset_filepath diff --git a/torchrl/data/datasets/d4rl_infos.py b/torchrl/data/datasets/d4rl_infos.py index e9790ea04f9..c3e0a743f35 100644 --- a/torchrl/data/datasets/d4rl_infos.py +++ b/torchrl/data/datasets/d4rl_infos.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations D4RL_DATASETS = { "maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 53efeb54898..ebe9d032c0d 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -9,7 +9,6 @@ import os.path import shutil import tempfile - from collections import defaultdict from contextlib import nullcontext from dataclasses import asdict @@ -17,7 +16,6 @@ from typing import Callable import torch - from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay @@ -167,7 +165,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, ): self.dataset_id = dataset_id @@ -381,7 +379,7 @@ def _load(self): return TensorDict.load_memmap(self.data_path) def _load_and_proc_metadata(self): - with open(self.metadata_path, "r") as file: + with open(self.metadata_path) as file: self.metadata = json.load(file) self.metadata["observation_space"] = _proc_spec( self.metadata["observation_space"] diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index c420eb93ad3..344cf43bf3f 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -11,7 +11,6 @@ import numpy as np from tensordict import TensorDict from torchrl.data.datasets.common import BaseDatasetExperienceReplay - from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import ( Sampler, @@ -67,7 +66,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 ): if sampler is None: diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 94ad82c7a88..607c68d84f9 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -11,10 +11,9 @@ import shutil import tempfile from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable import torch - from tensordict import make_tensordict, NonTensorData, pad, TensorDict from tensordict.utils import _is_non_tensor @@ -313,7 +312,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, strict_length: bool = True, ): @@ -656,7 +655,7 @@ def dumps(self, path): state_dict = self.state_dict() json.dump(state_dict, path / "state_dict.json") - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "repo": self.repo, "split": self.split, @@ -674,7 +673,7 @@ def loads(self, path): state_dict = json.load(path / "state_dict.json") self.load_state_dict(state_dict) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for key, val in state_dict.items(): setattr(self, key, val) self._init() @@ -722,7 +721,7 @@ class _StreamingSampler(Sampler): def __init__(self): ... - def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]: return range(batch_size), {} def _empty(self): @@ -734,10 +733,10 @@ def dumps(self, path): def loads(self, path): ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 1a83c302860..9eccf07a286 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -8,13 +8,11 @@ import os.path import shutil import tempfile - from contextlib import nullcontext from pathlib import Path from typing import Callable import torch - from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import ( KeyDependentDefaultDict, @@ -162,7 +160,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, **env_kwargs, ): diff --git a/torchrl/data/datasets/utils.py b/torchrl/data/datasets/utils.py index b88e3aee14e..975d69746ca 100644 --- a/torchrl/data/datasets/utils.py +++ b/torchrl/data/datasets/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index daf1c85e3d7..0c851e2408c 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -5,7 +5,6 @@ from __future__ import annotations import functools - import importlib import json import os @@ -14,21 +13,18 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Callable, List +from typing import Callable import numpy as np - import torch from tensordict import PersistentTensorDict, TensorDict from torch import multiprocessing as mp - from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer - from torchrl.envs.transforms import Compose, Resize, ToTensorImage from torchrl.envs.utils import _classproperty @@ -172,10 +168,10 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, totensor: bool = True, - image_size: int | List[int] | None = None, + image_size: int | list[int] | None = None, num_workers: int = 0, **env_kwargs, ): @@ -388,7 +384,7 @@ def _available_datasets(cls): return [str(path)[6:] for path in sibs] except Exception: # return the default datasets - with open(THIS_DIR / "vd4rl.json", "r") as file: + with open(THIS_DIR / "vd4rl.json") as file: return json.load(file) def _make_split(self): diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index c9bc25477c2..6a28a3ff856 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -7,3 +7,15 @@ from .query import HashToInt, QueryModule from .tdstorage import TensorDictMap, TensorMap from .tree import MCTSForest, Tree + +__all__ = [ + "BinaryToDecimal", + "RandomProjectionHash", + "SipHash", + "HashToInt", + "QueryModule", + "TensorDictMap", + "TensorMap", + "MCTSForest", + "Tree", +] diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index a3ae9ec1ae9..446e71f56fb 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Callable, List +from typing import Callable import torch from torch.nn import Module @@ -103,7 +103,7 @@ def __init__(self, as_tensor: bool = True): super().__init__() self.as_tensor = as_tensor - def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: + def forward(self, x: torch.Tensor) -> torch.Tensor | list[bytes]: hash_values = [] if x.dtype in (torch.bfloat16,): x = x.to(torch.float16) diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index 3eca179cf56..d53e2087760 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -5,7 +5,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, TypeVar +from typing import Any, Callable, List, Mapping, TypeVar import torch import torch.nn as nn @@ -39,7 +39,7 @@ def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor: ) return torch.tensor(result, device=key.device, dtype=key.dtype) - def state_dict(self) -> Dict[str, torch.Tensor]: + def state_dict(self) -> dict[str, torch.Tensor]: values = torch.tensor(self._index_to_index.values()) keys = torch.tensor(self._index_to_index.keys()) return {"keys": keys, "values": values} @@ -111,11 +111,11 @@ class QueryModule(TensorDictModuleBase): def __init__( self, - in_keys: List[NestedKey], + in_keys: list[NestedKey], index_key: NestedKey = "_index", hash_key: NestedKey = "_hash", *, - hash_module: Callable[[Any], int] | List[Callable[[Any], int]] | None = None, + hash_module: Callable[[Any], int] | list[Callable[[Any], int]] | None = None, hash_to_int: Callable[[int], int] | None = None, aggregator: Callable[[Any], int] = None, clone: bool = False, diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index 34d4bb8d0fa..f1464308144 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -7,11 +7,12 @@ import abc import functools from abc import abstractmethod -from typing import Any, Callable, Dict, Generic, List, TypeVar +from typing import Any, Callable, Generic, TypeVar import torch from tensordict import is_tensor_collection, NestedKey, TensorDictBase from tensordict.nn.common import TensorDictModuleBase + from torchrl.data.map.hash import RandomProjectionHash, SipHash from torchrl.data.map.query import QueryModule from torchrl.data.replay_buffers.storages import ( @@ -117,9 +118,9 @@ def __init__( self, *, query_module: QueryModule, - storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + storage: dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], collate_fn: Callable[[Any], Any] | None = None, - out_keys: List[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, ): super().__init__() @@ -143,7 +144,7 @@ def max_size(self): return self.storage.max_size @property - def out_keys(self) -> List[NestedKey]: + def out_keys(self) -> list[NestedKey]: out_keys = self.__dict__.get("_out_keys_and_lazy") if out_keys is not None: return out_keys[0] @@ -173,8 +174,8 @@ def from_tensordict_pair( cls, source, dest, - in_keys: List[NestedKey], - out_keys: List[NestedKey] | None = None, + in_keys: list[NestedKey], + out_keys: list[NestedKey] | None = None, max_size: int = 1000, storage_constructor: type | None = None, hash_module: Callable | None = None, diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 684c4f9901b..dfb87223435 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -6,8 +6,7 @@ import weakref from collections import deque - -from typing import Any, Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Literal import torch from tensordict import ( @@ -18,11 +17,11 @@ TensorDictBase, unravel_key, ) + from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage from torchrl.data.tensor_specs import Composite - from torchrl.envs.common import EnvBase @@ -88,10 +87,10 @@ class Tree(TensorClass["nocast"]): node_data: TensorDict | None = None # Stack of subtrees. A subtree is produced when an action is taken. - subtree: "Tree" = None + subtree: Tree = None # weakrefs to the parent(s) of the node - _parent: weakref.ref | List[weakref.ref] | None = None + _parent: weakref.ref | list[weakref.ref] | None = None # Specs: contains information such as action or observation keys and spaces. # If present, they should be structured like env specs are: @@ -389,7 +388,7 @@ def __contains__(self, other: Tree) -> bool: def vertices( self, *, key_type: Literal["id", "hash", "path"] = "hash" - ) -> Dict[int | Tuple[int], Tree]: + ) -> dict[int | tuple[int], Tree]: """Returns a map containing the vertices of the Tree. Keyword args: @@ -463,7 +462,7 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: } ) - def edges(self) -> List[Tuple[int, int]]: + def edges(self) -> list[tuple[int, int]]: """Retrieves a list of edges in the tree. Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. @@ -530,7 +529,7 @@ def max_length(self): return lengths[0] return max(*lengths) - def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + def rollout_from_path(self, path: tuple[int]) -> TensorDictBase | None: """Retrieves the rollout data along a given path in the tree. The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. @@ -557,7 +556,7 @@ def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: return torch.cat(rollouts, dim=-1) @staticmethod - def _label(info: List[str], tree: "Tree", root=False): + def _label(info: list[str], tree: Tree, root=False): labels = [] for key in info: if key == "hash": @@ -577,7 +576,7 @@ def plot( self: Tree, backend: str = "plotly", figure: str = "tree", - info: List[str] = None, + info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None, ): """Plots a visualization of the tree using the specified backend and figure type. @@ -811,11 +810,11 @@ def __init__( data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, - done_keys: List[NestedKey] | None = None, - reward_keys: List[NestedKey] = None, - observation_keys: List[NestedKey] = None, - action_keys: List[NestedKey] = None, - excluded_keys: List[NestedKey] = None, + done_keys: list[NestedKey] | None = None, + reward_keys: list[NestedKey] = None, + observation_keys: list[NestedKey] = None, + action_keys: list[NestedKey] = None, + excluded_keys: list[NestedKey] = None, consolidated: bool | None = None, ): @@ -856,7 +855,7 @@ def __init__( self.consolidated = consolidated @property - def done_keys(self) -> List[NestedKey]: + def done_keys(self) -> list[NestedKey]: """Done Keys. Returns the keys used to indicate that an episode has ended. @@ -877,7 +876,7 @@ def done_keys(self, value): self._done_keys = _make_list_of_nestedkeys(value, "done_keys") @property - def reward_keys(self) -> List[NestedKey]: + def reward_keys(self) -> list[NestedKey]: """Reward Keys. Returns the keys used to retrieve rewards from the environment's output. @@ -897,7 +896,7 @@ def reward_keys(self, value): self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys") @property - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: """Action Keys. Returns the keys used to retrieve actions from the environment's input. @@ -917,7 +916,7 @@ def action_keys(self, value): self._action_keys = _make_list_of_nestedkeys(value, "action_keys") @property - def observation_keys(self) -> List[NestedKey]: + def observation_keys(self) -> list[NestedKey]: """Observation Keys. Returns the keys used to retrieve observations from the environment's output. @@ -936,7 +935,7 @@ def observation_keys(self, value): self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys") @property - def excluded_keys(self) -> List[NestedKey] | None: + def excluded_keys(self) -> list[NestedKey] | None: return self._excluded_keys @excluded_keys.setter @@ -1223,7 +1222,7 @@ def _make_local_tree( root: TensorDictBase, index: torch.Tensor | None = None, compact: bool = True, - ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: + ) -> tuple[Tree, torch.Tensor | None, torch.Tensor | None]: root = root.select(*self.node_map.in_keys) node_meta = None if root in self.node_map: @@ -1422,7 +1421,7 @@ def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict( return tree.to_string(node_format_fn) -def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: +def _make_list_of_nestedkeys(obj: Any, attr: str) -> list[NestedKey]: if obj is None: return obj if isinstance(obj, (str, tuple)): diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index d9588d79905..43bd8ea2832 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Callable, List +from typing import Callable from tensordict import NestedKey def _plot_plotly_tree( - tree: "Tree", make_labels: Callable[[Tree], str] | None = None # noqa: F821 + tree: Tree, make_labels: Callable[[Tree], str] | None = None # noqa: F821 ): import plotly.graph_objects as go from igraph import Graph @@ -78,7 +78,7 @@ def make_labels(tree, path, *args, **kwargs): fig.show() -def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 +def _plot_plotly_box(tree: Tree, info: list[NestedKey] = None): # noqa: F821 import plotly.graph_objects as go if info is None: @@ -89,7 +89,7 @@ def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 _tree = tree - def extend(tree: "Tree", parent): # noqa: F821 + def extend(tree: Tree, parent): # noqa: F821 children = tree.subtree if children is None: return diff --git a/torchrl/data/postprocs/__init__.py b/torchrl/data/postprocs/__init__.py index 707740f6946..afa0f73ecfb 100644 --- a/torchrl/data/postprocs/__init__.py +++ b/torchrl/data/postprocs/__init__.py @@ -4,3 +4,5 @@ # LICENSE file in the root directory of this source tree. from .postprocs import MultiStep + +__all__ = ["MultiStep"] diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 4f230f30701..d3e8f18cb00 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -48,3 +48,46 @@ Writer, WriterEnsemble, ) + +__all__ = [ + "FlatStorageCheckpointer", + "H5StorageCheckpointer", + "ListStorageCheckpointer", + "NestedStorageCheckpointer", + "StorageCheckpointerBase", + "StorageEnsembleCheckpointer", + "TensorStorageCheckpointer", + "PrioritizedReplayBuffer", + "RemoteTensorDictReplayBuffer", + "ReplayBuffer", + "ReplayBufferEnsemble", + "TensorDictPrioritizedReplayBuffer", + "TensorDictReplayBuffer", + "PrioritizedSampler", + "PrioritizedSliceSampler", + "RandomSampler", + "Sampler", + "SamplerEnsemble", + "SamplerWithoutReplacement", + "SliceSampler", + "SliceSamplerWithoutReplacement", + "LazyMemmapStorage", + "LazyStackStorage", + "LazyTensorStorage", + "ListStorage", + "Storage", + "StorageEnsemble", + "TensorStorage", + "Flat2TED", + "H5Combine", + "H5Split", + "Nested2TED", + "TED2Flat", + "TED2Nested", + "ImmutableDatasetWriter", + "RoundRobinWriter", + "TensorDictMaxValueWriter", + "TensorDictRoundRobinWriter", + "Writer", + "WriterEnsemble", +] diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 6b74834385e..b545fd92227 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -120,7 +120,7 @@ def dumps(self, storage, path): ) def loads(self, storage, path): - with open(path / "storage_metadata.json", "r") as file: + with open(path / "storage_metadata.json") as file: metadata = json.load(file) is_pytree = metadata["is_pytree"] _len = metadata["len"] diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 6e8e879b512..20e029fc535 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -13,10 +13,9 @@ import warnings from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import numpy as np - import torch try: @@ -222,10 +221,10 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, - checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 + checkpointer: StorageCheckpointerBase | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, compilable: bool = None, @@ -460,7 +459,7 @@ def __setitem__(self, index, value) -> None: self._storage[index] = value return - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_storage": self._storage.state_dict(), "_sampler": self._sampler.state_dict(), @@ -472,7 +471,7 @@ def state_dict(self) -> Dict[str, Any]: else None, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) self._sampler.load_state_dict(state_dict["_sampler"]) self._writer.load_state_dict(state_dict["_writer"]) @@ -564,7 +563,7 @@ def loads(self, path): # fall back on state_dict for transforms if (path / "transform.t").exists(): self._transform.load_state_dict(torch.load(path / "transform.t")) - with open(path / "buffer_metadata.json", "r") as file: + with open(path / "buffer_metadata.json") as file: metadata = json.load(file) self._batch_size = metadata["batch_size"] @@ -662,8 +661,8 @@ def extend(self, data: Sequence) -> torch.Tensor: def update_priority( self, - index: Union[int, torch.Tensor, Tuple[torch.Tensor]], - priority: Union[int, torch.Tensor], + index: int | torch.Tensor | tuple[torch.Tensor], + priority: int | torch.Tensor, ) -> None: if isinstance(index, tuple): index = torch.stack(index, -1) @@ -675,7 +674,7 @@ def update_priority( self._sampler.update_priority(index, priority, storage=self.storage) @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: + def _sample(self, batch_size: int) -> tuple[Any, dict]: with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index @@ -755,11 +754,11 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An return out, info return result[0] - def mark_update(self, index: Union[int, torch.Tensor]) -> None: + def mark_update(self, index: int | torch.Tensor) -> None: self._sampler.mark_update(index, storage=self._storage) def append_transform( - self, transform: "Transform", *, invert: bool = False # noqa-F821 + self, transform: Transform, *, invert: bool = False # noqa-F821 ) -> ReplayBuffer: # noqa: D417 """Appends transform at the end. @@ -796,7 +795,7 @@ def append_transform( def insert_transform( self, index: int, - transform: "Transform", # noqa-F821 + transform: Transform, # noqa-F821 *, invert: bool = False, ) -> ReplayBuffer: # noqa: D417 @@ -832,7 +831,7 @@ def __iter__(self): ): yield self.sample() - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() if self._rng is not None: rng_state = TensorDict( @@ -848,7 +847,7 @@ def __getstate__(self) -> Dict[str, Any]: state["_futures_lock_placeholder"] = None return state - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: dict[str, Any]): rngstate = None if "_rng" in state: rngstate = state["_rng"] @@ -1008,14 +1007,14 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype) - super(PrioritizedReplayBuffer, self).__init__( + super().__init__( storage=storage, sampler=sampler, collate_fn=collate_fn, @@ -1355,7 +1354,7 @@ def sample( return data @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: + def _sample(self, batch_size: int) -> tuple[Any, dict]: with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index @@ -1523,7 +1522,7 @@ def __init__( collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: "Transform" | None = None, # noqa-F821 + transform: Transform | None = None, # noqa-F821 reduction: str = "max", batch_size: int | None = None, dim_extend: int | None = None, @@ -1536,7 +1535,7 @@ def __init__( sampler = PrioritizedSampler( storage.max_size, alpha, beta, eps, reduction=reduction ) - super(TensorDictPrioritizedReplayBuffer, self).__init__( + super().__init__( priority_key=priority_key, storage=storage, sampler=sampler, @@ -1572,11 +1571,11 @@ def sample( def add(self, data: TensorDictBase) -> int: return super().add(data) - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + def extend(self, tensordicts: list | TensorDictBase) -> torch.Tensor: return super().extend(tensordicts) def update_priority( - self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] + self, index: int | torch.Tensor, priority: int | torch.Tensor ) -> None: return super().update_priority(index, priority) @@ -1593,7 +1592,7 @@ def __init__(self, device: DEVICE_TYPING | None = None): ) -def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: +def stack_tensors(list_of_tensor_iterators: list) -> tuple[torch.Tensor]: """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together. Args: @@ -1765,10 +1764,10 @@ def __init__( storages: StorageEnsemble | None = None, samplers: SamplerEnsemble | None = None, writers: WriterEnsemble | None = None, - transform: "Transform" | None = None, # noqa: F821 + transform: Transform | None = None, # noqa: F821 batch_size: int | None = None, collate_fn: Callable | None = None, - collate_fns: List[Callable] | None = None, + collate_fns: list[Callable] | None = None, p: Tensor = None, sample_from_all: bool = False, num_buffer_sampled: int | None = None, @@ -1849,7 +1848,7 @@ def _collate_fn(self, value): _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead." def __getitem__( - self, index: Union[int, torch.Tensor, Tuple, np.ndarray, List, slice, Ellipsis] + self, index: int | torch.Tensor | tuple | np.ndarray | list | slice | Ellipsis ) -> Any: # accepts inputs: # (int | 1d tensor | 1d list | 1d array | slice | ellipsis | range, int | tensor | list | array | slice | ellipsis | range) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 911280eb667..19db2fa9431 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -12,16 +12,13 @@ from copy import copy, deepcopy from multiprocessing.context import get_spawning_popen from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import MemoryMappedTensor, TensorDict from tensordict.utils import NestedKey - from torchrl._extension import EXTENSION_WARNING - from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index @@ -50,7 +47,7 @@ class Sampler(ABC): _rng: torch.Generator | None = None @abstractmethod - def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]: ... def add(self, index: int) -> None: @@ -61,8 +58,8 @@ def extend(self, index: torch.Tensor) -> None: def update_priority( self, - index: Union[int, torch.Tensor], - priority: Union[float, torch.Tensor], + index: int | torch.Tensor, + priority: float | torch.Tensor, *, storage: Storage | None = None, ) -> dict | None: @@ -72,7 +69,7 @@ def update_priority( return def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: return @@ -81,11 +78,11 @@ def default_priority(self) -> float: return 1.0 @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @property @@ -123,7 +120,7 @@ class RandomSampler(Sampler): """ - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: if len(storage) == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) index = storage._rand_given_ndim(batch_size) @@ -140,10 +137,10 @@ def loads(self, path): # no op ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return @@ -232,7 +229,7 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int - ) -> Tuple[Any, dict]: # noqa: F811 + ) -> tuple[Any, dict]: # noqa: F811 len_storage = self._storage_len(storage) if len_storage == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) @@ -269,7 +266,7 @@ def _empty(self): self.len_storage = 0 self._ran_out = False - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return OrderedDict( len_storage=self.len_storage, _sample_list=self._sample_list, @@ -277,7 +274,7 @@ def state_dict(self) -> Dict[str, Any]: _ran_out=self._ran_out, ) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.len_storage = state_dict["len_storage"] self._sample_list = state_dict["_sample_list"] self.drop_last = state_dict["drop_last"] @@ -542,8 +539,8 @@ def extend(self, index: torch.Tensor | tuple) -> None: @torch.no_grad() def update_priority( self, - index: Union[int, torch.Tensor], - priority: Union[float, torch.Tensor], + index: int | torch.Tensor, + priority: float | torch.Tensor, *, storage: TensorStorage | None = None, ) -> None: # noqa: D417 @@ -626,11 +623,11 @@ def update_priority( self._max_priority = (maxval, maxidx) def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: self.update_priority(index, self.default_priority, storage=storage) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_alpha": self._alpha, "_beta": self._beta, @@ -640,7 +637,7 @@ def state_dict(self) -> Dict[str, Any]: "_min_tree": deepcopy(self._min_tree), } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._alpha = state_dict["_alpha"] self._beta = state_dict["_beta"] self._eps = state_dict["_eps"] @@ -693,7 +690,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "sampler_metadata.json", "r") as file: + with open(path / "sampler_metadata.json") as file: metadata = json.load(file) self._alpha = metadata["_alpha"] self._beta = metadata["_beta"] @@ -992,7 +989,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, compile: bool | dict = False, - span: bool | int | Tuple[bool | int, bool | int] = False, + span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False, ): self.num_slices = num_slices @@ -1324,7 +1321,7 @@ def _adjusted_batch_size(self, batch_size): num_slices = batch_size // self.slice_len return seq_length, num_slices - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: if self._batch_size_multiplier is not None: batch_size = batch_size * self._batch_size_multiplier # pick up as many trajs as we need @@ -1361,7 +1358,7 @@ def _sample_slices( traj_idx: torch.Tensor | None = None, *, storage, - ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: + ) -> tuple[tuple[torch.Tensor, ...], dict[str, Any]]: # start_idx and stop_idx are 2d tensors organized like a non-zero def get_traj_idx(maxval): @@ -1442,7 +1439,7 @@ def _get_index( traj_idx: torch.Tensor | None = None, *, storage, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: # end_point is the last possible index for start last_indexable_start = lengths[traj_idx] - seq_length + 1 if not self.span[1]: @@ -1555,10 +1552,10 @@ def loads(self, path): # no op ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @@ -1787,7 +1784,7 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int - ) -> Tuple[Tuple[torch.Tensor, ...], dict]: + ) -> tuple[tuple[torch.Tensor, ...], dict]: if self._batch_size_multiplier is not None: batch_size = batch_size * self._batch_size_multiplier start_idx, stop_idx, lengths = self._get_stop_and_length(storage) @@ -1827,10 +1824,10 @@ def tuple_to_tensor(traj_idx, lengths=lengths): ) return idx, info - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return SamplerWithoutReplacement.state_dict(self) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return SamplerWithoutReplacement.load_state_dict(self, state_dict) @@ -1985,7 +1982,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, compile: bool | dict = False, - span: bool | int | Tuple[bool | int, bool | int] = False, + span: bool | int | tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False, ): SliceSampler.__init__( @@ -2045,7 +2042,7 @@ def __getstate__(self): return state def mark_update( - self, index: Union[int, torch.Tensor], *, storage: Storage | None = None + self, index: int | torch.Tensor, *, storage: Storage | None = None ) -> None: return PrioritizedSampler.mark_update(self, index, storage=storage) @@ -2111,7 +2108,7 @@ def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx): self._cache["preceding_stop_idx"] = preceding_stop_idx return preceding_stop_idx - def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]: # Sample `batch_size` indices representing the start of a slice. # The sampling is based on a weight vector. start_idx, stop_idx, lengths = self._get_stop_and_length(storage) @@ -2388,13 +2385,13 @@ def loads(self, path: Path): for i, sampler in enumerate(self._samplers): sampler.loads(path / str(i)) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: state_dict = OrderedDict() for i, sampler in enumerate(self._samplers): state_dict[str(i)] = sampler.state_dict() return state_dict - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for i, sampler in enumerate(self._samplers): sampler.load_state_dict(state_dict[str(i)]) diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py index 6829424c620..4c031cd6082 100644 --- a/torchrl/data/replay_buffers/scheduler.py +++ b/torchrl/data/replay_buffers/scheduler.py @@ -5,13 +5,10 @@ from __future__ import annotations from abc import ABC, abstractmethod - -from typing import Any, Callable, Dict +from typing import Any, Callable import numpy as np - import torch - from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler @@ -69,7 +66,7 @@ def state_dict(self): del sd["sampler"] return sd - def load_state_dict(self, state_dict: Dict[str, Any]): + def load_state_dict(self, state_dict: dict[str, Any]): """Load the scheduler's state. Args: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 89d7a4dbe43..1d6a4ac69e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,7 +5,6 @@ from __future__ import annotations import abc - import logging import os import textwrap @@ -13,7 +12,7 @@ from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Sequence import numpy as np import tensordict @@ -29,6 +28,7 @@ from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( ListStorageCheckpointer, @@ -86,7 +86,7 @@ def _is_full(self): return len(self) == self.max_size @property - def _attached_entities(self) -> List: + def _attached_entities(self) -> list: # RBs that use a given instance of Storage should add # themselves to this set. _attached_entities_list = getattr(self, "_attached_entities_list", None) @@ -142,11 +142,11 @@ def __len__(self): ... @abc.abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abc.abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... @abc.abstractmethod @@ -241,7 +241,7 @@ def __init__(self, max_size: int | None = None, compilable: bool = False): def set( self, - cursor: Union[int, Sequence[int], slice], + cursor: int | Sequence[int] | slice, data: Any, *, set_cursor: bool = True, @@ -294,7 +294,7 @@ def set( else: self._storage[cursor] = data - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: if isinstance(index, (INT_CLASSES, slice)): return self._storage[index] elif isinstance(index, tuple): @@ -311,7 +311,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: def __len__(self): return len(self._storage) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_storage": [ elt if not hasattr(elt, "state_dict") else elt.state_dict() @@ -382,7 +382,7 @@ class LazyStackStorage(ListStorage): Keyword Args: compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at the cost of being executable in multiprocessed settings. - stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`. + stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`. Examples: >>> import torch @@ -416,12 +416,12 @@ def __init__( max_size: int | None = None, *, compilable: bool = False, - stack_dim: int = -1, + stack_dim: int = 0, ): super().__init__(max_size=max_size, compilable=compilable) self.stack_dim = stack_dim - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: out = super().get(index=index) if isinstance(out, list): stack_dim = self.stack_dim @@ -720,7 +720,7 @@ def __setstate__(self, state): state["_len_value"] = _len_value self.__dict__.update(state) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): pass @@ -794,8 +794,8 @@ def _get_new_len(self, data, cursor): @implement_for("torch", "2.0", None, compilable=True) def set( self, - cursor: Union[int, Sequence[int], slice], - data: Union[TensorDictBase, torch.Tensor], + cursor: int | Sequence[int] | slice, + data: TensorDictBase | torch.Tensor, *, set_cursor: bool = True, ): @@ -836,8 +836,8 @@ def set( @implement_for("torch", None, "2.0", compilable=True) def set( # noqa: F811 self, - cursor: Union[int, Sequence[int], slice], - data: Union[TensorDictBase, torch.Tensor], + cursor: int | Sequence[int] | slice, + data: TensorDictBase | torch.Tensor, *, set_cursor: bool = True, ): @@ -888,7 +888,7 @@ def set( # noqa: F811 ) self._storage[cursor] = data - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: _storage = self._storage is_tc = is_tensor_collection(_storage) if not self.initialized: @@ -1062,7 +1062,7 @@ def __init__( def _init( self, - data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 + data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 ) -> None: if not self._compilable: # TODO: Investigate why this seems to have a performance impact with @@ -1225,7 +1225,7 @@ def __init__( ) self._len = 0 - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): _storage = _mem_map_tensor_as_tensor(_storage) @@ -1282,7 +1282,7 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: + def _init(self, data: TensorDictBase | torch.Tensor) -> None: torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device @@ -1324,7 +1324,7 @@ def max_size_along_dim0(data_shape): self._storage = out self.initialized = True - def get(self, index: Union[int, Sequence[int], slice]) -> Any: + def get(self, index: int | Sequence[int] | slice) -> Any: result = super().get(index) return result @@ -1357,7 +1357,7 @@ class StorageEnsemble(Storage): def __init__( self, *storages: Storage, - transforms: List["Transform"] = None, # noqa: F821 + transforms: list[Transform] = None, # noqa: F821 ): self._rng_private = None self._storages = storages @@ -1408,10 +1408,10 @@ def _convert_id(self, sub): def _get_storage(self, sub): return self._storages[sub] - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead." diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 1e8985537f3..21c414e19bf 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -7,7 +7,6 @@ import contextlib import itertools - import math import operator import os @@ -93,7 +92,7 @@ def _pin_memory(output: Any) -> Any: def _reduce( tensor: torch.Tensor, reduction: str, dim: int | None = None -) -> Union[float, torch.Tensor]: +) -> float | torch.Tensor: """Reduces a tensor given the reduction method.""" if reduction == "max": result = tensor.max(dim=dim) @@ -977,15 +976,13 @@ def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None): # Copy-paste of unravel-index for PT 2.0 def _unravel_index( - indices: Tensor, shape: Union[int, typing.Sequence[int], torch.Size] -) -> typing.Tuple[Tensor, ...]: + indices: Tensor, shape: int | typing.Sequence[int] | torch.Size +) -> tuple[Tensor, ...]: res_tensor = _unravel_index_impl(indices, shape) return res_tensor.unbind(-1) -def _unravel_index_impl( - indices: Tensor, shape: Union[int, typing.Sequence[int]] -) -> Tensor: +def _unravel_index_impl(indices: Tensor, shape: int | typing.Sequence[int]) -> Tensor: if isinstance(shape, (int, torch.SymInt)): shape = torch.Size([shape]) else: diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index c7043f8829b..f7fd2a5eef2 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -11,11 +11,10 @@ from copy import copy from multiprocessing.context import get_spawning_popen from pathlib import Path -from typing import Any, Dict, Sequence +from typing import Any, Sequence import numpy as np import torch - from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase from tensordict.utils import _STRDTYPE2DTYPE, expand_as_right, is_tensorclass from torch import multiprocessing as mp @@ -70,11 +69,11 @@ def loads(self, path): ... @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: ... @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... def _replicate_index(self, index): @@ -131,10 +130,10 @@ def dumps(self, path): def loads(self, path): ... - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: return @@ -160,7 +159,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "metadata.json", "r") as file: + with open(path / "metadata.json") as file: metadata = json.load(file) self._cursor = metadata["cursor"] @@ -209,10 +208,10 @@ def extend(self, data: Sequence) -> torch.Tensor: ent.mark_update(index) return index - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"_cursor": self._cursor} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._cursor = state_dict["_cursor"] def _empty(self): @@ -611,7 +610,7 @@ def dumps(self, path): def loads(self, path): path = Path(path).absolute() - with open(path / "metadata.json", "r") as file: + with open(path / "metadata.json") as file: metadata = json.load(file) self._cursor = metadata["cursor"] self._rank_key = metadata["rank_key"] @@ -623,10 +622,10 @@ def loads(self, path): shape=shape, ).tolist() - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError def __repr__(self): @@ -731,8 +730,8 @@ def __repr__(self): writers = textwrap.indent(f"writers={self._writers}", " " * 4) return f"WriterEnsemble(\n{writers})" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError diff --git a/torchrl/data/rlhf/__init__.py b/torchrl/data/rlhf/__init__.py index f0db092f2d1..f9a961ec389 100644 --- a/torchrl/data/rlhf/__init__.py +++ b/torchrl/data/rlhf/__init__.py @@ -12,3 +12,17 @@ from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel + +__all__ = [ + "create_infinite_iterator", + "get_dataloader", + "TensorDictTokenizer", + "TokenizedDatasetLoader", + "PromptData", + "PromptTensorDictTokenizer", + "PairwiseDataset", + "RewardData", + "AdaptiveKLController", + "ConstantKLController", + "RolloutFromModel", +] diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index a0905c2d063..346a6046566 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -7,13 +7,10 @@ import importlib.util import os from pathlib import Path - -from typing import Sequence, Type +from typing import Sequence import torch - from tensordict import TensorDict, TensorDictBase - from tensordict.utils import NestedKey from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers import ( @@ -94,7 +91,7 @@ def __init__( split, max_length, dataset_name, - tokenizer_fn: Type[TensorDictTokenizer], + tokenizer_fn: type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, @@ -227,7 +224,7 @@ def _tokenize( @staticmethod def dataset_to_tensordict( - dataset: "datasets.Dataset" | TensorDict, # noqa: F821 + dataset: datasets.Dataset | TensorDict, # noqa: F821 data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, @@ -320,7 +317,7 @@ def create_infinite_iterator(iterator): def get_dataloader( batch_size: int, block_size: int, - tensorclass_type: Type, + tensorclass_type: type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index 6f41fe48698..8d6cfe54066 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional - import torch from tensordict import tensorclass, TensorDict @@ -21,9 +19,9 @@ class PromptData: input_ids: torch.Tensor attention_mask: torch.Tensor prompt_rindex: torch.Tensor - labels: Optional[torch.Tensor] = None - logits: Optional[torch.Tensor] = None - loss: Optional[torch.Tensor] = None + labels: torch.Tensor | None = None + logits: torch.Tensor | None = None + loss: torch.Tensor | None = None def mask_label(self, pad_token_id=50256): _, block_size = self.input_ids.shape diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py index 98976984e27..12e2deef2b2 100644 --- a/torchrl/data/rlhf/reward.py +++ b/torchrl/data/rlhf/reward.py @@ -5,12 +5,9 @@ from __future__ import annotations import importlib -from typing import Optional import torch - from tensordict import tensorclass - from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons" @@ -24,8 +21,8 @@ class RewardData: input_ids: torch.Tensor attention_mask: torch.Tensor - rewards: Optional[torch.Tensor] = None - end_scores: Optional[torch.Tensor] = None + rewards: torch.Tensor | None = None + end_scores: torch.Tensor | None = None @tensorclass diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index bbde6761f4a..a816c984062 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,11 +7,9 @@ import abc import collections import importlib -from typing import List, Tuple import numpy as np import torch - from tensordict import TensorDict from torch import nn, Tensor from torch.nn import functional as F @@ -30,7 +28,7 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: List[float]) -> float: + def update(self, kl_values: list[float]) -> float: ... @@ -63,7 +61,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: List[float] = None) -> float: + def update(self, kl_values: list[float] = None) -> float: if self.model is not None: self.model.kl_coef = self.coef return self.coef @@ -104,7 +102,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: List[float]): + def update(self, kl_values: list[float]): """Update ``self.coef`` adaptively. Arguments: @@ -422,7 +420,7 @@ def _default_conf(self): ) def _get_scores( - self, scores: Tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None + self, scores: tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None ): scores = torch.stack(scores, 1) if scores.shape[1] != self.max_new_tokens: diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index d43cbd7810d..1d3777eb48d 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -10,9 +10,7 @@ import numpy as np import torch - from torch import Tensor - from torchrl.data.tensor_specs import ( Binary, Categorical, @@ -139,7 +137,7 @@ def consolidate_spec( return spec -def _empty_like_spec(specs: List[TensorSpec], shape): +def _empty_like_spec(specs: list[TensorSpec], shape): for spec in specs[1:]: if spec.__class__ != specs[0].__class__: raise ValueError( @@ -224,7 +222,7 @@ def contains_lazy_spec(spec: TensorSpec) -> bool: return False -class CloudpickleWrapper(object): +class CloudpickleWrapper: """A wrapper for functions that allow for serialization in multiprocessed settings.""" def __init__(self, fn: Callable, **kwargs): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d9753eafc08..19d0cdbae41 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -46,6 +46,8 @@ from .transforms import ( ActionDiscretizer, ActionMask, + as_nested_tensor, + as_padded_tensor, AutoResetEnv, AutoResetTransform, BatchSizeTransform, @@ -117,8 +119,144 @@ check_marl_grouping, exploration_type, ExplorationType, + get_available_libraries, make_composite_from_td, MarlGroupMapType, + RandomPolicy, set_exploration_type, step_mdp, + terminated_or_truncated, ) + +__all__ = [ + "ActionDiscretizer", + "ActionMask", + "AutoResetEnv", + "AutoResetTransform", + "BatchSizeTransform", + "BinarizeReward", + "BraxEnv", + "BraxWrapper", + "BurnInTransform", + "CatFrames", + "CatTensors", + "CenterCrop", + "ChessEnv", + "ClipTransform", + "Compose", + "ConditionalSkip", + "Crop", + "DMControlEnv", + "DMControlWrapper", + "DTypeCastTransform", + "DataLoadingPrimer", + "DeviceCastTransform", + "DiscreteActionProjection", + "DoubleToFloat", + "DreamerDecoder", + "DreamerEnv", + "EndOfLifeTransform", + "EnvBase", + "EnvCreator", + "EnvMetaData", + "ExcludeTransform", + "ExplorationType", + "FiniteTensorDictCheck", + "FlattenObservation", + "FrameSkipTransform", + "GrayScale", + "GymEnv", + "GymLikeEnv", + "GymWrapper", + "HabitatEnv", + "Hash", + "InitTracker", + "IsaacGymEnv", + "IsaacGymWrapper", + "JumanjiEnv", + "JumanjiWrapper", + "KLRewardTransform", + "LLMEnv", + "LLMHashingEnv", + "LineariseRewards", + "MOGymEnv", + "MOGymWrapper", + "MarlGroupMapType", + "MeltingpotEnv", + "MeltingpotWrapper", + "ModelBasedEnvBase", + "MultiAction", + "MultiStepTransform", + "MultiThreadedEnv", + "MultiThreadedEnvWrapper", + "NoopResetEnv", + "ObservationNorm", + "ObservationTransform", + "OpenMLEnv", + "OpenSpielEnv", + "OpenSpielWrapper", + "ParallelEnv", + "PendulumEnv", + "PermuteTransform", + "PettingZooEnv", + "PettingZooWrapper", + "PinMemoryTransform", + "R3MTransform", + "RandomCropTensorDict", + "RandomPolicy", + "RemoveEmptySpecs", + "RenameTransform", + "Resize", + "Reward2GoTransform", + "RewardClipping", + "RewardScaling", + "RewardSum", + "RoboHiveEnv", + "SMACv2Env", + "SMACv2Wrapper", + "SelectTransform", + "SerialEnv", + "SignTransform", + "SqueezeTransform", + "Stack", + "StepCounter", + "TargetReturn", + "TensorDictPrimer", + "TicTacToeEnv", + "TimeMaxPool", + "Timer", + "ToTensorImage", + "Tokenizer", + "TrajCounter", + "Transform", + "TransformedEnv", + "UnaryTransform", + "UnityMLAgentsEnv", + "UnityMLAgentsWrapper", + "UnsqueezeTransform", + "VC1Transform", + "VIPRewardTransform", + "VIPTransform", + "VecGymEnvTransform", + "VecNorm", + "VmasEnv", + "VmasWrapper", + "as_nested_tensor", + "as_padded_tensor", + "check_env_specs", + "check_marl_grouping", + "default_info_dict_reader", + "env_creator", + "exploration_type", + "gSDENoise", + "get_available_libraries", + "get_env_metadata", + "gym_backend", + "make_composite_from_td", + "make_tensordict", + "register_gym_spec_conversion", + "set_exploration_type", + "set_gym_backend", + "step_mdp", + "terminated_or_truncated", +] diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index d7c25928f30..ea68171fdec 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -6,9 +6,7 @@ from __future__ import annotations import functools - import gc - import os import time import weakref @@ -17,11 +15,10 @@ from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence from warnings import warn import torch - from tensordict import ( is_tensor_collection, LazyStackedTensorDict, @@ -32,6 +29,7 @@ from tensordict.base import _is_leaf_nontensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp + from torchrl._utils import ( _check_for_faulty_process, _make_ordinal_device, @@ -49,7 +47,6 @@ MultiThreadedEnv, MultiThreadedEnvWrapper, ) - from torchrl.envs.utils import ( _aggregate_end_of_traj, _sort_keys, @@ -100,7 +97,7 @@ def __iter__(self): class _dispatch_caller_serial: - def __init__(self, list_callable: List[Callable, Any]): + def __init__(self, list_callable: list[Callable, Any]): self.list_callable = list_callable def __call__(self, *args, **kwargs): @@ -296,15 +293,15 @@ class BatchedEnvBase(EnvBase): def __init__( self, num_workers: int, - create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]], + create_env_fn: Callable[[], EnvBase] | Sequence[Callable[[], EnvBase]], *, - create_env_kwargs: Union[dict, Sequence[dict]] = None, + create_env_kwargs: dict | Sequence[dict] = None, pin_memory: bool = False, - share_individual_td: Optional[bool] = None, + share_individual_td: bool | None = None, shared_memory: bool = True, memmap: bool = False, - policy_proof: Optional[Callable] = None, - device: Optional[DEVICE_TYPING] = None, + policy_proof: Callable | None = None, + device: DEVICE_TYPING | None = None, allow_step_when_done: bool = False, num_threads: int = None, num_sub_threads: int = 1, @@ -492,7 +489,7 @@ def _has_dynamic_specs(self): return not self._use_buffers def _get_metadata( - self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] + self, create_env_fn: list[Callable], create_env_kwargs: list[dict] ): if self._single_task: # if EnvCreator, the metadata are already there @@ -514,7 +511,7 @@ def _get_metadata( self.share_individual_td = False else: n_tasks = len(create_env_fn) - self.meta_data: List[EnvMetaData] = [] + self.meta_data: list[EnvMetaData] = [] for i in range(n_tasks): self.meta_data.append( get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone() @@ -541,7 +538,7 @@ def _get_metadata( self._set_properties() - def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: + def update_kwargs(self, kwargs: dict | list[dict]) -> None: """Updates the kwargs of each environment given a dictionary or a list of dictionaries. Args: @@ -873,9 +870,8 @@ def close(self, *, raise_if_closed: bool = True) -> None: def _shutdown_workers(self) -> None: raise NotImplementedError - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """This method is not used in batched envs.""" - pass @lazy def start(self) -> None: @@ -976,8 +972,8 @@ def _shutdown_workers(self) -> None: @_check_start def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: for env in self._envs: new_seed = env.set_seed(seed, static_seed=static_seed) seed = new_seed @@ -1538,7 +1534,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): @@ -1610,7 +1606,7 @@ def select_and_clone(x, y): @_check_start def step_and_maybe_reset( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: if not self._use_buffers: # Simply dispatch the input to the workers # return self._step_and_maybe_reset_no_buffers(tensordict) @@ -1820,7 +1816,7 @@ def _wait_for_workers(self, workers_range): def _step_no_buffers( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): @@ -2073,7 +2069,7 @@ def _reset_no_buffers( tensordict: TensorDictBase, reset_kwargs_list, needs_resetting, - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: if is_tensor_collection(tensordict): # tensordict = tensordict.consolidate(share_memory=True, num_threads=1) if self.consolidate: @@ -2266,8 +2262,8 @@ def _shutdown_workers(self) -> None: @_check_start def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: self._seeds = [] for channel in self.parent_channels: channel.send(("seed", (seed, static_seed))) @@ -2344,8 +2340,8 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered def _run_worker_pipe_shared_mem( parent_pipe: connection.Connection, child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], + env_fun: EnvBase | Callable, + env_fun_kwargs: dict[str, Any], mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -2596,8 +2592,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): def _run_worker_pipe_direct( parent_pipe: connection.Connection, child_pipe: connection.Connection, - env_fun: Union[EnvBase, Callable], - env_fun_kwargs: Dict[str, Any], + env_fun: EnvBase | Callable, + env_fun_kwargs: dict[str, Any], mp_event: mp.Event = None, non_blocking: bool = False, has_lazy_inputs: bool = False, diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index bbd780aadd7..24ffee4b3f1 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -7,3 +7,5 @@ from .llm import LLMEnv, LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv + +__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv", "LLMEnv"] diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index ad8625a0418..6706242fd24 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -7,15 +7,19 @@ import importlib.util import io import pathlib -from typing import Dict, Optional import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded - +from torchrl.data.tensor_specs import ( + Binary, + Bounded, + Categorical, + Composite, + NonTensor, + Unbounded, +) from torchrl.envs import EnvBase from torchrl.envs.common import _EnvPostInit - from torchrl.envs.utils import _classproperty @@ -49,10 +53,8 @@ def maybe_add_keys(condition, in_key, out_key): ) elif include_hash_inv: raise ValueError( - ( - "'include_hash_inv=True' can only be set if" - f"'include_hash=True', but got 'include_hash={include_hash}'." - ) + "'include_hash_inv=True' can only be set if" + f"'include_hash=True', but got 'include_hash={include_hash}'." ) if kwargs.get("mask_actions", True): from torchrl.envs import ActionMask @@ -197,7 +199,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): """ - _hash_table: Dict[int, str] = {} + _hash_table: dict[int, str] = {} _PGN_RESTART = """[Event "?"] [Site "?"] [Date "????.??.??"] @@ -231,7 +233,7 @@ def san_moves(cls): def _legal_moves_to_index( self, tensordict: TensorDictBase | None = None, - board: "chess.Board" | None = None, # noqa: F821 + board: chess.Board | None = None, # noqa: F821 return_mask: bool = False, pad: bool = False, ) -> torch.Tensor: @@ -357,16 +359,12 @@ def __init__( def _is_done(self, board): return board.is_game_over() | board.is_fifty_moves() - def all_actions( - self, tensordict: Optional[TensorDictBase] = None - ) -> TensorDictBase: + def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: if not self.mask_actions: raise RuntimeError( - ( - "Cannot generate legal actions since 'mask_actions=False' was " - "set. If you really want to generate all actions, not just " - "legal ones, call 'env.full_action_spec.enumerate()'." - ) + "Cannot generate legal actions since 'mask_actions=False' was " + "set. If you really want to generate all actions, not just " + "legal ones, call 'env.full_action_spec.enumerate()'." ) return super().all_actions(tensordict) @@ -480,8 +478,8 @@ def _get_tensor_image(cls, board): @classmethod def _pgn_to_board( - cls, pgn_string: str, board: "chess.Board" | None = None # noqa: F821 - ) -> "chess.Board": # noqa: F821 + cls, pgn_string: str, board: chess.Board | None = None # noqa: F821 + ) -> chess.Board: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if board is None: @@ -493,7 +491,7 @@ def _pgn_to_board( return board @classmethod - def _add_move_to_pgn(cls, pgn_string: str, move: "chess.Move") -> str: # noqa: F821 + def _add_move_to_pgn(cls, pgn_string: str, move: chess.Move) -> str: # noqa: F821 pgn_io = io.StringIO(pgn_string) game = cls.lib.pgn.read_game(pgn_io) if game is None: @@ -502,7 +500,7 @@ def _add_move_to_pgn(cls, pgn_string: str, move: "chess.Move") -> str: # noqa: return str(game) @classmethod - def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821 + def _board_to_pgn(cls, board: chess.Board) -> str: # noqa: F821 game = cls.lib.pgn.Game.from_board(board) pgn_string = str(game) return pgn_string diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index f6dfc835e87..dd70a8c2598 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -11,12 +11,12 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _zip_strict from torch.utils.data import DataLoader -from torchrl.data import ( +from torchrl.data.map.hash import SipHash +from torchrl.data.tensor_specs import ( Bounded, Categorical as CategoricalSpec, Composite, NonTensor, - SipHash, TensorSpec, Unbounded, ) @@ -446,4 +446,3 @@ def _set_seed(self, *args): .. note:: This environment has no randomness, so this method does nothing. """ - pass diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index b530a01418e..bf99a768a61 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import numpy as np import torch diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 0a464a3e390..b5ec2cd9be0 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -4,11 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional - import torch from tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase @@ -279,7 +276,7 @@ def full(board: torch.Tensor) -> bool: def get_action_mask(): pass - def rand_action(self, tensordict: Optional[TensorDictBase] = None): + def rand_action(self, tensordict: TensorDictBase | None = None): mask = tensordict.get("mask") action_spec = self.action_spec if tensordict.ndim: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index f4cb8e263a1..0d81a8f7705 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -7,13 +7,11 @@ from collections import OrderedDict from multiprocessing.sharedctypes import Synchronized -from typing import Callable, Dict, Optional, Union +from typing import Callable import torch from tensordict import TensorDictBase - from torchrl._utils import logger as torchrl_logger - from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase, EnvMetaData @@ -80,7 +78,7 @@ class EnvCreator: def __init__( self, create_env_fn: Callable[..., EnvBase], - create_env_kwargs: Optional[Dict] = None, + create_env_kwargs: dict | None = None, share_memory: bool = True, **kwargs, ) -> None: @@ -230,9 +228,7 @@ def env_creator(fun: Callable) -> EnvCreator: return EnvCreator(fun) -def get_env_metadata( - env_or_creator: Union[EnvBase, Callable], kwargs: Optional[Dict] = None -): +def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = None): """Retrieves a EnvMetaData object from an env.""" if isinstance(env_or_creator, (EnvBase,)): return EnvMetaData.metadata_from_env(env_or_creator) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 857b3b96b2f..b74030998ae 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -8,13 +8,13 @@ import abc import re import warnings -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Mapping, Sequence import numpy as np import torch from tensordict import NonTensorData, TensorDict, TensorDictBase -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase @@ -24,13 +24,13 @@ class BaseInfoDictReader(metaclass=abc.ABCMeta): @abc.abstractmethod def __call__( - self, info_dict: Dict[str, Any], tensordict: TensorDictBase + self, info_dict: dict[str, Any], tensordict: TensorDictBase ) -> TensorDictBase: raise NotImplementedError @property @abc.abstractmethod - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> dict[str, TensorSpec]: raise NotImplementedError @@ -67,8 +67,8 @@ class default_info_dict_reader(BaseInfoDictReader): def __init__( self, - keys: List[str] | None = None, - spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None, + keys: list[str] | None = None, + spec: Sequence[TensorSpec] | dict[str, TensorSpec] | Composite | None = None, ignore_private: bool = True, ): self.ignore_private = ignore_private @@ -98,7 +98,7 @@ def __init__( self._info_spec = _info_spec def __call__( - self, info_dict: Dict[str, Any], tensordict: TensorDictBase + self, info_dict: dict[str, Any], tensordict: TensorDictBase ) -> TensorDictBase: if not isinstance(info_dict, (dict, TensorDictBase)) and len(self.keys): warnings.warn( @@ -142,7 +142,7 @@ def reset(self): self._info_spec = None @property - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> dict[str, TensorSpec]: return self._info_spec @@ -166,7 +166,7 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ - _info_dict_reader: List[BaseInfoDictReader] + _info_dict_reader: list[BaseInfoDictReader] @classmethod def __new__(cls, *args, **kwargs): @@ -191,7 +191,7 @@ def read_done( terminated: bool | None = None, truncated: bool | None = None, done: bool | None = None, - ) -> Tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]: + ) -> tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]: """Done state reader. In torchrl, a `"done"` signal means that a trajectory has reach its end, @@ -257,8 +257,8 @@ def read_reward(self, reward): return reward def read_obs( - self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] - ) -> Dict[str, Any]: + self, observations: dict[str, Any] | torch.Tensor | np.ndarray + ) -> dict[str, Any]: """Reads an observation from the environment and returns an observation compatible with the output TensorDict. Args: @@ -371,7 +371,7 @@ def validated(self, value): self.__dict__["_validated"] = value def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: obs, info = self._reset_output_transform(self._env.reset(**kwargs)) @@ -398,8 +398,8 @@ def _reset( @abc.abstractmethod def _output_transform( - self, step_outputs_tuple: Tuple - ) -> Tuple[ + self, step_outputs_tuple: tuple + ) -> tuple[ Any, float | np.ndarray, bool | np.ndarray | None, @@ -434,7 +434,7 @@ def _output_transform( ... @abc.abstractmethod - def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple: ... @_maybe_unlock diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 1cff97c1d49..8ae4695683c 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -26,3 +26,38 @@ from .smacv2 import SMACv2Env, SMACv2Wrapper from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper from .vmas import VmasEnv, VmasWrapper + +__all__ = [ + "BraxEnv", + "BraxWrapper", + "DMControlEnv", + "DMControlWrapper", + "MultiThreadedEnv", + "MultiThreadedEnvWrapper", + "gym_backend", + "GymEnv", + "GymWrapper", + "MOGymEnv", + "MOGymWrapper", + "register_gym_spec_conversion", + "set_gym_backend", + "HabitatEnv", + "IsaacGymEnv", + "IsaacGymWrapper", + "JumanjiEnv", + "JumanjiWrapper", + "MeltingpotEnv", + "MeltingpotWrapper", + "OpenMLEnv", + "OpenSpielEnv", + "OpenSpielWrapper", + "PettingZooEnv", + "PettingZooWrapper", + "RoboHiveEnv", + "SMACv2Env", + "SMACv2Wrapper", + "UnityMLAgentsEnv", + "UnityMLAgentsWrapper", + "VmasEnv", + "VmasWrapper", +] diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index d6e4db8b0e8..8785ed7597a 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -2,11 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import warnings -from typing import Dict, Optional, Union - import torch from packaging import version from tensordict import TensorDict, TensorDictBase @@ -209,7 +209,7 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): f"Setting a device in Brax wrapped environments is strongly recommended." ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): brax = self.lib if version.parse(brax.__version__) < version.parse("0.10.4"): raise ImportError("Brax v0.10.4 or greater is required.") @@ -223,12 +223,12 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, requires_grad: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -241,7 +241,7 @@ def _build_env( ) return env - def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 + def _make_state_spec(self, env: brax.envs.env.Env): # noqa: F821 jax = self.jax key = jax.random.PRNGKey(0) @@ -250,7 +250,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 state_spec = _extract_spec(state_dict).expand(self.batch_size) return state_spec - def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 + def _make_specs(self, env: brax.envs.env.Env) -> None: # noqa: F821 self.action_spec = Bounded( low=-1, high=1, @@ -291,7 +291,7 @@ def _make_state_example(self): state = _tree_reshape(state, self.batch_size) return state - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: jax = self.jax self._key = None self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) @@ -551,7 +551,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "brax.envs.env.Env": # noqa: F821 + ) -> brax.envs.env.Env: # noqa: F821 if not _has_brax: raise ImportError( f"brax not found, unable to create {env_name}. " @@ -576,7 +576,7 @@ def _build_env( def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index ba1fdcfc9ae..19647f55841 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -5,16 +5,14 @@ from __future__ import annotations import collections - import importlib import os -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict import numpy as np import torch from torchrl._utils import logger as torchrl_logger, VERBOSE - from torchrl.data.tensor_specs import ( Bounded, Categorical, @@ -23,7 +21,6 @@ TensorSpec, Unbounded, ) - from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict from torchrl.envs.gym_like import GymLikeEnv from torchrl.envs.utils import _classproperty @@ -41,7 +38,7 @@ def _dmcontrol_to_torchrl_spec_transform( spec, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, device: DEVICE_TYPING = None, categorical_discrete_encoding: bool = False, ) -> TensorSpec: @@ -94,7 +91,7 @@ def _dmcontrol_to_torchrl_spec_transform( raise NotImplementedError(type(spec)) -def _get_envs(to_dict: bool = True) -> Dict[str, Any]: +def _get_envs(to_dict: bool = True) -> dict[str, Any]: if not _has_dm_control: raise ImportError("Cannot find dm_control in virtual environment.") from dm_control import suite @@ -111,7 +108,7 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]: return d.items() -def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: +def _robust_to_tensor(array: float | np.ndarray) -> torch.Tensor: if isinstance(array, np.ndarray): return torch.as_tensor(array.copy()) else: @@ -211,11 +208,11 @@ def __init__(self, env=None, **kwargs): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -235,7 +232,7 @@ def _build_env( ) return env - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 # specs are defined when first called self.observation_spec = _dmcontrol_to_torchrl_spec_transform( self._env.observation_spec(), device=self.device @@ -260,7 +257,7 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 self._env.action_spec(), device=self.device ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): dm_control = self.lib from dm_control.suite.wrappers import pixels @@ -286,11 +283,11 @@ def to(self, device: DEVICE_TYPING) -> DMControlEnv: self._set_egl_device(self.device) return self - def _init_env(self, seed: Optional[int] = None) -> Optional[int]: + def _init_env(self, seed: int | None = None) -> int | None: seed = self.set_seed(seed) return seed - def _set_seed(self, _seed: Optional[int]) -> Optional[int]: + def _set_seed(self, _seed: int | None) -> int | None: from dm_control.suite.wrappers import pixels if _seed is None: @@ -308,8 +305,8 @@ def _set_seed(self, _seed: Optional[int]) -> Optional[int]: return _seed def _output_transform( - self, timestep_tuple: Tuple["TimeStep"] # noqa: F821 - ) -> Tuple[np.ndarray, float, bool, bool, dict]: + self, timestep_tuple: tuple[TimeStep] # noqa: F821 + ) -> tuple[np.ndarray, float, bool, bool, dict]: from dm_env import StepType if type(timestep_tuple) is not tuple: @@ -427,7 +424,7 @@ def _build_env( self, env_name: str, task_name: str, - _seed: Optional[int] = None, + _seed: int | None = None, **kwargs, ): from dm_control import suite @@ -467,7 +464,7 @@ def rebuild_with_kwargs(self, **new_kwargs): self._env = self._build_env() self._make_specs(self._env) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" in kwargs: env_name = kwargs["env_name"] if "task_name" in kwargs: diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index a4339820b9f..b59c7101c29 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -6,11 +6,10 @@ from __future__ import annotations import importlib -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded @@ -72,7 +71,7 @@ def lib(cls): def __init__( self, - env: Optional["envpool.python.envpool.EnvPoolMixin"] = None, # noqa: F821 + env: envpool.python.envpool.EnvPoolMixin | None = None, # noqa: F821 **kwargs, ): if not _has_envpool: @@ -88,9 +87,9 @@ def __init__( # Buffer to keep the latest observation for each worker # It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity" - self.obs: Union[torch.tensor, TensorDict] = self.observation_spec.zero() + self.obs: torch.tensor | TensorDict = self.observation_spec.zero() - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -99,11 +98,11 @@ def _check_kwargs(self, kwargs: Dict): if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)): raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.") - def _build_env(self, env: "envpool.python.envpool.EnvPoolMixin"): # noqa: F821 + def _build_env(self, env: envpool.python.envpool.EnvPoolMixin): # noqa: F821 return env def _make_specs( - self, env: "envpool.python.envpool.EnvPoolMixin" # noqa: F821 + self, env: envpool.python.envpool.EnvPoolMixin # noqa: F821 ) -> None: # noqa: F821 from torchrl.envs.libs.gym import set_gym_backend @@ -114,7 +113,7 @@ def _make_specs( self.reward_spec = output_spec["full_reward_spec"] self.done_spec = output_spec["full_done_spec"] - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: pass def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -212,10 +211,8 @@ def __repr__(self) -> str: def _transform_reset_output( self, - envpool_output: Tuple[ - Union["treevalue.TreeValue", np.ndarray], Any # noqa: F821 - ], - reset_workers: Optional[torch.Tensor], + envpool_output: tuple[treevalue.TreeValue | np.ndarray, Any], # noqa: F821 + reset_workers: torch.Tensor | None, ): """Process output of envpool env.reset.""" import treevalue @@ -243,7 +240,7 @@ def _transform_reset_output( return obs def _transform_step_output( - self, envpool_output: Tuple[Any, Any, Any, ...] + self, envpool_output: tuple[Any, Any, Any, ...] ) -> TensorDict: """Process output of envpool env.step.""" out = envpool_output @@ -272,8 +269,8 @@ def _transform_step_output( return tensordict_out def _treevalue_or_numpy_to_tensor_or_dict( - self, x: Union["treevalue.TreeValue", np.ndarray] # noqa: F821 - ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + self, x: treevalue.TreeValue | np.ndarray # noqa: F821 + ) -> torch.Tensor | dict[str, torch.Tensor]: """Converts observation returned by EnvPool. EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert @@ -291,8 +288,8 @@ def _treevalue_or_numpy_to_tensor_or_dict( return ret def _treevalue_to_dict( - self, tv: "treevalue.TreeValue" # noqa: F821 - ) -> Dict[str, Any]: + self, tv: treevalue.TreeValue # noqa: F821 + ) -> dict[str, Any]: """Converts TreeValue to a dictionary. Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary. @@ -301,7 +298,7 @@ def _treevalue_to_dict( return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)} - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): if seed is not None: torchrl_logger.info( "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ @@ -359,7 +356,7 @@ def __init__( num_workers: int, env_name: str, *, - create_env_kwargs: Optional[Dict[str, Any]] = None, + create_env_kwargs: dict[str, Any] | None = None, **kwargs, ): self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool @@ -376,7 +373,7 @@ def _build_env( self, env_name: str, num_workers: int, - create_env_kwargs: Optional[Dict[str, Any]], + create_env_kwargs: dict[str, Any] | None, ) -> Any: import envpool @@ -390,7 +387,7 @@ def _build_env( ) return super()._build_env(env) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """Library EnvPool only supports setting a seed by recreating the environment.""" if seed is not None: torchrl_logger.debug("Recreating EnvPool environment to set seed.") @@ -401,7 +398,7 @@ def _set_seed(self, seed: Optional[int]): create_env_kwargs=self.create_env_kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): for arg in ["num_workers", "env_name", "create_env_kwargs"]: if arg not in kwargs: raise TypeError(f"Expected '{arg}' to be part of kwargs") diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index d5bc3e3f4e8..f5c8c160ff9 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -10,13 +10,12 @@ import warnings from copy import copy from types import ModuleType -from typing import Dict, List, Tuple +from typing import Dict from warnings import warn import numpy as np import torch from packaging import version - from tensordict import TensorDict, TensorDictBase from torch.utils._pytree import tree_map @@ -37,9 +36,7 @@ from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict from torchrl.envs.batched_envs import CloudpickleWrapper from torchrl.envs.common import _EnvPostInit - from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv - from torchrl.envs.utils import _classproperty try: @@ -697,7 +694,7 @@ def _torchrl_to_gym_spec_transform( ) -def _get_envs(to_dict=False) -> List: +def _get_envs(to_dict=False) -> list: if not _has_gym: raise ImportError("Gym(nasium) could not be found in your virtual environment.") envs = _get_gym_envs() @@ -1052,7 +1049,7 @@ def _get_batch_size(self, env): # noqa: F811 def _get_batch_size(self, env): # noqa: F811 raise ImportError(GYMNASIUM_1_ERROR) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -1064,7 +1061,7 @@ def _build_env( env, from_pixels: bool = False, pixels_only: bool = False, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 self.batch_size = self._get_batch_size(env) env_from_pixels = _is_from_pixels(env) @@ -1230,7 +1227,7 @@ def _reward_space(self, env): # noqa: F811 rs = env.reward_space return rs - def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821 # If batch_size is provided, we se it to tell what batch size must be used # instead of self.batch_size cur_batch_size = self.batch_size if batch_size is None else torch.Size([]) @@ -1647,7 +1644,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 if not _has_gym: raise RuntimeError( f"gym not found, unable to create {env_name}. " @@ -1716,7 +1713,7 @@ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") @@ -1930,7 +1927,7 @@ def reset(self): self._final_validated = False -def _flip_info_tuple(info: Tuple[Dict]) -> Dict[str, tuple]: +def _flip_info_tuple(info: tuple[dict]) -> dict[str, tuple]: # In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples. # We patch this by flipping the tuple -> dict order. info_example = set(info[0]) diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 999277a2db8..380ffa3697e 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import importlib.util diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 0a64c395126..e8ae1c7bf79 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -5,16 +5,14 @@ from __future__ import annotations import importlib.util - import itertools import warnings -from typing import Any, Dict, Tuple, Union +from typing import Any import numpy as np import torch - from tensordict import TensorDictBase -from torchrl.data import Composite +from torchrl.data.tensor_specs import Composite from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.utils import _classproperty, make_composite_from_td @@ -45,7 +43,7 @@ def lib(self): return isaacgym def __init__( - self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs # noqa: F821 + self, env: isaacgymenvs.tasks.base.vec_task.Env, **kwargs # noqa: F821 ): warnings.warn( "IsaacGym environment support is an experimental feature that may change in the future." @@ -57,7 +55,7 @@ def __init__( # by convention in IsaacGymEnvs self.task = env.__name__ - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 super()._make_specs(env, batch_size=self.batch_size) self.full_done_spec = Composite( { @@ -133,7 +131,7 @@ def read_done( terminated: bool = None, truncated: bool | None = None, done: bool | None = None, - ) -> Tuple[bool, bool, bool]: + ) -> tuple[bool, bool, bool]: if terminated is not None: terminated = terminated.bool() if truncated is not None: @@ -146,8 +144,8 @@ def read_reward(self, total_reward): return total_reward def read_obs( - self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] - ) -> Dict[str, Any]: + self, observations: dict[str, Any] | torch.Tensor | np.ndarray + ) -> dict[str, Any]: """Reads an observation from the environment and returns an observation compatible with the output TensorDict. Args: diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 086533cb487..337e9e9a1a9 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import dataclasses import importlib.util -from typing import Union # import jax import numpy as np @@ -40,9 +41,7 @@ def _tree_flatten(x, batch_size: torch.Size): } -def _ndarray_to_tensor( - value: Union["jnp.ndarray", np.ndarray] # noqa: F821 -) -> torch.Tensor: +def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821 from jax import dlpack as jax_dlpack, numpy as jnp # JAX arrays generated by jax.vmap would have Numpy dtypes. @@ -59,7 +58,7 @@ def _ndarray_to_tensor( return out.to(numpy_to_torch_dtype_dict[value.dtype]) -def _tensor_to_ndarray(value: torch.Tensor) -> "jnp.ndarray": # noqa: F821 +def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821 from jax import dlpack as jax_dlpack return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous())) @@ -148,7 +147,7 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size return type(object_example)(**t) -def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> TensorSpec: +def _extract_spec(data: torch.Tensor | TensorDictBase, key=None) -> TensorSpec: if isinstance(data, torch.Tensor): shape = data.shape if key in ("reward", "done"): diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 3bbaf7caa1c..e00b869e755 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -5,13 +5,11 @@ from __future__ import annotations import importlib.util -from typing import Dict, Optional, Tuple, Union import numpy as np import torch from packaging import version from tensordict import TensorDict, TensorDictBase - from torchrl.envs.common import _EnvPostInit from torchrl.envs.utils import _classproperty @@ -51,7 +49,7 @@ def _get_envs(): def _jumanji_to_torchrl_spec_transform( spec, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, device: DEVICE_TYPING = None, categorical_action_encoding: bool = True, ) -> TensorSpec: @@ -352,7 +350,7 @@ def lib(self): def __init__( self, - env: "jumanji.env.Environment" = None, # noqa: F821 + env: jumanji.env.Environment = None, # noqa: F821 categorical_action_encoding=True, jit: bool = True, **kwargs, @@ -388,11 +386,11 @@ def jit(self, value): def _build_env( self, env, - _seed: Optional[int] = None, + _seed: int | None = None, from_pixels: bool = False, - render_kwargs: Optional[dict] = None, + render_kwargs: dict | None = None, pixels_only: bool = False, - camera_id: Union[int, str] = 0, + camera_id: int | str = 0, **kwargs, ): self.from_pixels = from_pixels @@ -480,7 +478,7 @@ def _make_reward_spec(self, env) -> TensorSpec: reward_spec.shape = torch.Size([1]) return reward_spec.expand([*self.batch_size, *reward_spec.shape]) - def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 + def _make_specs(self, env: jumanji.env.Environment) -> None: # noqa: F821 # extract spec from jumanji definition self.action_spec = self._make_action_spec(env) @@ -495,7 +493,7 @@ def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 # build state example for data conversion self._state_example = self._make_state_example(env) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): jumanji = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -674,7 +672,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: import jax from jax import numpy as jnp @@ -736,10 +734,10 @@ def read_reward(self, reward): return reward - def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: + def _output_transform(self, step_outputs_tuple: tuple) -> tuple: ... - def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: + def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple: ... @@ -938,7 +936,7 @@ def _build_env( self, env_name: str, **kwargs, - ) -> "jumanji.env.Environment": # noqa: F821 + ) -> jumanji.env.Environment: # noqa: F821 if not _has_jumanji: raise ImportError( f"jumanji not found, unable to create {env_name}. " @@ -957,7 +955,7 @@ def _build_env( def env_name(self): return self._constructor_kwargs["env_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index ce6fd0bd179..ce8e8e193c8 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -5,14 +5,12 @@ from __future__ import annotations import importlib - -from typing import Dict, List, Mapping, Sequence +from typing import Mapping, Sequence import torch - from tensordict import TensorDict, TensorDictBase -from torchrl.data import Categorical, Composite, TensorSpec +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType @@ -31,7 +29,7 @@ def _get_envs(): return list(substrate_configs.SUBSTRATES) -def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa +def _filter_global_state_from_dict(obs_dict: dict, world: bool) -> dict: # noqa return { key: value for key, value in obs_dict.items() @@ -40,8 +38,8 @@ def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa def _remove_world_observations_from_obs_spec( - observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]], # noqa -) -> Sequence[Mapping[str, "dm_env.specs.Array"]]: # noqa + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]], # noqa +) -> Sequence[Mapping[str, dm_env.specs.Array]]: # noqa return [ _filter_global_state_from_dict(agent_obs, world=False) for agent_obs in observation_spec @@ -49,8 +47,8 @@ def _remove_world_observations_from_obs_spec( def _global_state_spec_from_obs_spec( - observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]] # noqa -) -> Mapping[str, "dm_env.specs.Array"]: # noqa + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] # noqa +) -> Mapping[str, dm_env.specs.Array]: # noqa # We only look at agent 0 since world entries are the same for all agents world_entries = _filter_global_state_from_dict(observation_spec[0], world=True) if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries: @@ -60,7 +58,7 @@ def _global_state_spec_from_obs_spec( return _remove_world_prefix(world_entries) -def _remove_world_prefix(world_entries: Dict) -> Dict: +def _remove_world_prefix(world_entries: dict) -> dict: return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()} @@ -181,10 +179,10 @@ def available_envs(cls): def __init__( self, - env: "meltingpot.utils.substrates.substrate.Substrate" = None, # noqa + env: meltingpot.utils.substrates.substrate.Substrate = None, # noqa categorical_actions: bool = True, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, max_steps: int = None, **kwargs, ): @@ -198,7 +196,7 @@ def __init__( def _build_env( self, - env: "meltingpot.utils.substrates.substrate.Substrate", # noqa + env: meltingpot.utils.substrates.substrate.Substrate, # noqa ): return env @@ -208,7 +206,7 @@ def _make_group_map(self): check_marl_grouping(self.group_map, self.agent_names) def _make_specs( - self, env: "meltingpot.utils.substrates.substrate.Substrate" # noqa + self, env: meltingpot.utils.substrates.substrate.Substrate # noqa ) -> None: mp_obs_spec = self._env.observation_spec() # List of dict of arrays mp_obs_spec_no_world = _remove_world_observations_from_obs_spec( @@ -278,9 +276,9 @@ def _make_specs( def _make_group_specs( self, group: str, - torchrl_agent_obs_specs: List[TensorSpec], - torchrl_agent_act_specs: List[TensorSpec], - torchrl_rew_spec: List[TensorSpec], + torchrl_agent_obs_specs: list[TensorSpec], + torchrl_agent_act_specs: list[TensorSpec], + torchrl_rew_spec: list[TensorSpec], ): # Agent specs action_specs = [] @@ -327,7 +325,7 @@ def _make_group_specs( group_reward_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): meltingpot = self.lib if "env" not in kwargs: @@ -558,12 +556,12 @@ class MeltingpotEnv(MeltingpotWrapper): def __init__( self, - substrate: str | "ml_collections.config_dict.ConfigDict", # noqa + substrate: str | ml_collections.config_dict.ConfigDict, # noqa *, max_steps: int | None = None, categorical_actions: bool = True, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, **kwargs, ): if not _has_meltingpot: @@ -579,14 +577,14 @@ def __init__( **kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "substrate" not in kwargs: raise TypeError("Could not find environment key 'substrate' in kwargs.") def _build_env( self, - substrate: str | "ml_collections.config_dict.ConfigDict", # noqa - ) -> "meltingpot.utils.substrates.substrate.Substrate": # noqa + substrate: str | ml_collections.config_dict.ConfigDict, # noqa + ) -> meltingpot.utils.substrates.substrate.Substrate: # noqa from meltingpot import substrate as mp_substrate if isinstance(substrate, str): diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 831635f08cd..e4a2ed2c828 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 3a2ab55cd13..7dbfeac07a8 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -6,7 +6,6 @@ from __future__ import annotations import importlib.util -from typing import Dict, List import torch from tensordict import TensorDict, TensorDictBase @@ -159,7 +158,7 @@ def __init__( env=None, *, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions: bool = False, return_state: bool = False, **kwargs, @@ -176,7 +175,7 @@ def __init__( # `reset` allows resetting to any state, including a terminal state self._allow_done_after_reset = True - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): pyspiel = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -283,7 +282,7 @@ def _make_group_specs( group_reward_spec, ) - def _make_specs(self, env: "pyspiel.State") -> None: # noqa: F821 + def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821 self.agent_names = [f"player_{index}" for index in range(env.num_players())] self.agent_names_to_indices_map = { agent_name: i for i, agent_name in enumerate(self.agent_names) @@ -604,7 +603,7 @@ def __init__( game_string, *, group_map: MarlGroupMapType - | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions=False, return_state: bool = False, **kwargs, @@ -621,7 +620,7 @@ def _build_env( self, game_string: str, **kwargs, - ) -> "pyspiel.State": # noqa: F821 + ) -> pyspiel.State: # noqa: F821 if not _has_pyspiel: raise ImportError( f"open_spiel not found, unable to create {game_string}. Consider " @@ -647,7 +646,7 @@ def _build_env( def game_string(self): return self._constructor_kwargs["game_string"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "game_string" not in kwargs: raise TypeError("Expected 'game_string' to be part of kwargs") diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 616362df804..31fed35baa7 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -7,7 +7,7 @@ import copy import importlib import warnings -from typing import Dict, List, Tuple, Union +from typing import Dict import numpy as np import packaging @@ -36,7 +36,7 @@ def _get_envs(): return list(all_environments.keys()) -def _load_available_envs() -> Dict: +def _load_available_envs() -> dict: all_environments = {} try: from pettingzoo.mpe.all_modules import mpe_environments @@ -73,9 +73,7 @@ def _load_available_envs() -> Dict: return all_environments -def _extract_nested_with_index( - data: Union[np.ndarray, Dict[str, np.ndarray]], index: int -): +def _extract_nested_with_index(data: np.ndarray | dict[str, np.ndarray], index: int): if isinstance(data, np.ndarray): return data[index] elif isinstance(data, dict): @@ -208,12 +206,12 @@ def available_envs(cls): def __init__( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ] = None, + env: ( + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 + ) = None, return_state: bool = False, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, use_mask: bool = False, categorical_actions: bool = True, seed: int | None = None, @@ -232,7 +230,7 @@ def __init__( super().__init__(**kwargs, allow_done_after_reset=True) - def _get_default_group_map(self, agent_names: List[str]): + def _get_default_group_map(self, agent_names: list[str]): # This function performs the default grouping in pettingzoo if not self.parallel: # In AEC envs we will have one group per agent by default @@ -273,10 +271,10 @@ def lib(self): def _build_env( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ], + env: ( + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 + ), ): import pettingzoo @@ -300,10 +298,10 @@ def _build_env( @set_gym_backend("gymnasium") def _make_specs( self, - env: Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ], + env: ( + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 + ), ) -> None: # Set default for done on any or all if self.done_on_any is None: @@ -359,7 +357,7 @@ def _make_specs( self.reward_spec = reward_spec self.done_spec = done_spec - def _make_group_specs(self, group_name: str, agent_names: List[str]): + def _make_group_specs(self, group_name: str, agent_names: list[str]): n_agents = len(agent_names) action_specs = [] observation_specs = [] @@ -457,7 +455,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): group_done_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): import pettingzoo if "env" not in kwargs: @@ -588,7 +586,7 @@ def _reset( return tensordict_out - def _reset_aec(self, **kwargs) -> Tuple[Dict, Dict]: + def _reset_aec(self, **kwargs) -> tuple[dict, dict]: self._env.reset(**kwargs) observation_dict = { @@ -597,7 +595,7 @@ def _reset_aec(self, **kwargs) -> Tuple[Dict, Dict]: info_dict = self._env.infos return observation_dict, info_dict - def _reset_parallel(self, **kwargs) -> Tuple[Dict, Dict]: + def _reset_parallel(self, **kwargs) -> tuple[dict, dict]: return self._env.reset(**kwargs) def _step( @@ -741,7 +739,7 @@ def _aggregate_done(self, tensordict_out, use_any): def _step_parallel( self, tensordict: TensorDictBase, - ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + ) -> tuple[dict, dict, dict, dict, dict]: action_dict = {} for group, agents in self.group_map.items(): group_action = tensordict.get((group, "action")) @@ -758,7 +756,7 @@ def _step_parallel( def _step_aec( self, tensordict: TensorDictBase, - ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + ) -> tuple[dict, dict, dict, dict, dict]: for group, agents in self.group_map.items(): if self.agent_selection in agents: agent_index = agents.index(self._env.agent_selection) @@ -966,7 +964,7 @@ def __init__( task: str, parallel: bool, return_state: bool = False, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, use_mask: bool = False, categorical_actions: bool = True, seed: int | None = None, @@ -989,7 +987,7 @@ def __init__( super().__init__(**kwargs) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "task" not in kwargs: raise TypeError("Could not find environment key 'task' in kwargs.") if "parallel" not in kwargs: @@ -1000,10 +998,10 @@ def _build_env( task: str, parallel: bool, **kwargs, - ) -> Union[ - "pettingzoo.utils.env.ParallelEnv", # noqa: F821 - "pettingzoo.utils.env.AECEnv", # noqa: F821 - ]: + ) -> ( + pettingzoo.utils.env.ParallelEnv # noqa: F821 + | pettingzoo.utils.env.AECEnv # noqa: F821 + ): self.task_name = task try: diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 2a4e04f7d71..fd02664a2a0 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -2,16 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib import os import warnings - from copy import copy from pathlib import Path import numpy as np import torch from tensordict import TensorDict + from torchrl.data.tensor_specs import Unbounded from torchrl.envs.common import _maybe_unlock from torchrl.envs.libs.gym import ( @@ -32,7 +34,7 @@ os.environ.setdefault("sim_backend", "MUJOCO") -class set_directory(object): +class set_directory: """Sets the cwd within the context. Args: @@ -162,7 +164,7 @@ def _build_env( # noqa: F811 pixels_only: bool = False, from_depths: bool = False, **kwargs, - ) -> "gym.core.Env": # noqa: F821 + ) -> gym.core.Env: # noqa: F821 if from_pixels: if "cameras" not in kwargs: warnings.warn( @@ -219,7 +221,7 @@ def _build_env( # noqa: F811 self.set_info_dict_reader(self.read_info) return env - def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821 out = super()._make_specs(env=env, batch_size=batch_size) self.env.reset() *_, info = self.env.step(self.env.action_space.sample()) diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 397abbcc3c0..149c20606ac 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -6,7 +6,6 @@ from __future__ import annotations import importlib.util -from typing import Dict, List, Optional import torch from tensordict import TensorDict, TensorDictBase @@ -87,7 +86,6 @@ def lib(cls): if cls._lib is not None: return cls._lib - import mlagents_envs import mlagents_envs.environment cls._lib = mlagents_envs @@ -97,7 +95,7 @@ def __init__( self, env=None, *, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, categorical_actions: bool = False, **kwargs, ): @@ -108,7 +106,7 @@ def __init__( self.categorical_actions = categorical_actions super().__init__(**kwargs) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): mlagents_envs = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -177,7 +175,7 @@ def _make_group_map(self, group_map, agent_name_to_group_id_map): return group_map, agent_name_to_group_name_map def _make_specs( - self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821 + self, env: mlagents_envs.environment.UnityEnvironment # noqa: F821 ) -> None: # NOTE: We need to reset here because mlagents only initializes the # agents and behaviors after reset. In order to build specs, we make the @@ -288,17 +286,13 @@ def _set_seed(self, seed): def _check_agent_exists(self, agent_name, group_id): if agent_name not in self.agent_name_to_group_id_map: raise RuntimeError( - ( - "Unity environment added a new agent. This is not yet " - "supported in torchrl." - ) + "Unity environment added a new agent. This is not yet " + "supported in torchrl." ) if self.agent_name_to_group_id_map[agent_name] != group_id: raise RuntimeError( - ( - "Unity environment changed the group of an agent. This " - "is not yet supported in torchrl." - ) + "Unity environment changed the group of an agent. This " + "is not yet supported in torchrl." ) def _update_action_mask(self): @@ -836,10 +830,10 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper): def __init__( self, - file_name: Optional[str] = None, - registered_name: Optional[str] = None, + file_name: str | None = None, + registered_name: str | None = None, *, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, categorical_actions=False, **kwargs, ): @@ -853,10 +847,10 @@ def __init__( def _build_env( self, - file_name: Optional[str], - registered_name: Optional[str], + file_name: str | None, + registered_name: str | None, **kwargs, - ) -> "mlagents_envs.environment.UnityEnvironment": # noqa: F821 + ) -> mlagents_envs.environment.UnityEnvironment: # noqa: F821 if not _has_unity_mlagents: raise ImportError( "mlagents_envs not found, unable to create environment. " @@ -888,7 +882,7 @@ def file_name(self): def registered_name(self): return self._constructor_kwargs["registered_name"] - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): pass def __repr__(self) -> str: diff --git a/torchrl/envs/libs/utils.py b/torchrl/envs/libs/utils.py index d7d7dbf6bfe..a4124c8dbf6 100644 --- a/torchrl/envs/libs/utils.py +++ b/torchrl/envs/libs/utils.py @@ -2,10 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +# # Copied from gym > 0.19 release - +# # this file should only be accessed when gym is installed +from __future__ import annotations import collections import copy diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 1d786358ca9..772e79b497e 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -7,8 +7,6 @@ import importlib.util import warnings -from typing import Dict, List, Optional, Union - import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase @@ -237,9 +235,9 @@ def available_envs(cls): def __init__( self, - env: "vmas.simulator.environment.environment.Environment" = None, # noqa + env: vmas.simulator.environment.environment.Environment = None, # noqa categorical_actions: bool = True, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, **kwargs, ): if env is not None: @@ -253,7 +251,7 @@ def __init__( def _build_env( self, - env: "vmas.simulator.environment.environment.Environment", # noqa + env: vmas.simulator.environment.environment.Environment, # noqa from_pixels: bool = False, pixels_only: bool = False, ): @@ -281,7 +279,7 @@ def _build_env( return env - def _get_default_group_map(self, agent_names: List[str]): + def _get_default_group_map(self, agent_names: list[str]): # This function performs the default grouping in vmas. # Agents with names "_" will be grouped in group name "". # If any of the agents does not follow the naming convention, we fall back @@ -316,7 +314,7 @@ def _get_default_group_map(self, agent_names: List[str]): return group_map def _make_specs( - self, env: "vmas.simulator.environment.environment.Environment" # noqa + self, env: vmas.simulator.environment.environment.Environment # noqa ) -> None: # Create and check group map self.agent_names = [agent.name for agent in self.agents] @@ -478,7 +476,7 @@ def _make_unbatched_group_specs(self, group: str): group_info_spec, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): vmas = self.lib if "env" not in kwargs: @@ -489,14 +487,14 @@ def _check_kwargs(self, kwargs: Dict): "env is not of type 'vmas.simulator.environment.Environment'." ) - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: pass - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): self._env.seed(seed) def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): _reset = tensordict.get("_reset") @@ -607,9 +605,7 @@ def _step( ) return tensordict_out - def read_obs( - self, observations: Union[Dict, torch.Tensor] - ) -> Union[Dict, torch.Tensor]: + def read_obs(self, observations: dict | torch.Tensor) -> dict | torch.Tensor: if isinstance(observations, torch.Tensor): return _selective_unsqueeze(observations, batch_size=self.batch_size) return TensorDict( @@ -617,7 +613,7 @@ def read_obs( batch_size=self.batch_size, ) - def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: + def read_info(self, infos: dict[str, torch.Tensor]) -> torch.Tensor: if len(infos) == 0: return None infos = TensorDict( @@ -777,14 +773,14 @@ class VmasEnv(VmasWrapper): def __init__( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa + scenario: str | vmas.simulator.scenario.BaseScenario, # noqa *, num_envs: int, continuous_actions: bool = True, - max_steps: Optional[int] = None, + max_steps: int | None = None, categorical_actions: bool = True, - seed: Optional[int] = None, - group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + seed: int | None = None, + group_map: MarlGroupMapType | dict[str, list[str]] | None = None, **kwargs, ): if not _has_vmas: @@ -803,7 +799,7 @@ def __init__( **kwargs, ) - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): if "scenario" not in kwargs: raise TypeError("Could not find environment key 'scenario' in kwargs.") if "num_envs" not in kwargs: @@ -811,13 +807,13 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, - scenario: Union[str, "vmas.simulator.scenario.BaseScenario"], # noqa + scenario: str | vmas.simulator.scenario.BaseScenario, # noqa num_envs: int, continuous_actions: bool, - max_steps: Optional[int], - seed: Optional[int], + max_steps: int | None, + seed: int | None, **scenario_kwargs, - ) -> "vmas.simulator.environment.environment.Environment": # noqa + ) -> vmas.simulator.environment.environment.Environment: # noqa vmas = self.lib self.scenario_name = scenario diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 437146a4909..cb387af7ff8 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -5,3 +5,5 @@ from .common import ModelBasedEnvBase from .dreamer import DreamerDecoder, DreamerEnv + +__all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"] diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index d2e7a6271e5..ae09391485f 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -2,15 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import abc import warnings -from typing import List, Optional import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule - from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -113,13 +112,13 @@ class ModelBasedEnvBase(EnvBase): def __init__( self, world_model: TensorDictModule, - params: Optional[List[torch.Tensor]] = None, - buffers: Optional[List[torch.Tensor]] = None, + params: list[torch.Tensor] | None = None, + buffers: list[torch.Tensor] | None = None, device: DEVICE_TYPING = "cpu", - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, run_type_checks: bool = False, ): - super(ModelBasedEnvBase, self).__init__( + super().__init__( device=device, batch_size=batch_size, run_type_checks=run_type_checks, @@ -174,6 +173,6 @@ def _step( def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: raise NotImplementedError - def _set_seed(self, seed: Optional[int]) -> int: + def _set_seed(self, seed: int | None) -> int: warnings.warn("Set seed isn't needed for model based environments") return seed diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index b69f206bb01..9228c39aa66 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -2,13 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from typing import Optional, Tuple +from __future__ import annotations import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule - from torchrl.data.tensor_specs import Composite from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -22,15 +20,13 @@ class DreamerEnv(ModelBasedEnvBase): def __init__( self, world_model: TensorDictModule, - prior_shape: Tuple[int, ...], - belief_shape: Tuple[int, ...], + prior_shape: tuple[int, ...], + belief_shape: tuple[int, ...], obs_decoder: TensorDictModule = None, device: DEVICE_TYPING = "cpu", - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, ): - super(DreamerEnv, self).__init__( - world_model, device=device, batch_size=batch_size - ) + super().__init__(world_model, device=device, batch_size=batch_size) self.obs_decoder = obs_decoder self.prior_shape = prior_shape self.belief_shape = belief_shape diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 736bb7a2c9a..c6814f8745d 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -6,7 +6,12 @@ from .gym_transforms import EndOfLifeTransform from .r3m import R3MTransform from .rb_transforms import MultiStepTransform -from .rlhf import DataLoadingPrimer, KLRewardTransform +from .rlhf import ( + as_nested_tensor, + as_padded_tensor, + DataLoadingPrimer, + KLRewardTransform, +) from .transforms import ( ActionDiscretizer, ActionMask, @@ -70,3 +75,75 @@ ) from .vc1 import VC1Transform from .vip import VIPRewardTransform, VIPTransform + +__all__ = [ + "ActionDiscretizer", + "ActionMask", + "AutoResetEnv", + "AutoResetTransform", + "BatchSizeTransform", + "BinarizeReward", + "BurnInTransform", + "CatFrames", + "CatTensors", + "CenterCrop", + "ClipTransform", + "Compose", + "ConditionalSkip", + "Crop", + "DTypeCastTransform", + "DataLoadingPrimer", + "DeviceCastTransform", + "DiscreteActionProjection", + "DoubleToFloat", + "EndOfLifeTransform", + "ExcludeTransform", + "FiniteTensorDictCheck", + "FlattenObservation", + "FrameSkipTransform", + "GrayScale", + "Hash", + "InitTracker", + "KLRewardTransform", + "LineariseRewards", + "MultiAction", + "MultiStepTransform", + "NoopResetEnv", + "ObservationNorm", + "ObservationTransform", + "PermuteTransform", + "PinMemoryTransform", + "R3MTransform", + "RandomCropTensorDict", + "RemoveEmptySpecs", + "RenameTransform", + "Resize", + "Reward2GoTransform", + "RewardClipping", + "RewardScaling", + "RewardSum", + "SelectTransform", + "SignTransform", + "SqueezeTransform", + "Stack", + "StepCounter", + "TargetReturn", + "TensorDictPrimer", + "TimeMaxPool", + "Timer", + "ToTensorImage", + "Tokenizer", + "TrajCounter", + "Transform", + "TransformedEnv", + "UnaryTransform", + "UnsqueezeTransform", + "VC1Transform", + "VIPRewardTransform", + "VIPTransform", + "VecGymEnvTransform", + "VecNorm", + "as_nested_tensor", + "as_padded_tensor", + "gSDENoise", +] diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index 6ef23b11fd5..cd6c61c0502 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from typing import List +from __future__ import annotations from torch import Tensor @@ -15,10 +14,10 @@ def _get_image_num_channels(img: Tensor) -> int: elif img.ndim > 2: return img.shape[-3] - raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) + raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") -def _assert_channels(img: Tensor, permitted: List[int]) -> None: +def _assert_channels(img: Tensor, permitted: list[int]) -> None: c = _get_image_num_channels(img) if c not in permitted: raise TypeError( diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index dea05ad175f..487329942c6 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -4,10 +4,12 @@ # LICENSE file in the root directory of this source tree. """Gym-specific transforms.""" + +from __future__ import annotations + import warnings import torch -import torchrl.objectives.common from tensordict import TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torchrl.data.tensor_specs import Unbounded @@ -186,7 +188,9 @@ def transform_observation_spec(self, observation_spec): ) return observation_spec - def register_keys(self, loss_or_advantage: "torchrl.objectives.common.LossModule"): + def register_keys( + self, loss_or_advantage: torchrl.objectives.common.LossModule # noqa + ): """Registers the end-of-life key at appropriate places within the loss. Args: diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index dd150e8d94e..63b5ed44552 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import importlib.util -from typing import List, Optional, Union import torch from tensordict import set_lazy_legacy, TensorDict, TensorDictBase @@ -232,13 +232,13 @@ def __new__(cls, *args, **kwargs): def __init__( self, model_name: str, - in_keys: List[str], - out_keys: List[str] = None, + in_keys: list[str], + out_keys: list[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 - download_path: Optional[str] = None, - tensor_pixels_keys: List[str] = None, + download: bool | WeightsEnum | str = False, # noqa: F821 + download_path: str | None = None, + tensor_pixels_keys: list[str] = None, ): super().__init__() self.in_keys = in_keys if in_keys is not None else ["pixels"] @@ -356,7 +356,7 @@ def _init(self): if self._dtype is not None: self.to(self._dtype) - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: diff --git a/torchrl/envs/transforms/rb_transforms.py b/torchrl/envs/transforms/rb_transforms.py index 76a8e6039f8..8507ce6d8f3 100644 --- a/torchrl/envs/transforms/rb_transforms.py +++ b/torchrl/envs/transforms/rb_transforms.py @@ -4,10 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import List - import torch - from tensordict import NestedKey, TensorDictBase from torchrl.data.postprocs.postprocs import _multi_step_func from torchrl.envs.transforms.transforms import Transform @@ -112,9 +109,9 @@ def __init__( n_steps, gamma, *, - reward_keys: List[NestedKey] | None = None, + reward_keys: list[NestedKey] | None = None, done_key: NestedKey | None = None, - done_keys: List[NestedKey] | None = None, + done_keys: list[NestedKey] | None = None, mask_key: NestedKey | None = None, ): super().__init__() diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 963002e8c05..752aaea573d 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from copy import copy, deepcopy -from typing import Any, Callable, Iterable, List, Literal +from typing import Any, Callable, Iterable, Literal import torch from tensordict import ( @@ -19,6 +19,7 @@ from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import _zip_strict, is_seq_of_nested_key from torch import nn + from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param @@ -339,8 +340,8 @@ def __init__( self, dataloader: Iterable[Any], primers: Composite | None = None, - data_keys: List[NestedKey] | None = None, - data_specs: List[TensorSpec] | None = None, + data_keys: list[NestedKey] | None = None, + data_specs: list[TensorSpec] | None = None, example_data: Any = None, stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2ee42d19667..eff19ef1b61 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -87,7 +87,6 @@ make_composite_from_td, step_mdp, ) -from torchrl.objectives.value.functional import reward2go _has_tv = importlib.util.find_spec("torchvision", None) is not None @@ -318,7 +317,6 @@ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: def init(self, tensordict) -> None: """Runs init steps for the transform.""" - pass def _apply_transform(self, obs: torch.Tensor) -> None: """Applies the transform to a tensor or a leaf. @@ -1161,7 +1159,6 @@ def set_seed( def _set_seed(self, seed: int | None): """This method is not used in transformed envs.""" - pass def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): if tensordict is not None: @@ -8541,6 +8538,8 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: def _inv_apply_transform( self, reward: torch.Tensor, done: torch.Tensor ) -> torch.Tensor: + from torchrl.objectives.value.functional import reward2go + return reward2go(reward, done, self.gamma) def set_container(self, container): diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py index a1b30cb1aca..8ef96c04ce0 100644 --- a/torchrl/envs/transforms/utils.py +++ b/torchrl/envs/transforms/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations import torch from torch import nn diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 76335bd8917..592237cbae3 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib import os import subprocess from functools import partial -from typing import Union import torch from tensordict import TensorDictBase @@ -212,7 +213,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 7d64ada37c4..4bfcfc9b5ce 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util -from typing import List, Optional, Union import torch from tensordict import set_lazy_legacy, TensorDict, TensorDictBase @@ -203,13 +204,13 @@ def __new__(cls, *args, **kwargs): def __init__( self, model_name: str, - in_keys: List[str] = None, - out_keys: List[str] = None, + in_keys: list[str] = None, + out_keys: list[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 - download_path: Optional[str] = None, - tensor_pixels_keys: List[str] = None, + download: bool | WeightsEnum | str = False, # noqa: F821 + download_path: str | None = None, + tensor_pixels_keys: list[str] = None, ): super().__init__() self.in_keys = in_keys if in_keys is not None else ["pixels"] @@ -325,7 +326,7 @@ def _init(self): if self._dtype is not None: self.to(self._dtype) - def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: @@ -364,7 +365,7 @@ def _embed_goal(self, tensordict): tensordict_in = tensordict.select("goal_image").rename_key_( "goal_image", self.in_keys[0] ) - tensordict_in = super(VIPRewardTransform, self).forward(tensordict_in) + tensordict_in = super().forward(tensordict_in) tensordict = tensordict.update( tensordict_in.rename_key_(self.out_keys[0], "goal_embedding") ) diff --git a/torchrl/envs/vec_envs.py b/torchrl/envs/vec_envs.py index 73dd159751c..e1956ccd9f7 100644 --- a/torchrl/envs/vec_envs.py +++ b/torchrl/envs/vec_envs.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings warnings.warn("vec_env.py has moved to batch_envs.py.", category=DeprecationWarning) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index edf90a4e85b..3225da8e437 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -93,3 +93,91 @@ ) from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip + +__all__ = [ + "DistributionalDQNnet", + "Delta", + "distributions_maps", + "IndependentNormal", + "MaskedCategorical", + "MaskedOneHotCategorical", + "NormalParamExtractor", + "NormalParamWrapper", + "OneHotCategorical", + "OneHotOrdinal", + "Ordinal", + "ReparamGradientStrategy", + "TanhDelta", + "TanhNormal", + "TruncatedNormal", + "BatchRenorm1d", + "ConsistentDropout", + "ConsistentDropoutModule", + "Conv3dNet", + "ConvNet", + "DdpgCnnActor", + "DdpgCnnQNet", + "DdpgMlpActor", + "DdpgMlpQNet", + "DecisionTransformer", + "DreamerActor", + "DTActor", + "DuelingCnnDQNet", + "MLP", + "MultiAgentConvNet", + "MultiAgentMLP", + "MultiAgentNetBase", + "NoisyLazyLinear", + "NoisyLinear", + "ObsDecoder", + "ObsEncoder", + "OnlineDTActor", + "QMixer", + "reset_noise", + "RSSMPosterior", + "RSSMPrior", + "RSSMRollout", + "Squeeze2dLayer", + "SqueezeLayer", + "VDNMixer", + "Actor", + "ActorCriticOperator", + "ActorCriticWrapper", + "ActorValueOperator", + "AdditiveGaussianModule", + "AdditiveGaussianWrapper", + "DecisionTransformerInferenceWrapper", + "DistributionalQValueActor", + "DistributionalQValueHook", + "DistributionalQValueModule", + "EGreedyModule", + "EGreedyWrapper", + "GRU", + "GRUCell", + "GRUModule", + "LMHeadActorValueOperator", + "LSTM", + "LSTMCell", + "LSTMModule", + "MultiStepActorWrapper", + "OrnsteinUhlenbeckProcessModule", + "OrnsteinUhlenbeckProcessWrapper", + "ProbabilisticActor", + "QValueActor", + "QValueHook", + "QValueModule", + "recurrent_mode", + "SafeModule", + "SafeProbabilisticModule", + "SafeProbabilisticTensorDictSequential", + "SafeSequential", + "set_recurrent_mode", + "TanhModule", + "ValueOperator", + "VmapModule", + "WorldModelWrapper", + "get_primers_from_module", + "CEMPlanner", + "MPCPlannerBase", + "MPPIPlanner", +] diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 8f1b7da49a5..dd800372c24 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -51,3 +51,20 @@ torch_dist.Categorical: True, torch_dist.Normal: True, } + +__all__ = [ + "NormalParamExtractor", + "distributions", + "Delta", + "IndependentNormal", + "NormalParamWrapper", + "TanhDelta", + "TanhNormal", + "TruncatedNormal", + "MaskedCategorical", + "MaskedOneHotCategorical", + "OneHotCategorical", + "OneHotOrdinal", + "Ordinal", + "ReparamGradientStrategy", +] diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 3dbb42e9cef..48461e21ed6 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -6,20 +6,18 @@ import weakref from numbers import Number -from typing import Dict, Optional, Sequence, Union +from typing import Sequence import numpy as np import torch from packaging import version from torch import distributions as D, nn - from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, ) - from torchrl.modules.distributions.utils import ( _cast_device, FasterTransformedDistribution, @@ -179,9 +177,9 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - upscale: Union[torch.Tensor, float] = 5.0, - low: Union[torch.Tensor, float] = -1.0, - high: Union[torch.Tensor, float] = 1.0, + upscale: torch.Tensor | float = 5.0, + low: torch.Tensor | float = -1.0, + high: torch.Tensor | float = 1.0, tanh_loc: bool = False, ): @@ -345,9 +343,9 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - upscale: Union[torch.Tensor, Number] = 5.0, - low: Union[torch.Tensor, Number] = -1.0, - high: Union[torch.Tensor, Number] = 1.0, + upscale: torch.Tensor | Number = 5.0, + low: torch.Tensor | Number = -1.0, + high: torch.Tensor | Number = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, @@ -543,15 +541,15 @@ class Delta(D.Distribution): """ - arg_constraints: Dict = {} + arg_constraints: dict = {} def __init__( self, param: torch.Tensor, atol: float = 1e-6, rtol: float = 1e-6, - batch_shape: Union[torch.Size, Sequence[int]] = None, - event_shape: Union[torch.Size, Sequence[int]] = None, + batch_shape: torch.Size | Sequence[int] = None, + event_shape: torch.Size | Sequence[int] = None, ): if batch_shape is None: batch_shape = torch.Size([]) @@ -640,8 +638,8 @@ class TanhDelta(FasterTransformedDistribution): def __init__( self, param: torch.Tensor, - low: Union[torch.Tensor, float] = -1.0, - high: Union[torch.Tensor, float] = 1.0, + low: torch.Tensor | float = -1.0, + high: torch.Tensor | float = 1.0, event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, @@ -714,7 +712,7 @@ def max(self): self._warn_minmax() return self.high - def update(self, net_output: torch.Tensor) -> Optional[torch.Tensor]: + def update(self, net_output: torch.Tensor) -> torch.Tensor | None: loc = net_output if self.non_trivial: device = loc.device diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 9ba33806691..8e9cda99b3c 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,9 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from enum import Enum from functools import wraps -from typing import Any, Optional, Sequence, Union +from typing import Any, Sequence import torch import torch.distributions as D @@ -17,8 +19,8 @@ def _treat_categorical_params( - params: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: + params: torch.Tensor | None = None, +) -> torch.Tensor | None: if params is None: return None if params.shape[-1] == 1: @@ -94,8 +96,8 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs, ) -> None: @@ -126,12 +128,10 @@ def entropy(self): return -p_log_p.sum(-1) @_one_hot_wrapper(D.Categorical) - def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence]] = None - ) -> torch.Tensor: + def sample(self, sample_shape: torch.Size | Sequence | None = None) -> torch.Tensor: ... - def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) if hasattr(self, "logits") and self.logits is not None: @@ -217,13 +217,13 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, *, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = float("-inf"), - padding_value: Optional[int] = None, + padding_value: int | None = None, ) -> None: if not ((mask is None) ^ (indices is None)): raise ValueError( @@ -261,7 +261,7 @@ def __init__( self.num_samples = num_samples def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None + self, sample_shape: torch.Size | Sequence[int] | None = None ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size() @@ -298,10 +298,10 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: @staticmethod def _mask_logits( logits: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, neg_inf: float = float("-inf"), sparse_mask: bool = False, - padding_value: Optional[int] = None, + padding_value: int | None = None, ) -> torch.Tensor: if mask is None: return logits @@ -401,12 +401,12 @@ def probs(self): def __init__( self, - logits: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, + logits: torch.Tensor | None = None, + probs: torch.Tensor | None = None, mask: torch.Tensor = None, indices: torch.Tensor = None, neg_inf: float = float("-inf"), - padding_value: Optional[int] = None, + padding_value: int | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, ) -> None: self.grad_method = grad_method @@ -421,7 +421,7 @@ def __init__( @_one_hot_wrapper(MaskedCategorical) def sample( - self, sample_shape: Optional[Union[torch.Size, Sequence[int]]] = None + self, sample_shape: torch.Size | Sequence[int] | None = None ) -> torch.Tensor: ... @@ -439,7 +439,7 @@ def mode(self) -> torch.Tensor: def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1)) - def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) if hasattr(self, "logits") and self.logits is not None: diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 1350aeb2bc3..f8d481265cb 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - - +# # from https://github.com/toshas/torch_truncnorm +from __future__ import annotations import math from numbers import Number @@ -41,9 +41,7 @@ def __init__(self, a, b, validate_args=None, device=None): batch_shape = torch.Size() else: batch_shape = self.a.size() - super(TruncatedStandardNormal, self).__init__( - batch_shape, validate_args=validate_args - ) + super().__init__(batch_shape, validate_args=validate_args) if self.a.dtype != self.b.dtype: raise ValueError("Truncation bounds types are different") if any( @@ -154,7 +152,7 @@ def __init__(self, loc, scale, a, b, validate_args=None, device=None): self._non_std_b = b a = (a - self.loc) / self.scale b = (b - self.loc) / self.scale - super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) + super().__init__(a, b, validate_args=validate_args) self._log_scale = self.scale.log() self._mean = self._mean * self.scale + self.loc self._variance = self._variance * self.scale**2 @@ -167,7 +165,7 @@ def _from_std_rv(self, value): return value * self.scale + self.loc def cdf(self, value): - return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + return super().cdf(self._to_std_rv(value)) def icdf(self, value): sample = self._from_std_rv(super().icdf(value)) @@ -184,4 +182,4 @@ def icdf(self, value): def log_prob(self, value): value = self._to_std_rv(value) - return super(TruncatedNormal, self).log_prob(value) - self._log_scale + return super().log_prob(value) - self._log_scale diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 8c332c4efed..a64d55276c3 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from typing import Union +from __future__ import annotations import torch from torch import autograd, distributions as d @@ -15,7 +14,7 @@ from torch._dynamo import is_compiling as is_dynamo_compiling -def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: +def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float: if isinstance(elt, torch.Tensor): return elt.to(device) return elt @@ -55,7 +54,7 @@ def __init__(self, base_distribution, transforms, validate_args=None): raise ValueError("Make a ComposeTransform first.") else: raise ValueError( - "transforms must be a Transform or list, but was {}".format(transforms) + f"transforms must be a Transform or list, but was {transforms}" ) transform = self.transforms[0] # Reshape base_distribution according to transforms. diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 35a060e8d69..2b540d324c6 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -46,3 +46,39 @@ VDNMixer, ) from .utils import Squeeze2dLayer, SqueezeLayer + +__all__ = [ + "DistributionalDQNnet", + "BatchRenorm1d", + "DecisionTransformer", + "ConsistentDropout", + "ConsistentDropoutModule", + "NoisyLazyLinear", + "NoisyLinear", + "reset_noise", + "DreamerActor", + "ObsDecoder", + "ObsEncoder", + "RSSMPosterior", + "RSSMPrior", + "RSSMRollout", + "Conv2dNet", + "Conv3dNet", + "ConvNet", + "DdpgCnnActor", + "DdpgCnnQNet", + "DdpgMlpActor", + "DdpgMlpQNet", + "DTActor", + "DuelingCnnDQNet", + "DuelingMlpDQNet", + "MLP", + "OnlineDTActor", + "MultiAgentConvNet", + "MultiAgentMLP", + "MultiAgentNetBase", + "QMixer", + "VDNMixer", + "Squeeze2dLayer", + "SqueezeLayer", +] diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index c5534568af7..0f266e423b2 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import torch import torch.nn as nn diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index d0c40b4fbc4..923bfb9662d 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses - import importlib from contextlib import nullcontext from dataclasses import dataclass @@ -113,7 +112,7 @@ def __init__( f"Config of type {type(config)} is not supported." ) from err - super(DecisionTransformer, self).__init__() + super().__init__() with torch.device(device) if device is not None else nullcontext(): gpt_config = transformers.GPT2Config( diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 571ace39bc7..3458d494029 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -7,10 +7,9 @@ import functools import math import warnings -from typing import List, Optional, Sequence, Union +from typing import Sequence import torch - from tensordict.nn import TensorDictModuleBase from tensordict.utils import NestedKey from torch import distributions as d, nn @@ -18,6 +17,7 @@ from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter + from torchrl._utils import prod from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS @@ -56,8 +56,8 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[torch.dtype] = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | None = None, std_init: float = 0.1, ): nn.Module.__init__(self) @@ -128,7 +128,7 @@ def reset_noise(self) -> None: if self.bias_mu is not None: self.bias_epsilon.copy_(epsilon_out) - def _scale_noise(self, size: Union[int, torch.Size, Sequence]) -> torch.Tensor: + def _scale_noise(self, size: int | torch.Size | Sequence) -> torch.Tensor: if isinstance(size, int): size = (size,) x = torch.randn(*size, device=self.weight_mu.device) @@ -142,7 +142,7 @@ def weight(self) -> torch.Tensor: return self.weight_mu @property - def bias(self) -> Optional[torch.Tensor]: + def bias(self) -> torch.Tensor | None: if self.bias_mu is not None: if self.training: return self.bias_mu + self.bias_sigma * self.bias_epsilon @@ -177,8 +177,8 @@ def __init__( self, out_features: int, bias: bool = True, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[torch.dtype] = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | None = None, std_init: float = 0.1, ): super().__init__(0, 0, False, device=device) @@ -323,8 +323,8 @@ def __init__( scale_min: float = 0.01, scale_max: float = 10.0, learn_sigma: bool = True, - transform: Optional[d.Transform] = None, - device: Optional[DEVICE_TYPING] = None, + transform: d.Transform | None = None, + device: DEVICE_TYPING | None = None, ) -> None: super().__init__() self.action_dim = action_dim @@ -416,7 +416,7 @@ def forward(self, mu, state, _eps_gSDE): action = self.transform(action) return mu, sigma, action, _eps_gSDE - def to(self, device_or_dtype: Union[torch.dtype, DEVICE_TYPING]): + def to(self, device_or_dtype: torch.dtype | DEVICE_TYPING): if isinstance(device_or_dtype, DEVICE_TYPING_ARGS): self.transform = _cast_transform_device(self.transform, device_or_dtype) return super().to(device_or_dtype) @@ -458,8 +458,8 @@ def __init__( scale_min: float = 0.01, scale_max: float = 10.0, learn_sigma: bool = True, - transform: Optional[d.Transform] = None, - device: Optional[DEVICE_TYPING] = None, + transform: d.Transform | None = None, + device: DEVICE_TYPING | None = None, ) -> None: super().__init__( 0, @@ -642,8 +642,8 @@ class ConsistentDropoutModule(TensorDictModuleBase): def __init__( self, p: float, - in_keys: NestedKey | List[NestedKey], - out_keys: NestedKey | List[NestedKey] | None = None, + in_keys: NestedKey | list[NestedKey], + out_keys: NestedKey | list[NestedKey] | None = None, input_shape: torch.Size = None, input_dtype: torch.dtype | None = None, ): diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 60d4dd020ef..976f57dd5b9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings import torch diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 36a11a508f5..c1ce2b96f2b 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -5,10 +5,9 @@ from __future__ import annotations import dataclasses - from copy import deepcopy from numbers import Number -from typing import Callable, Dict, List, Sequence, Tuple, Type, Union +from typing import Callable, Sequence import torch from torch import nn @@ -165,14 +164,14 @@ def __init__( out_features: int | torch.Size = None, depth: int | None = None, num_cells: Sequence[int] | int | None = None, - activation_class: Type[nn.Module] | Callable = nn.Tanh, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.Tanh, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, dropout: float | None = None, bias_last_layer: bool = True, single_bias_last_layer: bool = False, - layer_class: Type[nn.Module] | Callable = nn.Linear, + layer_class: type[nn.Module] | Callable = nn.Linear, layer_kwargs: dict | None = None, activate_last_layer: bool = False, device: DEVICE_TYPING | None = None, @@ -244,7 +243,7 @@ def __init__( ] super().__init__(*layers) - def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: + def _make_net(self, device: DEVICE_TYPING | None) -> list[nn.Module]: layers = [] in_features = [self.in_features] + self.num_cells out_features = self.num_cells + [self._out_features_num] @@ -293,7 +292,7 @@ def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: return layers - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = (torch.cat([*inputs], -1),) @@ -408,15 +407,15 @@ def __init__( in_features: int | None = None, depth: int | None = None, num_cells: Sequence[int] | int = None, - kernel_sizes: Union[Sequence[int], int] = 3, + kernel_sizes: Sequence[int] | int = 3, strides: Sequence[int] | int = 1, paddings: Sequence[int] | int = 0, - activation_class: Type[nn.Module] | Callable = nn.ELU, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_class: type[nn.Module] | Callable | None = SquashDims, aggregator_kwargs: dict | None = None, squeeze_output: bool = False, device: DEVICE_TYPING | None = None, @@ -540,7 +539,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: *batch, C, L, W = inputs.shape if len(batch) > 1: inputs = inputs.flatten(0, len(batch) - 1) - out = super(ConvNet, self).forward(inputs) + out = super().forward(inputs) if len(batch) > 1: out = out.unflatten(0, batch) return out @@ -678,12 +677,12 @@ def __init__( kernel_sizes: Sequence[int] | int = 3, strides: Sequence[int] | int = 1, paddings: Sequence[int] | int = 0, - activation_class: Type[nn.Module] | Callable = nn.ELU, - activation_kwargs: dict | List[dict] | None = None, - norm_class: Type[nn.Module] | Callable | None = None, - norm_kwargs: dict | List[dict] | None = None, + activation_class: type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | list[dict] | None = None, + norm_class: type[nn.Module] | Callable | None = None, + norm_kwargs: dict | list[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_class: type[nn.Module] | Callable | None = SquashDims, aggregator_kwargs: dict | None = None, squeeze_output: bool = False, device: DEVICE_TYPING | None = None, @@ -1199,7 +1198,7 @@ def __init__( self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp, 6e-4, device=device) - def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, observation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden = self.convnet(observation) action = self.mlp(hidden) return action, hidden @@ -1478,7 +1477,7 @@ def __init__( "bias_last_layer": True, "activate_last_layer": True, } - mlp_net_kwargs_net1: Dict = ( + mlp_net_kwargs_net1: dict = ( mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else {} ) mlp1_net_default_kwargs.update(mlp_net_kwargs_net1) @@ -1539,7 +1538,7 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict | DecisionTransformer.DTConfig = None, + transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None, ): super().__init__() @@ -1577,7 +1576,7 @@ def forward( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_state = self.transformer(observation, action, return_to_go) mu = self.action_layer_mean(hidden_state) log_std = self.action_layer_logstd(hidden_state) @@ -1638,7 +1637,7 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict | DecisionTransformer.DTConfig = None, + transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None, ): super().__init__() @@ -1690,7 +1689,7 @@ def default_config(cls): ) -def _iter_maybe_over_single(item: dict | List[dict] | None, n): +def _iter_maybe_over_single(item: dict | list[dict] | None, n): if item is None: return iter([{} for _ in range(n)]) elif isinstance(item, dict): @@ -1703,7 +1702,7 @@ class _ExecutableLayer(nn.Module): """A thin wrapper around a function to be executed as a module.""" def __init__(self, func): - super(_ExecutableLayer, self).__init__() + super().__init__() self.func = func def forward(self, *args, **kwargs): diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 71b5c254d0a..e4f923d34dd 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -7,16 +7,13 @@ import abc from copy import deepcopy from textwrap import indent -from typing import Optional, Sequence, Tuple, Type, Union +from typing import Sequence import numpy as np - import torch - from tensordict import TensorDict from torch import nn from torchrl.data.utils import DEVICE_TYPING - from torchrl.modules.models import ConvNet, MLP from torchrl.modules.models.utils import _reset_parameters_recursive @@ -129,7 +126,7 @@ def exec_module(params, *input): return torch.vmap(exec_module, *args, **kwargs) - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = torch.cat([*inputs], -1) else: @@ -418,10 +415,10 @@ def __init__( *, centralized: bool | None = None, share_params: bool | None = None, - device: Optional[DEVICE_TYPING] = None, - depth: Optional[int] = None, - num_cells: Optional[Union[Sequence, int]] = None, - activation_class: Optional[Type[nn.Module]] = nn.Tanh, + device: DEVICE_TYPING | None = None, + depth: int | None = None, + num_cells: Sequence | int | None = None, + activation_class: type[nn.Module] | None = nn.Tanh, use_td_params: bool = True, **kwargs, ): @@ -631,10 +628,10 @@ def __init__( in_features: int | None = None, device: DEVICE_TYPING | None = None, num_cells: Sequence[int] | None = None, - kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, - strides: Union[Sequence, int] = 2, - paddings: Union[Sequence, int] = 0, - activation_class: Type[nn.Module] = nn.ELU, + kernel_sizes: Sequence[int | Sequence[int]] | int = 5, + strides: Sequence | int = 2, + paddings: Sequence | int = 0, + activation_class: type[nn.Module] = nn.ELU, use_td_params: bool = True, **kwargs, ): @@ -789,7 +786,7 @@ def __init__( self, n_agents: int, needs_state: bool, - state_shape: Union[Tuple[int, ...], torch.Size], + state_shape: tuple[int, ...] | torch.Size, device: DEVICE_TYPING, ): super().__init__() @@ -799,7 +796,7 @@ def __init__( self.needs_state = needs_state self.state_shape = state_shape - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: """Forward pass of the mixer. Args: @@ -1001,7 +998,7 @@ class QMixer(Mixer): def __init__( self, - state_shape: Union[Tuple[int, ...], torch.Size], + state_shape: tuple[int, ...] | torch.Size, mixing_embed_dim: int, n_agents: int, device: DEVICE_TYPING, diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 5a59bc55fa1..7173a030c5c 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch import torch.nn as nn @@ -17,7 +18,7 @@ def __init__( self, num_ch, ): - super(_ResNetBlock, self).__init__() + super().__init__() resnet_block = [] resnet_block.append(nn.ReLU(inplace=True)) resnet_block.append( diff --git a/torchrl/modules/models/rlhf.py b/torchrl/modules/models/rlhf.py index 48953e43a4a..8b4f01a38c6 100644 --- a/torchrl/modules/models/rlhf.py +++ b/torchrl/modules/models/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib from pathlib import Path diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 0c650087235..1ae6234a844 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -6,13 +6,11 @@ import inspect import warnings -from typing import Callable, Sequence, Type +from typing import Callable, Sequence import torch from torch import nn - from torchrl.data.utils import DEVICE_TYPING - from torchrl.modules.models.exploration import NoisyLazyLinear, NoisyLinear LazyMapping = { @@ -114,7 +112,7 @@ def _find_depth(depth: int | None, *list_or_ints: Sequence): def create_on_device( - module_class: Type[nn.Module] | Callable, + module_class: type[nn.Module] | Callable, device: DEVICE_TYPING | None, *args, **kwargs, diff --git a/torchrl/modules/planners/__init__.py b/torchrl/modules/planners/__init__.py index 56c0e48bc65..8ea9b2a3e01 100644 --- a/torchrl/modules/planners/__init__.py +++ b/torchrl/modules/planners/__init__.py @@ -6,3 +6,5 @@ from .cem import CEMPlanner from .common import MPCPlannerBase from .mppi import MPPIPlanner + +__all__ = ["CEMPlanner", "MPCPlannerBase", "MPPIPlanner"] diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index abc0e3d3f95..9739ce5e592 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch from tensordict import TensorDict, TensorDictBase diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 3d6a4961f50..35703e6cad7 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import abc -from typing import Optional import torch from tensordict import TensorDictBase @@ -52,7 +53,7 @@ def planning(self, td: TensorDictBase) -> torch.Tensor: def forward( self, tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, + tensordict_out: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: if "params" in kwargs or "vmap" in kwargs: diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 31c95650d25..e4b33ced697 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import torch from tensordict import TensorDict, TensorDictBase diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 3fb1559833a..add36202bba 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .actors import ( +from torchrl.modules.tensordict_module.actors import ( Actor, ActorCriticOperator, ActorCriticWrapper, @@ -21,8 +21,8 @@ TanhModule, ValueOperator, ) -from .common import SafeModule, VmapModule -from .exploration import ( +from torchrl.modules.tensordict_module.common import SafeModule, VmapModule +from torchrl.modules.tensordict_module.exploration import ( AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, @@ -30,11 +30,11 @@ OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) -from .probabilistic import ( +from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import ( +from torchrl.modules.tensordict_module.rnn import ( GRU, GRUCell, GRUModule, @@ -44,5 +44,44 @@ recurrent_mode, set_recurrent_mode, ) -from .sequence import SafeSequential -from .world_models import WorldModelWrapper +from torchrl.modules.tensordict_module.sequence import SafeSequential +from torchrl.modules.tensordict_module.world_models import WorldModelWrapper + +__all__ = [ + "Actor", + "ActorCriticOperator", + "ActorCriticWrapper", + "ActorValueOperator", + "DecisionTransformerInferenceWrapper", + "DistributionalQValueActor", + "DistributionalQValueHook", + "DistributionalQValueModule", + "LMHeadActorValueOperator", + "MultiStepActorWrapper", + "ProbabilisticActor", + "QValueActor", + "QValueHook", + "QValueModule", + "TanhModule", + "ValueOperator", + "SafeModule", + "VmapModule", + "AdditiveGaussianModule", + "AdditiveGaussianWrapper", + "EGreedyModule", + "EGreedyWrapper", + "OrnsteinUhlenbeckProcessModule", + "OrnsteinUhlenbeckProcessWrapper", + "SafeProbabilisticModule", + "SafeProbabilisticTensorDictSequential", + "GRU", + "GRUCell", + "GRUModule", + "LSTM", + "LSTMCell", + "LSTMModule", + "recurrent_mode", + "set_recurrent_mode", + "SafeSequential", + "WorldModelWrapper", +] diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index ca76acc4160..e4b91c1a543 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,10 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence import torch - from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( CompositeDistribution, @@ -98,10 +97,10 @@ class Actor(SafeModule): def __init__( self, module: nn.Module, - in_keys: Optional[Sequence[NestedKey]] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, *, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, **kwargs, ): if in_keys is None: @@ -360,10 +359,10 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): def __init__( self, module: TensorDictModule, - in_keys: Union[NestedKey, Sequence[NestedKey]], - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: NestedKey | Sequence[NestedKey], + out_keys: Sequence[NestedKey] | None = None, *, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, **kwargs, ): distribution_class = kwargs.get("distribution_class") @@ -450,8 +449,8 @@ class ValueOperator(TensorDictModule): def __init__( self, module: nn.Module, - in_keys: Optional[Sequence[NestedKey]] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, ) -> None: if in_keys is None: in_keys = ["observation"] @@ -532,12 +531,12 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: Optional[str] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, - var_nums: Optional[int] = None, - spec: Optional[TensorSpec] = None, + action_space: str | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, + var_nums: int | None = None, + spec: TensorSpec | None = None, safe: bool = False, ): if isinstance(action_space, TensorSpec): @@ -748,12 +747,12 @@ class DistributionalQValueModule(QValueModule): def __init__( self, - action_space: Optional[str], + action_space: str | None, support: torch.Tensor, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, - var_nums: Optional[int] = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, + var_nums: int | None = None, spec: TensorSpec = None, safe: bool = False, ): @@ -911,10 +910,10 @@ class QValueHook: def __init__( self, action_space: str, - var_nums: Optional[int] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + var_nums: int | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError( @@ -938,7 +937,7 @@ def __init__( def __call__( self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: kwargs = {self.action_value_key: values} return self.qvalue_model(**kwargs) @@ -1007,10 +1006,10 @@ def __init__( self, action_space: str, support: torch.Tensor, - var_nums: Optional[int] = None, - action_value_key: Optional[NestedKey] = None, - action_mask_key: Optional[NestedKey] = None, - out_keys: Optional[Sequence[NestedKey]] = None, + var_nums: int | None = None, + action_value_key: NestedKey | None = None, + action_mask_key: NestedKey | None = None, + out_keys: Sequence[NestedKey] | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError("Using specs in action_space is deprecated") @@ -1125,9 +1124,9 @@ def __init__( in_keys=None, spec=None, safe=False, - action_space: Optional[str] = None, + action_space: str | None = None, action_value_key=None, - action_mask_key: Optional[NestedKey] = None, + action_mask_key: NestedKey | None = None, ): if isinstance(action_space, TensorSpec): raise RuntimeError( @@ -1268,10 +1267,10 @@ def __init__( in_keys=None, spec=None, safe=False, - var_nums: Optional[int] = None, - action_space: Optional[str] = None, + var_nums: int | None = None, + action_space: str | None = None, action_value_key: str = "action_value", - action_mask_key: Optional[NestedKey] = None, + action_mask_key: NestedKey | None = None, make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): @@ -1836,7 +1835,7 @@ def __init__( policy: TensorDictModule, *, inference_context: int = 5, - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, device: torch.device | None = None, ): super().__init__(policy) @@ -2066,7 +2065,7 @@ def __init__( high=None, clamp: bool = False, ): - super(TanhModule, self).__init__() + super().__init__() self.in_keys = in_keys if out_keys is None: out_keys = in_keys @@ -2291,8 +2290,8 @@ def __init__( actor: TensorDictModuleBase, n_steps: int, *, - action_keys: List[NestedKey] | None = None, - init_key: List[NestedKey] | None = None, + action_keys: list[NestedKey] | None = None, + init_key: list[NestedKey] | None = None, ): self.action_keys = action_keys self.init_key = init_key @@ -2387,7 +2386,7 @@ def forward( return tensordict @property - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: action_keys = self.__dict__.get("_action_keys", None) if action_keys is None: @@ -2411,7 +2410,7 @@ def action_keys(self, value): self._action_keys = [unravel_key(key) for key in value] @property - def _actor_keys_map(self) -> Dict[NestedKey, NestedKey]: + def _actor_keys_map(self) -> dict[NestedKey, NestedKey]: val = self.__dict__.get("_actor_keys_map_values", None) if val is None: diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 7914f663a45..2bd09e81e81 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,20 +9,16 @@ import inspect import re import warnings -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable import torch - from tensordict import TensorDictBase, unravel_key_list - from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase from tensordict.utils import NestedKey - from torch import nn from torch.nn import functional as F from torchrl.data.tensor_specs import Composite, TensorSpec - from torchrl.data.utils import DEVICE_TYPING _has_functorch = importlib.util.find_spec("functorch") is not None @@ -194,12 +190,15 @@ class SafeModule(TensorDictModule): def __init__( self, - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], + module: ( + FunctionalModule + | FunctionalModuleWithBuffers + | TensorDictModule + | nn.Module + ), in_keys: Iterable[str], out_keys: Iterable[str], - spec: Optional[TensorSpec] = None, + spec: TensorSpec | None = None, safe: bool = False, ): super().__init__(module, in_keys, out_keys) @@ -282,14 +281,14 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: """See :obj:`TensorDictModule.random(...)`.""" return self.random(tensordict) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TensorDictModule: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> TensorDictModule: if hasattr(self, "spec") and self.spec is not None: self.spec = self.spec.to(dest) out = super().to(dest) return out -def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): +def is_tensordict_compatible(module: TensorDictModule | nn.Module): """Returns `True` if a module can be used as a TensorDictModule, and False if it can't. If the signature is misleading an error is raised. @@ -356,13 +355,13 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): def ensure_tensordict_compatible( - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], - in_keys: Optional[List[NestedKey]] = None, - out_keys: Optional[List[NestedKey]] = None, + module: ( + FunctionalModule | FunctionalModuleWithBuffers | TensorDictModule | nn.Module + ), + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, safe: bool = False, - wrapper_type: Optional[Type] = TensorDictModule, + wrapper_type: type | None = TensorDictModule, **kwargs, ): """Ensures module is compatible with TensorDictModule and, if not, it wraps it.""" diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 62ac9710cc5..050b8e4f27e 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -5,12 +5,10 @@ from __future__ import annotations import warnings -from typing import Optional, Union import numpy as np import torch from tensordict import TensorDictBase - from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, @@ -95,8 +93,8 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, *, - action_key: Optional[NestedKey] = "action", - action_mask_key: Optional[NestedKey] = None, + action_key: NestedKey | None = "action", + action_mask_key: NestedKey | None = None, device: torch.device | None = None, ): if not isinstance(eps_init, float): @@ -209,9 +207,9 @@ def __init__( eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, - action_key: Optional[NestedKey] = "action", - action_mask_key: Optional[NestedKey] = None, - spec: Optional[TensorSpec] = None, + action_key: NestedKey | None = "action", + action_mask_key: NestedKey | None = None, + spec: TensorSpec | None = None, ): raise RuntimeError( "This class has been deprecated in favor of torchrl.modules.EGreedyModule." @@ -230,9 +228,9 @@ def __init__( annealing_num_steps: int = 1000, mean: float = 0.0, std: float = 1.0, - action_key: Optional[NestedKey] = "action", - spec: Optional[TensorSpec] = None, - safe: Optional[bool] = True, + action_key: NestedKey | None = "action", + spec: TensorSpec | None = None, + safe: bool | None = True, device: torch.device | None = None, ): raise RuntimeError( @@ -287,7 +285,7 @@ def __init__( mean: float = 0.0, std: float = 1.0, *, - action_key: Optional[NestedKey] = "action", + action_key: NestedKey | None = "action", # safe is already implemented because we project in the noise addition safe: bool = False, device: torch.device | None = None, @@ -383,14 +381,14 @@ def __init__( mu: float = 0.0, sigma: float = 0.2, dt: float = 1e-2, - x0: Optional[Union[torch.Tensor, np.ndarray]] = None, - sigma_min: Optional[float] = None, + x0: torch.Tensor | np.ndarray | None = None, + sigma_min: float | None = None, n_steps_annealing: int = 1000, - action_key: Optional[NestedKey] = "action", - is_init_key: Optional[NestedKey] = "is_init", + action_key: NestedKey | None = "action", + is_init_key: NestedKey | None = "is_init", spec: TensorSpec = None, safe: bool = True, - key: Optional[NestedKey] = None, + key: NestedKey | None = None, device: torch.device | None = None, ): raise RuntimeError( @@ -611,11 +609,11 @@ def __init__( mu: float = 0.0, sigma: float = 0.2, dt: float = 1e-2, - x0: Optional[Union[torch.Tensor, np.ndarray]] = None, - sigma_min: Optional[float] = None, + x0: torch.Tensor | np.ndarray | None = None, + sigma_min: float | None = None, n_steps_annealing: int = 1000, - key: Optional[NestedKey] = "action", - is_init_key: Optional[NestedKey] = "is_init", + key: NestedKey | None = "action", + is_init_key: NestedKey | None = "is_init", device: torch.device | None = None, ): super().__init__() @@ -688,7 +686,7 @@ def add_sample( self, tensordict: TensorDictBase, eps: float = 1.0, - is_init: Optional[torch.Tensor] = None, + is_init: torch.Tensor | None = None, ) -> TensorDictBase: # Get the nested tensordict where the action lives diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 79b0d015823..89e56672623 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -5,12 +5,9 @@ from __future__ import annotations import warnings -from typing import Dict, List, Optional, Union import torch - from tensordict import TensorDictBase, unravel_key_list - from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule, @@ -186,16 +183,16 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): def __init__( self, - in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey], - out_keys: NestedKey | List[NestedKey] | None = None, - spec: Optional[TensorSpec] = None, + in_keys: NestedKey | list[NestedKey] | dict[str, NestedKey], + out_keys: NestedKey | list[NestedKey] | None = None, + spec: TensorSpec | None = None, *, safe: bool = False, default_interaction_type: InteractionType = InteractionType.DETERMINISTIC, distribution_class: type = Delta, distribution_kwargs: dict | None = None, return_log_prob: bool = False, - log_prob_keys: List[NestedKey] | None = None, + log_prob_keys: list[NestedKey] | None = None, log_prob_key: NestedKey | None = None, cache_dist: bool = False, n_empirical_estimate: int = 1000, @@ -315,7 +312,7 @@ class SafeProbabilisticTensorDictSequential( def __init__( self, - *modules: Union[TensorDictModule, ProbabilisticTensorDictModule], + *modules: TensorDictModule | ProbabilisticTensorDictModule, partial_tolerant: bool = False, ) -> None: super().__init__(*modules, partial_tolerant=partial_tolerant) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 815756c528a..598c8026578 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -6,27 +6,19 @@ import typing import warnings -from typing import Any, Optional, Tuple +from typing import Any import torch import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list - from tensordict.base import NO_DEFAULT - from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase from tensordict.utils import expand_as_right, prod, set_lazy_legacy - from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded -from torchrl.objectives.value.functional import ( - _inv_pad_sequence, - _split_and_pad_sequence, -) -from torchrl.objectives.value.utils import _get_num_per_traj_init class LSTMCell(RNNCellBase): @@ -78,8 +70,8 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ) -> Tuple[Tensor, Tensor]: + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None + ) -> tuple[Tensor, Tensor]: if input.dim() not in (1, 2): raise ValueError( f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -721,6 +713,11 @@ def set_recurrent_mode(self, mode: bool = True): @dispatch def forward(self, tensordict: TensorDictBase): + from torchrl.objectives.value.functional import ( + _inv_pad_sequence, + _split_and_pad_sequence, + ) + # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None, None] shape = tensordict.shape @@ -745,6 +742,8 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): + from torchrl.objectives.value.utils import _get_num_per_traj_init + # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -795,16 +794,16 @@ def _lstm( steps, device, dtype, - hidden0_in: Optional[torch.Tensor] = None, - hidden1_in: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") if hidden1_in is None and hidden0_in is None: shape = (batch, steps) - hidden0_in, hidden1_in = [ + hidden0_in, hidden1_in = ( torch.zeros( *shape, self.lstm.num_layers, @@ -813,7 +812,7 @@ def _lstm( dtype=dtype, ) for _ in range(2) - ] + ) elif hidden1_in is None or hidden0_in is None: raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" @@ -887,7 +886,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -1536,6 +1535,11 @@ def set_recurrent_mode(self, mode: bool = True): @dispatch @set_lazy_legacy(False) def forward(self, tensordict: TensorDictBase): + from torchrl.objectives.value.functional import ( + _inv_pad_sequence, + _split_and_pad_sequence, + ) + # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None] shape = tensordict.shape @@ -1560,6 +1564,8 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): + from torchrl.objectives.value.utils import _get_num_per_traj_init + # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -1606,8 +1612,8 @@ def _gru( steps, device, dtype, - hidden_in: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_in: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") diff --git a/torchrl/modules/tensordict_module/world_models.py b/torchrl/modules/tensordict_module/world_models.py index 78384196926..ae3ea4d9a00 100644 --- a/torchrl/modules/tensordict_module/world_models.py +++ b/torchrl/modules/tensordict_module/world_models.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations from tensordict.nn import TensorDictModule, TensorDictSequential diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index ae57de949bf..cfca29b0f5d 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -10,9 +10,9 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"): - from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta + from torch.nn.parameter import _ParameterMeta else: - from torch.nn.parameter import _disabled_torch_function_impl + pass # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): @@ -26,3 +26,13 @@ def __instancecheck__(self, instance): from .mappings import biased_softplus, inv_softplus, mappings from .utils import get_primers_from_module + +__all__ = [ + "OrderedDict", + "torch", + "version", + "biased_softplus", + "inv_softplus", + "mappings", + "get_primers_from_module", +] diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py index a9e3ab189d5..ebcf776f605 100644 --- a/torchrl/modules/utils/mappings.py +++ b/torchrl/modules/utils/mappings.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 9a8914aab89..cb1e66a7f98 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import warnings diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8f5636db95..fd7ac06048b 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -3,24 +3,28 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .a2c import A2CLoss -from .common import LossModule -from .cql import CQLLoss, DiscreteCQLLoss -from .crossq import CrossQLoss -from .ddpg import DDPGLoss -from .decision_transformer import DTLoss, OnlineDTLoss -from .dqn import DistributionalDQNLoss, DQNLoss -from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss -from .gail import GAILLoss -from .iql import DiscreteIQLLoss, IQLLoss -from .multiagent import QMixerLoss -from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss -from .redq import REDQLoss -from .reinforce import ReinforceLoss -from .sac import DiscreteSACLoss, SACLoss -from .td3 import TD3Loss -from .td3_bc import TD3BCLoss -from .utils import ( +from torchrl.objectives.a2c import A2CLoss +from torchrl.objectives.common import LossModule +from torchrl.objectives.cql import CQLLoss, DiscreteCQLLoss +from torchrl.objectives.crossq import CrossQLoss +from torchrl.objectives.ddpg import DDPGLoss +from torchrl.objectives.decision_transformer import DTLoss, OnlineDTLoss +from torchrl.objectives.dqn import DistributionalDQNLoss, DQNLoss +from torchrl.objectives.dreamer import ( + DreamerActorLoss, + DreamerModelLoss, + DreamerValueLoss, +) +from torchrl.objectives.gail import GAILLoss +from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss +from torchrl.objectives.multiagent import QMixerLoss +from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss +from torchrl.objectives.redq import REDQLoss +from torchrl.objectives.reinforce import ReinforceLoss +from torchrl.objectives.sac import DiscreteSACLoss, SACLoss +from torchrl.objectives.td3 import TD3Loss +from torchrl.objectives.td3_bc import TD3BCLoss +from torchrl.objectives.utils import ( default_value_kwargs, distance_loss, group_optimizers, @@ -32,3 +36,43 @@ TargetNetUpdater, ValueEstimators, ) + +__all__ = [ + "A2CLoss", + "CQLLoss", + "ClipPPOLoss", + "CrossQLoss", + "DDPGLoss", + "DQNLoss", + "DTLoss", + "DiscreteCQLLoss", + "DiscreteIQLLoss", + "DiscreteSACLoss", + "DistributionalDQNLoss", + "DreamerActorLoss", + "DreamerModelLoss", + "DreamerValueLoss", + "GAILLoss", + "HardUpdate", + "IQLLoss", + "KLPENPPOLoss", + "LossModule", + "OnlineDTLoss", + "PPOLoss", + "QMixerLoss", + "REDQLoss", + "ReinforceLoss", + "SACLoss", + "SoftUpdate", + "TD3BCLoss", + "TD3Loss", + "TargetNetUpdater", + "ValueEstimators", + "add_random_module", + "default_value_kwargs", + "distance_loss", + "group_optimizers", + "hold_out_net", + "hold_out_params", + "next_state_value", +] diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 65ef79d4606..2ebcd4120c7 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -7,7 +7,6 @@ import contextlib from copy import deepcopy from dataclasses import dataclass -from typing import Tuple import torch from tensordict import ( @@ -29,7 +28,6 @@ from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, @@ -437,7 +435,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: @set_composite_lp_aggregate(False) def _log_probs( self, tensordict: TensorDictBase - ) -> Tuple[torch.Tensor, d.Distribution]: + ) -> tuple[torch.Tensor, d.Distribution]: # current log_prob of actions tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False @@ -466,7 +464,7 @@ def _log_probs( log_prob = log_prob.unsqueeze(-1) return log_prob, dist - def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: + def loss_critic(self, tensordict: TensorDictBase) -> tuple[torch.Tensor, float]: """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``. Returns the loss and the clip-fraction. diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index e4782195f4d..50fc7ee7fba 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -10,20 +10,19 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple +from typing import Iterator import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase - from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from tensordict.utils import Buffer from torch import nn from torch.nn import Parameter + from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import set_recurrent_mode - -from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators +from torchrl.modules.tensordict_module.rnn import set_recurrent_mode +from torchrl.objectives.utils import ValueEstimators from torchrl.objectives.value import ValueEstimatorBase try: @@ -126,8 +125,6 @@ class _AcceptedKeys: default values. """ - pass - tensor_keys: _AcceptedKeys _vmap_randomness = None default_value_estimator: ValueEstimators = None @@ -280,9 +277,9 @@ def convert_to_functional( self, module: TensorDictModule, module_name: str, - expand_dim: Optional[int] = None, + expand_dim: int | None = None, create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, + compare_against: list[Parameter] | None = None, **kwargs, ) -> None: """Converts a module to functional to be used in the loss. @@ -486,7 +483,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def named_parameters( self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, Parameter]]: + ) -> Iterator[tuple[str, Parameter]]: for name, param in super().named_parameters(prefix=prefix, recurse=recurse): if not name.startswith("_target"): yield name, param @@ -636,6 +633,8 @@ def vmap_randomness(self): """ if self._vmap_randomness is None: + import torchrl.objectives.utils + main_modules = list(self.__dict__.values()) + list(self.children()) modules = ( module @@ -644,7 +643,7 @@ def vmap_randomness(self): for module in main_module.modules() ) for val in modules: - if isinstance(val, RANDOM_MODULE_LIST): + if isinstance(val, torchrl.objectives.utils.RANDOM_MODULE_LIST): self._vmap_randomness = "different" break else: @@ -688,7 +687,10 @@ def __call__(self, x): return x -def add_ramdom_module(module): +def add_random_module(module): """Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random.""" - global RANDOM_MODULE_LIST - RANDOM_MODULE_LIST = RANDOM_MODULE_LIST + (module,) + import torchrl.objectives.utils + + torchrl.objectives.utils.RANDOM_MODULE_LIST = ( + torchrl.objectives.utils.RANDOM_MODULE_LIST + (module,) + ) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0e3a2447650..d94728985a2 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -9,8 +9,6 @@ from copy import deepcopy from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - import numpy as np import torch import torch.nn as nn @@ -22,8 +20,7 @@ from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type - -from torchrl.modules import ProbabilisticActor, QValueActor +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -35,7 +32,6 @@ distance_loss, ValueEstimators, ) - from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -274,7 +270,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -282,7 +278,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", delay_actor: bool = False, delay_qvalue: bool = True, gamma: float = None, @@ -581,7 +577,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor: ) return bc_actor_loss, metadata - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -705,7 +701,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def q_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._get_value_v( tensordict.copy(), @@ -743,7 +739,7 @@ def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_qval, metadata - def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def cql_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: pred_q1 = tensordict.get(self.tensor_keys.pred_q1) pred_q2 = tensordict.get(self.tensor_keys.pred_q2) @@ -1089,9 +1085,9 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[QValueActor, nn.Module], + value_network: QValueActor | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, gamma: float = None, action_space=None, @@ -1218,7 +1214,7 @@ def in_keys(self, values): def value_loss( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: td_copy = tensordict.clone(False) with self.value_network_params.to_module(self.value_network): self.value_network(td_copy) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 45976c3c48f..2f576f219b3 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -7,19 +7,17 @@ import math from dataclasses import dataclass from functools import wraps -from typing import Dict, List, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _reduce, @@ -256,7 +254,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -265,7 +263,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -559,7 +557,7 @@ def _cached_detached_qvalue_params(self): def actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which @@ -601,7 +599,7 @@ def actor_loss( def qvalue_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 34a7aa72242..50973c7077f 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -7,13 +7,12 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule - from tensordict.utils import NestedKey, unravel_key + from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -340,7 +339,7 @@ def loss_actor( def loss_value( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, dict]: + ) -> tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 05ade582d2a..1038a7151e8 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -6,16 +6,14 @@ import math from dataclasses import dataclass -from typing import Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey - from torch import distributions as d -from torchrl.modules import ProbabilisticActor +from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import _reduce, distance_loss @@ -85,7 +83,7 @@ def __init__( min_alpha: float = None, max_alpha: float = None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", samples_mc_entropy: int = 1, reduction: str = None, ) -> None: diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 5faadccfe93..7221f6b3be5 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -7,11 +7,9 @@ import math from dataclasses import dataclass from numbers import Number -from typing import List, Tuple, Union import numpy as np import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey @@ -149,7 +147,7 @@ def __post_init__(self): def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -159,7 +157,7 @@ def __init__( max_alpha: float = 10.0, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, gSDE: bool = False, gamma: float = None, @@ -362,7 +360,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detach_qvalue_network_params(self): return self.qvalue_network_params.detach() - def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def _actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, Tensor]: obs_keys = self.actor_network.in_keys tensordict_clone = tensordict.select(*obs_keys, strict=False) with set_exploration_type( diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 21376bfd5b2..6a4373e1751 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -6,24 +6,21 @@ import warnings from dataclasses import dataclass -from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import nn -from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space - from torchrl.envs.utils import step_mdp from torchrl.modules.tensordict_module.actors import ( DistributionalQValueActor, QValueActor, ) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible - from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, @@ -175,13 +172,13 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[QValueActor, nn.Module], + value_network: QValueActor | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, double_dqn: bool = False, gamma: float = None, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, priority_key: str = None, reduction: str = None, ) -> None: @@ -454,7 +451,7 @@ class _AcceptedKeys: def __init__( self, - value_network: Union[DistributionalQValueActor, nn.Module], + value_network: DistributionalQValueActor | nn.Module, *, gamma: float, delay_value: bool = True, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index a8c439288eb..0eb976da5d0 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -5,7 +5,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple import torch from tensordict import TensorDict @@ -20,10 +19,9 @@ _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, - # distance_loss, hold_out_net, ValueEstimators, -) +) # distance_loss, from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -103,8 +101,8 @@ def __init__( lambda_kl: float = 1.0, lambda_reco: float = 1.0, lambda_reward: float = 1.0, - reco_loss: Optional[str] = None, - reward_loss: Optional[str] = None, + reco_loss: str | None = None, + reward_loss: str | None = None, free_nats: int = 3, delayed_clamp: bool = False, global_average: bool = False, @@ -277,7 +275,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, ) - def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: + def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief).detach() with timeit("actor_loss/time-rollout"), hold_out_net( @@ -409,7 +407,7 @@ class _AcceptedKeys: def __init__( self, value_model: TensorDictModule, - value_loss: Optional[str] = None, + value_loss: str | None = None, discount_loss: bool = True, # for consistency with paper gamma: int = 0.99, ): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index ca1efcc337b..58057636cf3 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -6,16 +6,15 @@ import warnings from dataclasses import dataclass -from typing import List, Optional, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor + from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import _find_action_space - from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -256,8 +255,8 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], - value_network: Optional[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], + value_network: TensorDictModule | None, *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -410,7 +409,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) return td_out - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -455,7 +454,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) @@ -478,7 +477,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys tensordict = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False @@ -769,9 +768,9 @@ def __init__( self, actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, - value_network: Optional[TensorDictModule], + value_network: TensorDictModule | None, *, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", temperature: float = 1.0, @@ -809,7 +808,7 @@ def __init__( self.action_space = _find_action_space(action_space) self.reduction = reduction - def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -870,7 +869,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # Min Q value with torch.no_grad(): # Min Q value @@ -914,7 +913,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: + def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys next_td = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py index 7340cffd841..cec01e0ca0c 100644 --- a/torchrl/objectives/multiagent/__init__.py +++ b/torchrl/objectives/multiagent/__init__.py @@ -4,3 +4,5 @@ # LICENSE file in the root directory of this source tree. from .qmixer import QMixerLoss + +__all__ = ["QMixerLoss"] diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f3572cef9df..ce5752e70b0 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -8,7 +8,6 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -17,15 +16,11 @@ from torch import nn from torchrl.data.tensor_specs import TensorSpec - from torchrl.data.utils import _find_action_space - from torchrl.modules import SafeSequential from torchrl.modules.tensordict_module.actors import QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible - from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -193,13 +188,13 @@ class _AcceptedKeys: def __init__( self, - local_value_network: Union[QValueActor, nn.Module], - mixer_network: Union[TensorDictModule, nn.Module], + local_value_network: QValueActor | nn.Module, + mixer_network: TensorDictModule | nn.Module, *, - loss_function: Optional[str] = "l2", + loss_function: str | None = "l2", delay_value: bool = True, gamma: float = None, - action_space: Union[str, TensorSpec] = None, + action_space: str | TensorSpec = None, priority_key: str = None, ) -> None: super().__init__() diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 71455c83e2d..db887cf0fba 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -6,10 +6,8 @@ import contextlib import warnings - from copy import deepcopy from dataclasses import dataclass -from typing import List, Tuple import torch from tensordict import ( @@ -32,7 +30,6 @@ from torchrl._utils import _standardize from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, @@ -298,11 +295,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey | List[NestedKey] | None = None - action: NestedKey | List[NestedKey] = "action" - reward: NestedKey | List[NestedKey] = "reward" - done: NestedKey | List[NestedKey] = "done" - terminated: NestedKey | List[NestedKey] = "terminated" + sample_log_prob: NestedKey | list[NestedKey] | None = None + action: NestedKey | list[NestedKey] = "action" + reward: NestedKey | list[NestedKey] = "reward" + done: NestedKey | list[NestedKey] = "done" + terminated: NestedKey | list[NestedKey] = "terminated" def __post_init__(self): if self.sample_log_prob is None: @@ -333,7 +330,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -521,7 +518,7 @@ def _get_entropy( def _log_weight( self, tensordict: TensorDictBase, adv_shape: torch.Size - ) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]: + ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: with self.actor_network_params.to_module( self.actor_network @@ -891,7 +888,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -902,7 +899,7 @@ def __init__( if isinstance(clip_value, bool): clip_value = clip_epsilon if clip_value else None - super(ClipPPOLoss, self).__init__( + super().__init__( actor_network, critic_network, entropy_bonus=entropy_bonus, @@ -1162,14 +1159,14 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Tuple[int] = (), + normalize_advantage_exclude_dims: tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, clip_value: float | None = None, **kwargs, ): - super(KLPENPPOLoss, self).__init__( + super().__init__( actor_network, critic_network, entropy_bonus=entropy_bonus, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index fd94404c2c1..6be5172f0fd 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -7,11 +7,9 @@ import math from dataclasses import dataclass from numbers import Number -from typing import List, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -19,7 +17,6 @@ from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -266,7 +263,7 @@ def __post_init__(self): def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -276,7 +273,7 @@ def __init__( max_alpha: float = 10.0, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, gSDE: bool = False, gamma: float = None, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 12e027c08f5..d34313e5d8e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,12 +9,10 @@ from dataclasses import dataclass from functools import wraps from numbers import Number -from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams - from tensordict.nn import ( composite_lp_aggregate, CompositeDistribution, @@ -24,13 +22,13 @@ ) from tensordict.utils import expand_right, NestedKey from torch import Tensor + from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -317,8 +315,8 @@ def __post_init__(self): def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule | List[TensorDictModule], - value_network: Optional[TensorDictModule] = None, + qvalue_network: TensorDictModule | list[TensorDictModule], + value_network: TensorDictModule | None = None, *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -327,7 +325,7 @@ def __init__( max_alpha: float = None, action_spec=None, fixed_alpha: bool = False, - target_entropy: Union[str, float] = "auto", + target_entropy: str | float = "auto", delay_actor: bool = False, delay_qvalue: bool = True, delay_value: bool = True, @@ -653,7 +651,7 @@ def _cached_detached_qvalue_params(self): def _actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -693,7 +691,7 @@ def _cached_target_params_actor_value(self): def _qvalue_v1_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: target_params = self._cached_target_params_actor_value with set_exploration_type(self.deterministic_sampling_mode): target_value = self.value_estimator.value_estimate( @@ -808,7 +806,7 @@ def _compute_target_v2(self, tensordict) -> Tensor: def _qvalue_v2_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._compute_target_v2(tensordict) @@ -830,7 +828,7 @@ def _qvalue_v2_loss( def _value_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): @@ -1085,8 +1083,8 @@ def __init__( actor_network: ProbabilisticActor, qvalue_network: TensorDictModule, *, - action_space: Union[str, TensorSpec] = None, - num_actions: Optional[int] = None, + action_space: str | TensorSpec = None, + num_actions: int | None = None, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -1094,7 +1092,7 @@ def __init__( max_alpha: float = None, fixed_alpha: bool = False, target_entropy_weight: float = 0.98, - target_entropy: Union[str, Number] = "auto", + target_entropy: str | Number = "auto", delay_qvalue: bool = True, priority_key: str = None, separate_losses: bool = False, @@ -1338,7 +1336,7 @@ def _compute_target(self, tensordict) -> Tensor: def _value_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: target_value = self._compute_target(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), @@ -1376,7 +1374,7 @@ def _value_loss( def _actor_loss( self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> tuple[Tensor, dict[str, Tensor]]: # get probs and log probs for actions with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict.clone(False)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 124620ab040..40760ef95fb 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -5,18 +5,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -226,10 +223,10 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, action_spec: TensorSpec = None, - bounds: Optional[Tuple[float]] = None, + bounds: tuple[float] | None = None, num_qvalue_nets: int = 2, policy_noise: float = 0.2, noise_clip: float = 0.5, @@ -373,7 +370,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -406,7 +403,7 @@ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def value_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 08c79bdffd4..b7292c4fdb2 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -5,18 +5,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple import torch - from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule - from torchrl.objectives.utils import ( _cache_values, _reduce, @@ -241,10 +238,10 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule | List[TensorDictModule], + qvalue_network: TensorDictModule | list[TensorDictModule], *, action_spec: TensorSpec = None, - bounds: Optional[Tuple[float]] = None, + bounds: tuple[float] | None = None, num_qvalue_nets: int = 2, policy_noise: float = 0.2, noise_clip: float = 0.5, @@ -387,7 +384,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -441,7 +438,7 @@ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def qvalue_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: + def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 7ec4736862d..4a7d8466ddb 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,7 +8,7 @@ import re import warnings from enum import Enum -from typing import Iterable, List, Optional, Union +from typing import Iterable import torch from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key @@ -159,7 +159,7 @@ class TargetNetUpdater: def __init__( self, - loss_module: "LossModule", # noqa: F821 + loss_module: LossModule, # noqa: F821 ): from torchrl.objectives.common import LossModule @@ -284,7 +284,7 @@ def step(self) -> None: f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) for key, param in self._sources.items(): - target = self._targets.get("target_{}".format(key)) + target = self._targets.get(f"target_{key}") if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") self._step(param, target) @@ -320,16 +320,16 @@ class SoftUpdate(TargetNetUpdater): def __init__( self, - loss_module: Union[ - "DQNLoss", # noqa: F821 - "DDPGLoss", # noqa: F821 - "SACLoss", # noqa: F821 - "REDQLoss", # noqa: F821 - "TD3Loss", # noqa: F821 - ], + loss_module: ( + DQNLoss # noqa: F821 + | DDPGLoss # noqa: F821 + | SACLoss # noqa: F821 + | REDQLoss # noqa: F821 + | TD3Loss # noqa: F821 # noqa: F821 + ), *, eps: float = None, - tau: Optional[float] = None, + tau: float | None = None, ): if eps is None and tau is None: raise RuntimeError( @@ -350,7 +350,7 @@ def __init__( raise ValueError( f"Got eps = {eps} when it was supposed to be between 0 and 1." ) - super(SoftUpdate, self).__init__(loss_module) + super().__init__(loss_module) self.eps = eps def _step( @@ -375,11 +375,11 @@ class HardUpdate(TargetNetUpdater): def __init__( self, - loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss", "TD3Loss"], # noqa: F821 + loss_module: DQNLoss | DDPGLoss | SACLoss | TD3Loss, # noqa: F821 *, value_network_update_interval: float = 1000, ): - super(HardUpdate, self).__init__(loss_module) + super().__init__(loss_module) self.value_network_update_interval = value_network_update_interval self.counter = 0 @@ -441,10 +441,10 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @torch.no_grad() def next_state_value( tensordict: TensorDictBase, - operator: Optional[TensorDictModule] = None, + operator: TensorDictModule | None = None, next_val_key: str = "state_action_value", gamma: float = 0.99, - pred_next_val: Optional[Tensor] = None, + pred_next_val: Tensor | None = None, **kwargs, ) -> torch.Tensor: """Computes the next state value (without gradient) to compute a target value. @@ -550,7 +550,7 @@ def decorated_module(*module_args_params): ) from err -def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]: +def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor: """Reduces a tensor given the reduction method.""" if reduction == "none": result = tensor @@ -632,8 +632,8 @@ def _maybe_get_or_select(td, key_or_keys, target_shape=None): def _maybe_add_or_extend_key( - tensor_keys: List[NestedKey], - key_or_list_of_keys: NestedKey | List[NestedKey], + tensor_keys: list[NestedKey], + key_or_list_of_keys: NestedKey | list[NestedKey], prefix: NestedKey = None, ): if prefix is not None: diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 51496986153..4c8a29d6da3 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -14,3 +14,15 @@ ValueEstimatorBase, VTrace, ) + +__all__ = [ + "GAE", + "TD0Estimate", + "TD0Estimator", + "TD1Estimate", + "TD1Estimator", + "TDLambdaEstimate", + "TDLambdaEstimator", + "ValueEstimatorBase", + "VTrace", +] diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a52b6e40d97..8194b72ee3e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -10,7 +10,7 @@ from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps -from typing import Callable, List, Union +from typing import Callable import torch from tensordict import is_tensor_collection, TensorDictBase @@ -29,13 +29,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp - -from torchrl.objectives.utils import ( - _maybe_get_or_select, - _vmap_func, - hold_out_net, - RANDOM_MODULE_LIST, -) +from torchrl.objectives.utils import _maybe_get_or_select, _vmap_func, hold_out_net from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -153,7 +147,7 @@ def __post_init__(self): default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - value_network: Union[TensorDictModule, Callable] + value_network: TensorDictModule | Callable _vmap_randomness = None @property @@ -290,7 +284,6 @@ def in_keys(self): except AttributeError: # value network does not have an `in_keys` attribute in_keys = [] - pass return in_keys @property @@ -390,8 +383,12 @@ def vmap_randomness(self): do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): + import torchrl.objectives.utils + for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): + if isinstance( + module, torchrl.objectives.utils.RANDOM_MODULE_LIST + ): self._vmap_randomness = "different" do_break = True break @@ -1038,8 +1035,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, ) -> TensorDictBase: r"""Computes the TD(:math:`\lambda`) advantage given the data in tensordict. @@ -1307,8 +1304,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, time_dim: int | None = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -1646,8 +1643,8 @@ def forward( self, tensordict: TensorDictBase, *, - params: List[Tensor] | None = None, - target_params: List[Tensor] | None = None, + params: list[Tensor] | None = None, + target_params: list[Tensor] | None = None, time_dim: int | None = None, ) -> TensorDictBase: """Computes the V-Trace correction given the data in tensordict. diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 15e5d56d6bf..8484a025835 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -5,10 +5,8 @@ from __future__ import annotations import math - import warnings from functools import wraps -from typing import Optional, Tuple, Union import torch @@ -129,7 +127,7 @@ def generalized_advantage_estimate( terminated: torch.Tensor | None = None, *, time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Generalized advantage estimate of a trajectory. Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" @@ -271,8 +269,8 @@ def _fast_vec_gae( @_transpose_time def vec_generalized_advantage_estimate( - gamma: Union[float, torch.Tensor], - lmbda: Union[float, torch.Tensor], + gamma: float | torch.Tensor, + lmbda: float | torch.Tensor, state_value: torch.Tensor, next_state_value: torch.Tensor, reward: torch.Tensor, @@ -280,7 +278,7 @@ def vec_generalized_advantage_estimate( terminated: torch.Tensor | None = None, *, time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Vectorized Generalized advantage estimate of a trajectory. Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" @@ -382,7 +380,7 @@ def td0_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """TD(0) advantage estimate of a trajectory. Also known as bootstrapped Temporal Difference or one-step return. @@ -422,7 +420,7 @@ def td0_return_estimate( terminated: torch.Tensor | None = None, *, done: torch.Tensor | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: # noqa: D417 """TD(0) discounted return estimate of a trajectory. @@ -645,7 +643,7 @@ def vec_td1_return_estimate( reward, done: torch.Tensor, terminated: torch.Tensor | None = None, - rolling_gamma: Optional[bool] = None, + rolling_gamma: bool | None = None, time_dim: int = -2, ): """Vectorized TD(1) return estimate. @@ -970,7 +968,7 @@ def td_lambda_advantage_estimate( def _fast_td_lambda_return_estimate( - gamma: Union[torch.Tensor, float], + gamma: torch.Tensor | float, lmbda: float, next_state_value: torch.Tensor, reward: torch.Tensor, @@ -1035,7 +1033,7 @@ def vec_td_lambda_return_estimate( reward, done, terminated: torch.Tensor | None = None, - rolling_gamma: Optional[bool] = None, + rolling_gamma: bool | None = None, *, time_dim: int = -2, ): @@ -1277,11 +1275,11 @@ def vtrace_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, - rho_thresh: Union[float, torch.Tensor] = 1.0, - c_thresh: Union[float, torch.Tensor] = 1.0, + rho_thresh: float | torch.Tensor = 1.0, + c_thresh: float | torch.Tensor = 1.0, # not a kwarg because used directly time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Computes V-Trace off-policy actor critic targets. Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" diff --git a/torchrl/objectives/value/pg.py b/torchrl/objectives/value/pg.py deleted file mode 100644 index d62fe90a685..00000000000 --- a/torchrl/objectives/value/pg.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# implements a function that takes a sequence of returns and multiply its by the policy log_prob to get a differentiable objective diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index c28f19e3062..4d3c4c3b552 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from typing import Union +from __future__ import annotations import torch @@ -212,7 +211,7 @@ def _get_num_per_traj(done): def _split_and_pad_sequence( - tensor: Union[torch.Tensor, TensorDictBase], + tensor: torch.Tensor | TensorDictBase, splits: torch.Tensor, return_mask=False, time_dim=-1, @@ -318,7 +317,7 @@ def _fill_tensor(tensor): def _inv_pad_sequence( - tensor: Union[torch.Tensor, TensorDictBase], + tensor: torch.Tensor | TensorDictBase, splits: torch.Tensor, mask: torch.Tensor = None, ): diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index f6c9bcdefbb..149fb9e7a04 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -5,3 +5,13 @@ from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder + +__all__ = [ + "CSVLogger", + "MLFlowLogger", + "TensorboardLogger", + "WandbLogger", + "PixelRenderTransform", + "TensorDictRecorder", + "VideoRecorder", +] diff --git a/torchrl/record/loggers/__init__.py b/torchrl/record/loggers/__init__.py index 92714675046..48aa8d9175e 100644 --- a/torchrl/record/loggers/__init__.py +++ b/torchrl/record/loggers/__init__.py @@ -11,3 +11,13 @@ from .utils import generate_exp_name, get_logger from .wandb import WandbLogger + +__all__ = [ + "Logger", + "CSVLogger", + "MLFlowLogger", + "TensorboardLogger", + "generate_exp_name", + "get_logger", + "WandbLogger", +] diff --git a/torchrl/record/loggers/common.py b/torchrl/record/loggers/common.py index b8325763166..e6db65eb816 100644 --- a/torchrl/record/loggers/common.py +++ b/torchrl/record/loggers/common.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import abc -from typing import Dict, Sequence, Union +from typing import Sequence from torch import Tensor @@ -21,7 +22,7 @@ def __init__(self, exp_name: str, log_dir: str) -> None: self.experiment = self._create_experiment() @abc.abstractmethod - def _create_experiment(self) -> "Experiment": # noqa: F821 + def _create_experiment(self) -> Experiment: # noqa: F821 ... @abc.abstractmethod @@ -33,7 +34,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non ... @abc.abstractmethod - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 ... @abc.abstractmethod diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 745e0feabb1..ae5124fb11c 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -7,11 +7,10 @@ import os from collections import defaultdict from pathlib import Path -from typing import Dict, Optional, Sequence, Union +from typing import Sequence import tensordict.utils import torch - from tensordict import MemoryMappedTensor from torch import Tensor @@ -35,7 +34,7 @@ def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30): self.files = {} - def add_scalar(self, name: str, value: float, global_step: Optional[int] = None): + def add_scalar(self, name: str, value: float, global_step: int | None = None): if global_step is None: global_step = len(self.scalars[name]) value = float(value) @@ -50,7 +49,7 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) fd.write(",".join([str(global_step), str(value)]) + "\n") fd.flush() - def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs): + def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs): """Writes a video on a file on disk. The video format can be one of @@ -106,7 +105,7 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'." ) - def add_text(self, tag, text, global_step: Optional[int] = None): + def add_text(self, tag, text, global_step: int | None = None): if global_step is None: global_step = self.videos_counter[tag] self.videos_counter[tag] += 1 @@ -161,7 +160,7 @@ def __init__( super().__init__(exp_name=exp_name, log_dir=log_dir) self._has_imported_moviepy = False - def _create_experiment(self) -> "CSVExperiment": + def _create_experiment(self) -> CSVExperiment: """Creates a CSV experiment.""" log_dir = str(os.path.join(self.log_dir, self.exp_name)) return CSVExperiment( @@ -205,7 +204,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/loggers/mlflow.py b/torchrl/record/loggers/mlflow.py index 548d8213279..e2df9f30f42 100644 --- a/torchrl/record/loggers/mlflow.py +++ b/torchrl/record/loggers/mlflow.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence from torch import Tensor @@ -34,7 +36,7 @@ def __init__( self, exp_name: str, tracking_uri: str, - tags: Optional[Dict[str, Any]] = None, + tags: dict[str, Any] | None = None, *, video_fps: int = 30, **kwargs, @@ -51,7 +53,7 @@ def __init__( self.video_log_counter = 0 self.video_fps = video_fps - def _create_experiment(self) -> "mlflow.ActiveRun": # noqa + def _create_experiment(self) -> mlflow.ActiveRun: # noqa import mlflow """Creates an mlflow experiment. @@ -70,7 +72,7 @@ def _create_experiment(self) -> "mlflow.ActiveRun": # noqa self.id = experiment.experiment_id return mlflow.start_run(experiment_id=self.id) - def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to mlflow. Args: @@ -118,7 +120,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: torchvision.io.write_video(filename=f.name, video_array=video, fps=fps) mlflow.log_artifact(f.name, "videos") - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/loggers/tensorboard.py b/torchrl/record/loggers/tensorboard.py index 5ecc9742614..39518807046 100644 --- a/torchrl/record/loggers/tensorboard.py +++ b/torchrl/record/loggers/tensorboard.py @@ -2,10 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os -from typing import Dict, Sequence, Union +from typing import Sequence from torch import Tensor @@ -31,7 +33,7 @@ def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None: self._has_imported_moviepy = False - def _create_experiment(self) -> "SummaryWriter": # noqa + def _create_experiment(self) -> SummaryWriter: # noqa """Creates a tensorboard experiment. Args: @@ -91,7 +93,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index 226135f333f..5fe443db301 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations import os import pathlib diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index c015c2b0214..3d23a485458 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util import os import warnings -from typing import Dict, Optional, Sequence, Union +from typing import Sequence from torch import Tensor @@ -93,7 +95,7 @@ def __init__( self.video_log_counter = 0 - def _create_experiment(self) -> "WandbLogger": + def _create_experiment(self) -> WandbLogger: """Creates a wandb experiment. Args: @@ -111,7 +113,7 @@ def _create_experiment(self) -> "WandbLogger": return wandb.init(**self._wandb_kwargs) - def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: + def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to wandb. Args: @@ -173,7 +175,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: **kwargs, ) - def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821 + def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 1181d6e2d0d..c2cf93dd119 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -7,15 +7,12 @@ import importlib.util import math from copy import copy -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Sequence import numpy as np import torch - from tensordict import NonTensorData, TensorDictBase - from tensordict.utils import NestedKey - from torchrl._utils import _can_be_pickled from torchrl.data import TensorSpec from torchrl.data.tensor_specs import NonTensor, Unbounded @@ -108,11 +105,11 @@ def __init__( self, logger: Logger, tag: str, - in_keys: Optional[Sequence[NestedKey]] = None, + in_keys: Sequence[NestedKey] | None = None, skip: int | None = None, - center_crop: Optional[int] = None, + center_crop: int | None = None, make_grid: bool | None = None, - out_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Sequence[NestedKey] | None = None, fps: int | None = None, **kwargs, ) -> None: @@ -239,7 +236,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return self._call(tensordict) - def dump(self, suffix: Optional[str] = None) -> None: + def dump(self, suffix: str | None = None) -> None: """Writes the video to the ``self.logger`` attribute. Calling ``dump`` when no image has been stored in a no-op. @@ -296,7 +293,7 @@ def __init__( out_file_base: str, skip_reset: bool = True, skip: int = 4, - in_keys: Optional[Sequence[str]] = None, + in_keys: Sequence[str] | None = None, ) -> None: if in_keys is None: in_keys = [] @@ -318,7 +315,7 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: self.td.append(_td) return next_tensordict - def dump(self, suffix: Optional[str] = None) -> None: + def dump(self, suffix: str | None = None) -> None: if suffix is None: tag = self.tag else: @@ -430,7 +427,7 @@ class PixelRenderTransform(Transform): def __init__( self, - out_keys: List[NestedKey] = None, + out_keys: list[NestedKey] = None, preproc: Callable[ [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor ] = None, @@ -544,7 +541,7 @@ def enabled(self) -> bool: """Whether the recorder is enabled.""" return self._enabled - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: out = super().set_container(container) if isinstance(self.parent, EnvBase): # Start the env if needed diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 9d593d64f17..0136da42761 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -20,3 +20,21 @@ TrainerHookBase, UpdateWeights, ) + +__all__ = [ + "BatchSubSampler", + "ClearCudaCache", + "CountFramesLog", + "LogReward", + "LogScalar", + "LogValidationReward", + "mask_batch", + "OptimizerHook", + "Recorder", + "ReplayBufferTrainer", + "RewardNormalizer", + "SelectKeys", + "Trainer", + "TrainerHookBase", + "UpdateWeights", +] diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index b09becdc15a..90c4f91aaa2 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -20,3 +20,21 @@ from .models import make_dqn_actor, make_dreamer from .replay_buffer import make_replay_buffer from .trainers import make_trainer + +__all__ = [ + "make_collector_offpolicy", + "make_collector_onpolicy", + "sync_async_collector", + "sync_sync_collector", + "correct_for_frame_skip", + "get_stats_random_rollout", + "parallel_env_constructor", + "transformed_env_constructor", + "LoggerConfig", + "make_dqn_loss", + "make_target_updater", + "make_dqn_actor", + "make_dreamer", + "make_replay_buffer", + "make_trainer", +] diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index d60773a04c2..4f13597a8e2 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable from tensordict import TensorDictBase @@ -22,10 +23,10 @@ def sync_async_collector( - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, ) -> MultiaSyncDataCollector: """Runs asynchronous collectors, each running synchronous environments. @@ -82,12 +83,12 @@ def sync_async_collector( def sync_sync_collector( - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, -) -> Union[SyncDataCollector, MultiSyncDataCollector]: +) -> SyncDataCollector | MultiSyncDataCollector: """Runs synchronous collectors, each running synchronous environments. E.g. @@ -164,16 +165,16 @@ def sync_sync_collector( def _make_collector( - collector_class: Type, - env_fns: Union[Callable, List[Callable]], - env_kwargs: Optional[Union[dict, List[dict]]], + collector_class: type, + env_fns: Callable | list[Callable], + env_kwargs: dict | list[dict] | None, policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int = -1, frames_per_batch: int = 200, - total_frames: Optional[int] = None, - postproc: Optional[Callable] = None, - num_env_per_collector: Optional[int] = None, - num_collectors: Optional[int] = None, + total_frames: int | None = None, + postproc: Callable | None = None, + num_env_per_collector: int | None = None, + num_collectors: int | None = None, **kwargs, ) -> DataCollectorBase: if env_kwargs is None: @@ -249,11 +250,9 @@ def _make_collector( def make_collector_offpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[ - TensorDictModuleWrapper, ProbabilisticTensorDictSequential - ], - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Optional[Dict] = None, + actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Returns a data collector for off-policy sota-implementations. @@ -313,11 +312,9 @@ def make_collector_offpolicy( def make_collector_onpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[ - TensorDictModuleWrapper, ProbabilisticTensorDictSequential - ], - cfg: "DictConfig", # noqa: F821 - make_env_kwargs: Optional[Dict] = None, + actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), + cfg: DictConfig, # noqa: F821 + make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Makes a collector in on-policy settings. diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index e236b61c8e5..32965742d66 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -2,12 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +# This makes omegaconf unhappy with typing.Any +# Therefore we need Optional and Union +# from __future__ import annotations + +import importlib.util from copy import copy from dataclasses import dataclass, field as dataclass_field -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Union import torch - from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase @@ -43,8 +48,16 @@ "dm_control": DMControlEnv, } +_has_omegaconf = importlib.util.find_spec("omegaconf") is not None +if _has_omegaconf: + from omegaconf import DictConfig +else: + + class DictConfig: # noqa + ... + -def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 +def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames @@ -208,9 +221,9 @@ def get_norm_state_dict(env): def transformed_env_constructor( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 video_tag: str = "", - logger: Optional[Logger] = None, + logger: Optional[Logger] = None, # noqa stats: Optional[dict] = None, norm_obs_only: bool = False, use_env_creator: bool = False, @@ -326,7 +339,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: def parallel_env_constructor( - cfg: "DictConfig", **kwargs # noqa: F821 + cfg: DictConfig, **kwargs # noqa: F821 ) -> Union[ParallelEnv, EnvCreator]: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. @@ -370,7 +383,7 @@ def parallel_env_constructor( @torch.no_grad() def get_stats_random_rollout( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, key: Optional[str] = None, ): @@ -450,7 +463,7 @@ def get_stats_random_rollout( def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: Union[str, Tuple[str, ...]] = None, + key: Union[str, tuple[str, ...]] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. @@ -530,7 +543,7 @@ class EnvConfig: # maximum steps per trajectory, frames per batch or any other factor in the algorithm, # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 # the actual number of frames retrieved will be 200e6. Default=1. - reward_scaling: Optional[float] = None + reward_scaling: Any = None # noqa # scale of the reward. reward_loc: float = 0.0 # location of the reward. diff --git a/torchrl/trainers/helpers/logger.py b/torchrl/trainers/helpers/logger.py index b0b37533519..b06c3593557 100644 --- a/torchrl/trainers/helpers/logger.py +++ b/torchrl/trainers/helpers/logger.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass, field from typing import Any diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 152d7e2891f..91c8f5f8675 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate from torchrl.objectives.common import LossModule @@ -12,8 +13,8 @@ def make_target_updater( - cfg: "DictConfig", loss_module: LossModule # noqa: F821 -) -> Optional[TargetNetUpdater]: + cfg: DictConfig, loss_module: LossModule # noqa: F821 +) -> TargetNetUpdater | None: """Builds a target network weight update object.""" if cfg.loss == "double": if not cfg.hard_update: @@ -35,7 +36,7 @@ def make_target_updater( return target_net_updater -def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]: +def make_dqn_loss(model, cfg) -> tuple[DQNLoss, TargetNetUpdater | None]: """Builds the DQN loss module.""" loss_kwargs = {} if cfg.distributional: diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 6e74386ed63..543bf940031 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import itertools from dataclasses import dataclass @@ -60,7 +62,7 @@ def make_dqn_actor( - proof_environment: EnvBase, cfg: "DictConfig", device: torch.device # noqa: F821 + proof_environment: EnvBase, cfg: DictConfig, device: torch.device # noqa: F821 ) -> Actor: """DQN constructor helper function. @@ -194,7 +196,7 @@ def make_dqn_actor( @set_lazy_legacy(False) def make_dreamer( - cfg: "DictConfig", # noqa: F821 + cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, device: DEVICE_TYPING = "cpu", action_key: str = "action", diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index 6ccbb15a291..d0da6a02964 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -2,8 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional import torch from torchrl._utils import _make_ordinal_device @@ -18,7 +19,7 @@ def make_replay_buffer( - device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 + device: DEVICE_TYPING, cfg: DictConfig # noqa: F821 ) -> ReplayBuffer: # noqa: F821 """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = _make_ordinal_device(torch.device(device)) @@ -52,7 +53,7 @@ class ReplayArgsConfig: # buffer size, in number of frames stored. Default=1e6 prb: bool = False # whether a Prioritized replay buffer should be used instead of a more basic circular one. - buffer_scratch_dir: Optional[str] = None + buffer_scratch_dir: str | None = None # directory where the buffer data should be stored. If none is passed, they will be placed in /tmp/ buffer_prefetch: int = 10 # prefetching queue length for the replay buffer diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 4819d9e07e8..4a1e35e0e4a 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Union from warnings import warn import torch @@ -51,7 +51,7 @@ class TrainerConfig: # Optimizer to be used. lr_scheduler: str = "cosine" # LR scheduler. - selected_keys: Optional[List] = None + selected_keys: list | None = None # a list of strings that indicate the data that should be kept from the data collector. Since storing and # retrieving information from the replay buffer does not come for free, limiting the amount of data # passed to it can improve the algorithm performance. @@ -80,14 +80,12 @@ class TrainerConfig: def make_trainer( collector: DataCollectorBase, loss_module: LossModule, - recorder: Optional[EnvBase] = None, - target_net_updater: Optional[TargetNetUpdater] = None, - policy_exploration: Optional[ - Union[TensorDictModuleWrapper, TensorDictModule] - ] = None, - replay_buffer: Optional[ReplayBuffer] = None, - logger: Optional[Logger] = None, - cfg: "DictConfig" = None, # noqa: F821 + recorder: EnvBase | None = None, + target_net_updater: TargetNetUpdater | None = None, + policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None, + replay_buffer: ReplayBuffer | None = None, + logger: Logger | None = None, + cfg: DictConfig = None, # noqa: F821 ) -> Trainer: """Creates a Trainer instance given its constituents. diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 65be247cd33..d70a7358eb0 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -11,7 +11,7 @@ from collections import defaultdict, OrderedDict from copy import deepcopy from textwrap import indent -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Sequence, Tuple import numpy as np import torch.nn @@ -70,11 +70,11 @@ class TrainerHookBase: """An abstract hooking class for torchrl Trainer class.""" @abc.abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: raise NotImplementedError @abc.abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise NotImplementedError @abc.abstractmethod @@ -143,7 +143,7 @@ def __new__(cls, *args, **kwargs): # trackers cls._optim_count: int = 0 cls._collected_frames: int = 0 - cls._last_log: Dict[str, Any] = {} + cls._last_log: dict[str, Any] = {} cls._last_save: int = 0 cls.collected_frames = 0 cls._app_state = None @@ -156,16 +156,16 @@ def __init__( total_frames: int, frame_skip: int, optim_steps_per_batch: int, - loss_module: Union[LossModule, Callable[[TensorDictBase], TensorDictBase]], - optimizer: Optional[optim.Optimizer] = None, - logger: Optional[Logger] = None, + loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase], + optimizer: optim.Optimizer | None = None, + logger: Logger | None = None, clip_grad_norm: bool = True, clip_norm: float = None, progress_bar: bool = True, seed: int = None, save_trainer_interval: int = 10000, log_interval: int = 10000, - save_trainer_file: Optional[Union[str, pathlib.Path]] = None, + save_trainer_file: str | pathlib.Path | None = None, ) -> None: # objects @@ -248,7 +248,7 @@ def app_state(self): } return self._app_state - def state_dict(self) -> Dict: + def state_dict(self) -> dict: state = self._get_state() state_dict = OrderedDict( collector=self.collector.state_dict(), @@ -258,7 +258,7 @@ def state_dict(self) -> Dict: ) return state_dict - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: model_state_dict = state_dict["loss_module"] collector_state_dict = state_dict["collector"] @@ -296,7 +296,7 @@ def save_trainer(self, force_save: bool = False) -> None: if _save and self.save_trainer_file: self._save_trainer() - def load_from_file(self, file: Union[str, pathlib.Path], **kwargs) -> Trainer: + def load_from_file(self, file: str | pathlib.Path, **kwargs) -> Trainer: """Loads a file and its state-dict in the trainer. Keyword arguments are passed to the :func:`~torch.load` function. @@ -617,10 +617,10 @@ def __init__(self, keys: Sequence[str]): def __call__(self, batch: TensorDictBase) -> TensorDictBase: return batch.select(*self.keys) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer, name="select_keys") -> None: @@ -664,11 +664,11 @@ class ReplayBufferTrainer(TrainerHookBase): def __init__( self, replay_buffer: TensorDictReplayBuffer, - batch_size: Optional[int] = None, + batch_size: int | None = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, - max_dims: Optional[Sequence[int]] = None, + max_dims: Sequence[int] | None = None, ) -> None: self.replay_buffer = replay_buffer self.batch_size = batch_size @@ -704,7 +704,7 @@ def sample(self, batch: TensorDictBase) -> TensorDictBase: def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "replay_buffer": self.replay_buffer.state_dict(), } @@ -738,7 +738,7 @@ class OptimizerHook(TrainerHookBase): def __init__( self, optimizer: optim.Optimizer, - loss_components: Optional[Sequence[str]] = None, + loss_components: Sequence[str] | None = None, ): if loss_components is not None and not loss_components: raise ValueError( @@ -788,10 +788,10 @@ def __call__( return losses_td - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer, name="optimizer") -> None: @@ -838,7 +838,7 @@ def __init__( self, logname="r_training", log_pbar: bool = False, - reward_key: Union[str, tuple] = None, + reward_key: str | tuple = None, ): self.logname = logname self.log_pbar = log_pbar @@ -846,7 +846,7 @@ def __init__( reward_key = REWARD_KEY self.reward_key = reward_key - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: if ("collector", "mask") in batch.keys(True): return { self.logname: batch.get(self.reward_key)[ @@ -873,7 +873,7 @@ def __init__( self, logname="r_training", log_pbar: bool = False, - reward_key: Union[str, tuple] = None, + reward_key: str | tuple = None, ): warnings.warn( "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", @@ -971,7 +971,7 @@ def normalize_reward(self, tensordict: TensorDictBase) -> TensorDictBase: self._normalize_has_been_called = True return tensordict - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "_reward_stats": deepcopy(self._reward_stats), "scale": self.scale, @@ -979,7 +979,7 @@ def state_dict(self) -> Dict[str, Any]: "_update_has_been_called": self._update_has_been_called, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: for key, value in state_dict.items(): setattr(self, key, value) @@ -1126,10 +1126,10 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: raise RuntimeError("Sampled invalid steps") return td - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def register(self, trainer: Trainer, name: str = "batch_subsampler"): @@ -1194,9 +1194,9 @@ def __init__( policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = ExplorationType.RANDOM, - log_keys: Optional[List[Union[str, Tuple[str]]]] = None, - out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, - suffix: Optional[str] = None, + log_keys: list[str | tuple[str]] | None = None, + out_keys: dict[str | tuple[str], str] | None = None, + suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None, ) -> None: @@ -1223,7 +1223,7 @@ def __init__( self.log_pbar = log_pbar @torch.inference_mode() - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: out = None if self._count % self.record_interval == 0: with set_exploration_type(self.exploration_type): @@ -1259,13 +1259,13 @@ def __call__(self, batch: TensorDictBase) -> Dict: self.environment.close() return out - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return { "_count": self._count, "recorder_state_dict": self.environment.state_dict(), } - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self._count = state_dict["_count"] self.environment.load_state_dict(state_dict["recorder_state_dict"]) @@ -1289,9 +1289,9 @@ def __init__( policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = ExplorationType.RANDOM, - log_keys: Optional[List[Union[str, Tuple[str]]]] = None, - out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, - suffix: Optional[str] = None, + log_keys: list[str | tuple[str]] | None = None, + out_keys: dict[str | tuple[str], str] | None = None, + suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None, ) -> None: @@ -1352,7 +1352,7 @@ def register(self, trainer: Trainer, name: str = "update_weights"): self, ) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {} def load_state_dict(self, state_dict) -> None: @@ -1385,7 +1385,7 @@ def __init__(self, frame_skip: int, log_pbar: bool = False): self.frame_skip = frame_skip self.log_pbar = log_pbar - def __call__(self, batch: TensorDictBase) -> Dict: + def __call__(self, batch: TensorDictBase) -> dict: if ("collector", "mask") in batch.keys(True): current_frames = ( batch.get(("collector", "mask")).sum().item() * self.frame_skip @@ -1402,7 +1402,7 @@ def register(self, trainer: Trainer, name: str = "count_frames_log"): self, ) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {"frame_count": self.frame_count} def load_state_dict(self, state_dict) -> None: @@ -1410,7 +1410,7 @@ def load_state_dict(self, state_dict) -> None: def _check_input_output_typehint( - func: Callable, input: Type | List[Type], output: Type + func: Callable, input: type | list[type], output: type ): # Placeholder for a function that checks the types input / output against expectations return diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index fc360114377..06d89ad63ad 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -108,7 +108,6 @@ # If we can't set the method globally we can still run the parallel env with "fork" # This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"` mp_context = "fork" - pass # sphinx_gallery_end_ignore import os @@ -743,7 +742,7 @@ def print_csv_files_in_folder(folder_path): csv_files.append(os.path.join(dirpath, file)) for csv_file in csv_files: output_str += f"File: {csv_file}\n" - with open(csv_file, "r") as f: + with open(csv_file) as f: for i, line in enumerate(f): if i == 10: break