From a7e8bc9889e045866fafa1244f14408a9857d4d6 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Fri, 14 Apr 2017 11:02:38 -0700 Subject: [PATCH] Update calls to save/load featureset to newer cesium API --- cesium_app/handlers/prediction.py | 19 +++++++++++-------- cesium_app/models.py | 26 ++++++++++++++------------ cesium_app/tests/fixtures.py | 7 ++++--- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/cesium_app/handlers/prediction.py b/cesium_app/handlers/prediction.py index a36b3af..aa20b07 100644 --- a/cesium_app/handlers/prediction.py +++ b/cesium_app/handlers/prediction.py @@ -115,15 +115,15 @@ def post(self): model_or_gridcv) preds = executor.submit(lambda fset, model: model.predict(fset), imputed_fset, model_data) - pred_probs = executor.submit(lambda fset, model: model.predict_proba(fset) + pred_probs = executor.submit(lambda fset, model: + pd.DataFrame(model.predict_proba(fset), + index=fset.index, + columns=model.classes_) if hasattr(model, 'predict_proba') else [], imputed_fset, model_data) - all_classes = executor.submit(lambda model: model.classes_ - if hasattr(model, 'classes_') else [], - model_data) future = executor.submit(featurize.save_featureset, imputed_fset, pred_path, labels=all_labels, preds=preds, - pred_probs=pred_probs, all_classes=all_classes) + pred_probs=pred_probs) prediction.task_id = future.key prediction.save() @@ -182,9 +182,12 @@ def post(self): features_to_use=features_to_use, meta_features=meta_feats) fset = featurize.impute_featureset(fset, **impute_kwargs) - data = {'preds': model_data.predict(fset), - 'all_classes': model_data.classes_} + data = {'preds': model_data.predict(fset)} if hasattr(model_data, 'predict_proba'): - data['pred_probs'] = model_data.predict_proba(fset) + data['pred_probs'] = pd.DataFrame(model_data.predict_proba(fset), + index=fset.index, + columns=model_data.classes_) + else: + data['pred_probs'] = [] pred_info = Prediction.format_pred_data(fset, data) return self.success(pred_info) diff --git a/cesium_app/models.py b/cesium_app/models.py index 52111ff..09851de 100644 --- a/cesium_app/models.py +++ b/cesium_app/models.py @@ -3,7 +3,7 @@ import os import sys import time -import numpy as np +import pandas as pd import peewee as pw from playhouse.postgres_ext import ArrayField, BinaryJSONField @@ -73,6 +73,7 @@ class File(BaseModel): name = pw.CharField(null=True) created = pw.DateTimeField(default=datetime.datetime.now) + @signals.post_delete(sender=File) def remove_file_after_delete(sender, instance): try: @@ -135,6 +136,7 @@ class Meta: (('dataset', 'file'), True), ) + @signals.pre_delete(sender=Dataset) def remove_related_files(sender, instance): for f in instance.files: @@ -148,7 +150,7 @@ class Featureset(BaseModel): name = pw.CharField() created = pw.DateTimeField(default=datetime.datetime.now) features_list = ArrayField(pw.CharField) - custom_features_script = pw.CharField(null=True) # move to fset file? + custom_features_script = pw.CharField(null=True) # move to fset file? file = pw.ForeignKeyField(File, on_delete='CASCADE') task_id = pw.CharField(null=True) finished = pw.DateTimeField(null=True) @@ -194,16 +196,15 @@ def is_owned_by(self, username): def format_pred_data(fset, data): fset.columns = fset.columns.droplevel('channel') fset.index = fset.index.astype(str) # can't use ints as JSON keys - result = {} - for i, name in enumerate(fset.index): - result[name] = {'features': fset.loc[name].to_dict()} - if 'labels' in data: - result[name]['label'] = data['labels'][i] - if len(data['pred_probs']) > 0: - result[name]['prediction'] = dict(zip(data['all_classes'], - data['pred_probs'][i])) - else: - result[name]['prediction'] = data['preds'][i] + labels = pd.Series(data.get('labels'), index=fset.index) + if len(data.get('pred_probs', [])) > 0: + preds = pd.DataFrame(data.get('pred_probs', []), + index=fset.index).to_dict(orient='index') + else: + preds = pd.Series(data['preds'], index=fset.index).to_dict() + result = {name: {'features': feats, 'label': labels.loc[name], + 'prediction': preds[name]} + for name, feats in fset.to_dict(orient='index').items()} return result def display_info(self): @@ -238,6 +239,7 @@ def create_tables(retry=5): print('Could not connect to database...sleeping 5') time.sleep(5) + def drop_tables(): db.drop_tables(models, safe=True, cascade=True) diff --git a/cesium_app/tests/fixtures.py b/cesium_app/tests/fixtures.py index 493ac29..fd18d7b 100644 --- a/cesium_app/tests/fixtures.py +++ b/cesium_app/tests/fixtures.py @@ -14,6 +14,7 @@ import peewee import datetime import joblib +import pandas as pd @contextmanager @@ -160,14 +161,14 @@ def create_test_prediction(dataset, model): if hasattr(model_data, 'best_estimator_'): model_data = model_data.best_estimator_ preds = model_data.predict(fset) - pred_probs = (model_data.predict_proba(fset) + pred_probs = (pd.DataFrame(model_data.predict_proba(fset), + index=fset.index, columns=model_data.classes_) if hasattr(model_data, 'predict_proba') else []) all_classes = model_data.classes_ if hasattr(model_data, 'classes_') else [] pred_path = pjoin(cfg['paths']['predictions_folder'], '{}.npz'.format(str(uuid.uuid4()))) featurize.save_featureset(fset, pred_path, labels=data['labels'], - preds=preds, pred_probs=pred_probs, - all_classes=all_classes) + preds=preds, pred_probs=pred_probs) f, created = m.File.get_or_create(uri=pred_path) pred = m.Prediction.create(file=f, dataset=dataset, project=dataset.project, model=model, finished=datetime.datetime.now())