diff --git a/cesium_app/handlers/prediction.py b/cesium_app/handlers/prediction.py index 3220ae5..a36b3af 100644 --- a/cesium_app/handlers/prediction.py +++ b/cesium_app/handlers/prediction.py @@ -25,10 +25,10 @@ def _get_prediction(self, prediction_id): try: d = Prediction.get(Prediction.id == prediction_id) except Prediction.DoesNotExist: - raise AccessError('No such dataset') + raise AccessError('No such prediction') if not d.is_owned_by(self.get_username()): - raise AccessError('No such dataset') + raise AccessError('No such prediction') return d @@ -67,6 +67,9 @@ def post(self): dataset_id = data['datasetID'] model_id = data['modelID'] + # If only a subset of specified dataset is to be used, a list of the + # corresponding time series file names can be provided + ts_names = data.get('ts_names') dataset = Dataset.get(Dataset.id == data["datasetID"]) model = Model.get(Model.id == data["modelID"]) @@ -88,7 +91,15 @@ def post(self): executor = yield self._get_executor() - all_time_series = executor.map(time_series.load, dataset.uris) + # If only a subset of the dataset is to be used, get specified files + if ts_names: + ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name) + in ts_names or os.path.basename(f.name).split('.npz')[0] + in ts_names] + else: + ts_uris = dataset.uris + + all_time_series = executor.map(time_series.load, ts_uris) all_labels = executor.map(lambda ts: ts.label, all_time_series) all_features = executor.map(featurize.featurize_single_ts, all_time_series, diff --git a/cesium_app/tests/frontend/test_predict.py b/cesium_app/tests/frontend/test_predict.py index cddfc0a..05f3f3e 100644 --- a/cesium_app/tests/frontend/test_predict.py +++ b/cesium_app/tests/frontend/test_predict.py @@ -7,6 +7,9 @@ from os.path import join as pjoin import numpy as np import numpy.testing as npt +from cesium_app.config import cfg +import json +import requests from cesium_app.tests.fixtures import (create_test_project, create_test_dataset, create_test_featureset, create_test_model, create_test_prediction) @@ -204,3 +207,31 @@ def test_download_prediction_csv_regr(driver): [4, 3.1, 3.1]]) finally: os.remove('/tmp/cesium_prediction_results.csv') + + +def test_predict_specific_ts_name(): + with create_test_project() as p, create_test_dataset(p) as ds,\ + create_test_featureset(p) as fs, create_test_model(fs) as m: + ts_data = [[1, 2, 3, 4], [32.2, 53.3, 32.3, 32.52], [0.2, 0.3, 0.6, 0.3]] + impute_kwargs = {'strategy': 'constant', 'value': None} + data = {'datasetID': ds.id, + 'ts_names': ['217801'], + 'modelID': m.id} + response = requests.post('{}/predictions'.format(cfg['server']['url']), + data=json.dumps(data)).json() + assert response['status'] == 'success' + + n_secs = 0 + while n_secs < 5: + pred_info = requests.get('{}/predictions/{}'.format( + cfg['server']['url'], response['data']['id'])).json() + if pred_info['status'] == 'success' and pred_info['data']['finished']: + assert isinstance(pred_info['data']['results']['217801'] + ['features']['total_time'], + float) + assert 'Mira' in pred_info['data']['results']['217801']['prediction'] + break + n_secs += 1 + time.sleep(1) + else: + raise Exception('test_predict_specific_ts_name timed out') diff --git a/tools/watch_logs.py b/tools/watch_logs.py index 8689a4a..8c38d82 100755 --- a/tools/watch_logs.py +++ b/tools/watch_logs.py @@ -96,7 +96,6 @@ def logs_from_config(supervisor_conf): with nostdout(): from cesium_app.config import cfg -watched.append(cfg['paths']['err_log_path']) watched.append('log/error.log') watched.append('log/nginx-error.log')