Skip to content

Commit 33a573e

Browse files
author
Can Liu
committed
train l2 on predicted translations
1 parent 264ee03 commit 33a573e

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

src/l2_train_on_predict.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#!/usr/bin/env python3
2+
3+
"""Train on training data. This will open up a file in ../trainingdata/ and get
4+
the sweet, sweet instances out of it. And then train on them and produce a
5+
maxent classifier for a given target language and source word."""
6+
7+
import nltk
8+
import pickle
9+
import sys
10+
import argparse
11+
from operator import itemgetter
12+
13+
from nltk.classify.maxent import MaxentClassifier
14+
15+
import util_run_experiment
16+
from wsd_problem import WSDProblem
17+
from parse_corpus import extract_wsd_problems
18+
import read_gold
19+
import features
20+
import stanford
21+
from co_occur import Occurrence
22+
import train_from_extracted
23+
24+
25+
def get_four_friends(target):
26+
all_languages = set(['es','fr','nl','de','it'])
27+
four_friends = all_languages - set([target])
28+
print(four_friends)
29+
return four_friends
30+
31+
def extend_features(features,more_features,frd1,frd2,frd3,frd4):
32+
features["friend_{}({})".format(frd1,more_features[0])] = True
33+
features["friend_{}({})".format(frd2,more_features[1])] = True
34+
features["friend_{}({})".format(frd3,more_features[2])] = True
35+
features["friend_{}({})".format(frd4,more_features[3])] = True
36+
37+
return features
38+
39+
def get_training_data_from_extracted(sourceword, targetlang):
40+
"""Return a list of (featureset, label) for training."""
41+
42+
43+
frd1,frd2,frd3,frd4 = sorted(list(get_four_friends(targetlang))) ##Get other four languages.
44+
classfrd1,classfrd2,classfrd3,classfrd4 = get_level1_classifiers(frd1,frd2,frd3,frd4,sourceword)
45+
46+
##Get the intersection of four training sentences.
47+
tool_class = Occurrence(sourceword,frd1,frd2)
48+
49+
out = []
50+
problems = []
51+
fn = "../trainingdata/{0}.{1}.train".format(sourceword, targetlang)
52+
53+
with open(fn) as infile:
54+
lines = infile.readlines()
55+
lines = [line.strip() for line in lines]
56+
contexts = [line for line in lines[0::3]]
57+
indices = [int(line) for line in lines[1::3]]
58+
labelss = [line.split(",") for line in lines[2::3]]
59+
assert len(contexts) == len(labelss) == len(indices)
60+
61+
print("the length of them...",len(contexts),len(indices),len(labelss))
62+
#input()
63+
answers = []
64+
for context, index, labels in zip(contexts, indices, labelss):
65+
problem = WSDProblem(sourceword, context,
66+
testset=False, head_index=index)
67+
68+
#print(more_featuress)
69+
for label in labels:
70+
if label == '': continue
71+
problems.append(problem)
72+
#more_features = intersection[context]
73+
answers.append(label)
74+
75+
for problem,answer in zip(problems, answers):
76+
level1_features = features.extract(problem)
77+
answer_frd1 = classfrd1.classify(level1_features)
78+
answer_frd2 = classfrd2.classify(level1_features)
79+
answer_frd3 = classfrd3.classify(level1_features)
80+
answer_frd4 = classfrd4.classify(level1_features)
81+
level2_features = extend_features(level1_features,(answer_frd1,answer_frd2,answer_frd3,answer_frd4),frd1,frd2,frd3,frd4)
82+
label = answer
83+
assert(type(label) is str)
84+
#print("=================@@@@features {}\n@@@@label{}\n".format(featureset,label))
85+
out.append((level2_features, label))
86+
print("###Length of the output should be the same{}\n".format(len(out)))
87+
return out
88+
89+
def get_maxent_classifier(sourceword, target):
90+
instances = get_training_data_from_extracted(sourceword, target)
91+
instances = train_from_extracted.remove_onecount_instances(instances)
92+
print("got {0} training instances!!".format(len(instances)))
93+
print("... training ...")
94+
classifier = MaxentClassifier.train(instances,
95+
trace=0,
96+
max_iter=20,
97+
algorithm='megam')
98+
print("LABELS", classifier.labels())
99+
return classifier
100+
101+
def get_level1_answers(classifier_frd1,classifier_frd2,classifier_frd3,classifier_frd4,featureset):
102+
frd1,frd2,frd3,frd4 = sorted(list(get_four_friends(target)))
103+
answer_frd1 = classifier.classify()
104+
105+
def get_level1_classifiers(frd1,frd2,frd3,frd4,sourceword):
106+
107+
classifier_frd1 = util_run_experiment.get_pickled_classifier(sourceword,frd1,'level1')
108+
classifier_frd2 = util_run_experiment.get_pickled_classifier(sourceword,frd2,'level1')
109+
classifier_frd3 = util_run_experiment.get_pickled_classifier(sourceword,frd3,'level1')
110+
classifier_frd4 = util_run_experiment.get_pickled_classifier(sourceword,frd4,'level1')
111+
112+
return classifier_frd1,classifier_frd2,classifier_frd3,classifier_frd4
113+
114+
115+
def train_l2_classifiers():
116+
all_languages = [sys.argv[1]]
117+
path = "../L2pickle"
118+
all_words = util_run_experiment.final_test_words
119+
nltk.classify.megam.config_megam(bin='/usr/local/bin/megam')
120+
121+
for sourceword in all_words:
122+
for target in all_languages:
123+
level2_classifier = get_maxent_classifier(sourceword, target)
124+
pickle.dump( level2_classifier,open( "{}/{}.{}.level2.pickle".format(path,sourceword,target),'wb') )
125+
#answer = level2_classifier.classifiy( {"cw(deposit)":True,"cw(money)":True,"cw(finacial)":True} )
126+
#print("the answer::::",answer)
127+
###pickle the level2 classifiers...
128+
#test_level2(sourceword,target,level2_classifier)
129+
130+
131+
if __name__ == "__main__":
132+
stanford.taggerhome = '/home/liucan/stanford-postagger-2012-11-11'
133+
train_l2_classifiers()
134+
#test_level2('bank','es')
135+
#lan sys.argv[1]
136+
#get_four_friends(lan)
137+
#get_training_data_from_extracted('bank','de')

0 commit comments

Comments
 (0)