-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from Toni-SM/develop
Develop
- Loading branch information
Showing
20 changed files
with
400 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Hugging Face integration | ||
======================== | ||
|
||
.. raw:: html | ||
|
||
<hr> | ||
|
||
Download model from Hugging Face Hub | ||
------------------------------------ | ||
|
||
Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub | ||
|
||
These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others | ||
|
||
Visit the `skrl organization on the Hugging Face Hub <https://huggingface.co/skrl>`_ to access publicly available models! | ||
|
||
API | ||
""" | ||
|
||
.. autofunction:: skrl.utils.huggingface.download_model_from_huggingface |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ Utilities | |
========= | ||
|
||
.. contents:: Table of Contents | ||
:depth: 2 | ||
:depth: 1 | ||
:local: | ||
:backlinks: none | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from skrl import logger, __version__ | ||
|
||
|
||
def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str: | ||
"""Download a model from Hugging Face Hub | ||
:param repo_id: Hugging Face user or organization name and a repo name separated by a ``/`` | ||
:type repo_id: str | ||
:param filename: The name of the model file in the repo (default: ``"agent.pt"``) | ||
:type filename: str, optional | ||
:raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed | ||
:raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub | ||
:return: Local path of file or if networking is off, last version of file cached on disk | ||
:rtype: str | ||
Example:: | ||
# download trained agent from the skrl organization (https://huggingface.co/skrl) | ||
>>> from skrl.utils.huggingface import download_model_from_huggingface | ||
>>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO") | ||
'/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt' | ||
# download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1") | ||
>>> from skrl.utils.huggingface import download_model_from_huggingface | ||
>>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth") | ||
'/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth' | ||
""" | ||
try: | ||
import huggingface_hub | ||
except ImportError: | ||
logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") | ||
huggingface_hub = None | ||
|
||
if huggingface_hub is None: | ||
raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") | ||
|
||
# download and cache the model from Hugging Face Hub | ||
downloaded_model_file = huggingface_hub.hf_hub_download(repo_id=repo_id, | ||
filename=filename, | ||
library_name="skrl", | ||
library_version=__version__) | ||
|
||
return downloaded_model_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.9.0 | ||
0.9.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
import warnings | ||
import hypothesis | ||
import hypothesis.strategies as st | ||
|
||
import torch | ||
|
||
from skrl.agents.torch import Agent | ||
|
||
from skrl.agents.torch.a2c import A2C | ||
from skrl.agents.torch.amp import AMP | ||
from skrl.agents.torch.cem import CEM | ||
from skrl.agents.torch.ddpg import DDPG | ||
from skrl.agents.torch.dqn import DQN, DDQN | ||
from skrl.agents.torch.ppo import PPO | ||
from skrl.agents.torch.q_learning import Q_LEARNING | ||
from skrl.agents.torch.sac import SAC | ||
from skrl.agents.torch.sarsa import SARSA | ||
from skrl.agents.torch.td3 import TD3 | ||
from skrl.agents.torch.trpo import TRPO | ||
|
||
from .utils import DummyModel | ||
|
||
|
||
@pytest.fixture | ||
def classes_and_kwargs(): | ||
return [(A2C, {"models": {"policy": DummyModel()}}), | ||
(AMP, {"models": {"policy": DummyModel()}}), | ||
(CEM, {"models": {"policy": DummyModel()}}), | ||
(DDPG, {"models": {"policy": DummyModel()}}), | ||
(DQN, {"models": {"policy": DummyModel()}}), | ||
(DDQN, {"models": {"policy": DummyModel()}}), | ||
(PPO, {"models": {"policy": DummyModel()}}), | ||
(Q_LEARNING, {"models": {"policy": DummyModel()}}), | ||
(SAC, {"models": {"policy": DummyModel()}}), | ||
(SARSA, {"models": {"policy": DummyModel()}}), | ||
(TD3, {"models": {"policy": DummyModel()}}), | ||
(TRPO, {"models": {"policy": DummyModel()}})] | ||
|
||
|
||
def test_agent(capsys, classes_and_kwargs): | ||
for klass, kwargs in classes_and_kwargs: | ||
cfg = {"learning_starts": 1, | ||
"experiment": {"write_interval": 0}} | ||
agent: Agent = klass(cfg=cfg, **kwargs) | ||
|
||
agent.init() | ||
agent.pre_interaction(timestep=0, timesteps=1) | ||
# agent.act(None, timestep=0, timestesps=1) | ||
agent.record_transition(states=torch.tensor([]), | ||
actions=torch.tensor([]), | ||
rewards=torch.tensor([]), | ||
next_states=torch.tensor([]), | ||
terminated=torch.tensor([]), | ||
truncated=torch.tensor([]), | ||
infos={}, | ||
timestep=0, | ||
timesteps=1) | ||
agent.post_interaction(timestep=0, timesteps=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pytest | ||
import warnings | ||
import hypothesis | ||
import hypothesis.strategies as st | ||
|
||
import torch | ||
|
||
from skrl.envs.torch import Wrapper | ||
|
||
from skrl.envs.torch import wrap_env | ||
|
||
from .utils import DummyEnv | ||
|
||
|
||
@pytest.fixture | ||
def classes_and_kwargs(): | ||
return [] | ||
|
||
|
||
@pytest.mark.parametrize("wrapper", ["gym", "gymnasium", "dm", "robosuite", \ | ||
"isaacgym-preview2", "isaacgym-preview3", "isaacgym-preview4", "omniverse-isaacgym"]) | ||
def test_wrap_env(capsys, classes_and_kwargs, wrapper): | ||
env = DummyEnv(num_envs=1) | ||
|
||
try: | ||
env: Wrapper = wrap_env(env=env, wrapper=wrapper) | ||
except ValueError as e: | ||
warnings.warn(f"{e}. This test will be skipped for '{wrapper}'") | ||
except ModuleNotFoundError as e: | ||
warnings.warn(f"{e}. The '{wrapper}' wrapper module is not found. This test will be skipped") | ||
|
||
env.observation_space | ||
env.action_space | ||
env.state_space | ||
env.num_envs | ||
env.device |
Oops, something went wrong.