diff --git a/experiments/issue51_new/stable_baselinesDDPG.py b/experiments/issue51_new/stable_baselinesDDPG.py index 5713cfd2..9c8b47ef 100644 --- a/experiments/issue51_new/stable_baselinesDDPG.py +++ b/experiments/issue51_new/stable_baselinesDDPG.py @@ -2,6 +2,10 @@ from os import makedirs from typing import List +import torch as th +import torch.nn as nn +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + import gym import numpy as np from stable_baselines3 import DDPG @@ -110,13 +114,47 @@ def _on_step(self) -> bool: n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) -model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/') -checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/') + +class CustomMPL(BaseFeaturesExtractor): + + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): + super(CustomMPL, self).__init__(observation_space, features_dim) + # We assume CxHxW images (channels first) + # Re-ordering will be done by pre-preprocessing or wrapper + n_input_channels = observation_space.shape[0] + self.cnn = nn.Sequential( + nn.Linear(n_input_channels, 32), + nn.ReLU(), + nn.Linear(32, 64), + nn.ReLU(), + ) + + # Compute shape by doing one forward pass + with th.no_grad(): + n_flatten = self.cnn( + th.as_tensor(observation_space.sample()[None]).float() + ).shape[1] + + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.linear(self.cnn(observations)) + + +policy_kwargs = dict( + features_extractor_class=CustomMPL, + features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]), +) + +# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10])) +# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32]) +model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs) +checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/') record_env = RecordEnvCallback() -plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env) -model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback]) +plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env) +model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback]) -model.save('ddpg_CC') +model.save('ddpg_CC2') del model # remove to demonstrate saving and loading