forked from agramfort/datacamp-assignment1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsklearn_questions.py
36 lines (31 loc) · 912 Bytes
/
sklearn_questions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# noqa: D100
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.utils.validation import check_array
class OneNearestNeighbor(BaseEstimator, ClassifierMixin):
"""Write docstring
"""
def __init__(self): # noqa: D107
pass
def fit(self, X, y):
"""Write docstring
"""
X, y = check_X_y(X, y)
self.classes_ = np.unique(y)
# XXX fix
return self
def predict(self, X):
"""Write docstring
"""
check_is_fitted(self)
X = check_array(X)
y_pred = np.full(shape=len(X), fill_value=self.classes_[0])
# XXX fix
return y_pred
def score(self, X, y):
"""Write docstring
"""
X, y = check_X_y(X, y)
y_pred = self.predict(X)
return np.mean(y_pred == y)