-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathmodel.py
562 lines (468 loc) · 26.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
"""
It contains model as well as loss function described in paper..
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Reshape, Activation, Conv2D, Input, MaxPooling2D, BatchNormalization, Flatten, Dense, Lambda
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import SGD, Adam, RMSprop
import tensorflow.keras.backend as K
def space_to_depth_x2(x):
return tf.nn.space_to_depth(x, block_size=2)
def ConvBatchLReLu(x,filters,kernel_size,index,trainable):
# when strides = None, strides = pool_size.
x = Conv2D(filters, kernel_size, strides=(1,1),
padding='same', name='conv_{}'.format(index),
use_bias=False, trainable=trainable)(x)
x = BatchNormalization(name='norm_{}'.format(index), trainable=trainable)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
return(x)
def ConvBatchLReLu_loop(x,index,convstack,trainable):
for para in convstack:
x = ConvBatchLReLu(x,para["filters"],para["kernel_size"],index,trainable)
index += 1
return(x)
def define_YOLOv2(IMAGE_H,IMAGE_W,GRID_H,GRID_W,TRUE_BOX_BUFFER,BOX,CLASS, trainable=False):
convstack3to5 = [{"filters":128, "kernel_size":(3,3)}, # 3
{"filters":64, "kernel_size":(1,1)}, # 4
{"filters":128, "kernel_size":(3,3)}] # 5
convstack6to8 = [{"filters":256, "kernel_size":(3,3)}, # 6
{"filters":128, "kernel_size":(1,1)}, # 7
{"filters":256, "kernel_size":(3,3)}] # 8
convstack9to13 = [{"filters":512, "kernel_size":(3,3)}, # 9
{"filters":256, "kernel_size":(1,1)}, # 10
{"filters":512, "kernel_size":(3,3)}, # 11
{"filters":256, "kernel_size":(1,1)}, # 12
{"filters":512, "kernel_size":(3,3)}] # 13
convstack14to20 = [{"filters":1024, "kernel_size":(3,3)}, # 14
{"filters":512, "kernel_size":(1,1)}, # 15
{"filters":1024, "kernel_size":(3,3)}, # 16
{"filters":512, "kernel_size":(1,1)}, # 17
{"filters":1024, "kernel_size":(3,3)}, # 18
{"filters":1024, "kernel_size":(3,3)}, # 19
{"filters":1024, "kernel_size":(3,3)}] # 20
input_image = Input(shape=(IMAGE_H, IMAGE_W, 3),name="input_image")
true_boxes = Input(shape=(1, 1, 1, TRUE_BOX_BUFFER , 4),name="input_hack")
# Layer 1
x = ConvBatchLReLu(input_image,filters=32,kernel_size=(3,3),index=1,trainable=trainable)
x = MaxPooling2D(pool_size=(2, 2),name="maxpool1_416to208")(x)
# Layer 2
x = ConvBatchLReLu(x,filters=64,kernel_size=(3,3),index=2,trainable=trainable)
x = MaxPooling2D(pool_size=(2, 2),name="maxpool1_208to104")(x)
# Layer 3 - 5
x = ConvBatchLReLu_loop(x,3,convstack3to5,trainable)
x = MaxPooling2D(pool_size=(2, 2),name="maxpool1_104to52")(x)
# Layer 6 - 8
x = ConvBatchLReLu_loop(x,6,convstack6to8,trainable)
x = MaxPooling2D(pool_size=(2, 2),name="maxpool1_52to26")(x)
# Layer 9 - 13
x = ConvBatchLReLu_loop(x,9,convstack9to13,trainable)
skip_connection = x
x = MaxPooling2D(pool_size=(2, 2),name="maxpool1_26to13")(x)
# Layer 14 - 20
x = ConvBatchLReLu_loop(x,14,convstack14to20,trainable)
# Layer 21
skip_connection = ConvBatchLReLu(skip_connection,filters=64,
kernel_size=(1,1),index=21,trainable=trainable)
skip_connection = Lambda(space_to_depth_x2)(skip_connection)
x = tf.keras.layers.concatenate([skip_connection, x])
# Layer 22
x = ConvBatchLReLu(x,filters=1024,kernel_size=(3,3),index=22,trainable=trainable)
# Layer 23
x = Conv2D(BOX * (4 + 1 + CLASS), (1,1), strides=(1,1), padding='same', name='conv_23')(x)
output = Reshape((GRID_H, GRID_W, BOX, 4 + 1 + CLASS),name="final_output")(x)
# small hack to allow true_boxes to be registered when Keras build the model
# for more information: https://github.com/fchollet/keras/issues/2790
output = Lambda(lambda args: args[0],name="hack_layer")([output, true_boxes])
model = keras.models.Model([input_image, true_boxes], output)
return(model, true_boxes)
class WeightReader:
def __init__(self, weight_file):
self.offset = 4
self.all_weights = np.fromfile(weight_file, dtype='float32')
def read_bytes(self, size):
self.offset = self.offset + size
return self.all_weights[self.offset-size:self.offset]
def reset(self):
self.offset = 4
def set_pretrained_weight(model,nb_conv, path_to_weight):
weight_reader = WeightReader(path_to_weight)
weight_reader.reset()
for i in range(1, nb_conv+1):
conv_layer = model.get_layer('conv_' + str(i)) ## convolusional layer
if i < nb_conv:
norm_layer = model.get_layer('norm_' + str(i)) ## batch normalization layer
size = np.prod(norm_layer.get_weights()[0].shape)
beta = weight_reader.read_bytes(size)
gamma = weight_reader.read_bytes(size)
mean = weight_reader.read_bytes(size)
var = weight_reader.read_bytes(size)
weights = norm_layer.set_weights([gamma, beta, mean, var])
if len(conv_layer.get_weights()) > 1: ## with bias
bias = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[1].shape))
kernel = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[0].shape))
kernel = kernel.reshape(list(reversed(conv_layer.get_weights()[0].shape)))
kernel = kernel.transpose([2,3,1,0])
conv_layer.set_weights([kernel, bias])
else: ## without bias
kernel = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[0].shape))
kernel = kernel.reshape(list(reversed(conv_layer.get_weights()[0].shape)))
kernel = kernel.transpose([2,3,1,0])
conv_layer.set_weights([kernel])
return(model)
def initialize_weight(layer,sd):
weights = layer.get_weights()
new_kernel = np.random.normal(size=weights[0].shape, scale=sd)
new_bias = np.random.normal(size=weights[1].shape, scale=sd)
layer.set_weights([new_kernel, new_bias])
def get_cell_grid(GRID_W,GRID_H,BATCH_SIZE,BOX):
'''
Helper function to assure that the bounding box x and y are in the grid cell scale
== output ==
for any i=0,1..,batch size - 1
output[i,5,3,:,:] = array([[3., 5.],
[3., 5.],
[3., 5.]], dtype=float32)
'''
## cell_x.shape = (1, 13, 13, 1, 1)
## cell_x[:,i,j,:] = [[[j]]]
cell_x = tf.cast(tf.reshape(tf.tile(tf.range(GRID_W), [GRID_H]), (1, GRID_H, GRID_W, 1, 1)),tf.float32)
## cell_y.shape = (1, 13, 13, 1, 1)
## cell_y[:,i,j,:] = [[[i]]]
cell_y = tf.transpose(cell_x, (0,2,1,3,4))
## cell_gird.shape = (16, 13, 13, 5, 2)
## for any n, k, i, j
## cell_grid[n, i, j, anchor, k] = j when k = 0
## for any n, k, i, j
## cell_grid[n, i, j, anchor, k] = i when k = 1
cell_grid = tf.tile(tf.concat([cell_x,cell_y], -1), [BATCH_SIZE, 1, 1, BOX, 1])
return(cell_grid)
def adjust_scale_prediction(y_pred, cell_grid, ANCHORS):
"""
Adjust prediction
== input ==
y_pred : takes any real values
tensor of shape = (N batch, NGrid h, NGrid w, NAnchor, 4 + 1 + N class)
ANCHORS : list containing width and height specializaiton of anchor box
== output ==
pred_box_xy : shape = (N batch, N grid x, N grid y, N anchor, 2), contianing [center_y, center_x] rangining [0,0]x[grid_H-1,grid_W-1]
pred_box_xy[irow,igrid_h,igrid_w,ianchor,0] = center_x
pred_box_xy[irow,igrid_h,igrid_w,ianchor,1] = center_1
calculation process:
tf.sigmoid(y_pred[...,:2]) : takes values between 0 and 1
tf.sigmoid(y_pred[...,:2]) + cell_grid : takes values between 0 and grid_W - 1 for x coordinate
takes values between 0 and grid_H - 1 for y coordinate
pred_Box_wh : shape = (N batch, N grid h, N grid w, N anchor, 2), containing width and height, rangining [0,0]x[grid_H-1,grid_W-1]
pred_box_conf : shape = (N batch, N grid h, N grid w, N anchor, 1), containing confidence to range between 0 and 1
pred_box_class : shape = (N batch, N grid h, N grid w, N anchor, N class), containing
"""
BOX = int(len(ANCHORS)/2)
## cell_grid is of the shape of
### adjust x and y
# the bounding box bx and by are rescaled to range between 0 and 1 for given gird.
# Since there are BOX x BOX grids, we rescale each bx and by to range between 0 to BOX + 1
pred_box_xy = tf.sigmoid(y_pred[..., :2]) + cell_grid # bx, by
### adjust w and h
# exp to make width and height positive
# rescale each grid to make some anchor "good" at representing certain shape of bounding box
pred_box_wh = tf.exp(y_pred[..., 2:4]) * np.reshape(ANCHORS,[1,1,1,BOX,2]) # bw, bh
### adjust confidence
pred_box_conf = tf.sigmoid(y_pred[..., 4])# prob bb
### adjust class probabilities
pred_box_class = y_pred[..., 5:] # prC1, prC2, ..., prC20
return(pred_box_xy,pred_box_wh,pred_box_conf,pred_box_class)
def print_min_max(vec,title):
print("{} MIN={:5.2f}, MAX={:5.2f}".format(
title,np.min(vec),np.max(vec)))
def extract_ground_truth(y_true):
true_box_xy = y_true[..., 0:2] # bounding box x, y coordinate in grid cell scale
true_box_wh = y_true[..., 2:4] # number of cells accross, horizontally and vertically
true_box_conf = y_true[...,4] # confidence
true_box_class = tf.argmax(y_true[..., 5:], -1)
return(true_box_xy, true_box_wh, true_box_conf, true_box_class)
def calc_loss_xywh(true_box_conf,
COORD_SCALE,
true_box_xy, pred_box_xy,true_box_wh,pred_box_wh):
'''
coord_mask: np.array of shape (Nbatch, Ngrid h, N grid w, N anchor, 1)
lambda_{coord} L_{i,j}^{obj}
'''
# lambda_{coord} L_{i,j}^{obj}
# np.array of shape (Nbatch, Ngrid h, N grid w, N anchor, 1)
coord_mask = tf.expand_dims(true_box_conf, axis=-1) * COORD_SCALE
nb_coord_box = tf.reduce_sum(tf.cast(coord_mask > 0.0,tf.float32))
loss_xy = tf.reduce_sum(tf.square(true_box_xy-pred_box_xy) * coord_mask) / (nb_coord_box + 1e-6) / 2.
loss_wh = tf.reduce_sum(tf.square(true_box_wh-pred_box_wh) * coord_mask) / (nb_coord_box + 1e-6) / 2.
return(loss_xy + loss_wh, coord_mask)
def calc_loss_class(true_box_conf,CLASS_SCALE, true_box_class,pred_box_class):
'''
== input ==
true_box_conf : tensor of shape (N batch, N grid h, N grid w, N anchor)
true_box_class : tensor of shape (N batch, N grid h, N grid w, N anchor), containing class index
pred_box_class : tensor of shape (N batch, N grid h, N grid w, N anchor, N class)
CLASS_SCALE : 1.0
== output ==
class_mask
if object exists in this (grid_cell, anchor) pair and the class object receive nonzero weight
class_mask[iframe,igridy,igridx,ianchor] = 1
else:
0
'''
class_mask = true_box_conf * CLASS_SCALE ## L_{i,j}^obj * lambda_class
nb_class_box = tf.reduce_sum(tf.cast(class_mask > 0.0,tf.float32))
loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = true_box_class,
logits = pred_box_class)
loss_class = tf.reduce_sum(loss_class * class_mask) / (nb_class_box + 1e-6)
return(loss_class)
#Example useage of tf.gather
#indices = np.array([[0,0],
# [1,0],
# [0,1]])
#arr = tf.constant(indices)
#temp = tf.gather(np.array([100,-20]), arr)
#with tf.Session() as sess:
# t = sess.run(temp)
#print(t)
#[[100 100]
# [-20 100]
# [100 -20]]
def get_intersect_area(true_xy,true_wh,
pred_xy,pred_wh):
'''
== INPUT ==
true_xy,pred_xy, true_wh and pred_wh must have the same shape length
p1 : pred_mins = (px1,py1)
p2 : pred_maxs = (px2,py2)
t1 : true_mins = (tx1,ty1)
t2 : true_maxs = (tx2,ty2)
p1______________________
| t1___________ |
| | | |
|_______|___________|__|p2
| |rmax
|___________|
t2
intersect_mins : rmin = t1 = (tx1,ty1)
intersect_maxs : rmax = (rmaxx,rmaxy)
intersect_wh : (rmaxx - tx1, rmaxy - ty1)
'''
true_wh_half = true_wh / 2.
true_mins = true_xy - true_wh_half
true_maxes = true_xy + true_wh_half
pred_wh_half = pred_wh / 2.
pred_mins = pred_xy - pred_wh_half
pred_maxes = pred_xy + pred_wh_half
intersect_mins = tf.maximum(pred_mins, true_mins)
intersect_maxes = tf.minimum(pred_maxes, true_maxes)
intersect_wh = tf.maximum(intersect_maxes - intersect_mins, 0.)
intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1]
true_areas = true_wh[..., 0] * true_wh[..., 1]
pred_areas = pred_wh[..., 0] * pred_wh[..., 1]
union_areas = pred_areas + true_areas - intersect_areas
iou_scores = tf.truediv(intersect_areas, union_areas)
return(iou_scores)
def calc_IOU_pred_true_assigned(true_box_conf,
true_box_xy, true_box_wh,
pred_box_xy, pred_box_wh):
'''
== input ==
true_box_conf : tensor of shape (N batch, N grid h, N grid w, N anchor )
true_box_xy : tensor of shape (N batch, N grid h, N grid w, N anchor , 2)
true_box_wh : tensor of shape (N batch, N grid h, N grid w, N anchor , 2)
pred_box_xy : tensor of shape (N batch, N grid h, N grid w, N anchor , 2)
pred_box_wh : tensor of shape (N batch, N grid h, N grid w, N anchor , 2)
== output ==
true_box_conf : tensor of shape (N batch, N grid h, N grid w, N anchor)
true_box_conf value depends on the predicted values
true_box_conf = IOU_{true,pred} if objecte exist in this anchor else 0
'''
iou_scores = get_intersect_area(true_box_xy,true_box_wh,
pred_box_xy,pred_box_wh)
true_box_conf_IOU = iou_scores * true_box_conf
return(true_box_conf_IOU)
def calc_IOU_pred_true_best(pred_box_xy,pred_box_wh,true_boxes):
'''
== input ==
pred_box_xy : tensor of shape (N batch, N grid h, N grid w, N anchor, 2)
pred_box_wh : tensor of shape (N batch, N grid h, N grid w, N anchor, 2)
true_boxes : tensor of shape (N batch, N grid h, N grid w, N anchor, 2)
== output ==
best_ious
for each iframe,
best_ious[iframe,igridy,igridx,ianchor] contains
the IOU of the object that is most likely included (or best fitted)
within the bounded box recorded in (grid_cell, anchor) pair
NOTE: a same object may be contained in multiple (grid_cell, anchor) pair
from best_ious, you cannot tell how may actual objects are captured as the "best" object
'''
true_xy = true_boxes[..., 0:2] # (N batch, 1, 1, 1, TRUE_BOX_BUFFER, 2)
true_wh = true_boxes[..., 2:4] # (N batch, 1, 1, 1, TRUE_BOX_BUFFER, 2)
pred_xy = tf.expand_dims(pred_box_xy, 4) # (N batch, N grid_h, N grid_w, N anchor, 1, 2)
pred_wh = tf.expand_dims(pred_box_wh, 4) # (N batch, N grid_h, N grid_w, N anchor, 1, 2)
iou_scores = get_intersect_area(true_xy,
true_wh,
pred_xy,
pred_wh) # (N batch, N grid_h, N grid_w, N anchor, 50)
best_ious = tf.reduce_max(iou_scores, axis=4) # (N batch, N grid_h, N grid_w, N anchor)
return(best_ious)
def get_conf_mask(best_ious, true_box_conf, true_box_conf_IOU,LAMBDA_NO_OBJECT, LAMBDA_OBJECT):
'''
== input ==
best_ious : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
true_box_conf : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
true_box_conf_IOU : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
LAMBDA_NO_OBJECT : 1.0
LAMBDA_OBJECT : 5.0
== output ==
conf_mask : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
conf_mask[iframe, igridy, igridx, ianchor] = 0
when there is no object assigned in (grid cell, anchor) pair and the region seems useless i.e.
y_true[iframe,igridx,igridy,4] = 0 "and" the predicted region has no object that has IoU > 0.6
conf_mask[iframe, igridy, igridx, ianchor] = NO_OBJECT_SCALE
when there is no object assigned in (grid cell, anchor) pair but region seems to include some object
y_true[iframe,igridx,igridy,4] = 0 "and" the predicted region has some object that has IoU > 0.6
conf_mask[iframe, igridy, igridx, ianchor] = OBJECT_SCALE
when there is an object in (grid cell, anchor) pair
'''
conf_mask = tf.cast(best_ious < 0.6,tf.float32) * (1 - true_box_conf) * LAMBDA_NO_OBJECT
# penalize the confidence of the boxes, which are reponsible for corresponding ground truth box
conf_mask = conf_mask + true_box_conf_IOU * LAMBDA_OBJECT
return(conf_mask)
def calc_loss_conf(conf_mask,true_box_conf_IOU, pred_box_conf):
'''
== input ==
conf_mask : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
true_box_conf_IOU : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
pred_box_conf : tensor of shape (Nbatch, N grid h, N grid w, N anchor)
'''
# the number of (grid cell, anchor) pair that has an assigned object or
# that has no assigned object but some objects may be in bounding box.
# N conf
nb_conf_box = tf.reduce_sum(tf.cast(conf_mask > 0.0,tf.float32))
loss_conf = tf.reduce_sum(tf.square(true_box_conf_IOU-pred_box_conf) * conf_mask) / (nb_conf_box + 1e-6) / 2.
return(loss_conf)
def custom_loss(y_true, y_pred):
'''
y_true : (N batch, N grid h, N grid w, N anchor, 4 + 1 + N classes)
y_true[irow, i_gridh, i_gridw, i_anchor, :4] = center_x, center_y, w, h
center_x : The x coordinate center of the bounding box.
Rescaled to range between 0 and N gird w (e.g., ranging between [0,13)
center_y : The y coordinate center of the bounding box.
Rescaled to range between 0 and N gird h (e.g., ranging between [0,13)
w : The width of the bounding box.
Rescaled to range between 0 and N gird w (e.g., ranging between [0,13)
h : The height of the bounding box.
Rescaled to range between 0 and N gird h (e.g., ranging between [0,13)
y_true[irow, i_gridh, i_gridw, i_anchor, 4] = ground truth confidence
ground truth confidence is 1 if object exists in this (anchor box, gird cell) pair
y_true[irow, i_gridh, i_gridw, i_anchor, 5 + iclass] = 1 if the object is in category else 0
'''
total_recall = tf.Variable(0.)
# Step 1: Adjust prediction output
cell_grid = get_cell_grid(GRID_W,GRID_H,BATCH_SIZE,BOX)
pred_box_xy, pred_box_wh, pred_box_conf, pred_box_class = adjust_scale_prediction(y_pred,cell_grid,ANCHORS)
# Step 2: Extract ground truth output
true_box_xy, true_box_wh, true_box_conf, true_box_class = extract_ground_truth(y_true)
# Step 3: Calculate loss for the bounding box parameters
loss_xywh, coord_mask = calc_loss_xywh(true_box_conf,LAMBDA_COORD,
true_box_xy, pred_box_xy,true_box_wh,pred_box_wh)
# Step 4: Calculate loss for the class probabilities
loss_class = calc_loss_class(true_box_conf,LAMBDA_CLASS,
true_box_class,pred_box_class)
# Step 5: For each (grid cell, anchor) pair,
# calculate the IoU between predicted and ground truth bounding box
true_box_conf_IOU = calc_IOU_pred_true_assigned(true_box_conf,
true_box_xy, true_box_wh,
pred_box_xy, pred_box_wh)
# Step 6: For each predicted bounded box from (grid cell, anchor box),
# calculate the best IOU, regardless of the ground truth anchor box that each object gets assigned.
best_ious = calc_IOU_pred_true_best(pred_box_xy,pred_box_wh,true_boxes)
# Step 7: For each grid cell, calculate the L_{i,j}^{noobj}
conf_mask = get_conf_mask(best_ious, true_box_conf, true_box_conf_IOU,LAMBDA_NO_OBJECT, LAMBDA_OBJECT)
# Step 8: Calculate loss for the confidence
loss_conf = calc_loss_conf(conf_mask,true_box_conf_IOU, pred_box_conf)
loss = loss_xywh + loss_conf + loss_class
return loss
def custom_loss_core(y_true,
y_pred,
true_boxes,
GRID_W,
GRID_H,
BATCH_SIZE,
ANCHORS,
LAMBDA_COORD,
LAMBDA_CLASS,
LAMBDA_NO_OBJECT,
LAMBDA_OBJECT):
'''
y_true : (N batch, N grid h, N grid w, N anchor, 4 + 1 + N classes)
y_true[irow, i_gridh, i_gridw, i_anchor, :4] = center_x, center_y, w, h
center_x : The x coordinate center of the bounding box.
Rescaled to range between 0 and N gird w (e.g., ranging between [0,13)
center_y : The y coordinate center of the bounding box.
Rescaled to range between 0 and N gird h (e.g., ranging between [0,13)
w : The width of the bounding box.
Rescaled to range between 0 and N gird w (e.g., ranging between [0,13)
h : The height of the bounding box.
Rescaled to range between 0 and N gird h (e.g., ranging between [0,13)
y_true[irow, i_gridh, i_gridw, i_anchor, 4] = ground truth confidence
ground truth confidence is 1 if object exists in this (anchor box, gird cell) pair
y_true[irow, i_gridh, i_gridw, i_anchor, 5 + iclass] = 1 if the object is in category <iclass> else 0
=====================================================
tensor that connect to the YOLO model's hack input
=====================================================
true_boxes
=========================================
training parameters specification example
=========================================
GRID_W = 13
GRID_H = 13
BATCH_SIZE = 34
ANCHORS = np.array([1.07709888, 1.78171903, # anchor box 1, width , height
2.71054693, 5.12469308, # anchor box 2, width, height
10.47181473, 10.09646365, # anchor box 3, width, height
5.48531347, 8.11011331]) # anchor box 4, width, height
LAMBDA_NO_OBJECT = 1.0
LAMBDA_OBJECT = 5.0
LAMBDA_COORD = 1.0
LAMBDA_CLASS = 1.0
'''
BOX = int(len(ANCHORS)/2)
# Step 1: Adjust prediction output
cell_grid = get_cell_grid(GRID_W,GRID_H,BATCH_SIZE,BOX)
pred_box_xy, pred_box_wh, pred_box_conf, pred_box_class = adjust_scale_prediction(y_pred,cell_grid,ANCHORS)
# Step 2: Extract ground truth output
true_box_xy, true_box_wh, true_box_conf, true_box_class = extract_ground_truth(y_true)
# Step 3: Calculate loss for the bounding box parameters
loss_xywh, coord_mask = calc_loss_xywh(true_box_conf,LAMBDA_COORD,
true_box_xy, pred_box_xy,true_box_wh,pred_box_wh)
# Step 4: Calculate loss for the class probabilities
loss_class = calc_loss_class(true_box_conf,LAMBDA_CLASS,
true_box_class,pred_box_class)
# Step 5: For each (grid cell, anchor) pair,
# calculate the IoU between predicted and ground truth bounding box
true_box_conf_IOU = calc_IOU_pred_true_assigned(true_box_conf,
true_box_xy, true_box_wh,
pred_box_xy, pred_box_wh)
# Step 6: For each predicted bounded box from (grid cell, anchor box),
# calculate the best IOU, regardless of the ground truth anchor box that each object gets assigned.
best_ious = calc_IOU_pred_true_best(pred_box_xy,pred_box_wh,true_boxes)
# Step 7: For each grid cell, calculate the L_{i,j}^{noobj}
conf_mask = get_conf_mask(best_ious, true_box_conf, true_box_conf_IOU,LAMBDA_NO_OBJECT, LAMBDA_OBJECT)
# Step 8: Calculate loss for the confidence
loss_conf = calc_loss_conf(conf_mask,true_box_conf_IOU, pred_box_conf)
loss = loss_xywh + loss_conf + loss_class
return(loss)
def custom_loss(y_true, y_pred):
return(custom_loss_core(
y_true,
y_pred,
true_boxes,
GRID_W,
GRID_H,
BATCH_SIZE,
ANCHORS,
LAMBDA_COORD,
LAMBDA_CLASS,
LAMBDA_NO_OBJECT,
LAMBDA_OBJECT))