Skip to content

Commit b55b500

Browse files
author
Daniel Weber
committed
#51 started adjusting DDPG
1 parent 94fa340 commit b55b500

File tree

1 file changed

+43
-5
lines changed

1 file changed

+43
-5
lines changed

experiments/issue51_new/stable_baselinesDDPG.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from os import makedirs
33
from typing import List
44

5+
import torch as th
6+
import torch.nn as nn
7+
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
8+
59
import gym
610
import numpy as np
711
from stable_baselines3 import DDPG
@@ -110,13 +114,47 @@ def _on_step(self) -> bool:
110114
n_actions = env.action_space.shape[-1]
111115
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
112116

113-
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/')
114-
checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/')
117+
118+
class CustomMPL(BaseFeaturesExtractor):
119+
120+
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
121+
super(CustomMPL, self).__init__(observation_space, features_dim)
122+
# We assume CxHxW images (channels first)
123+
# Re-ordering will be done by pre-preprocessing or wrapper
124+
n_input_channels = observation_space.shape[0]
125+
self.cnn = nn.Sequential(
126+
nn.Linear(n_input_channels, 32),
127+
nn.ReLU(),
128+
nn.Linear(32, 64),
129+
nn.ReLU(),
130+
)
131+
132+
# Compute shape by doing one forward pass
133+
with th.no_grad():
134+
n_flatten = self.cnn(
135+
th.as_tensor(observation_space.sample()[None]).float()
136+
).shape[1]
137+
138+
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
139+
140+
def forward(self, observations: th.Tensor) -> th.Tensor:
141+
return self.linear(self.cnn(observations))
142+
143+
144+
policy_kwargs = dict(
145+
features_extractor_class=CustomMPL,
146+
features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]),
147+
)
148+
149+
# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10]))
150+
# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32])
151+
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs)
152+
checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/')
115153
record_env = RecordEnvCallback()
116-
plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env)
117-
model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback])
154+
plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env)
155+
model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback])
118156

119-
model.save('ddpg_CC')
157+
model.save('ddpg_CC2')
120158

121159
del model # remove to demonstrate saving and loading
122160

0 commit comments

Comments
 (0)