Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support discrete environment #286

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/envs/discrete_env.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
OmniSafe Discrete Environment
=============================

.. currentmodule:: omnisafe.envs.discrete_env

Discrete Environment Interface
------------------------------

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Documentation
^^^

.. autoclass:: DiscreteEnv
:members:
:private-members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page <htt
envs/wrapper
envs/safety_gymnasium
envs/mujoco_env
envs/discrete_env
envs/adapter


Expand Down
14 changes: 14 additions & 0 deletions docs/source/model/actor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ VAE Actor
.. autoclass:: VAE
:members:
:private-members:

Categorical Actor
-----------------

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Documentation
^^^

.. autoclass:: CategoricalActor
:members:
:private-members:
2 changes: 1 addition & 1 deletion docs/source/saferl/lag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Policy update

.. math::

\min _\lambda(1-\lambda) [J^R(\pi)-\lambda J^C(\pi)] \\
\min _{\lambda} (J^R(\pi)-\lambda (J^C(\pi)-d)) \\
\text { s.t. } \lambda \geq 0


Expand Down
14 changes: 10 additions & 4 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ def _wrapper(
cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True.
"""
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000, device=self._device)
self._eval_env = TimeLimit(self._eval_env, time_limit=1000, device=self._device)
time_limit = (
self._cfgs.train_cfgs.time_limit
if hasattr(self._cfgs.train_cfgs, 'time_limit')
else 1000
)
self._env = TimeLimit(self._env, time_limit=time_limit, device=self._device)
self._eval_env = TimeLimit(self._eval_env, time_limit=time_limit, device=self._device)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env, device=self._device)
self._eval_env = AutoReset(self._eval_env, device=self._device)
Expand All @@ -121,8 +126,9 @@ def _wrapper(
self._env = RewardNormalize(self._env, device=self._device)
if cost_normalize:
self._env = CostNormalize(self._env, device=self._device)
self._env = ActionScale(self._env, low=-1.0, high=1.0, device=self._device)
self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device)
if self._env.need_action_scale_wrapper:
self._env = ActionScale(self._env, low=-1.0, high=1.0, device=self._device)
self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env, device=self._device)
self._eval_env = Unsqueeze(self._eval_env, device=self._device)
Expand Down
5 changes: 3 additions & 2 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from omnisafe.algorithms import ALGORITHM2TYPE, ALGORITHMS, registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.envs import support_envs
from omnisafe.envs import ENVIRONMNET2TYPE, support_envs
from omnisafe.evaluator import Evaluator
from omnisafe.utils import distributed
from omnisafe.utils.config import Config, check_all_configs, get_default_kwargs_yaml
Expand Down Expand Up @@ -88,6 +88,7 @@ def _init_config(self) -> Config:
self.algo in ALGORITHMS['all']
), f"{self.algo} doesn't exist. Please choose from {ALGORITHMS['all']}."
self.algo_type = ALGORITHM2TYPE.get(self.algo, '')
self.env_type = ENVIRONMNET2TYPE.get(self.env_id, '')
if self.train_terminal_cfgs is not None:
if self.algo_type in ['model-based', 'offline']:
assert (
Expand Down Expand Up @@ -146,7 +147,7 @@ def _init_checks(self) -> None:

def _init_algo(self) -> None:
"""Initialize the algorithm."""
check_all_configs(self.cfgs, self.algo_type)
check_all_configs(self.cfgs, self.algo_type, self.env_type)
if distributed.fork(
self.cfgs.train_cfgs.parallel,
device=self.cfgs.train_cfgs.device,
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,15 @@ def _loss_pi(
"""
distribution = self._actor_critic.actor(obs)
logp_ = self._actor_critic.actor.log_prob(act)
std = self._actor_critic.actor.std
ratio = torch.exp(logp_ - logp)
loss = -(ratio * adv).mean()
entropy = distribution.entropy().mean().item()
if self._cfgs.model_cfgs.actor_type == 'gaussian_learning':
self._logger.store({'Train/PolicyStd': self._actor_critic.actor.std})
self._logger.store(
{
'Train/Entropy': entropy,
'Train/PolicyRatio': ratio,
'Train/PolicyStd': std,
'Loss/Loss_pi': loss.mean().item(),
},
)
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/algorithms/on_policy/base/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _loss_pi(
"""
distribution = self._actor_critic.actor(obs)
logp_ = self._actor_critic.actor.log_prob(act)
std = self._actor_critic.actor.std
ratio = torch.exp(logp_ - logp)
ratio_cliped = torch.clamp(
ratio,
Expand All @@ -76,11 +75,12 @@ def _loss_pi(
loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean()
# useful extra info
entropy = distribution.entropy().mean().item()
if self._cfgs.model_cfgs.actor_type == 'gaussian_learning':
self._logger.store({'Train/PolicyStd': self._actor_critic.actor.std})
self._logger.store(
{
'Train/Entropy': entropy,
'Train/PolicyRatio': ratio,
'Train/PolicyStd': std,
'Loss/Loss_pi': loss.mean().item(),
},
)
Expand Down
28 changes: 20 additions & 8 deletions omnisafe/common/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from abc import ABC, abstractmethod

import numpy as np
import torch
from gymnasium.spaces import Box
from gymnasium.spaces import Box, Discrete

from omnisafe.typing import DEVICE_CPU, OmnisafeSpace

Expand All @@ -28,7 +29,7 @@ class BaseBuffer(ABC):
r"""Abstract base class for buffer.

.. warning::
The buffer only supports Box spaces.
The buffer only supports ``Box`` and ``Discrete`` spaces.

In base buffer, we store the following data:

Expand Down Expand Up @@ -57,8 +58,8 @@ class BaseBuffer(ABC):
data (dict[str, torch.Tensor]): The data of the buffer.

Raises:
NotImplementedError: If the observation space or the action space is not Box.
NotImplementedError: If the action space or the action space is not Box.
NotImplementedError: If the observation space or the action space is not Box nor Discrete.
NotImplementedError: If the action space or the action space is not Box nor Discrete.
"""

def __init__(
Expand All @@ -70,12 +71,23 @@ def __init__(
) -> None:
"""Initialize an instance of :class:`BaseBuffer`."""
self._device: torch.device = device
if isinstance(obs_space, Box):
obs_buf = torch.zeros((size, *obs_space.shape), dtype=torch.float32, device=device)

if isinstance(obs_space, (Box, Discrete)):
obs_buf = torch.zeros(
(size, int(np.array(obs_space.shape).prod())),
dtype=torch.float32,
device=device,
)
else:
raise NotImplementedError
if isinstance(act_space, Box):
act_buf = torch.zeros((size, *act_space.shape), dtype=torch.float32, device=device)

if isinstance(act_space, (Box, Discrete)):
act_buf = torch.zeros(
(size, int(np.array(act_space.shape).prod())),
dtype=torch.float32,
device=device,
)

else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion omnisafe/common/buffer/onpolicy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class OnPolicyBuffer(BaseBuffer): # pylint: disable=too-many-instance-attribute
state-action pairs, ranging from ``GAE``, ``GAE-RTG`` , ``V-trace`` to ``Plain`` method.

.. warning::
The buffer only supports Box spaces.
The buffer only supports ``Box`` and ``Discrete`` spaces.

Compared to the base buffer, the on-policy buffer stores extra data:

Expand Down
2 changes: 1 addition & 1 deletion omnisafe/common/buffer/vector_onpolicy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class VectorOnPolicyBuffer(OnPolicyBuffer):
stored in a list of on-policy buffers, each of which corresponds to one environment.

.. warning::
The buffer only supports Box spaces.
The buffer only supports ``Box`` and ``Discrete`` spaces.

Args:
obs_space (OmnisafeSpace): Observation space.
Expand Down
36 changes: 36 additions & 0 deletions omnisafe/configs/on-policy/CPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,39 @@ defaults:
activation: tanh
# learning rate
lr: 0.001

CartPole-v1:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 500
# total number of steps to train
total_steps: 1000000
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"

Taxi-v3:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 200
# total number of steps to train
total_steps: 1000000
# algorithm configurations
algo_cfgs:
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"
38 changes: 38 additions & 0 deletions omnisafe/configs/on-policy/CPPOPID.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,41 @@ defaults:
penalty_max: 100.0
# Initial value of lagrangian multiplier
lagrangian_multiplier_init: 0.001

CartPole-v1:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 500
# total number of steps to train
total_steps: 1000000
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"

Taxi-v3:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 200
# total number of steps to train
total_steps: 1000000
# algorithm configurations
algo_cfgs:
# normalize observation
obs_normalize: False
# entropy coefficient
entropy_coef: 0.01
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"
38 changes: 38 additions & 0 deletions omnisafe/configs/on-policy/IPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,41 @@ defaults:
lambda_lr: 0.035
# Type of lagrangian optimizer
lambda_optimizer: "Adam"

CartPole-v1:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 500
# total number of steps to train
total_steps: 1000000
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"

Taxi-v3:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 200
# total number of steps to train
total_steps: 1000000
# algorithm configurations
algo_cfgs:
# normalize observation
obs_normalize: False
# entropy coefficient
entropy_coef: 0.01
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"
36 changes: 36 additions & 0 deletions omnisafe/configs/on-policy/NaturalPG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,39 @@ defaults:
activation: tanh
# learning rate
lr: 0.001

CartPole-v1:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 500
# total number of steps to train
total_steps: 1000000
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"

Taxi-v3:
# logger configurations
logger_cfgs:
# save model frequency
save_model_freq: 10
# training configurations
train_cfgs:
# max time-step for each episode
time_limit: 200
# total number of steps to train
total_steps: 1000000
# algorithm configurations
algo_cfgs:
# normalize observation
obs_normalize: False
# model configurations
model_cfgs:
# actor type, options: gaussian, gaussian_learning
actor_type: "discrete"
Loading