Skip to content

Commit 566ee25

Browse files
committed
documentation + error handling
1 parent 0e6f1f5 commit 566ee25

File tree

9 files changed

+310
-14
lines changed

9 files changed

+310
-14
lines changed

ada_model.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,19 @@
77

88

99
class AdaModel:
10+
"""
11+
This class represents a model based on
12+
adaboost.
13+
"""
1014
def __init__(self, train_file="./in/train.dat", test_file="./in/test.dat",
1115
out_file="./out/ensemble.ab"):
16+
"""
17+
Initialize the model.
1218
19+
:param train_file: training data
20+
:param test_file: test data
21+
:param out_file: output data
22+
"""
1323
files = (train_file, test_file)
1424
lines = parse(files)
1525

@@ -19,6 +29,11 @@ def __init__(self, train_file="./in/train.dat", test_file="./in/test.dat",
1929
self.tree = None
2030

2131
def train(self, ensemble_size=5):
32+
"""
33+
Learns an ensemble using adaboost and
34+
saves the model to a file.
35+
"""
36+
2237
examples = self.data["train"]
2338
features = set(examples[0].features.keys())
2439
sample = WeightedSample(examples)
@@ -50,11 +65,16 @@ def train(self, ensemble_size=5):
5065
pickle.dump(self, f)
5166
f.close()
5267

53-
def test(self, h_file=None):
68+
def test(self, test_file=None):
69+
"""
70+
Tests the model.
71+
72+
:param test_file: test data
73+
"""
5474
if not self.ensemble:
5575
self.train()
5676

57-
examples = parse([h_file])[0] if h_file else self.data["test"]
77+
examples = parse([test_file])[0] if test_file else self.data["test"]
5878
result = []
5979

6080
for ex in examples:
@@ -65,6 +85,13 @@ def test(self, h_file=None):
6585
evaluate(result, examples)
6686

6787
def vote(self, instance):
88+
"""
89+
Classifies an instance by collecting
90+
votes from the ensemble
91+
92+
:param instance: instance to classify
93+
:return: classification
94+
"""
6895
count = {}
6996
max_count = 0
7097
winner = None
@@ -84,12 +111,18 @@ def vote(self, instance):
84111
return winner
85112

86113

87-
def evaluate(result, examples):
114+
def evaluate(results, examples):
115+
"""
116+
Evaluates results from a model.
117+
118+
:param results: list of results
119+
:param examples: test data
120+
"""
88121
correct = 0
89122

90123
print()
91124

92-
for res in result:
125+
for res in results:
93126
if res["result"] == res["goal"]:
94127
correct += 1
95128
else:
@@ -102,6 +135,9 @@ def evaluate(result, examples):
102135

103136

104137
def main():
138+
"""
139+
Main function. (Test)
140+
"""
105141
model = AdaModel()
106142
model.train(5)
107143
model.test()

classify.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66

77

88
def train(examples, out_file, learner):
9+
"""
10+
Train a learner on some examples and saves
11+
the resulting model to a file.
12+
13+
:param examples: training data
14+
:param out_file: output file
15+
:param learner: "ada" or "dt"
16+
17+
:return:
18+
"""
919
if learner == "dt":
1020
model = DecisionModel(train_file=examples, out_file=out_file)
1121
else:
@@ -15,17 +25,44 @@ def train(examples, out_file, learner):
1525

1626

1727
def predict(h_file, test_file):
28+
"""
29+
Loads a hypothesis (model) from h_file and
30+
uses it to predict the results of instances
31+
in a test file.
32+
33+
:param h_file: model file
34+
:param test_file: test file
35+
"""
1836
h_file = open(h_file, "rb")
1937
model = pickle.load(h_file)
2038

2139
h_file.close()
2240
model.test(test_file)
2341

2442

43+
def usage(train_msg=True, predict_msg=True):
44+
if train_msg:
45+
print("Usage: python3 classify.py train <examples> <hypothesisOut> <learning-type>")
46+
47+
if predict_msg:
48+
print("Usage: python3 classify.py predict <hypothesis> <file>")
49+
50+
exit(1)
51+
52+
2553
def main():
54+
"""
55+
Main function. Accepts user input.
56+
"""
57+
if len(sys.argv) < 2:
58+
usage()
59+
2660
action = sys.argv[1]
2761

2862
if action == "train":
63+
if len(sys.argv) < 5:
64+
usage(predict_msg=False)
65+
2966
examples = sys.argv[2]
3067
out_file = sys.argv[3]
3168
learner = sys.argv[4]
@@ -35,6 +72,9 @@ def main():
3572
print("Done.")
3673

3774
elif action == "predict":
75+
if len(sys.argv) < 4:
76+
usage(train_msg=False)
77+
3878
h_file = sys.argv[2]
3979
test_file = sys.argv[3]
4080

d_model.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,20 @@
55

66

77
class DecisionModel:
8+
"""
9+
This class represents a model based on
10+
decision trees.
11+
"""
12+
813
def __init__(self, train_file="./in/train.dat", test_file="./in/test.dat",
914
out_file="./out/tree.dt"):
15+
"""
16+
Initialize the model
17+
18+
:param train_file: training data
19+
:param test_file: test data
20+
:param out_file: output data
21+
"""
1022
files = (train_file, test_file)
1123
lines = parse(files)
1224

@@ -15,6 +27,10 @@ def __init__(self, train_file="./in/train.dat", test_file="./in/test.dat",
1527
self.tree = None
1628

1729
def train(self):
30+
"""
31+
Learns a decision tree and saves
32+
the model to a file.
33+
"""
1834
examples = self.data["train"]
1935
features = set(examples[0].features.keys())
2036

@@ -24,11 +40,16 @@ def train(self):
2440
pickle.dump(self, f)
2541
f.close()
2642

27-
def test(self, h_file=None):
43+
def test(self, test_file=None):
44+
"""
45+
Tests the model.
46+
47+
:param test_file: test data
48+
"""
2849
if not self.tree:
2950
self.train()
3051

31-
examples = parse([h_file])[0] if h_file else self.data["test"]
52+
examples = parse([test_file])[0] if test_file else self.data["test"]
3253
result = []
3354

3455
for ex in examples:
@@ -39,12 +60,18 @@ def test(self, h_file=None):
3960
evaluate(result, examples)
4061

4162

42-
def evaluate(result, examples):
63+
def evaluate(results, examples):
64+
"""
65+
Evaluates results from a model.
66+
67+
:param results: list of results
68+
:param examples: test data
69+
"""
4370
correct = 0
4471

4572
print()
4673

47-
for res in result:
74+
for res in results:
4875
if res["result"] == res["goal"]:
4976
correct += 1
5077
else:
@@ -57,6 +84,9 @@ def evaluate(result, examples):
5784

5885

5986
def main():
87+
"""
88+
Main function. (Test)
89+
"""
6090
model = DecisionModel()
6191
model.train()
6292
model.test()

0 commit comments

Comments
 (0)