Skip to content

Commit

Permalink
Merge pull request #120 from Toni-SM/docs-patch
Browse files Browse the repository at this point in the history
Docs patch
  • Loading branch information
Toni-SM authored Aug 18, 2023
2 parents 8a7995c + 6062fa2 commit a7eb5c5
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ body:
attributes:
value: |
**Your help in making skrl better is greatly appreciated!**
* Please ensure that:
* The issue hasn't already been reported by using the [issue search](https://github.com/Toni-SM/skrl/search?q=is%3Aissue&type=issues).
* The issue (and its solution) is not listed in the skrl documentation [troubleshooting](https://skrl.readthedocs.io/en/latest/intro/installation.html#known-issues-and-troubleshooting) section.
Expand All @@ -18,7 +18,7 @@ body:
description: A clear and concise description of the bug/issue. Try to provide a minimal example to reproduce it (error/log messages are also helpful).
placeholder: |
Markdown formatting might be applied to the text.
```python
# use triple backticks for code blocks or error/log messages
```
Expand Down
4 changes: 3 additions & 1 deletion docs/source/api/models/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ skrl provides a Python mixin (:literal:`CategoricalMixin`) to assist in the crea

* The :ref:`Model <models_base_class>` base class constructor must be invoked before the mixins constructor.

.. note::
.. warning::

For models in JAX/Flax it is imperative to define all parameters (except ``observation_space``, ``action_space`` and ``device``) with default values to avoid errors (``TypeError: __init__() missing N required positional argument``) during initialization.

In addition, it is necessary to initialize the model's ``state_dict`` (via the ``init_state_dict`` method) after its instantiation to avoid errors (``AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'``) during its use.

.. tabs::

.. group-tab:: |_4| |pytorch| |_4|
Expand Down
4 changes: 3 additions & 1 deletion docs/source/api/models/deterministic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ skrl provides a Python mixin (:literal:`DeterministicMixin`) to assist in the cr

* The :ref:`Model <models_base_class>` base class constructor must be invoked before the mixins constructor.

.. note::
.. warning::

For models in JAX/Flax it is imperative to define all parameters (except ``observation_space``, ``action_space`` and ``device``) with default values to avoid errors (``TypeError: __init__() missing N required positional argument``) during initialization.

In addition, it is necessary to initialize the model's ``state_dict`` (via the ``init_state_dict`` method) after its instantiation to avoid errors (``AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'``) during its use.

.. tabs::

.. group-tab:: |_4| |pytorch| |_4|
Expand Down
4 changes: 3 additions & 1 deletion docs/source/api/models/gaussian.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ skrl provides a Python mixin (:literal:`GaussianMixin`) to assist in the creatio

* The :ref:`Model <models_base_class>` base class constructor must be invoked before the mixins constructor.

.. note::
.. warning::

For models in JAX/Flax it is imperative to define all parameters (except ``observation_space``, ``action_space`` and ``device``) with default values to avoid errors (``TypeError: __init__() missing N required positional argument``) during initialization.

In addition, it is necessary to initialize the model's ``state_dict`` (via the ``init_state_dict`` method) after its instantiation to avoid errors (``AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'``) during its use.

.. tabs::

.. group-tab:: |_4| |pytorch| |_4|
Expand Down
36 changes: 30 additions & 6 deletions docs/source/intro/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ To install **skrl** with pip, execute:

.. warning::

JAX installs its CPU version if not specified. For GPU/TPU versions see the JAX `installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.
JAX installs its CPU version if not specified. For GPU/TPU versions visit the JAX
`installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.

.. code-block:: bash
Expand Down Expand Up @@ -108,7 +109,8 @@ Clone or download the library from its GitHub repository (https://github.com/Ton

.. warning::

JAX installs its CPU version if not specified. For GPU/TPU versions see the JAX `installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.
JAX installs its CPU version if not specified. For GPU/TPU versions visit the JAX
`installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.

.. code-block:: bash
Expand Down Expand Up @@ -140,7 +142,8 @@ Clone or download the library from its GitHub repository (https://github.com/Ton

.. warning::

JAX installs its CPU version if not specified. For GPU/TPU versions see the JAX `installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.
JAX installs its CPU version if not specified. For GPU/TPU versions visit the JAX
`installation <https://github.com/google/jax#installation>`_ page before proceeding with the steps described below.

.. code-block:: bash
Expand Down Expand Up @@ -188,7 +191,26 @@ Bug detection and/or correction, feature requests and everything else are more t
AttributeError: 'Adam' object has no attribute '_warned_capturable_if_run_uncaptured'
2. When training/evaluating using JAX in Python 3.7 (e.g. OmniIsaacGymEnvs on Isaac Sim 2022.2.1 and earlier).
2. When installing the JAX version in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Orbit on Isaac Sim 2022.2.1 and earlier).

.. code-block:: text
ERROR: Ignored the following versions that require a different python version: 0.4.0 Requires-Python >=3.8; ...
ERROR: Could not find a version that satisfies the requirement jax>=0.4.3; extra == "jax" (from skrl[jax]) (from versions: 0.0, ..., 0.3.25)
ERROR: No matching distribution found for jax>=0.4.3; extra == "jax"
JAX support for Python 3.7 is up to version 0.3.25, while skrl requires ``jax>=0.4.3``.
Furthermore, ``jaxlib<=0.3.25`` builds are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.

However, it is possible to use **skrl** under these circumstances, subject to the following points:

* Install JAX, Flax and Optax manually using ``pip install jax flax optax`` and ignore the installation errors for skrl.

* The ``jax.Device = jax.xla.Device`` statement is required by skrl to support ``jax<0.4.3``.

* Overload models ``__hash__`` method to avoid :literal:`"TypeError: Failed to hash Flax Module"`.

3. When training/evaluating using JAX in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Orbit on Isaac Sim 2022.2.1 and earlier).

.. code-block:: text
Expand All @@ -201,14 +223,16 @@ Bug detection and/or correction, feature requests and everything else are more t
def __hash__(self):
return id(self)
3. When training/evaluating using JAX with the NVIDIA Isaac Gym Preview, Isaac Orbit or Omniverse Isaac Gym environments.
4. When training/evaluating using JAX with the NVIDIA Isaac Gym Preview, Isaac Orbit or Omniverse Isaac Gym environments.

.. code-block:: text
PxgCudaDeviceMemoryAllocator fail to allocate memory XXXXXX bytes!! Result = 2
RuntimeError: CUDA error: an illegal memory access was encountered
NVIDIA environments use PyTorch as a backend, and both PyTorch (for CUDA kernels, among others) and JAX preallocate GPU memory, which can lead to out-of-memory (OOM) problems. Reduce or disable GPU memory preallocation as indicated in JAX `GPU memory allocation <https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html>`_ to avoid this issue. For example:
NVIDIA environments use PyTorch as a backend, and both PyTorch (for CUDA kernels, among others) and JAX preallocate GPU memory,
which can lead to out-of-memory (OOM) problems. Reduce or disable GPU memory preallocation as indicated in JAX
`GPU memory allocation <https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html>`_ to avoid this issue. For example:

.. code-block:: bash
Expand Down
6 changes: 6 additions & 0 deletions docs/source/snippets/categorical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __call__(self, inputs, role):
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True)

# initialize model's state dict
policy.init_state_dict("policy")
# [end-mlp-setup-jax]

# [start-mlp-compact-jax]
Expand Down Expand Up @@ -138,6 +141,9 @@ def __call__(self, inputs, role):
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True)

# initialize model's state dict
policy.init_state_dict("policy")
# [end-mlp-compact-jax]

# =============================================================================
Expand Down
10 changes: 8 additions & 2 deletions docs/source/snippets/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,13 @@ def __call__(self, inputs, role):


# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
critic = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=False)

# initialize model's state dict
critic.init_state_dict("critic")
# [end-mlp-setup-jax]

# [start-mlp-compact-jax]
Expand All @@ -136,10 +139,13 @@ def __call__(self, inputs, role):


# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
critic = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=False)

# initialize model's state dict
critic.init_state_dict("critic")
# [end-mlp-compact-jax]

# =============================================================================
Expand Down
6 changes: 6 additions & 0 deletions docs/source/snippets/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __call__(self, inputs, role):
min_log_std=-20,
max_log_std=2,
reduction="sum")

# initialize model's state dict
policy.init_state_dict("policy")
# [end-mlp-setup-jax]

# [start-mlp-compact-jax]
Expand Down Expand Up @@ -171,6 +174,9 @@ def __call__(self, inputs, role):
min_log_std=-20,
max_log_std=2,
reduction="sum")

# initialize model's state dict
policy.init_state_dict("policy")
# [end-mlp-compact-jax]

# =============================================================================
Expand Down

0 comments on commit a7eb5c5

Please sign in to comment.