Skip to content

Commit 2b8aea6

Browse files
author
Daniel Weber
committed
#51 added DDPG example
1 parent cb678e0 commit 2b8aea6

File tree

1 file changed

+5
-43
lines changed

1 file changed

+5
-43
lines changed

experiments/issue51_new/stable_baselinesDDPG.py

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
from os import makedirs
33
from typing import List
44

5-
import torch as th
6-
import torch.nn as nn
7-
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
8-
95
import gym
106
import numpy as np
117
from stable_baselines3 import DDPG
@@ -114,47 +110,13 @@ def _on_step(self) -> bool:
114110
n_actions = env.action_space.shape[-1]
115111
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
116112

117-
118-
class CustomMPL(BaseFeaturesExtractor):
119-
120-
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
121-
super(CustomMPL, self).__init__(observation_space, features_dim)
122-
# We assume CxHxW images (channels first)
123-
# Re-ordering will be done by pre-preprocessing or wrapper
124-
n_input_channels = observation_space.shape[0]
125-
self.cnn = nn.Sequential(
126-
nn.Linear(n_input_channels, 32),
127-
nn.ReLU(),
128-
nn.Linear(32, 64),
129-
nn.ReLU(),
130-
)
131-
132-
# Compute shape by doing one forward pass
133-
with th.no_grad():
134-
n_flatten = self.cnn(
135-
th.as_tensor(observation_space.sample()[None]).float()
136-
).shape[1]
137-
138-
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
139-
140-
def forward(self, observations: th.Tensor) -> th.Tensor:
141-
return self.linear(self.cnn(observations))
142-
143-
144-
policy_kwargs = dict(
145-
features_extractor_class=CustomMPL,
146-
features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]),
147-
)
148-
149-
# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10]))
150-
# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32])
151-
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs)
152-
checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/')
113+
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/')
114+
checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/')
153115
record_env = RecordEnvCallback()
154-
plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env)
155-
model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback])
116+
plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env)
117+
model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback])
156118

157-
model.save('ddpg_CC2')
119+
model.save('ddpg_CC')
158120

159121
del model # remove to demonstrate saving and loading
160122

0 commit comments

Comments
 (0)