Skip to content

Commit

Permalink
Merge remote-tracking branch 'cesium-ml/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelTamaki committed Apr 3, 2017
2 parents d5e0d9b + 1436f95 commit 14c3d7c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
17 changes: 14 additions & 3 deletions cesium_app/handlers/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def _get_prediction(self, prediction_id):
try:
d = Prediction.get(Prediction.id == prediction_id)
except Prediction.DoesNotExist:
raise AccessError('No such dataset')
raise AccessError('No such prediction')

if not d.is_owned_by(self.get_username()):
raise AccessError('No such dataset')
raise AccessError('No such prediction')

return d

Expand Down Expand Up @@ -67,6 +67,9 @@ def post(self):

dataset_id = data['datasetID']
model_id = data['modelID']
# If only a subset of specified dataset is to be used, a list of the
# corresponding time series file names can be provided
ts_names = data.get('ts_names')

dataset = Dataset.get(Dataset.id == data["datasetID"])
model = Model.get(Model.id == data["modelID"])
Expand All @@ -88,7 +91,15 @@ def post(self):

executor = yield self._get_executor()

all_time_series = executor.map(time_series.load, dataset.uris)
# If only a subset of the dataset is to be used, get specified files
if ts_names:
ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name)
in ts_names or os.path.basename(f.name).split('.npz')[0]
in ts_names]
else:
ts_uris = dataset.uris

all_time_series = executor.map(time_series.load, ts_uris)
all_labels = executor.map(lambda ts: ts.label, all_time_series)
all_features = executor.map(featurize.featurize_single_ts,
all_time_series,
Expand Down
31 changes: 31 additions & 0 deletions cesium_app/tests/frontend/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from os.path import join as pjoin
import numpy as np
import numpy.testing as npt
from cesium_app.config import cfg
import json
import requests
from cesium_app.tests.fixtures import (create_test_project, create_test_dataset,
create_test_featureset, create_test_model,
create_test_prediction)
Expand Down Expand Up @@ -204,3 +207,31 @@ def test_download_prediction_csv_regr(driver):
[4, 3.1, 3.1]])
finally:
os.remove('/tmp/cesium_prediction_results.csv')


def test_predict_specific_ts_name():
with create_test_project() as p, create_test_dataset(p) as ds,\
create_test_featureset(p) as fs, create_test_model(fs) as m:
ts_data = [[1, 2, 3, 4], [32.2, 53.3, 32.3, 32.52], [0.2, 0.3, 0.6, 0.3]]
impute_kwargs = {'strategy': 'constant', 'value': None}
data = {'datasetID': ds.id,
'ts_names': ['217801'],
'modelID': m.id}
response = requests.post('{}/predictions'.format(cfg['server']['url']),
data=json.dumps(data)).json()
assert response['status'] == 'success'

n_secs = 0
while n_secs < 5:
pred_info = requests.get('{}/predictions/{}'.format(
cfg['server']['url'], response['data']['id'])).json()
if pred_info['status'] == 'success' and pred_info['data']['finished']:
assert isinstance(pred_info['data']['results']['217801']
['features']['total_time'],
float)
assert 'Mira' in pred_info['data']['results']['217801']['prediction']
break
n_secs += 1
time.sleep(1)
else:
raise Exception('test_predict_specific_ts_name timed out')
1 change: 0 additions & 1 deletion tools/watch_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def logs_from_config(supervisor_conf):
with nostdout():
from cesium_app.config import cfg

watched.append(cfg['paths']['err_log_path'])
watched.append('log/error.log')
watched.append('log/nginx-error.log')

Expand Down

0 comments on commit 14c3d7c

Please sign in to comment.