Skip to content

Commit

Permalink
Removing unused viterbi_decoding in model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishi-i committed Jul 6, 2020
1 parent 21f2413 commit 8fdc1a7
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions nagisa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,38 +150,38 @@ def score_sentence(self, observations, tags):
score = score + dy.pick(self.trans[self.sp_e], tags[-1])
return score


def viterbi_decoding(self, observations):
backpointers = []
init_vvars = [-1e10] * self.dim_output
init_vvars[self.sp_s] = 0
for_expr = dy.inputVector(init_vvars)
trans_exprs = [self.trans[idx] for idx in range(self.dim_output)]
for obs in observations:
bptrs_t = []
vvars_t = []
for next_tag in range(self.dim_output):
next_tag_expr = for_expr + trans_exprs[next_tag]
next_tag_arr = next_tag_expr.npvalue()
best_tag_id = np.argmax(next_tag_arr)
bptrs_t.append(best_tag_id)
vvars_t.append(dy.pick(next_tag_expr, best_tag_id))
for_expr = dy.concatenate(vvars_t) + obs
backpointers.append(bptrs_t)
terminal_expr = for_expr + trans_exprs[self.sp_e]
terminal_arr = terminal_expr.npvalue()
best_tag_id = np.argmax(terminal_arr)
path_score = dy.pick(terminal_expr, best_tag_id)
best_path = [best_tag_id]
for bptrs_t in reversed(backpointers):
best_tag_id = bptrs_t[best_tag_id]
best_path.append(best_tag_id)
start = best_path.pop()
best_path.reverse()
if not start == self.sp_s:
raise AssertionError("start != self.sp_s")
return best_path, path_score

# Nagisa does not use this method.
#
# def viterbi_decoding(self, observations):
# backpointers = []
# init_vvars = [-1e10] * self.dim_output
# init_vvars[self.sp_s] = 0
# for_expr = dy.inputVector(init_vvars)
# trans_exprs = [self.trans[idx] for idx in range(self.dim_output)]
# for obs in observations:
# bptrs_t = []
# vvars_t = []
# for next_tag in range(self.dim_output):
# next_tag_expr = for_expr + trans_exprs[next_tag]
# next_tag_arr = next_tag_expr.npvalue()
# best_tag_id = np.argmax(next_tag_arr)
# bptrs_t.append(best_tag_id)
# vvars_t.append(dy.pick(next_tag_expr, best_tag_id))
# for_expr = dy.concatenate(vvars_t) + obs
# backpointers.append(bptrs_t)
# terminal_expr = for_expr + trans_exprs[self.sp_e]
# terminal_arr = terminal_expr.npvalue()
# best_tag_id = np.argmax(terminal_arr)
# path_score = dy.pick(terminal_expr, best_tag_id)
# best_path = [best_tag_id]
# for bptrs_t in reversed(backpointers):
# best_tag_id = bptrs_t[best_tag_id]
# best_path.append(best_tag_id)
# start = best_path.pop()
# best_path.reverse()
# if not start == self.sp_s:
# raise AssertionError("start != self.sp_s")
# return best_path, path_score

def encode_pt(self, X, train=False):
dy.renew_cg()
Expand Down

0 comments on commit 8fdc1a7

Please sign in to comment.