|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +"""Train on training data. This will open up a file in ../trainingdata/ and get |
| 4 | +the sweet, sweet instances out of it. And then train on them and produce a |
| 5 | +maxent classifier for a given target language and source word.""" |
| 6 | + |
| 7 | +import nltk |
| 8 | +import pickle |
| 9 | +import sys |
| 10 | +import argparse |
| 11 | +from operator import itemgetter |
| 12 | + |
| 13 | +from nltk.classify.maxent import MaxentClassifier |
| 14 | + |
| 15 | +import util_run_experiment |
| 16 | +from wsd_problem import WSDProblem |
| 17 | +from parse_corpus import extract_wsd_problems |
| 18 | +import read_gold |
| 19 | +import features |
| 20 | +import stanford |
| 21 | +from co_occur import Occurrence |
| 22 | +import train_from_extracted |
| 23 | + |
| 24 | + |
| 25 | +def get_four_friends(target): |
| 26 | + all_languages = set(['es','fr','nl','de','it']) |
| 27 | + four_friends = all_languages - set([target]) |
| 28 | + print(four_friends) |
| 29 | + return four_friends |
| 30 | + |
| 31 | +def extend_features(features,more_features,frd1,frd2,frd3,frd4): |
| 32 | + features["friend_{}({})".format(frd1,more_features[0])] = True |
| 33 | + features["friend_{}({})".format(frd2,more_features[1])] = True |
| 34 | + features["friend_{}({})".format(frd3,more_features[2])] = True |
| 35 | + features["friend_{}({})".format(frd4,more_features[3])] = True |
| 36 | + |
| 37 | + return features |
| 38 | + |
| 39 | +def get_training_data_from_extracted(sourceword, targetlang): |
| 40 | + """Return a list of (featureset, label) for training.""" |
| 41 | + |
| 42 | + |
| 43 | + frd1,frd2,frd3,frd4 = sorted(list(get_four_friends(targetlang))) ##Get other four languages. |
| 44 | + classfrd1,classfrd2,classfrd3,classfrd4 = get_level1_classifiers(frd1,frd2,frd3,frd4,sourceword) |
| 45 | + |
| 46 | + ##Get the intersection of four training sentences. |
| 47 | + tool_class = Occurrence(sourceword,frd1,frd2) |
| 48 | + |
| 49 | + out = [] |
| 50 | + problems = [] |
| 51 | + fn = "../trainingdata/{0}.{1}.train".format(sourceword, targetlang) |
| 52 | + |
| 53 | + with open(fn) as infile: |
| 54 | + lines = infile.readlines() |
| 55 | + lines = [line.strip() for line in lines] |
| 56 | + contexts = [line for line in lines[0::3]] |
| 57 | + indices = [int(line) for line in lines[1::3]] |
| 58 | + labelss = [line.split(",") for line in lines[2::3]] |
| 59 | + assert len(contexts) == len(labelss) == len(indices) |
| 60 | + |
| 61 | + print("the length of them...",len(contexts),len(indices),len(labelss)) |
| 62 | + #input() |
| 63 | + answers = [] |
| 64 | + for context, index, labels in zip(contexts, indices, labelss): |
| 65 | + problem = WSDProblem(sourceword, context, |
| 66 | + testset=False, head_index=index) |
| 67 | + |
| 68 | + #print(more_featuress) |
| 69 | + for label in labels: |
| 70 | + if label == '': continue |
| 71 | + problems.append(problem) |
| 72 | + #more_features = intersection[context] |
| 73 | + answers.append(label) |
| 74 | + |
| 75 | + for problem,answer in zip(problems, answers): |
| 76 | + level1_features = features.extract(problem) |
| 77 | + answer_frd1 = classfrd1.classify(level1_features) |
| 78 | + answer_frd2 = classfrd2.classify(level1_features) |
| 79 | + answer_frd3 = classfrd3.classify(level1_features) |
| 80 | + answer_frd4 = classfrd4.classify(level1_features) |
| 81 | + level2_features = extend_features(level1_features,(answer_frd1,answer_frd2,answer_frd3,answer_frd4),frd1,frd2,frd3,frd4) |
| 82 | + label = answer |
| 83 | + assert(type(label) is str) |
| 84 | + #print("=================@@@@features {}\n@@@@label{}\n".format(featureset,label)) |
| 85 | + out.append((level2_features, label)) |
| 86 | + print("###Length of the output should be the same{}\n".format(len(out))) |
| 87 | + return out |
| 88 | + |
| 89 | +def get_maxent_classifier(sourceword, target): |
| 90 | + instances = get_training_data_from_extracted(sourceword, target) |
| 91 | + instances = train_from_extracted.remove_onecount_instances(instances) |
| 92 | + print("got {0} training instances!!".format(len(instances))) |
| 93 | + print("... training ...") |
| 94 | + classifier = MaxentClassifier.train(instances, |
| 95 | + trace=0, |
| 96 | + max_iter=20, |
| 97 | + algorithm='megam') |
| 98 | + print("LABELS", classifier.labels()) |
| 99 | + return classifier |
| 100 | + |
| 101 | +def get_level1_answers(classifier_frd1,classifier_frd2,classifier_frd3,classifier_frd4,featureset): |
| 102 | + frd1,frd2,frd3,frd4 = sorted(list(get_four_friends(target))) |
| 103 | + answer_frd1 = classifier.classify() |
| 104 | + |
| 105 | +def get_level1_classifiers(frd1,frd2,frd3,frd4,sourceword): |
| 106 | + |
| 107 | + classifier_frd1 = util_run_experiment.get_pickled_classifier(sourceword,frd1,'level1') |
| 108 | + classifier_frd2 = util_run_experiment.get_pickled_classifier(sourceword,frd2,'level1') |
| 109 | + classifier_frd3 = util_run_experiment.get_pickled_classifier(sourceword,frd3,'level1') |
| 110 | + classifier_frd4 = util_run_experiment.get_pickled_classifier(sourceword,frd4,'level1') |
| 111 | + |
| 112 | + return classifier_frd1,classifier_frd2,classifier_frd3,classifier_frd4 |
| 113 | + |
| 114 | + |
| 115 | +def train_l2_classifiers(): |
| 116 | + all_languages = [sys.argv[1]] |
| 117 | + path = "../L2pickle" |
| 118 | + all_words = util_run_experiment.final_test_words |
| 119 | + nltk.classify.megam.config_megam(bin='/usr/local/bin/megam') |
| 120 | + |
| 121 | + for sourceword in all_words: |
| 122 | + for target in all_languages: |
| 123 | + level2_classifier = get_maxent_classifier(sourceword, target) |
| 124 | + pickle.dump( level2_classifier,open( "{}/{}.{}.level2.pickle".format(path,sourceword,target),'wb') ) |
| 125 | + #answer = level2_classifier.classifiy( {"cw(deposit)":True,"cw(money)":True,"cw(finacial)":True} ) |
| 126 | + #print("the answer::::",answer) |
| 127 | + ###pickle the level2 classifiers... |
| 128 | + #test_level2(sourceword,target,level2_classifier) |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + stanford.taggerhome = '/home/liucan/stanford-postagger-2012-11-11' |
| 133 | + train_l2_classifiers() |
| 134 | + #test_level2('bank','es') |
| 135 | + #lan sys.argv[1] |
| 136 | + #get_four_friends(lan) |
| 137 | + #get_training_data_from_extracted('bank','de') |
0 commit comments