diff --git a/nagisa/model.py b/nagisa/model.py index 611f7c8..1bd30bf 100644 --- a/nagisa/model.py +++ b/nagisa/model.py @@ -150,38 +150,38 @@ def score_sentence(self, observations, tags): score = score + dy.pick(self.trans[self.sp_e], tags[-1]) return score - - def viterbi_decoding(self, observations): - backpointers = [] - init_vvars = [-1e10] * self.dim_output - init_vvars[self.sp_s] = 0 - for_expr = dy.inputVector(init_vvars) - trans_exprs = [self.trans[idx] for idx in range(self.dim_output)] - for obs in observations: - bptrs_t = [] - vvars_t = [] - for next_tag in range(self.dim_output): - next_tag_expr = for_expr + trans_exprs[next_tag] - next_tag_arr = next_tag_expr.npvalue() - best_tag_id = np.argmax(next_tag_arr) - bptrs_t.append(best_tag_id) - vvars_t.append(dy.pick(next_tag_expr, best_tag_id)) - for_expr = dy.concatenate(vvars_t) + obs - backpointers.append(bptrs_t) - terminal_expr = for_expr + trans_exprs[self.sp_e] - terminal_arr = terminal_expr.npvalue() - best_tag_id = np.argmax(terminal_arr) - path_score = dy.pick(terminal_expr, best_tag_id) - best_path = [best_tag_id] - for bptrs_t in reversed(backpointers): - best_tag_id = bptrs_t[best_tag_id] - best_path.append(best_tag_id) - start = best_path.pop() - best_path.reverse() - if not start == self.sp_s: - raise AssertionError("start != self.sp_s") - return best_path, path_score - + # Nagisa does not use this method. + # + # def viterbi_decoding(self, observations): + # backpointers = [] + # init_vvars = [-1e10] * self.dim_output + # init_vvars[self.sp_s] = 0 + # for_expr = dy.inputVector(init_vvars) + # trans_exprs = [self.trans[idx] for idx in range(self.dim_output)] + # for obs in observations: + # bptrs_t = [] + # vvars_t = [] + # for next_tag in range(self.dim_output): + # next_tag_expr = for_expr + trans_exprs[next_tag] + # next_tag_arr = next_tag_expr.npvalue() + # best_tag_id = np.argmax(next_tag_arr) + # bptrs_t.append(best_tag_id) + # vvars_t.append(dy.pick(next_tag_expr, best_tag_id)) + # for_expr = dy.concatenate(vvars_t) + obs + # backpointers.append(bptrs_t) + # terminal_expr = for_expr + trans_exprs[self.sp_e] + # terminal_arr = terminal_expr.npvalue() + # best_tag_id = np.argmax(terminal_arr) + # path_score = dy.pick(terminal_expr, best_tag_id) + # best_path = [best_tag_id] + # for bptrs_t in reversed(backpointers): + # best_tag_id = bptrs_t[best_tag_id] + # best_path.append(best_tag_id) + # start = best_path.pop() + # best_path.reverse() + # if not start == self.sp_s: + # raise AssertionError("start != self.sp_s") + # return best_path, path_score def encode_pt(self, X, train=False): dy.renew_cg()