|
2 | 2 | from os import makedirs
|
3 | 3 | from typing import List
|
4 | 4 |
|
| 5 | +import torch as th |
| 6 | +import torch.nn as nn |
| 7 | +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
| 8 | + |
5 | 9 | import gym
|
6 | 10 | import numpy as np
|
7 | 11 | from stable_baselines3 import DDPG
|
@@ -110,13 +114,47 @@ def _on_step(self) -> bool:
|
110 | 114 | n_actions = env.action_space.shape[-1]
|
111 | 115 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
112 | 116 |
|
113 |
| -model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/') |
114 |
| -checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/') |
| 117 | + |
| 118 | +class CustomMPL(BaseFeaturesExtractor): |
| 119 | + |
| 120 | + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): |
| 121 | + super(CustomMPL, self).__init__(observation_space, features_dim) |
| 122 | + # We assume CxHxW images (channels first) |
| 123 | + # Re-ordering will be done by pre-preprocessing or wrapper |
| 124 | + n_input_channels = observation_space.shape[0] |
| 125 | + self.cnn = nn.Sequential( |
| 126 | + nn.Linear(n_input_channels, 32), |
| 127 | + nn.ReLU(), |
| 128 | + nn.Linear(32, 64), |
| 129 | + nn.ReLU(), |
| 130 | + ) |
| 131 | + |
| 132 | + # Compute shape by doing one forward pass |
| 133 | + with th.no_grad(): |
| 134 | + n_flatten = self.cnn( |
| 135 | + th.as_tensor(observation_space.sample()[None]).float() |
| 136 | + ).shape[1] |
| 137 | + |
| 138 | + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) |
| 139 | + |
| 140 | + def forward(self, observations: th.Tensor) -> th.Tensor: |
| 141 | + return self.linear(self.cnn(observations)) |
| 142 | + |
| 143 | + |
| 144 | +policy_kwargs = dict( |
| 145 | + features_extractor_class=CustomMPL, |
| 146 | + features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]), |
| 147 | +) |
| 148 | + |
| 149 | +# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10])) |
| 150 | +# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32]) |
| 151 | +model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs) |
| 152 | +checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/') |
115 | 153 | record_env = RecordEnvCallback()
|
116 |
| -plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env) |
117 |
| -model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback]) |
| 154 | +plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env) |
| 155 | +model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback]) |
118 | 156 |
|
119 |
| -model.save('ddpg_CC') |
| 157 | +model.save('ddpg_CC2') |
120 | 158 |
|
121 | 159 | del model # remove to demonstrate saving and loading
|
122 | 160 |
|
|
0 commit comments