From cf4da20c7b6078dcf307946dab968ba0224afe9e Mon Sep 17 00:00:00 2001 From: Ari Crellin-Quick Date: Mon, 29 Jan 2018 16:39:08 -0800 Subject: [PATCH] Add OOB score and feature importance chart to displayed model metrics --- cesium_app/handlers/model.py | 21 ++++++++++++--- cesium_app/models.py | 2 +- package.json | 2 ++ static/js/components/FeatureImportances.jsx | 29 +++++++++++++++++++++ static/js/components/Models.jsx | 16 +++++++++--- 5 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 static/js/components/FeatureImportances.jsx diff --git a/cesium_app/handlers/model.py b/cesium_app/handlers/model.py index 5aa4fe1..d924842 100644 --- a/cesium_app/handlers/model.py +++ b/cesium_app/handlers/model.py @@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params, if params_to_optimize: model = GridSearchCV(model, params_to_optimize) model.fit(fset, data['labels']) - score = model.score(fset, data['labels']) + + metrics = {} + metrics['train_score'] = model.score(fset, data['labels']) + best_params = model.best_params_ if params_to_optimize else {} joblib.dump(model, model_path) - return score, best_params + if model_type == 'RandomForestClassifier': + if params_to_optimize: + model = model.best_estimator_ + if hasattr(model, 'oob_score_'): + metrics['oob_score'] = model.oob_score_ + if hasattr(model, 'feature_importances_'): + metrics['feature_importances'] = dict(zip( + fset.columns.get_level_values(0).tolist(), + model.feature_importances_.tolist())) + + return metrics, best_params class ModelHandler(BaseHandler): @@ -84,12 +97,12 @@ def get(self, model_id=None): @auth_or_token async def _await_model_statistics(self, model_stats_future, model): try: - score, best_params = await model_stats_future + model_metrics, best_params = await model_stats_future model = DBSession().merge(model) model.task_id = None model.finished = datetime.datetime.now() - model.train_score = score + model.metrics = model_metrics model.params.update(best_params) DBSession().commit() diff --git a/cesium_app/models.py b/cesium_app/models.py index f538d17..8a26f4f 100644 --- a/cesium_app/models.py +++ b/cesium_app/models.py @@ -89,7 +89,7 @@ class Model(Base): file_uri = sa.Column(sa.String(), nullable=True, index=True) task_id = sa.Column(sa.String()) finished = sa.Column(sa.DateTime) - train_score = sa.Column(sa.Float) + metrics = sa.Column(sa.JSON, nullable=True) featureset = relationship('Featureset') project = relationship('Project') diff --git a/package.json b/package.json index b5ce18e..95d27a9 100644 --- a/package.json +++ b/package.json @@ -9,12 +9,14 @@ "bokehjs": "^0.12.5", "bootstrap": "^3.3.7", "bootstrap-css": "^3.0.0", + "chart.js": "^2.7.1", "css-loader": "^0.26.2", "exports-loader": "^0.6.4", "imports-loader": "^0.7.1", "jquery": "^3.1.1", "prop-types": "^15.5.10", "react": "^15.1.0", + "react-chartjs-2": "^2.7.0", "react-dom": "^15.1.0", "react-redux": "^5.0.3", "react-tabs": "^0.8.2", diff --git a/static/js/components/FeatureImportances.jsx b/static/js/components/FeatureImportances.jsx new file mode 100644 index 0000000..e678cea --- /dev/null +++ b/static/js/components/FeatureImportances.jsx @@ -0,0 +1,29 @@ +import React from 'react'; +import { HorizontalBar } from 'react-chartjs-2'; + + +const FeatureImportancesBarchart = props => { + const sorted_features = Object.keys(props.data).sort( + (a, b) => props.data[b] - props.data[a]).slice(0, 15); + const values = sorted_features.map( + feature => props.data[feature].toFixed(3)); + const data = { + labels: sorted_features, + datasets: [ + { + label: 'Feature Importance', + backgroundColor: '#2222ff', + hoverBackgroundColor: '#5555ff', + data: values + } + ] + }; + + return ( +
+ +
+ ); +}; + +export default FeatureImportancesBarchart; diff --git a/static/js/components/Models.jsx b/static/js/components/Models.jsx index 1422cb8..d7734dd 100644 --- a/static/js/components/Models.jsx +++ b/static/js/components/Models.jsx @@ -11,6 +11,7 @@ import Expand from './Expand'; import Delete from './Delete'; import { $try, reformatDatetime } from '../utils'; import FoldableRow from './FoldableRow'; +import FeatureImportances from './FeatureImportances'; const ModelsTab = props => ( @@ -169,7 +170,7 @@ let ModelInfo = props => ( Model Type Hyperparameters - Training Data Score + {Object.keys(props.model.metrics).map(metric => {metric})} @@ -191,9 +192,16 @@ let ModelInfo = props => ( - - {props.model.train_score} - + { + Object.keys(props.model.metrics).map(metric => ( + + { + metric == 'feature_importances' ? + : + props.model.metrics[metric] + } + )) + }