Skip to content

Commit

Permalink
Merge pull request #188 from bnaul/featureset_df_save
Browse files Browse the repository at this point in the history
Update calls to save/load featureset to newer cesium API
  • Loading branch information
bnaul committed Apr 17, 2017
2 parents 1436f95 + a7e8bc9 commit 69e4ba2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
19 changes: 11 additions & 8 deletions cesium_app/handlers/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ def post(self):
model_or_gridcv)
preds = executor.submit(lambda fset, model: model.predict(fset),
imputed_fset, model_data)
pred_probs = executor.submit(lambda fset, model: model.predict_proba(fset)
pred_probs = executor.submit(lambda fset, model:
pd.DataFrame(model.predict_proba(fset),
index=fset.index,
columns=model.classes_)
if hasattr(model, 'predict_proba') else [],
imputed_fset, model_data)
all_classes = executor.submit(lambda model: model.classes_
if hasattr(model, 'classes_') else [],
model_data)
future = executor.submit(featurize.save_featureset, imputed_fset,
pred_path, labels=all_labels, preds=preds,
pred_probs=pred_probs, all_classes=all_classes)
pred_probs=pred_probs)

prediction.task_id = future.key
prediction.save()
Expand Down Expand Up @@ -182,9 +182,12 @@ def post(self):
features_to_use=features_to_use,
meta_features=meta_feats)
fset = featurize.impute_featureset(fset, **impute_kwargs)
data = {'preds': model_data.predict(fset),
'all_classes': model_data.classes_}
data = {'preds': model_data.predict(fset)}
if hasattr(model_data, 'predict_proba'):
data['pred_probs'] = model_data.predict_proba(fset)
data['pred_probs'] = pd.DataFrame(model_data.predict_proba(fset),
index=fset.index,
columns=model_data.classes_)
else:
data['pred_probs'] = []
pred_info = Prediction.format_pred_data(fset, data)
return self.success(pred_info)
26 changes: 14 additions & 12 deletions cesium_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
import time
import numpy as np
import pandas as pd

import peewee as pw
from playhouse.postgres_ext import ArrayField, BinaryJSONField
Expand Down Expand Up @@ -73,6 +73,7 @@ class File(BaseModel):
name = pw.CharField(null=True)
created = pw.DateTimeField(default=datetime.datetime.now)


@signals.post_delete(sender=File)
def remove_file_after_delete(sender, instance):
try:
Expand Down Expand Up @@ -135,6 +136,7 @@ class Meta:
(('dataset', 'file'), True),
)


@signals.pre_delete(sender=Dataset)
def remove_related_files(sender, instance):
for f in instance.files:
Expand All @@ -148,7 +150,7 @@ class Featureset(BaseModel):
name = pw.CharField()
created = pw.DateTimeField(default=datetime.datetime.now)
features_list = ArrayField(pw.CharField)
custom_features_script = pw.CharField(null=True) # move to fset file?
custom_features_script = pw.CharField(null=True) # move to fset file?
file = pw.ForeignKeyField(File, on_delete='CASCADE')
task_id = pw.CharField(null=True)
finished = pw.DateTimeField(null=True)
Expand Down Expand Up @@ -194,16 +196,15 @@ def is_owned_by(self, username):
def format_pred_data(fset, data):
fset.columns = fset.columns.droplevel('channel')
fset.index = fset.index.astype(str) # can't use ints as JSON keys
result = {}
for i, name in enumerate(fset.index):
result[name] = {'features': fset.loc[name].to_dict()}
if 'labels' in data:
result[name]['label'] = data['labels'][i]
if len(data['pred_probs']) > 0:
result[name]['prediction'] = dict(zip(data['all_classes'],
data['pred_probs'][i]))
else:
result[name]['prediction'] = data['preds'][i]
labels = pd.Series(data.get('labels'), index=fset.index)
if len(data.get('pred_probs', [])) > 0:
preds = pd.DataFrame(data.get('pred_probs', []),
index=fset.index).to_dict(orient='index')
else:
preds = pd.Series(data['preds'], index=fset.index).to_dict()
result = {name: {'features': feats, 'label': labels.loc[name],
'prediction': preds[name]}
for name, feats in fset.to_dict(orient='index').items()}
return result

def display_info(self):
Expand Down Expand Up @@ -238,6 +239,7 @@ def create_tables(retry=5):
print('Could not connect to database...sleeping 5')
time.sleep(5)


def drop_tables():
db.drop_tables(models, safe=True, cascade=True)

Expand Down
7 changes: 4 additions & 3 deletions cesium_app/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import peewee
import datetime
import joblib
import pandas as pd


@contextmanager
Expand Down Expand Up @@ -160,14 +161,14 @@ def create_test_prediction(dataset, model):
if hasattr(model_data, 'best_estimator_'):
model_data = model_data.best_estimator_
preds = model_data.predict(fset)
pred_probs = (model_data.predict_proba(fset)
pred_probs = (pd.DataFrame(model_data.predict_proba(fset),
index=fset.index, columns=model_data.classes_)
if hasattr(model_data, 'predict_proba') else [])
all_classes = model_data.classes_ if hasattr(model_data, 'classes_') else []
pred_path = pjoin(cfg['paths']['predictions_folder'],
'{}.npz'.format(str(uuid.uuid4())))
featurize.save_featureset(fset, pred_path, labels=data['labels'],
preds=preds, pred_probs=pred_probs,
all_classes=all_classes)
preds=preds, pred_probs=pred_probs)
f, created = m.File.get_or_create(uri=pred_path)
pred = m.Prediction.create(file=f, dataset=dataset, project=dataset.project,
model=model, finished=datetime.datetime.now())
Expand Down

0 comments on commit 69e4ba2

Please sign in to comment.