Skip to content

Commit 3a34eeb

Browse files
committed
Initial commit
0 parents  commit 3a34eeb

11 files changed

+810
-0
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Auto detect text files and perform LF normalization
2+
* text=auto

.idea/REINFORCE.iml

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

Lines changed: 349 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

NormalizedActions.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#-*- coding: UTF-8 -*-
2+
"""
3+
filename:
4+
function:
5+
date: 2017/8/7
6+
author:
7+
________ ____.__
8+
\______ \ ____ ____ ____ | |__| ____ ____ _____ ____
9+
| | \_/ __ \ / \ / ___\ | | |/ \_/ ___\\__ \ / \
10+
| ` \ ___/| | / /_/ /\__| | | | \ \___ / __ \| | \
11+
/_______ /\___ |___| \___ /\________|__|___| /\___ (____ |___| /
12+
\/ \/ \/_____/ \/ \/ \/ \/
13+
14+
   へ     /|
15+
  /\7    ∠_/
16+
  / │   / /
17+
 │ Z _,< /   /`ヽ
18+
 │     ヽ   /  〉
19+
  Y     `  /  /
20+
 イ● 、 ●  ⊂⊃〈  /
21+
 ()  へ    | \〈
22+
  ー 、_  ィ  │ //
23+
  / へ   / ノ<| \\
24+
  ヽ_ノ  (_/  │//
25+
  7       |/
26+
  >―r ̄ ̄`ー―_
27+
"""
28+
import gym
29+
30+
31+
class NormalizedActions(gym.ActionWrapper):
32+
33+
def _action(self, action):
34+
action = (action + 1) / 2 # [-1, 1] => [0, 1]
35+
action *= (self.action_space.high - self.action_space.low)
36+
action += self.action_space.low
37+
return action
38+
39+
def _reverse_action(self, action):
40+
action -= self.action_space.low
41+
action /= (self.action_space.high - self.action_space.low)
42+
action = action * 2 - 1
43+
return action
44+
45+
"""
46+
░░░░░░░░░▄░░░░░░░░░░░░░░▄░░░░
47+
░░░░░░░░▌▒█░░░░░░░░░░░▄▀▒▌░░░
48+
░░░░░░░░▌▒▒█░░░░░░░░▄▀▒▒▒▐░░░
49+
░░░░░░░▐▄▀▒▒▀▀▀▀▄▄▄▀▒▒▒▒▒▐░░░
50+
░░░░░▄▄▀▒░▒▒▒▒▒▒▒▒▒█▒▒▄█▒▐░░░
51+
░░░▄▀▒▒▒░░░▒▒▒░░░▒▒▒▀██▀▒▌░░░
52+
░░▐▒▒▒▄▄▒▒▒▒░░░▒▒▒▒▒▒▒▀▄▒▒▌░░
53+
░░▌░░▌█▀▒▒▒▒▒▄▀█▄▒▒▒▒▒▒▒█▒▐░░
54+
░▐░░░▒▒▒▒▒▒▒▒▌██▀▒▒░░░▒▒▒▀▄▌░
55+
░▌░▒▄██▄▒▒▒▒▒▒▒▒▒░░░░░░▒▒▒▒▌░
56+
▀▒▀▐▄█▄█▌▄░▀▒▒░░░░░░░░░░▒▒▒▐░
57+
▐▒▒▐▀▐▀▒░▄▄▒▄▒▒▒▒▒▒░▒░▒░▒▒▒▒▌
58+
▐▒▒▒▀▀▄▄▒▒▒▄▒▒▒▒▒▒▒▒░▒░▒░▒▒▐░
59+
░▌▒▒▒▒▒▒▀▀▀▒▒▒▒▒▒░▒░▒░▒░▒▒▒▌░
60+
░▐▒▒▒▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▒▄▒▒▐░░
61+
░░▀▄▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▄▒▒▒▒▌░░
62+
░░░░▀▄▒▒▒▒▒▒▒▒▒▒▄▄▄▀▒▒▒▒▄▀░░░
63+
░░░░░░▀▄▄▄▄▄▄▀▀▀▒▒▒▒▒▄▄▀░░░░░
64+
░░░░░░░░░▒▒▒▒▒▒▒▒▒▒▀▀░░░░░░░░
65+
"""

NormalizedActions.pyc

1.86 KB
Binary file not shown.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# REINFORCE_mx

main.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#-*- coding: UTF-8 -*-
2+
"""
3+
filename: main.py
4+
function: the code implementing REINFORCE algorithm in mxnet gluon
5+
date: 2018/3/13
6+
author:
7+
________ ____.__
8+
\______ \ ____ ____ ____ | |__| ____ ____ _____ ____
9+
| | \_/ __ \ / \ / ___\ | | |/ \_/ ___\\__ \ / \
10+
| ` \ ___/| | / /_/ /\__| | | | \ \___ / __ \| | \
11+
/_______ /\___ |___| \___ /\________|__|___| /\___ (____ |___| /
12+
\/ \/ \/_____/ \/ \/ \/ \/
13+
14+
   へ     /|
15+
  /\7    ∠_/
16+
  / │   / /
17+
 │ Z _,< /   /`ヽ
18+
 │     ヽ   /  〉
19+
  Y     `  /  /
20+
 イ● 、 ●  ⊂⊃〈  /
21+
 ()  へ    | \〈
22+
  ー 、_  ィ  │ //
23+
  / へ   / ノ<| \\
24+
  ヽ_ノ  (_/  │//
25+
  7       |/
26+
  >―r ̄ ̄`ー―_
27+
"""
28+
from __future__ import print_function
29+
import numpy as np
30+
import mxnet as mx
31+
from mxnet import nd, autograd, gluon
32+
import argparse, math, os
33+
import gym
34+
from gym import wrappers
35+
from NormalizedActions import NormalizedActions
36+
37+
# argument parser
38+
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
39+
# parser.add_argument('--env_name', type=str, default='CartPole-v0')
40+
parser.add_argument('--env_name', type=str, default='InvertedPendulum-v1')
41+
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
42+
help='discount factor for reward (default: 0.99)')
43+
parser.add_argument('--exploration_end', type=int, default=100, metavar='N',
44+
help='number of episodes with noise (default: 100)')
45+
parser.add_argument('--seed', type=int, default=123, metavar='N',
46+
help='random seed (default: 123)')
47+
parser.add_argument('--num_steps', type=int, default=1000, metavar='N',
48+
help='max episode length (default: 1000)')
49+
parser.add_argument('--num_episodes', type=int, default=2000, metavar='N',
50+
help='number of episodes (default: 2000)')
51+
parser.add_argument('--hidden_size', type=int, default=128, metavar='N',
52+
help='number of episodes (default: 128)')
53+
parser.add_argument('--render', action='store_true',
54+
help='render the environment')
55+
parser.add_argument('--ckpt_freq', type=int, default=100,
56+
help='model saving frequency')
57+
parser.add_argument('--display', type=bool, default=False,
58+
help='display or not')
59+
args = parser.parse_args()
60+
61+
# global variables
62+
env_name = args.env_name
63+
env = gym.make(env_name)
64+
if type(env.action_space) != gym.spaces.discrete.Discrete:
65+
from reinforce_continuous import REINFORCE
66+
env = NormalizedActions(gym.make(env_name))
67+
else:
68+
# from reinforce_discrete import REINFORCE
69+
raise NotImplementedError()
70+
71+
if args.display:
72+
env = wrappers.Monitor(env, '/tmp/{}-experiment'.format(env_name), force=True)
73+
74+
env.seed(args.seed)
75+
mx.random.seed(args.seed)
76+
np.random.seed(args.seed)
77+
78+
agent = REINFORCE(args.hidden_size, env.observation_space.shape[0], env.action_space)
79+
80+
dir = 'ckpt_' + env_name
81+
if not os.path.exists(dir):
82+
os.mkdir(dir)
83+
84+
for i_episode in range(args.num_episodes):
85+
# state = torch.Tensor([env.reset()])
86+
state = nd.array([env.reset()])
87+
entropies = []
88+
log_probs = []
89+
rewards = []
90+
# generate examples
91+
for t in range(args.num_steps):
92+
action, log_prob, entropy = agent.select_action(state)
93+
94+
next_state, reward, done, _ = env.step(action.numpy()[0])
95+
96+
entropies.append(entropy)
97+
log_probs.append(log_prob)
98+
rewards.append(reward)
99+
state = nd.array([next_state])
100+
101+
if done:
102+
break
103+
104+
agent.update_parameters(rewards, log_probs, entropies, args.gamma)
105+
106+
# if i_episode % args.ckpt_freq == 0:
107+
# torch.save(agent.model.state_dict(), os.path.join(dir, 'reinforce-' + str(i_episode) + '.pkl'))
108+
109+
print("Episode: {}, reward: {}".format(i_episode, np.sum(rewards)))
110+
111+
env.close()
112+
113+
"""
114+
░░░░░░░░░▄░░░░░░░░░░░░░░▄░░░░
115+
░░░░░░░░▌▒█░░░░░░░░░░░▄▀▒▌░░░
116+
░░░░░░░░▌▒▒█░░░░░░░░▄▀▒▒▒▐░░░
117+
░░░░░░░▐▄▀▒▒▀▀▀▀▄▄▄▀▒▒▒▒▒▐░░░
118+
░░░░░▄▄▀▒░▒▒▒▒▒▒▒▒▒█▒▒▄█▒▐░░░
119+
░░░▄▀▒▒▒░░░▒▒▒░░░▒▒▒▀██▀▒▌░░░
120+
░░▐▒▒▒▄▄▒▒▒▒░░░▒▒▒▒▒▒▒▀▄▒▒▌░░
121+
░░▌░░▌█▀▒▒▒▒▒▄▀█▄▒▒▒▒▒▒▒█▒▐░░
122+
░▐░░░▒▒▒▒▒▒▒▒▌██▀▒▒░░░▒▒▒▀▄▌░
123+
░▌░▒▄██▄▒▒▒▒▒▒▒▒▒░░░░░░▒▒▒▒▌░
124+
▀▒▀▐▄█▄█▌▄░▀▒▒░░░░░░░░░░▒▒▒▐░
125+
▐▒▒▐▀▐▀▒░▄▄▒▄▒▒▒▒▒▒░▒░▒░▒▒▒▒▌
126+
▐▒▒▒▀▀▄▄▒▒▒▄▒▒▒▒▒▒▒▒░▒░▒░▒▒▐░
127+
░▌▒▒▒▒▒▒▀▀▀▒▒▒▒▒▒░▒░▒░▒░▒▒▒▌░
128+
░▐▒▒▒▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▒▄▒▒▐░░
129+
░░▀▄▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▄▒▒▒▒▌░░
130+
░░░░▀▄▒▒▒▒▒▒▒▒▒▒▄▄▄▀▒▒▒▒▄▀░░░
131+
░░░░░░▀▄▄▄▄▄▄▀▀▀▒▒▒▒▒▄▄▀░░░░░
132+
░░░░░░░░░▒▒▒▒▒▒▒▒▒▒▀▀░░░░░░░░
133+
"""

reinforce_continuous.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#-*- coding: UTF-8 -*-
2+
"""
3+
filename: REINFORCE_continuous.py
4+
function: the REINFORCE algorithm for continuous variables
5+
date: 2017/8/7
6+
author:
7+
________ ____.__
8+
\______ \ ____ ____ ____ | |__| ____ ____ _____ ____
9+
| | \_/ __ \ / \ / ___\ | | |/ \_/ ___\\__ \ / \
10+
| ` \ ___/| | / /_/ /\__| | | | \ \___ / __ \| | \
11+
/_______ /\___ |___| \___ /\________|__|___| /\___ (____ |___| /
12+
\/ \/ \/_____/ \/ \/ \/ \/
13+
14+
   へ     /|
15+
  /\7    ∠_/
16+
  / │   / /
17+
 │ Z _,< /   /`ヽ
18+
 │     ヽ   /  〉
19+
  Y     `  /  /
20+
 イ● 、 ●  ⊂⊃〈  /
21+
 ()  へ    | \〈
22+
  ー 、_  ィ  │ //
23+
  / へ   / ノ<| \\
24+
  ヽ_ノ  (_/  │//
25+
  7       |/
26+
  >―r ̄ ̄`ー―_
27+
"""
28+
from __future__ import print_function
29+
import numpy as np
30+
import mxnet as mx
31+
from mxnet import nd, autograd, gluon
32+
33+
import math
34+
35+
# set ctx
36+
data_ctx = mx.cpu()
37+
model_ctx = mx.cpu()
38+
39+
def normal(x, mu, sigma_sq):
40+
a = nd.exp(-1*nd.power(x-mu, 2)/(2*sigma_sq))
41+
b = np.sqrt(1/(2*sigma_sq*math.pi.expand_as(sigma_sq)))
42+
return a*b
43+
44+
45+
class Policy(gluon.Block):
46+
def __init__(self, hidden_size, num_inputs, action_space):
47+
super(Policy, self).__init__()
48+
self.action_space = action_space
49+
num_outputs = action_space.shape[0]
50+
with self.name_scope():
51+
self.dense0 = gluon.nn.Dense(hidden_size)
52+
self.dense1 = gluon.nn.Dense(num_outputs)
53+
self.dense2 = gluon.nn.Dense(num_outputs)
54+
55+
def forward(self, inputs):
56+
x = inputs
57+
x = nd.relu(self.dense0(x))
58+
mu = self.dense1(x)
59+
sigma_sq = self.dense2(x)
60+
61+
return mu, sigma_sq
62+
63+
class REINFORCE:
64+
def __init__(self, hidden_size, num_inputs, action_space):
65+
self.action_space = action_space
66+
self.model = Policy(hidden_size, num_inputs, action_space)
67+
self.model.collect_params().initialize(mx.init.Normal(sigma=0.01), ctx=model_ctx)
68+
self.optimizer = gluon.Trainer(self.model.collect_params(), 'sgd', {'learning_rate': 0.01})
69+
70+
def select_action(self, state):
71+
with autograd.record():
72+
mu, sigma_sq = self.model(state.as_in_context(model_ctx))
73+
# sigma_sq = nd.softrelu(sigma_sq)
74+
# the implementation of softplus
75+
sigma_sq = nd.log(1+nd.exp(sigma_sq))
76+
77+
eps = nd.random.normal(0,1, mu.shape, dtype=np.float32)
78+
# calculate the probability
79+
action = mu + nd.sqrt(sigma_sq)*eps
80+
prob = normal(action, mu, sigma_sq)
81+
82+
entropy = -0.5*(np.log(sigma_sq+math.pi*2)+1)
83+
log_prob = nd.log(prob)
84+
85+
return action, log_prob, entropy
86+
87+
def update_parameters(self, rewards, log_probs, entropies, gamma):
88+
# loss = myloss(rewards, log_probs, entropies, gamma, sample_weight=None)
89+
# self.model.collect_params().zero_grad()
90+
with autograd.record():
91+
R = nd.zeros((1, 1))
92+
loss = 0
93+
for i in reversed(range(len(rewards))):
94+
R = gamma * R + rewards[i]
95+
loss = loss - (log_probs[i] * R).sum() - (0.0001 * entropies[i]).sum()
96+
self.model.collect_params().zero_grad()
97+
loss.backward()
98+
grads = [i.grad(data_ctx) for i in self.model.collect_params().values()]
99+
# 梯度裁剪。需要注意的是,这里的梯度是整个批量的梯度。
100+
# 因此我们将clipping_norm乘以num_steps和batch_size。
101+
gluon.utils.clip_global_norm(grads, 40)
102+
self.optimizer.step(batch_size=len(rewards))
103+
104+
105+
"""
106+
░░░░░░░░░▄░░░░░░░░░░░░░░▄░░░░
107+
░░░░░░░░▌▒█░░░░░░░░░░░▄▀▒▌░░░
108+
░░░░░░░░▌▒▒█░░░░░░░░▄▀▒▒▒▐░░░
109+
░░░░░░░▐▄▀▒▒▀▀▀▀▄▄▄▀▒▒▒▒▒▐░░░
110+
░░░░░▄▄▀▒░▒▒▒▒▒▒▒▒▒█▒▒▄█▒▐░░░
111+
░░░▄▀▒▒▒░░░▒▒▒░░░▒▒▒▀██▀▒▌░░░
112+
░░▐▒▒▒▄▄▒▒▒▒░░░▒▒▒▒▒▒▒▀▄▒▒▌░░
113+
░░▌░░▌█▀▒▒▒▒▒▄▀█▄▒▒▒▒▒▒▒█▒▐░░
114+
░▐░░░▒▒▒▒▒▒▒▒▌██▀▒▒░░░▒▒▒▀▄▌░
115+
░▌░▒▄██▄▒▒▒▒▒▒▒▒▒░░░░░░▒▒▒▒▌░
116+
▀▒▀▐▄█▄█▌▄░▀▒▒░░░░░░░░░░▒▒▒▐░
117+
▐▒▒▐▀▐▀▒░▄▄▒▄▒▒▒▒▒▒░▒░▒░▒▒▒▒▌
118+
▐▒▒▒▀▀▄▄▒▒▒▄▒▒▒▒▒▒▒▒░▒░▒░▒▒▐░
119+
░▌▒▒▒▒▒▒▀▀▀▒▒▒▒▒▒░▒░▒░▒░▒▒▒▌░
120+
░▐▒▒▒▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▒▄▒▒▐░░
121+
░░▀▄▒▒▒▒▒▒▒▒▒▒▒░▒░▒░▒▄▒▒▒▒▌░░
122+
░░░░▀▄▒▒▒▒▒▒▒▒▒▒▄▄▄▀▒▒▒▒▄▀░░░
123+
░░░░░░▀▄▄▄▄▄▄▀▀▀▒▒▒▒▒▄▄▀░░░░░
124+
░░░░░░░░░▒▒▒▒▒▒▒▒▒▒▀▀░░░░░░░░
125+
"""

0 commit comments

Comments
 (0)