|
2 | 2 | from os import makedirs
|
3 | 3 | from typing import List
|
4 | 4 |
|
5 |
| -import torch as th |
6 |
| -import torch.nn as nn |
7 |
| -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
8 |
| - |
9 | 5 | import gym
|
10 | 6 | import numpy as np
|
11 | 7 | from stable_baselines3 import DDPG
|
@@ -114,47 +110,13 @@ def _on_step(self) -> bool:
|
114 | 110 | n_actions = env.action_space.shape[-1]
|
115 | 111 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
116 | 112 |
|
117 |
| - |
118 |
| -class CustomMPL(BaseFeaturesExtractor): |
119 |
| - |
120 |
| - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): |
121 |
| - super(CustomMPL, self).__init__(observation_space, features_dim) |
122 |
| - # We assume CxHxW images (channels first) |
123 |
| - # Re-ordering will be done by pre-preprocessing or wrapper |
124 |
| - n_input_channels = observation_space.shape[0] |
125 |
| - self.cnn = nn.Sequential( |
126 |
| - nn.Linear(n_input_channels, 32), |
127 |
| - nn.ReLU(), |
128 |
| - nn.Linear(32, 64), |
129 |
| - nn.ReLU(), |
130 |
| - ) |
131 |
| - |
132 |
| - # Compute shape by doing one forward pass |
133 |
| - with th.no_grad(): |
134 |
| - n_flatten = self.cnn( |
135 |
| - th.as_tensor(observation_space.sample()[None]).float() |
136 |
| - ).shape[1] |
137 |
| - |
138 |
| - self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) |
139 |
| - |
140 |
| - def forward(self, observations: th.Tensor) -> th.Tensor: |
141 |
| - return self.linear(self.cnn(observations)) |
142 |
| - |
143 |
| - |
144 |
| -policy_kwargs = dict( |
145 |
| - features_extractor_class=CustomMPL, |
146 |
| - features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]), |
147 |
| -) |
148 |
| - |
149 |
| -# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10])) |
150 |
| -# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32]) |
151 |
| -model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs) |
152 |
| -checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/') |
| 113 | +model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/') |
| 114 | +checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/') |
153 | 115 | record_env = RecordEnvCallback()
|
154 |
| -plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env) |
155 |
| -model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback]) |
| 116 | +plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env) |
| 117 | +model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback]) |
156 | 118 |
|
157 |
| -model.save('ddpg_CC2') |
| 119 | +model.save('ddpg_CC') |
158 | 120 |
|
159 | 121 | del model # remove to demonstrate saving and loading
|
160 | 122 |
|
|
0 commit comments