Skip to content

Commit

Permalink
Add OOB score and feature importance chart to displayed model metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
acrellin committed Jan 30, 2018
1 parent 96fb017 commit cf4da20
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 9 deletions.
21 changes: 17 additions & 4 deletions cesium_app/handlers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
if params_to_optimize:
model = GridSearchCV(model, params_to_optimize)
model.fit(fset, data['labels'])
score = model.score(fset, data['labels'])

metrics = {}
metrics['train_score'] = model.score(fset, data['labels'])

best_params = model.best_params_ if params_to_optimize else {}
joblib.dump(model, model_path)

return score, best_params
if model_type == 'RandomForestClassifier':
if params_to_optimize:
model = model.best_estimator_
if hasattr(model, 'oob_score_'):
metrics['oob_score'] = model.oob_score_
if hasattr(model, 'feature_importances_'):
metrics['feature_importances'] = dict(zip(
fset.columns.get_level_values(0).tolist(),
model.feature_importances_.tolist()))

return metrics, best_params


class ModelHandler(BaseHandler):
Expand All @@ -84,12 +97,12 @@ def get(self, model_id=None):
@auth_or_token
async def _await_model_statistics(self, model_stats_future, model):
try:
score, best_params = await model_stats_future
model_metrics, best_params = await model_stats_future

model = DBSession().merge(model)
model.task_id = None
model.finished = datetime.datetime.now()
model.train_score = score
model.metrics = model_metrics
model.params.update(best_params)
DBSession().commit()

Expand Down
2 changes: 1 addition & 1 deletion cesium_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Model(Base):
file_uri = sa.Column(sa.String(), nullable=True, index=True)
task_id = sa.Column(sa.String())
finished = sa.Column(sa.DateTime)
train_score = sa.Column(sa.Float)
metrics = sa.Column(sa.JSON, nullable=True)

featureset = relationship('Featureset')
project = relationship('Project')
Expand Down
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
"bokehjs": "^0.12.5",
"bootstrap": "^3.3.7",
"bootstrap-css": "^3.0.0",
"chart.js": "^2.7.1",
"css-loader": "^0.26.2",
"exports-loader": "^0.6.4",
"imports-loader": "^0.7.1",
"jquery": "^3.1.1",
"prop-types": "^15.5.10",
"react": "^15.1.0",
"react-chartjs-2": "^2.7.0",
"react-dom": "^15.1.0",
"react-redux": "^5.0.3",
"react-tabs": "^0.8.2",
Expand Down
29 changes: 29 additions & 0 deletions static/js/components/FeatureImportances.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import React from 'react';
import { HorizontalBar } from 'react-chartjs-2';


const FeatureImportancesBarchart = props => {
const sorted_features = Object.keys(props.data).sort(
(a, b) => props.data[b] - props.data[a]).slice(0, 15);
const values = sorted_features.map(
feature => props.data[feature].toFixed(3));
const data = {
labels: sorted_features,
datasets: [
{
label: 'Feature Importance',
backgroundColor: '#2222ff',
hoverBackgroundColor: '#5555ff',
data: values
}
]
};

return (
<div style={{ height: 300, width: 600 }}>
<HorizontalBar data={data} />
</div>
);
};

export default FeatureImportancesBarchart;
16 changes: 12 additions & 4 deletions static/js/components/Models.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Expand from './Expand';
import Delete from './Delete';
import { $try, reformatDatetime } from '../utils';
import FoldableRow from './FoldableRow';
import FeatureImportances from './FeatureImportances';


const ModelsTab = props => (
Expand Down Expand Up @@ -169,7 +170,7 @@ let ModelInfo = props => (
<tr>
<th>Model Type</th>
<th>Hyperparameters</th>
<th>Training Data Score</th>
{Object.keys(props.model.metrics).map(metric => <th>{metric}</th>)}
</tr>
</thead>
<tbody>
Expand All @@ -191,9 +192,16 @@ let ModelInfo = props => (
</tbody>
</table>
</td>
<td>
{props.model.train_score}
</td>
{
Object.keys(props.model.metrics).map(metric => (
<td>
{
metric == 'feature_importances' ?
<FeatureImportances data={props.model.metrics[metric]} /> :
props.model.metrics[metric]
}
</td>))
}
</tr>
</tbody>
</table>
Expand Down

0 comments on commit cf4da20

Please sign in to comment.