Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/cesium-ml/cesium_web into…
Browse files Browse the repository at this point in the history
… add_project
  • Loading branch information
MichaelTamaki committed Apr 25, 2017
2 parents 3a5ec67 + 993cb7e commit 220c1a9
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 121 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ ghostdriver.log
*.swo
__pycache__/
node_modules/
*.pyc
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ webpack = ./node_modules/.bin/webpack


dependencies:
@./tools/silent_monitor.py ./tools/install_deps.py requirements.txt
@./tools/silent_monitor.py pip install -r requirements.txt
@./tools/silent_monitor.py ./tools/check_js_deps.sh

db_init:
Expand Down
23 changes: 6 additions & 17 deletions cesium_app/handlers/plot_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,9 @@


class PlotFeaturesHandler(BaseHandler):
def _get_featureset(self, featureset_id):
try:
f = Featureset.get(Featureset.id == featureset_id)
except Featureset.DoesNotExist:
raise AccessError('No such feature set')

if not f.is_owned_by(self.get_username()):
raise AccessError('No such feature set')

return f

def get(self, featureset_id=None):
fset = self._get_featureset(featureset_id)
features_to_plot = sorted(fset.features_list)[0:4]
data, layout = plot.feature_scatterplot(fset.file.uri, features_to_plot)

self.success({'data': data, 'layout': layout})
def get(self, featureset_id):
fset = Featureset.get_if_owned(featureset_id, self.get_username())
features_to_plot = sorted(fset.features_list)[0:4] # TODO from form
docs_json, render_items = plot.feature_scatterplot(fset.file.uri,
features_to_plot)
self.success({'docs_json': docs_json, 'render_items': render_items})
4 changes: 2 additions & 2 deletions cesium_app/handlers/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def get(self, prediction_id=None, action=None):
'label': data['labels'],
'prediction': data['preds']},
columns=['ts_name', 'label', 'prediction'])
if data.get('pred_probs'):
result['probability'] = np.max(data['pred_probs'], axis=1)
if len(data.get('pred_probs', [])) > 0:
result['probability'] = data['pred_probs'].max(axis=1).values
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
self.set_header("Content-Disposition", "attachment; "
"filename=cesium_prediction_results.csv")
Expand Down
12 changes: 12 additions & 0 deletions cesium_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ class Featureset(BaseModel):
def is_owned_by(self, username):
return self.project.is_owned_by(username)

@staticmethod
def get_if_owned(fset_id, username):
try:
f = Featureset.get(Featureset.id == fset_id)
except Featureset.DoesNotExist:
raise AccessError('No such feature set')

if not f.is_owned_by(username):
raise AccessError('No such feature set')

return f


class Model(BaseModel):
"""ORM model of the Model table"""
Expand Down
83 changes: 40 additions & 43 deletions cesium_app/plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from itertools import cycle
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import plotly
import plotly.offline as py
from plotly.tools import FigureFactory as FF

from cesium import featurize
from .config import cfg
from bokeh.plotting import figure
from bokeh.layouts import gridplot
from bokeh.palettes import Viridis as palette
from bokeh.core.json_encoder import serialize_json
from bokeh.document import Document
from bokeh.util.serialization import make_id


def feature_scatterplot(fset_path, features_to_plot):
Expand All @@ -21,42 +21,39 @@ def feature_scatterplot(fset_path, features_to_plot):
Returns
-------
(fig.data, fig.layout)
Returns (fig.data, fig.layout) where `fig` is an instance of
`plotly.tools.FigureFactory`.
(str, str)
Returns (docs_json, render_items) json for the desired plot.
"""
fset, data = featurize.load_featureset(fset_path)
fset = fset[features_to_plot]

if 'label' in data:
fset['label'] = data['label']
index = 'label'
else:
index = None

# TODO replace 'trace {i}' with class labels
fig = FF.create_scatterplotmatrix(fset, diag='box', index=index,
height=800, width=800)

py.plot(fig, auto_open=False, output_type='div')

return fig.data, fig.layout


#def prediction_heatmap(pred_path):
# with xr.open_dataset(pred_path) as pset:
# pred_df = pd.DataFrame(pset.prediction.values, index=pset.name,
# columns=pset.class_label.values)
# pred_labels = pred_df.idxmax(axis=1)
# C = confusion_matrix(pset.label, pred_labels)
# row_sums = C.sum(axis=1)
# C = C / row_sums[:, np.newaxis]
# fig = FF.create_annotated_heatmap(C, x=[str(el) for el in
# pset.class_label.values],
# y=[str(el) for el in
# pset.class_label.values],
# colorscale='Viridis')
#
# py.plot(fig, auto_open=False, output_type='div')
#
# return fig.data, fig.layout
colors = cycle(palette[5])
plots = np.array([[figure(width=300, height=200)
for j in range(len(features_to_plot))]
for i in range(len(features_to_plot))])

for (j, i), p in np.ndenumerate(plots):
if (j == i == 0):
p.title.text = "Scatterplot matrix"
p.circle(fset.values[:,i], fset.values[:,j], color=next(colors))
p.xaxis.minor_tick_line_color = None
p.yaxis.minor_tick_line_color = None
p.ygrid[0].ticker.desired_num_ticks = 2
p.xgrid[0].ticker.desired_num_ticks = 4
p.outline_line_color = None
p.axis.visible = None

plot = gridplot(plots.tolist(), ncol=len(features_to_plot), mergetools=True, responsive=True, title="Test")

# Convert plot to json objects necessary for rendering with bokeh on the
# frontend
render_items = [{'docid': plot._id, 'elementid': make_id()}]

doc = Document()
doc.add_root(plot)
docs_json_inner = doc.to_json()
docs_json = {render_items[0]['docid']: docs_json_inner}

docs_json = serialize_json(docs_json)
render_items = serialize_json(render_items)

return docs_json, render_items
2 changes: 1 addition & 1 deletion cesium_app/tests/frontend/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_plot_features(driver):
driver.find_element_by_xpath("//b[contains(text(),'Please wait while we load your plotting data...')]")

driver.implicitly_wait(3)
driver.find_element_by_css_selector("[class=svg-container]")
driver.find_element_by_css_selector("[class=bk-plotdiv]")


def test_delete_featureset(driver):
Expand Down
21 changes: 21 additions & 0 deletions cesium_app/tests/frontend/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from os.path import join as pjoin
import numpy as np
import numpy.testing as npt
import pandas as pd
from cesium_app.config import cfg
import json
import requests
Expand Down Expand Up @@ -185,6 +186,26 @@ def test_download_prediction_csv_class(driver):
os.remove('/tmp/cesium_prediction_results.csv')


def test_download_prediction_csv_class_prob(driver):
driver.get('/')
with create_test_project() as p, create_test_dataset(p) as ds,\
create_test_featureset(p) as fs,\
create_test_model(fs, model_type='RandomForestClassifier') as m,\
create_test_prediction(ds, m):
_click_download(p.id, driver)
assert os.path.exists('/tmp/cesium_prediction_results.csv')
try:
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
npt.assert_array_equal(result.ts_name, np.arange(5))
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
'Mira', 'Classical_Cepheid',
'Mira'])
npt.assert_array_equal(result.label, result.prediction)
assert (result.probability >= 0.0).all()
finally:
os.remove('/tmp/cesium_prediction_results.csv')


def test_download_prediction_csv_regr(driver):
driver.get('/')
with create_test_project() as p, create_test_dataset(p, label_type='regr') as ds,\
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
"test": "eslint -c .eslintrc --ext .jsx,.js public/scripts/ && make test"
},
"dependencies": {
"bokehjs": "^0.12.5",
"bootstrap": "^3.3.7",
"bootstrap-css": "^3.0.0",
"css-loader": "^0.26.2",
"exports-loader": "^0.6.4",
"imports-loader": "^0.7.1",
"jquery": "^3.1.1",
"plotly.js": "^1.23.1",
"react": "^15.1.0",
"react-dom": "^15.1.0",
"react-redux": "^5.0.3",
Expand All @@ -23,6 +23,7 @@
"redux-logger": "^2.8.1",
"redux-thunk": "^2.2.0",
"style-loader": "^0.13.2",
"typescript": "^2.2.2",
"webpack": "^2.2.1",
"webpack-dev-server": "^2.4.1",
"whatwg-fetch": "^2.0.2"
Expand Down
27 changes: 22 additions & 5 deletions public/scripts/Plot.jsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import React, { Component } from 'react';
import { connect } from 'react-redux';
import Plotly from './custom-plotly';
import { showNotification } from './Notifications';
import "../../node_modules/bokehjs/build/js/bokeh.js";
import "../../node_modules/bokehjs/build/css/bokeh.css";

function bokeh_render_plot(node, docs_json, render_items) {
// Create bokeh div element
var bokeh_div = document.createElement("div");
var inner_div = document.createElement("div");
bokeh_div.setAttribute("class", "bk-root" );
inner_div.setAttribute("class", "bk-plotdiv");
inner_div.setAttribute("id", render_items[0].elementid);
bokeh_div.appendChild(inner_div);
node.appendChild(bokeh_div);

// Generate plot
Bokeh.safely(function() {
Bokeh.embed.embed_items(docs_json, render_items);
});
}

class Plot extends Component {
constructor(props) {
Expand Down Expand Up @@ -32,16 +48,17 @@ class Plot extends Component {
if (!plotData) {
return <b>Please wait while we load your plotting data...</b>;
}

let { data, layout } = plotData;
var docs_json = JSON.parse(plotData.docs_json);
var render_items = JSON.parse(plotData.render_items);

return (
plotData &&
<div
ref={
(node) => {
node && Plotly.plot(node, data, layout);
}}
node && bokeh_render_plot(node, docs_json, render_items)
}
}
/>
);
}
Expand Down
13 changes: 0 additions & 13 deletions public/scripts/custom-plotly.js

This file was deleted.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ pyyaml
tornado
pyzmq
pyjwt
plotly>=2.0.5
simplejson
distributed>=1.14.3
selenium
pytest
joblib>=0.11
bokeh==0.12.5
37 changes: 0 additions & 37 deletions tools/install_deps.py

This file was deleted.

0 comments on commit 220c1a9

Please sign in to comment.