Skip to content

Commit

Permalink
Add flatten layer and update dependencies (#18)
Browse files Browse the repository at this point in the history
* Add flatten layer and update dependencies

* Reformat
  • Loading branch information
araffin authored Nov 6, 2023
1 parent f662613 commit 9bd4bca
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 23 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
# # Install Atari Roms
# pip install autorom
# wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
# base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
# AutoROM --accept-license --source-file Roms.tar.gz
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install .[tests]
# Use headless version
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ lint:

format:
# Sort imports
isort ${LINT_PATHS}
ruff --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ max-complexity = 15
[tool.black]
line-length = 127

[tool.isort]
profile = "black"
line_length = 127
src_paths = ["sbx"]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
Expand Down
12 changes: 12 additions & 0 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# import copy
from typing import Dict, Optional, Tuple, Union, no_type_check

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation


class Flatten(nn.Module):
"""
Equivalent to PyTorch nn.Flatten() layer.
"""

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x.reshape((x.shape[0], -1))


class BaseJaxPolicy(BasePolicy):
def __init__(self, *args, **kwargs):
super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState


Expand All @@ -18,6 +18,7 @@ class QNetwork(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = nn.Dense(self.n_units)(x)
Expand Down
4 changes: 3 additions & 1 deletion sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy
from sbx.common.policies import BaseJaxPolicy, Flatten

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions
Expand All @@ -24,6 +24,7 @@ class Critic(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
Expand All @@ -45,6 +46,7 @@ def get_std(self):

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
Expand Down
4 changes: 3 additions & 1 deletion sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.policies import BaseJaxPolicy
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState

tfp = tensorflow_probability.substrates.jax
Expand All @@ -25,6 +25,7 @@ class Critic(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
Expand Down Expand Up @@ -75,6 +76,7 @@ def get_std(self):

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
Expand Down
4 changes: 3 additions & 1 deletion sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState


Expand All @@ -19,6 +19,7 @@ class Critic(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
Expand Down Expand Up @@ -63,6 +64,7 @@ class Actor(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore[name-defined]
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
Expand Down
4 changes: 3 additions & 1 deletion sbx/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.policies import BaseJaxPolicy
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState

tfp = tensorflow_probability.substrates.jax
Expand All @@ -26,6 +26,7 @@ class Critic(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, a], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
Expand All @@ -50,6 +51,7 @@ def get_std(self):

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0
0.9.0
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.1.0",
"stable_baselines3>=2.2.0a9",
"jax",
"jaxlib",
"flax",
Expand All @@ -59,8 +59,6 @@
"mypy",
# Lint code
"ruff",
# Sort imports
"isort>=5.0",
# Reformat
"black",
],
Expand Down
31 changes: 31 additions & 0 deletions tests/test_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Dict, Optional

import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces

from sbx import DQN, PPO, SAC, TD3, TQC


@dataclass
class DummyEnv(gym.Env):
observation_space: spaces.Space
action_space: spaces.Space

def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}


@pytest.mark.parametrize("model_class", [DQN, PPO, SAC, TD3, TQC])
def test_flatten(model_class) -> None:
action_space = spaces.Discrete(15) if model_class == DQN else spaces.Box(-1, 1, shape=(2,), dtype=np.float32)
env = DummyEnv(spaces.Box(-1, 1, shape=(2, 1), dtype=np.float32), action_space)

model_class("MlpPolicy", env).learn(150)

0 comments on commit 9bd4bca

Please sign in to comment.