Skip to content

Commit aab419c

Browse files
committed
neural net
1 parent b316211 commit aab419c

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

linear_reg.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,20 @@ def predict(x, y):
1414
else:
1515
file = 'model'
1616
pickle.dump(regr, open(file, 'wb'))
17-
# model = pickle.load(open('model', 'rb'))
18-
# y_pred = model.predict(x)
19-
# print("Predictions")
20-
# print(y_pred)
17+
2118

2219
def read_file(arg):
2320

2421
args = len(arg)
2522

26-
if (args > 1):
23+
if args > 1:
2724
file_name = sys.argv[1]
2825
else:
2926
file_name = 'data.csv'
3027

3128
data = pd.read_csv(file_name).squeeze()
3229

33-
if (args > 2):
30+
if args > 2:
3431
feature_count = int(sys.argv[2])
3532
else:
3633
feature_count = len(data.columns)
@@ -48,6 +45,5 @@ def main():
4845
predict(x, y)
4946

5047

51-
5248
if __name__ == '__main__':
5349
main()

neural_net.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pandas as pd
2+
import numpy as np
3+
from sklearn import linear_model
4+
import sys
5+
import pickle
6+
from sklearn.neural_network import MLPClassifier
7+
8+
9+
def nn(x, y):
10+
clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes = (5, 2), random_state = 1)
11+
clf.fit(x, y)
12+
if len(sys.argv) > 3:
13+
file = sys.argv[3]
14+
else:
15+
file = 'model'
16+
pickle.dump(clf, open(file, 'wb'))
17+
18+
19+
def read_file(arg):
20+
21+
args = len(arg)
22+
23+
if args > 1:
24+
file_name = sys.argv[1]
25+
else:
26+
file_name = 'data.csv'
27+
28+
data = pd.read_csv(file_name).squeeze()
29+
30+
if args > 2:
31+
feature_count = int(sys.argv[2])
32+
else:
33+
feature_count = len(data.columns)
34+
35+
x = data[data.columns[0:feature_count]]
36+
y = data[data.columns[feature_count:]]
37+
38+
return x, y
39+
40+
41+
def main():
42+
43+
x, y = read_file(sys.argv)
44+
45+
nn(x, y)
46+
47+
48+
if __name__ == '__main__':
49+
main()

0 commit comments

Comments
 (0)