-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_A2C.py
83 lines (62 loc) · 2.79 KB
/
train_A2C.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import sys
import getopt
import time
import gym
import gym_reflected_xss
import uuid
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy
from stable_baselines import A2C
from stable_baselines.common import make_vec_env
import torch as th
# remove tensorflow warning messages
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from stable_baselines.common.policies import FeedForwardPolicy, LstmPolicy
# Custom MLP policy of three layers of size 128 each
class CustomPolicy(FeedForwardPolicy):
def __init__(self,*args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
net_arch=[dict(pi=[128, 128, 128],
vf=[128, 128, 128])],
feature_extraction="mlp")
class CustomLSTMPolicy(LstmPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=128, reuse=False, **_kwargs):
super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
net_arch=[128, 'lstm', dict(pi=[128, 128, 128],
vf=[128, 128, 128])],
layer_norm=True, feature_extraction="mlp", **_kwargs)
def main(argv):
start_url = ""
test_suite_name = ""
timesteps = 4000000
try:
opts, etc_args= getopt.getopt(argv[1:], "o:t:")
except getopt.GetoptError:
print("Use option -o")
sys.exit(2)
for opt,arg in opts:
if opt in ("-u"):
option = arg
if opt in ("-t"):
timesteps = int(arg)
start_url = option
env = gym.make("reflected-xss-v0", start_url=start_url, mode=0, log_file_name="train_log.txt")
# create learning agent
print("[*] Creating A2C model ...")
policy_kwargs = dict(activation_fn=th.nn.ReLU,net_arch=[dict(pi=[128,128,128], vf=[128,128,128])])
learning_rate = 0.0005
gamma = 0.95
model = A2C(CustomPolicy, env, verbose=1,tensorboard_log="./tensorboard_log/", learning_rate=learning_rate, gamma=gamma)
print("[*] Start Agent learning ...")
log_title = time.strftime('%Y.%m.%d', time.localtime(time.time())) + "-" + test_suite_name + "-" + str(timesteps) + "-A2C-learning-" + str(learning_rate) + "-gamma-" + str(gamma) + "O"
model.learn(total_timesteps=timesteps , tb_log_name=log_title)
model_name = "models/" + log_title + "-" + str(uuid.uuid4()) + "-model.pkl"
# save trained model
model.save(model_name)
# env.show_graph()
del model
if __name__ == '__main__':
main(sys.argv)