-
Notifications
You must be signed in to change notification settings - Fork 9
/
rank_based.py
188 lines (166 loc) · 7.01 KB
/
rank_based.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: [email protected]
# description:
import sys
import math
import random
import numpy as np
import binary_heap
class Experience(object):
def __init__(self, conf):
self.size = conf['size']
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
self.index = 0
self.record_size = 0
self.isFull = False
self._experience = {}
self.priority_queue = binary_heap.BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()
self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
def build_distributions(self):
"""
preprocess pow of rank
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
:return: distributions, dict
"""
res = {}
n_partitions = self.partition_num
partition_num = 1
# each part size
partition_size = int(math.floor(1.0*self.size / n_partitions))
for n in range(partition_size, self.size + 1, partition_size):
if True:
distribution = {}
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
pdf = list(
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
)
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
# split to k segment, and than uniform sample in each k
# set k = batch_size, each segment has total probability is 1 / batch_size
# strata_ends keep each segment start pos and end pos
cdf = np.cumsum(distribution['pdf'])
strata_ends = {1: 0, self.batch_size + 1: n}
step = 1.0 / self.batch_size
index = 1
for s in range(2, self.batch_size + 1):
while cdf[index] < step:
index += 1
strata_ends[s] = index
step += 1.0 / self.batch_size
distribution['strata_ends'] = strata_ends
res[partition_num] = distribution
partition_num += 1
return res
def fix_index(self):
"""
get next insert index
:return: index, int
"""
if self.record_size < self.size:#self.record_size <= self.size:
self.record_size += 1
if self.index % self.size == 0:
self.isFull = True if len(self._experience) == self.size else False
if self.replace_flag:
self.index = 1
return self.index
else:
sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n')
return -1
else:
self.index += 1
return self.index
def store(self, experience):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is
:param experience: maybe a tuple,valid or list
:return: bool, indicate insert status
"""
insert_index = self.fix_index()
if insert_index > 0:
if insert_index in self._experience:
del self._experience[insert_index]
self._experience[insert_index] = experience
# add to priority queue
priority = self.priority_queue.get_max_priority()
if priority!=1:
stop=0
self.priority_queue.update(priority, insert_index)
return True
else:
sys.stderr.write('Insert failed\n')
return False
def retrieve(self, indices):
"""
get experience from indices
:param indices: list of experience id
:return: experience replay sample
"""
return [self._experience[v] for v in indices]
def rebalance(self):
"""
rebalance priority queue
:return: None
"""
self.priority_queue.balance_tree()
def update_priority(self, indices, delta):
"""
update priority according indices and deltas
:param indices: list of experience id
:param delta: list of delta, order correspond to indices
:return: None
"""
for i in range(0, len(indices)):
self.priority_queue.update(math.fabs(delta[i]), indices[i])
def sample(self, global_step):
"""
sample a mini batch from experience replay
:param global_step: now training step
:return: experience, list, samples
:return: w, list, weights
:return: rank_e_id, list, samples id, used for update priority
"""
if self.record_size < self.learn_start:
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False
dist_index = int(math.floor(1.0*self.record_size / self.size * self.partition_num))
# issue 1 by @camigord
partition_size = int(math.floor(1.0*self.size / self.partition_num))
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
if distribution['strata_ends'][n] + 1 <= distribution['strata_ends'][n + 1]:
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
else:
index = random.randint(distribution['strata_ends'][n + 1],
distribution['strata_ends'][n] + 1)
rank_list.append(index)
# beta, increase by global_step, max 1
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
# find all alpha pow, notice that pdf is a list, start from 0
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
# w = (N * P(i)) ^ (-beta) / max w
w = np.power(np.array(alpha_pow) * partition_max, -beta)
w_max = max(w)
w = np.divide(w, w_max)
# rank list is priority id
# convert to experience id
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
# get experience id according rank_e_id
experience = self.retrieve(rank_e_id)
return experience, w, rank_e_id