Skip to content

Commit e0fed61

Browse files
author
Genevieve Patterson
committed
sped up classifier creation by simplifying sqlalchemy queries and subsampling features
1 parent 8854d8e commit e0fed61

File tree

4 files changed

+223
-118
lines changed

4 files changed

+223
-118
lines changed

app/classifier_views.py

Lines changed: 142 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,57 @@
11
from app import app, db
22
from flask import render_template, redirect, url_for, jsonify, send_file
3-
from app.forms import ActiveQueryForm, ClassifierForm, BlobForm, DetectForm, ClassifierEvaluateForm
4-
from app.models import User, Classifier, PatchQuery, PatchResponse, HitResponse, Estimator, Detection, Round, Dataset, ValPrediction, Blob, Patch, dataset_x_blob
3+
from app.forms import (
4+
ActiveQueryForm,
5+
ClassifierForm,
6+
BlobForm,
7+
DetectForm,
8+
ClassifierEvaluateForm,
9+
)
10+
from app.models import (
11+
User,
12+
Classifier,
13+
PatchQuery,
14+
PatchResponse,
15+
HitResponse,
16+
Estimator,
17+
Detection,
18+
Round,
19+
Dataset,
20+
ValPrediction,
21+
Blob,
22+
Patch,
23+
dataset_x_blob,
24+
)
525

626

727
import tasks
828

929
import time, itertools
1030

11-
@app.route('/classifier/')
12-
def classifier_top():
13-
classifiers = Classifier.query.filter_by(is_ready=True).order_by(Classifier.id.desc()).limit(50)
14-
return render_template('classifier_top.html',
15-
title='Classifier Library',
16-
classifiers=classifiers,
17-
next=next)
1831

19-
@app.route('/classifier/<int:id>')
32+
@app.route("/classifier/")
33+
def classifier_top():
34+
classifiers = (
35+
Classifier.query.filter_by(is_ready=True)
36+
.order_by(Classifier.id.desc())
37+
.limit(50)
38+
)
39+
return render_template(
40+
"classifier_top.html",
41+
title="Classifier Library",
42+
classifiers=classifiers,
43+
next=next,
44+
)
45+
46+
47+
@app.route("/classifier/<int:id>")
2048
def classifier(id, r_id=None):
2149
classifier = Classifier.query.get_or_404(id)
2250

2351
if r_id is not None:
24-
round = Round.query.filter(Round.classifier==classifier, Round.number==r_id).first()
52+
round = Round.query.filter(
53+
Round.classifier == classifier, Round.number == r_id
54+
).first()
2555
else:
2656
round = classifier.latest_round
2757

@@ -30,49 +60,59 @@ def classifier(id, r_id=None):
3060
form = ActiveQueryForm()
3161
form.classifier.data = classifier
3262
form.round.data = round
33-
classifier_title = classifier.keyword.name if classifier.keyword else "Export Classifier"
63+
classifier_title = (
64+
classifier.keyword.name if classifier.keyword else "Export Classifier"
65+
)
3466

35-
return render_template('classifier.html', classifier=classifier, round=round,
36-
title='%s %d' % (classifier_title,
37-
classifier.rounds.count()),
38-
form=form, hits = hits)
67+
return render_template(
68+
"classifier.html",
69+
classifier=classifier,
70+
round=round,
71+
title="%s %d" % (classifier_title, classifier.rounds.count()),
72+
form=form,
73+
hits=hits,
74+
)
3975

40-
@app.route('/classifier/<int:id>/<int:i>/')
76+
77+
@app.route("/classifier/<int:id>/<int:i>/")
4178
def classifier_round(id, i):
4279
return classifier(id, r_id=i)
43-
80+
4481
# return render_template( 'classifier_round.html', classifier=classifier,
4582
# title='%s %d' % (classifier_title, i),
4683
# form=form, round = round, hits = hits)
4784

48-
@app.route('/classifier/<int:id>/download')
85+
86+
@app.route("/classifier/<int:id>/download")
4987
def classifier_dl(id):
5088
classifier = Classifier.query.get_or_404(id)
5189
round = classifier.latest_round
5290
return send_file(round.location)
5391

54-
@app.route('/classifier/update/<int:id>', methods = ['POST'])
92+
93+
@app.route("/classifier/update/<int:id>", methods=["POST"])
5594
def classifier_update(id):
56-
form = ActiveQueryForm();
95+
form = ActiveQueryForm()
5796

5897
if form.validate_on_submit():
59-
print('form.user.data: '+form.user.data)
60-
user = None#current_user
98+
print("form.user.data: " + form.user.data)
99+
user = None # current_user
61100

62-
if True:#user.is_anonymous and form.user.data:
101+
if True: # user.is_anonymous and form.user.data:
63102
# TODO fix when authentication is fixed
64-
user = User.query.get(1)#User.find(form.user.data)
103+
user = User.query.get(1) # User.find(form.user.data)
65104
# if user:
66105
# assert user.password == None
67106
# else:
68107
# user = User(username=form.user.data, password=None)
69108
# db.session.add(user)
70109
# user.location = form.location.data
71110
# user.nationality = form.nationality.data
72-
print('update classifier %s from %s.....' % (id, user))
111+
print("update classifier %s from %s....." % (id, user))
73112

74-
hit_resp = HitResponse(time = form.time.data,
75-
user = None if user.is_anonymous else user)
113+
hit_resp = HitResponse(
114+
time=form.time.data, user=None if user.is_anonymous else user
115+
)
76116
db.session.add(hit_resp)
77117
db.session.commit()
78118
print(str(time.time()) + ": " + str(hit_resp))
@@ -82,7 +122,7 @@ def classifier_update(id):
82122

83123
for patch, v in itertools.chain(pos_response, neg_response):
84124
pq = PatchQuery.query.filter_by(patch=patch, round=form.round.data).one()
85-
pr = PatchResponse(value=v, hitresponse=hit_resp, patchquery = pq)
125+
pr = PatchResponse(value=v, hitresponse=hit_resp, patchquery=pq)
86126
db.session.add(pr)
87127

88128
print(str(time.time()) + ": done")
@@ -96,26 +136,32 @@ def classifier_update(id):
96136
print(field.name)
97137
print(field.errors)
98138
print(field.data)
99-
return jsonify(results='success')
139+
return jsonify(results="success")
100140

101-
@app.route('/classifier/new', methods = ['POST'])
141+
142+
@app.route("/classifier/new", methods=["POST"])
102143
def classifier_new():
103144
form = ClassifierForm()
104145

105146
if form.validate_on_submit():
106-
c = Classifier(dataset = form.dataset.data,
107-
keyword = form.keyword.data,
108-
estimator = form.estimator.data)
147+
c = Classifier(
148+
dataset=form.dataset.data,
149+
keyword=form.keyword.data,
150+
estimator=form.estimator.data,
151+
)
109152
db.session.add(c)
110153
db.session.commit()
111-
tasks.if_classifier(c)
154+
155+
# TODO: for now, a new classifier randomly samples at most 100k features; consider parameterizing this.
156+
tasks.if_classifier(c, limited_number_of_features_to_evaluate=100000)
112157
return redirect(c.url)
113158
else:
114-
print('did not validate')
159+
print("did not validate")
115160
print(form.dataset.errors)
116-
return redirect(url_for('classifier_top'))
161+
return redirect(url_for("classifier_top"))
162+
117163

118-
@app.route('/classifier/<int:classifier_id>/<int:round_id>/evaluate/<int:dataset_id>')
164+
@app.route("/classifier/<int:classifier_id>/<int:round_id>/evaluate/<int:dataset_id>")
119165
def classifier_evaluate(classifier_id, round_id, dataset_id):
120166
classifier = Classifier.query.get_or_404(classifier_id)
121167
round = Round.query.get_or_404(round_id)
@@ -127,33 +173,44 @@ def classifier_evaluate(classifier_id, round_id, dataset_id):
127173
form.dataset.data = dataset
128174

129175
if dataset.is_train:
130-
return 'not validation set'
176+
return "not validation set"
131177

132-
#maybe instead I can check the notes in the round
133-
predicts = db.engine.execute('SELECT * from val_prediction where patch_id in\
178+
# maybe instead I can check the notes in the round
179+
predicts = db.engine.execute(
180+
"SELECT * from val_prediction where patch_id in\
134181
(select id from patch where blob_id in \
135182
(select blob_id from dataset_x_blob where dataset_id=%d)) \
136183
and round_id=%d\
137-
ORDER BY value DESC;'\
138-
% (dataset.id, round.id)).fetchmany(100)
184+
ORDER BY value DESC;"
185+
% (dataset.id, round.id)
186+
).fetchmany(100)
139187
if predicts:
140188
first_last_patches = None
141189
form.note.data = None
142190
if round.notes:
143191
rn = eval(round.notes)
144192
for note_id in rn:
145-
user_id = -1 #if current_user.is_anonymous else current_user.id
146-
if rn[note_id]["dataset"] == dataset.id and rn[note_id]["user"] == user_id:
193+
user_id = -1 # if current_user.is_anonymous else current_user.id
194+
if (
195+
rn[note_id]["dataset"] == dataset.id
196+
and rn[note_id]["user"] == user_id
197+
):
147198
first_last_patches = rn[note_id]["first_last_patches"]
148199
form.note.data = note_id
149-
return render_template('classifier_evaluate.html', predicts=predicts, form=form, eval=first_last_patches)
150-
return redirect(url_for('evaluate_top'))
200+
return render_template(
201+
"classifier_evaluate.html",
202+
predicts=predicts,
203+
form=form,
204+
eval=first_last_patches,
205+
)
206+
return redirect(url_for("evaluate_top"))
151207

152-
@app.route("/classifier/eval_range", methods=['POST'])
208+
209+
@app.route("/classifier/eval_range", methods=["POST"])
153210
def classifier_eval_range():
154211
form = ClassifierEvaluateForm()
155212
if form.validate_on_submit():
156-
user = None#current_user
213+
user = None # current_user
157214
round = form.round.data
158215
round = Round.query.get(round.id)
159216

@@ -166,28 +223,39 @@ def classifier_eval_range():
166223

167224
if note:
168225
note = int(note)
169-
tops_bottoms = {"eval_score": abs(first_incorrect.id - last_correct.id),
170-
"first_last_patches": (first_incorrect.id, last_correct.id)}
226+
tops_bottoms = {
227+
"eval_score": abs(first_incorrect.id - last_correct.id),
228+
"first_last_patches": (first_incorrect.id, last_correct.id),
229+
}
171230
new_note = {**prev_notes[note], **tops_bottoms}
172-
round.notes = str({**prev_notes, **{note:new_note}})
231+
round.notes = str({**prev_notes, **{note: new_note}})
173232
else:
174233
# pass
175234
insert_id = -1 if user.is_anonymous else user.id
176-
note = max(prev_notes)+1
177-
tops_bottoms = {note : {"user":insert_id, "dataset":dataset.id,
178-
"eval_score": abs(first_incorrect.id - last_correct.id),
179-
"first_last_patches": (first_incorrect.id, last_correct.id)}}
180-
round.notes = str({**prev_notes, **tops_bottoms}) if round.notes else str(tops_bottoms)
235+
note = max(prev_notes) + 1
236+
tops_bottoms = {
237+
note: {
238+
"user": insert_id,
239+
"dataset": dataset.id,
240+
"eval_score": abs(first_incorrect.id - last_correct.id),
241+
"first_last_patches": (first_incorrect.id, last_correct.id),
242+
}
243+
}
244+
round.notes = (
245+
str({**prev_notes, **tops_bottoms})
246+
if round.notes
247+
else str(tops_bottoms)
248+
)
181249
db.session.commit()
182250
return jsonify({"note_id": note})
183251
else:
184-
return redirect(url_for('evaluate_top'))
252+
return redirect(url_for("evaluate_top"))
185253

186254

187255
####################
188256
### Detect Views ###
189257
####################
190-
@app.route('/detect/', methods = ['GET', 'POST'])
258+
@app.route("/detect/", methods=["GET", "POST"])
191259
def detect_top():
192260
form = BlobForm()
193261
detect_form = DetectForm()
@@ -202,19 +270,27 @@ def detect_top():
202270
for d in detects:
203271
tasks.detect.delay(d.id)
204272

205-
return render_template('detect.html', title='Detect', form=form,
206-
detect_form=detect_form, detects=Detection.query.all())
273+
return render_template(
274+
"detect.html",
275+
title="Detect",
276+
form=form,
277+
detect_form=detect_form,
278+
detects=Detection.query.all(),
279+
)
280+
207281

208-
@app.route('/detect/<int:id>/')
282+
@app.route("/detect/<int:id>/")
209283
def detect(id):
210284
detect = Detection.query.get_or_404(id)
211285

212-
return render_template('detection.html', detect=detect,
213-
title='Detect %d' % (detect.id))
286+
return render_template(
287+
"detection.html", detect=detect, title="Detect %d" % (detect.id)
288+
)
289+
214290

215291
def make_hit(examples, patch_queries):
216292
return {
217-
"positives": [ex.patch.id for ex in examples if ex.value],
293+
"positives": [ex.patch.id for ex in examples if ex.value],
218294
"negatives": [ex.patch.id for ex in examples if not ex.value],
219-
"queries": [pq.patch.id for pq in patch_queries]
295+
"queries": [pq.patch.id for pq in patch_queries],
220296
}

0 commit comments

Comments
 (0)