-
Notifications
You must be signed in to change notification settings - Fork 1
/
number_classifier.py
59 lines (44 loc) · 2.18 KB
/
number_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from sklearn.base import TransformerMixin
from help_functions import data_retriever
from help_functions.validate_classifier import validate_model
class FeatureExtractor(TransformerMixin):
"""
Extracts features from the bitmap.
"""
def fit(self, bitmaps: iter, *others):
"""
Use the data to prepare for any transformation.
:param bitmaps: A list of bitmaps.
:param others: Stuff other modules might need.
:return: The Transformer itself. This allows for method-chaining.
"""
raise NotImplementedError('FeatureExtractor: You may store som relevant information about the data set here.')
def transform(self, bitmaps, *others):
"""
Transform the bitmaps to wanted representation.
:param bitmaps: A list of bitmaps.
:param others: Stuff other modules might need.
:return: The extracted features.
"""
raise NotImplementedError('FeatureExtractor: Implement transform method.')
def split_and_shuffle_data_set(data: np.ndarray, labels: np.ndarray, train_proportion: float = 0.8):
raise NotImplementedError("split_and_shuffle_data_set: You need to split the data")
def train_classifier(training_features, training_labels):
raise NotImplementedError('train_classifier: Choose estimator (or create your own), fit it and return it.')
def run_number_classifier():
rows = -1 # -1 means retrieving complete set. When testing, set lower for faster training (e.g. 5000).
print('-- Executing number classification')
print('Loading data...')
data, labels = data_retriever.load_mnist(rows)
print('Splitting data...')
training_data, test_data, training_labels, test_labels = split_and_shuffle_data_set(data, labels)
print('Extracting features...')
extractor = FeatureExtractor()
extractor.fit(training_data)
training_features = extractor.transform(training_data)
test_features = extractor.transform(test_data)
print('Training classifier...')
classifier = train_classifier(training_features, training_labels)
print('Testing classifier...')
validate_model(classifier, test_data, test_features, test_labels, bitmap=True)