Skip to content

Commit f1d6d60

Browse files
committed
make release-tag: Merge branch 'master' into stable
2 parents c98f492 + bfce1e8 commit f1d6d60

File tree

11 files changed

+361
-40
lines changed

11 files changed

+361
-40
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# History
22

3+
## 0.2.2 (2019-07-30)
4+
5+
### New Features
6+
7+
* Curate dependencies - [Issue #152](https://github.com/HDI-Project/ATM/issues/152) by @csala
8+
* POST request blocked by CORS policy - [Issue #151](https://github.com/HDI-Project/ATM/issues/151) by @pvk-developer
9+
310
## 0.2.1 (2019-06-24)
411

512
### New Features

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<i>An open source project from Data to AI Lab at MIT.</i>
44
</p>
55

6+
# ATM - Auto Tune Models
67

78

89
[![CircleCI](https://circleci.com/gh/HDI-Project/ATM.svg?style=shield)](https://circleci.com/gh/HDI-Project/ATM)
@@ -12,7 +13,7 @@
1213
[![Downloads](https://pepy.tech/badge/atm)](https://pepy.tech/project/atm)
1314

1415

15-
# ATM - Auto Tune Models
16+
1617

1718
- License: MIT
1819
- Documentation: https://HDI-Project.github.io/ATM/
@@ -143,7 +144,7 @@ For this demo we will be using the pollution csv from the atm-data bucket, which
143144
[from here](https://atm-data.s3.amazonaws.com/pollution_1.csv), or using the following command:
144145

145146
```bash
146-
wget https://atm-data.s3.amazonaws.com/pollution_1.csv
147+
atm download_demo pollution_1.csv
147148
```
148149

149150
## 2. Create an ATM instance

atm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
__author__ = """MIT Data To AI Lab"""
1212
__email__ = '[email protected]'
13-
__version__ = '0.2.1'
13+
__version__ = '0.2.2-dev'
1414

1515
# this defines which modules will be imported by "from atm import *"
1616
__all__ = ['ATM', 'Model', 'config', 'constants', 'data', 'database',

atm/api/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def create_app(atm, debug=False):
1818
# Allow the CORS header
1919
@app.after_request
2020
def add_cors_headers(response):
21+
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
2122
response.headers['Access-Control-Allow-Origin'] = '*'
2223
response.headers['Access-Control-Allow-Credentials'] = 'true'
2324
return response
@@ -28,7 +29,9 @@ def atm_run():
2829
data = request.json
2930
run_conf = RunConfig(data)
3031

31-
dataruns = atm.create_dataruns(run_conf)
32+
dataruns = atm.add_datarun(**run_conf.to_dict())
33+
if not isinstance(dataruns, list):
34+
dataruns = [dataruns]
3235

3336
response = {
3437
'status': 200,

atm/cli.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from atm.api import create_app
1616
from atm.config import AWSConfig, DatasetConfig, LogConfig, RunConfig, SQLConfig
1717
from atm.core import ATM
18-
from atm.data import copy_files, get_demos
18+
from atm.data import copy_files, download_demo, get_demos
1919

2020
LOGGER = logging.getLogger(__name__)
2121

@@ -25,7 +25,13 @@ def _get_atm(args):
2525
aws_conf = AWSConfig(args)
2626
log_conf = LogConfig(args)
2727

28-
return ATM(**sql_conf.to_dict(), **aws_conf.to_dict(), **log_conf.to_dict())
28+
# Build params dictionary to pass to ATM.
29+
# Needed because Python 2.7 does not support multiple star operators in a single statement.
30+
atm_args = sql_conf.to_dict()
31+
atm_args.update(aws_conf.to_dict())
32+
atm_args.update(log_conf.to_dict())
33+
34+
return ATM(**atm_args)
2935

3036

3137
def _work(args, wait=False):
@@ -209,7 +215,19 @@ def _make_config(args):
209215

210216

211217
def _get_demos(args):
212-
get_demos()
218+
datasets = get_demos()
219+
for dataset in datasets:
220+
print(dataset)
221+
222+
223+
def _download_demo(args):
224+
paths = download_demo(args.dataset, args.path)
225+
if isinstance(paths, list):
226+
for path in paths:
227+
print('Dataset has been saved to {}'.format(path))
228+
229+
else:
230+
print('Dataset has been saved to {}'.format(paths))
213231

214232

215233
def _get_parser():
@@ -330,8 +348,13 @@ def _get_parser():
330348

331349
# Get Demos
332350
get_demos = subparsers.add_parser('get_demos', parents=[logging_args],
333-
help='Create a demos folder and put the demo CSVs inside.')
351+
help='Print a list with the available demo datasets.')
334352
get_demos.set_defaults(action=_get_demos)
353+
download_demo = subparsers.add_parser('download_demo', parents=[logging_args],
354+
help='Downloads a demo dataset from AWS3.')
355+
download_demo.set_defaults(action=_download_demo)
356+
download_demo.add_argument('dataset', nargs='+', help='Name of the dataset to be downloaded.')
357+
download_demo.add_argument('--path', help='Directory to be used to store the dataset.')
335358

336359
return parser
337360

atm/data.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import boto3
77
import pandas as pd
88
import requests
9+
from botocore import UNSIGNED
10+
from botocore.client import Config
911
from botocore.exceptions import ClientError
1012

1113
LOGGER = logging.getLogger('atm')
@@ -66,9 +68,40 @@ def copy_files(extension, source, target=None):
6668
return file_paths
6769

6870

69-
def get_demos():
70-
"""Copy the demo CSV files to the ``{cwd}/demos`` folder."""
71-
return copy_files('csv', 'demos')
71+
def download_demo(datasets, path=None):
72+
73+
if not isinstance(datasets, list):
74+
datasets = [datasets]
75+
76+
if path is None:
77+
path = os.path.join(os.getcwd(), 'demos')
78+
79+
if not os.path.exists(path):
80+
os.makedirs(path)
81+
82+
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
83+
84+
paths = list()
85+
86+
for dataset in datasets:
87+
save_path = os.path.join(path, dataset)
88+
89+
try:
90+
LOGGER.info('Downloading {}'.format(dataset))
91+
client.download_file('atm-data', dataset, save_path)
92+
paths.append(save_path)
93+
94+
except ClientError as e:
95+
LOGGER.error('An error occurred trying to download from AWS3.'
96+
'The following error has been returned: {}'.format(e))
97+
98+
return paths[0] if len(paths) == 1 else paths
99+
100+
101+
def get_demos(args=None):
102+
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
103+
available_datasets = [obj['Key'] for obj in client.list_objects(Bucket='atm-data')['Contents']]
104+
return available_datasets
72105

73106

74107
def _download_from_s3(path, local_path, aws_access_key=None, aws_secret_key=None, **kwargs):

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.2.1
2+
current_version = 0.2.2-dev
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+))?

setup.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,25 @@
1010
history = history_file.read()
1111

1212
install_requires = [
13-
'baytune==0.2.5',
14-
'boto3>=1.9.146',
15-
'future>=0.16.0',
16-
'joblib>=0.11',
17-
'pymysql>=0.9.3',
18-
'cryptography>=2.6.1',
19-
'numpy>=1.13.1',
20-
'pandas>=0.22.0',
21-
'psutil>=5.6.1',
22-
'python-daemon>=2.2.3',
23-
'pyyaml>=3.12',
24-
'requests>=2.18.4',
25-
'scikit-learn>=0.18.2',
26-
'scipy>=0.19.1',
27-
'sklearn-pandas>=1.5.0',
28-
'sqlalchemy>=1.1.14',
29-
'flask>=1.0.2',
30-
'flask-restless>=0.17.0',
31-
'flask-sqlalchemy>=2.3.2',
32-
'flask-restless-swagger-2>=0.0.3',
33-
'simplejson>=3.16.0',
34-
'tqdm>=4.31.1',
13+
'baytune>=0.2.5,<0.3',
14+
'boto3>=1.9.146,<2',
15+
'future>=0.16.0,<0.18',
16+
'pymysql>=0.9.3,<0.10',
17+
'numpy>=1.13.1,<1.17',
18+
'pandas>=0.22.0,<0.25',
19+
'psutil>=5.6.1,<6',
20+
'python-daemon>=2.2.3,<3',
21+
'requests>=2.18.4,<3',
22+
'scikit-learn>=0.18.2,<0.22',
23+
'scipy>=0.19.1,<1.4',
24+
'sqlalchemy>=1.1.14,<1.4',
25+
'flask>=1.0.2,<2',
26+
'flask-restless>=0.17.0,<0.18',
27+
'flask-sqlalchemy>=2.3.2,<2.5',
28+
'flask-restless-swagger-2==0.0.3',
29+
'simplejson>=3.16.0,<4',
30+
'tqdm>=4.31.1,<5',
31+
'docutils>=0.10,<0.15',
3532
]
3633

3734
setup_requires = [
@@ -113,6 +110,6 @@
113110
test_suite='tests',
114111
tests_require=tests_require,
115112
url='https://github.com/HDI-project/ATM',
116-
version='0.2.1',
113+
version='0.2.2-dev',
117114
zip_safe=False,
118115
)

tests/api/test___init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,38 +54,51 @@ def test_create_app_debug(atm):
5454
assert app.config['DEBUG']
5555

5656

57-
def test_home(client):
57+
def test_get_home(client):
5858
res = client.get('/', follow_redirects=False)
5959

6060
assert res.status == '302 FOUND'
6161
assert res.location == 'http://localhost/static/swagger/swagger-ui/index.html'
6262

6363

64-
def test_dataset(client):
64+
def test_get_dataset(client):
6565
res = client.get('api/datasets')
6666
data = json.loads(res.data.decode('utf-8'))
6767

6868
assert res.status == '200 OK'
6969
assert data.get('num_results') == 1
7070

7171

72-
def test_datarun(client):
72+
def test_options_dataset(client):
73+
res = client.options('api/datasets')
74+
75+
expected_headers = [
76+
('Content-Type', 'text/html; charset=utf-8'),
77+
('Access-Control-Allow-Headers', 'Content-Type, Authorization'),
78+
('Access-Control-Allow-Origin', '*'),
79+
('Access-Control-Allow-Credentials', 'true'),
80+
]
81+
82+
assert set(expected_headers).issubset(set(res.headers.to_list()))
83+
84+
85+
def test_get_datarun(client):
7386
res = client.get('api/dataruns')
7487
data = json.loads(res.data.decode('utf-8'))
7588

7689
assert res.status == '200 OK'
7790
assert data.get('num_results') == 2
7891

7992

80-
def test_hyperpartition(client):
93+
def test_get_hyperpartition(client):
8194
res = client.get('api/hyperpartitions')
8295
data = json.loads(res.data.decode('utf-8'))
8396

8497
assert res.status == '200 OK'
8598
assert data.get('num_results') == 40
8699

87100

88-
def test_classifier(client):
101+
def test_get_classifier(client):
89102
res = client.get('api/classifiers')
90103
data = json.loads(res.data.decode('utf-8'))
91104

tests/test_cli.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from mock import Mock, patch
2+
3+
from atm import cli
4+
5+
6+
@patch('atm.cli.get_demos')
7+
def test__get_demos(mock_get_demos):
8+
"""Test that the method get_demos is being called properly."""
9+
10+
# run
11+
cli._get_demos(None) # Args are not being used.
12+
13+
# assert
14+
mock_get_demos.assert_called_once_with()
15+
16+
17+
@patch('atm.cli.download_demo')
18+
def test__download_demo(mock_download_demo):
19+
"""Test that the method _download_demo is being called properly with a single dataset."""
20+
21+
# setup
22+
args_mock = Mock(dataset='test.csv', path=None)
23+
24+
# run
25+
cli._download_demo(args_mock)
26+
27+
# assert
28+
mock_download_demo.assert_called_once_with('test.csv', None)
29+
30+
31+
@patch('atm.cli.download_demo')
32+
def test__download_demo_array(mock_download_demo):
33+
"""Test that the method _download_demo is being called properly with a two datasets."""
34+
35+
# setup
36+
args_mock = Mock(dataset=['test.csv', 'test2.csv'], path=None)
37+
mock_download_demo.return_value = ['test.csv', 'test2.csv']
38+
39+
# run
40+
cli._download_demo(args_mock)
41+
42+
# assert
43+
mock_download_demo.assert_called_once_with(['test.csv', 'test2.csv'], None)
44+
45+
46+
@patch('atm.cli.download_demo')
47+
def test__download_demo_path(mock_download_demo):
48+
"""Test that the method _download_demo is being called properly with a given path."""
49+
50+
# setup
51+
args_mock = Mock(dataset=['test.csv', 'test2.csv'], path='my_test_path')
52+
mock_download_demo.return_value = ['test.csv', 'test2.csv']
53+
54+
# run
55+
cli._download_demo(args_mock)
56+
57+
# assert
58+
mock_download_demo.assert_called_once_with(['test.csv', 'test2.csv'], 'my_test_path')
59+
60+
61+
@patch('atm.cli._get_atm')
62+
def test__work(mock__get_atm):
63+
# setup
64+
args_mock = Mock(dataruns=[1], total_time=[1], save_files=False, cloud_mode=False)
65+
66+
# run
67+
cli._work(args_mock)
68+
69+
# assert
70+
mock__get_atm.assert_called_once_with(args_mock)
71+
72+
mock__get_atm.return_value.work.assert_called_once_with(
73+
datarun_ids=[1],
74+
choose_randomly=False,
75+
save_files=False,
76+
cloud_mode=False,
77+
total_time=[1],
78+
wait=False
79+
)
80+
81+
82+
@patch('atm.cli.create_app')
83+
@patch('atm.cli._get_atm')
84+
def test__serve(mock__get_atm, mock_create_app):
85+
# setup
86+
args_mock = Mock(debug=False, host='1.2.3', port='456')
87+
88+
# run
89+
cli._serve(args_mock)
90+
91+
# assert
92+
mock__get_atm.assert_called_once_with(args_mock)
93+
mock_create_app.assert_called_once_with(mock__get_atm.return_value, False)
94+
mock_create_app.return_value.run.assert_called_once_with(host='1.2.3', port='456')

0 commit comments

Comments
 (0)