-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
180 lines (145 loc) · 5.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# OpenAI Gym Libraries
import gymnasium as gym
from gymnasium import spaces
import time
# Custom Environment
import sys
import build.fire_environment
# Numpy, for dealing with all of the representations of the observation space
import numpy as np
# Stable Baselines Imports
from stable_baselines3 import DQN
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# PyTorch, for building the custom CNN
import torch as th
import torch.nn as nn
# Tracking Time
total_time = 0
leaving_time, enter_time, external_time = 0, 0, 0
class WildfireEnv(gym.Env):
'''
Sets up the basic FireEnvironment object from C++, an observation space, and an action space
'''
def __init__(self):
# Fire Environment example
self.fire_env = build.fire_environment.FireEnvironment(20)
# Set the observation space
sep = (self.fire_env.getState())
self.observation_space = spaces.Box(low=0, high=200, shape = sep.shape, dtype=np.float64)
# Set the action space
actions = self.fire_env.getActions()
#I'm not doing a fancy prefix sum solution for an action space on the order of 10
self.ind_to_pair = [[populated_area, path] for populated_area, path_count in enumerate(actions) for path in range(path_count)]
self.action_space = spaces.Discrete(n = len(self.ind_to_pair), start=0)
# Creating starting times
self.total_time = 0
self.leaving_time, self.enter_time, self.external_time = 0, 0, 0
'''
Reset the entire environment by creating a new environment.
'''
def reset(self):
self.fire_env = build.fire_environment.FireEnvironment(20)
sep = self.fire_env.getState()
return sep, {"": ""}
'''
Take a step and advance the environment after taking an action.
'''
def step(self, action):
# Get start of time call
self.enter_time = time.time()
if (self.leaving_time != 0):
self.external_time += self.enter_time - self.leaving_time
start = time.time()
# Call C++ function to take the action
actionTuple = self.ind_to_pair[action]
rewards = self.fire_env.inputAction(actionTuple[0] - 1, actionTuple[1])
# Gather the observations, rewards, terminated, and truncated
observations = self.fire_env.getState()
terminated = self.fire_env.getTerminated()
truncated = False
# Get end of time call
end = time.time()
self.total_time += (end - start)
self.leaving_time = time.time()
# Return necessary 4 tuple
return observations, rewards, terminated, truncated, {"": ""}
'''
Useful in debugging and seeing the progression of the environment over time.
'''
def print_environment(self):
self.fire_env.printData()
print("")
'''
Defines our own custom class for a CNN
'''
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: spaces.Box, features_dim=512, normalized_image=False):
super().__init__(observation_space, features_dim)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=4, stride=1, padding=2),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim * 2), nn.ReLU())
self.linear2 = nn.Sequential(nn.Linear(features_dim * 2, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear2(self.linear(self.cnn(observations)))
def run_simulation(train):
# Set up the basic environment
env = WildfireEnv()
# Set up DQN Model
model = None
if train:
# Create specific key word arguments and create model
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=256),
normalize_images=False
)
model = DQN("CnnPolicy", env, verbose=1, policy_kwargs=policy_kwargs)
# Without using custom network
#model = DQN("CnnPolicy", env, verbose=1, policy_kwargs=dict(normalize_images=False))
print("begin training")
model.learn(total_timesteps=110000, log_interval=100)
print(env.total_time)
print(env.external_time)
model.save("Trained Policy")
else:
model = DQN.load("policies/Trained Policy")
# Run a random sampling basically for 20 iterations
for _ in range(10):
# Define the initial observation and cumulative reward
obs = env.reset()[0]
total_reward = 0
# Run for 100 timesteps
for _ in range(99):
# Get a next action
action, _states = model.predict(obs, deterministic=True)
print("Action: " + str(action))
# Take the action and advance the state
observation, reward, terminated, truncated, info = env.step(action)
env.print_environment()
print("Current Reward: " + str(reward))
# Add to the total reward
total_reward += reward
# Print out the total accumulated reward
print("Accumulated Reward: " + str(total_reward))
def main(train):
train = int(train)
run_simulation(train)
'''
Call with CLI argument "1" if you want to train, "0" if you don't
'''
if __name__ == "__main__":
main(sys.argv[1])