Skip to content

Commit

Permalink
Fixed beamsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
Div99 committed May 19, 2018
1 parent 03eccb9 commit e1fe6a6
Show file tree
Hide file tree
Showing 2 changed files with 894 additions and 40 deletions.
22 changes: 10 additions & 12 deletions CapGenerator/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pickle import load
from numpy import argmax
import numpy as np
from keras.preprocessing.sequence import pad_sequences
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
Expand Down Expand Up @@ -53,29 +53,24 @@ def generate_desc(model, tokenizer, photo, index_word, max_length, beam_size=10)
# pad input
sequence = pad_sequences([sequence], maxlen=max_length)
# predict next words
yhat = model.predict([photo,sequence], verbose=0)
y_pred = model.predict([photo,sequence], verbose=0)[0]
# convert probability to integer
yhats = np.argsort(yhat)[-beam_size:]
yhats = np.argsort(y_pred)[-beam_size:]

for j in yhats:
# map integer to word
word = index_word.get(yhat)
word = index_word.get(j)
# stop if we cannot map the word
if word is None:
continue
# Add word to caption, and generate negative log prob
caption = [sentence + ' ' + word, score - np.log(yhat[j])]
# Add word to caption, and generate log prob
caption = [sentence + ' ' + word, score + np.log(y_pred[j])]
all_caps.append(caption)

# order all candidates by score
ordered = sorted(all_caps, key=lambda tup:tup[1], reverse=True)
captions = ordered[:beam_size]

# append as input for generating the next word
in_text += ' ' + word
# stop if we predict the end of the sequence

break
return captions

# evaluate the skill of the model
Expand Down Expand Up @@ -125,7 +120,10 @@ def evaluate_model(model, descriptions, photos, tokenizer, index_word, max_lengt
# generate description
captions = generate_desc(model, tokenizer, photo, index_word, max_length)
for cap in captions:
print('{}. log_prob: {}'.format(cap[0],cap[1]))
# remove start and end tokens
seq = cap[0].split()[1:-1]
desc = ' '.join(seq)
print('{} [log prob: {:1.2f}]'.format(desc,cap[1]))
else:
# load test set
test_features, test_descriptions = ld.prepare_dataset('test')[1]
Expand Down
Loading

0 comments on commit e1fe6a6

Please sign in to comment.