Skip to content

Commit

Permalink
Merge pull request #107 from HDI-Project/GH_105_remove_flake8_ignores
Browse files Browse the repository at this point in the history
Gh 105 remove flake8 ignores
  • Loading branch information
micahjsmith authored Sep 7, 2018
2 parents 9dfe5d9 + d9dd7fe commit 799468f
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 106 deletions.
20 changes: 11 additions & 9 deletions atm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from builtins import map, object, str

import yaml

from .constants import *
from .utilities import ensure_directory
from atm.constants import (BUDGET_TYPES, CUSTOM_CLASS_REGEX, DATA_TEST_PATH,
JSON_REGEX, LOG_LEVELS, METHODS, METRICS,
SCORE_TARGETS, SELECTORS, SQL_DIALECTS, TIME_FMT,
TUNERS)
from atm.utilities import ensure_directory


class Config(object):
Expand Down Expand Up @@ -344,8 +346,8 @@ def add_arguments_datarun(parser):
# Config file
parser.add_argument('--run-config', help='path to yaml datarun config file')

## Dataset Arguments #####################################################
############################################################################
# Dataset Arguments #####################################################
# ##########################################################################
parser.add_argument('--dataset-id', type=int,
help="ID of dataset, if it's already in the database")

Expand All @@ -355,8 +357,8 @@ def add_arguments_datarun(parser):
parser.add_argument('--data-description', help='Description of dataset')
parser.add_argument('--class-column', help='Name of the class column in the input data')

## Datarun Arguments #####################################################
############################################################################
# Datarun Arguments #####################################################
# ##########################################################################
# Notes:
# - Support vector machines (svm) can take a long time to train. It's not an
# error, it's just part of what happens when the method happens to explore
Expand Down Expand Up @@ -428,8 +430,8 @@ def add_arguments_datarun(parser):
'performance on a test dataset, and "mu_sigma" will use '
'the lower confidence bound on the CV performance.')

## AutoML Arguments ######################################################
############################################################################
# AutoML Arguments ######################################################
# ##########################################################################
# hyperparameter selection strategy
# How should ATM sample hyperparameters from a given hyperpartition?
# uniform - pick randomly! (baseline)
Expand Down
4 changes: 2 additions & 2 deletions atm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import os
from builtins import object

from . import PROJECT_ROOT

from btb.selection import (UCB1, BestKReward, BestKVelocity,
HierarchicalByAlgorithm, PureBestKVelocity,
RecentKReward, RecentKVelocity)
from btb.selection import Uniform as UniformSelector
from btb.tuning import GP, GPEi, GPEiVelocity
from btb.tuning import Uniform as UniformTuner

from atm import PROJECT_ROOT

# A bunch of constants which are used throughout the project, mostly for config.
# TODO: convert these lists and classes to something more elegant, like enums
SQL_DIALECTS = ['sqlite', 'mysql']
Expand Down
31 changes: 17 additions & 14 deletions atm/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, unicode_literals

import json
import os
import pickle
from builtins import object
from datetime import datetime
Expand All @@ -15,8 +16,10 @@
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.properties import ColumnProperty

from .constants import *
from .utilities import *
from atm.constants import (BUDGET_TYPES, CLASSIFIER_STATUS, DATARUN_STATUS,
METRICS, PARTITION_STATUS, SCORE_TARGETS,
ClassifierStatus, PartitionStatus, RunStatus)
from atm.utilities import base_64_to_object, object_to_base_64

# The maximum number of errors allowed in a single hyperpartition. If more than
# this many classifiers using a hyperpartition error, the hyperpartition will be
Expand Down Expand Up @@ -296,9 +299,9 @@ def __repr__(self):

Base.metadata.create_all(bind=self.engine)

###########################################################################
## Save/load the database ###############################################
###########################################################################
# ##########################################################################
# # Save/load the database ###############################################
# ##########################################################################

@try_with_session()
def to_csv(self, path):
Expand Down Expand Up @@ -343,9 +346,9 @@ def from_csv(self, path):
create_func = getattr(self, 'create_%s' % table)
create_func(**r)

###########################################################################
## Standard query methods ###############################################
###########################################################################
# ##########################################################################
# # Standard query methods ###############################################
# ##########################################################################

@try_with_session()
def get_dataset(self, dataset_id):
Expand Down Expand Up @@ -453,9 +456,9 @@ def get_classifiers(self, dataset_id=None, datarun_id=None, method=None,

return query.all()

###########################################################################
## Special-purpose queries ##############################################
###########################################################################
# ##########################################################################
# # Special-purpose queries ##############################################
# ##########################################################################

@try_with_session()
def is_datatun_gridding_done(self, datarun_id):
Expand Down Expand Up @@ -539,9 +542,9 @@ def load_metrics(self, classifier_id):
with open(clf.metrics_location, 'r') as f:
return json.load(f)

###########################################################################
## Methods to update the database #######################################
###########################################################################
# ##########################################################################
# # Methods to update the database #######################################
# ##########################################################################

@try_with_session(commit=True)
def create_dataset(self, **kwargs):
Expand Down
13 changes: 6 additions & 7 deletions atm/enter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

from past.utils import old_div

from .config import *
from .constants import *
from .database import Database
from .encoder import MetaData
from .method import Method
from .utilities import download_data
from atm.constants import TIME_FMT, PartitionStatus
from atm.database import Database
from atm.encoder import MetaData
from atm.method import Method
from atm.utilities import download_data

# load the library-wide logger
logger = logging.getLogger('atm')
Expand Down Expand Up @@ -72,7 +71,7 @@ def create_datarun(db, dataset, run_config):
# TODO: why not walltime and classifiers budget simultaneously?
run_config.budget_type = 'walltime'
elif run_config.budget_type == 'walltime':
deadline = datetime.now() + timedelta(minutes=budget)
deadline = datetime.now() + timedelta(minutes=run_config.budget)

target = run_config.score_target + '_judgment_metric'
datarun = db.create_datarun(dataset_id=dataset.id,
Expand Down
7 changes: 4 additions & 3 deletions atm/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from builtins import str as newstr
from os.path import join

from .constants import METHOD_PATH, METHODS_MAP

import btb

from atm.constants import METHOD_PATH, METHODS_MAP


class HyperParameter(object):
@property
Expand Down Expand Up @@ -113,7 +113,8 @@ def __repr__(self):
cons = '[%s]' % ', '.join(['%s=%s' % c for c in self.constants])
if self.tunables:
tuns = '[%s]' % ', '.join(['%s' % t for t, _ in self.tunables])
return '<HyperPartition: categoricals: %s; constants: %s; tunables: %s>' % (cats, cons, tuns)
return ('<HyperPartition: categoricals: %s; constants: %s; tunables: %s>'
% (cats, cons, tuns))


HYPERPARAMETER_TYPES = {
Expand Down
3 changes: 2 additions & 1 deletion atm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
precision_recall_curve, roc_auc_score, roc_curve)
from sklearn.model_selection import StratifiedKFold

from .constants import *
from atm.constants import (METRICS_BINARY, METRICS_MULTICLASS, N_FOLDS_DEFAULT,
Metrics)


def rank_n_accuracy(y_true, y_prob_mat, n=0.33):
Expand Down
10 changes: 5 additions & 5 deletions atm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from .constants import *
from .encoder import DataEncoder, MetaData
from .method import Method
from .metrics import cross_validate_pipeline, test_pipeline
from atm.constants import Metrics
from atm.encoder import DataEncoder, MetaData
from atm.method import Method
from atm.metrics import cross_validate_pipeline, test_pipeline

# load the library-wide logger
logger = logging.getLogger('atm')
Expand Down Expand Up @@ -268,7 +268,7 @@ def special_conversions(self, params):
# sort the list by index
params[lname] = [val for idx, val in sorted(items)]

## Gaussian process classifier
# Gaussian process classifier
if self.method == "gp":
if params["kernel"] == "constant":
params["kernel"] = ConstantKernel()
Expand Down
7 changes: 3 additions & 4 deletions atm/tests/unit_tests/test_enter_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
import os

import pytest

from atm import PROJECT_ROOT, constants
from atm import PROJECT_ROOT
from atm.config import RunConfig, SQLConfig
from atm.database import Database, db_session
from atm.enter_data import create_datarun, create_dataset, enter_data
from atm.enter_data import create_dataset, enter_data
from atm.utilities import get_local_data_path

DB_PATH = '/tmp/atm.db'
Expand Down Expand Up @@ -126,4 +125,4 @@ def test_run_per_partition(dataset):
runs.append(run)

assert len(runs) == METHOD_HYPERPARTS['logreg']
assert all([len(run.hyperpartitions) == 1 for run in runs])
assert all([len(r.hyperpartitions) == 1 for r in runs])
7 changes: 2 additions & 5 deletions atm/tests/unit_tests/test_method.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#!/usr/bin/python2.7
import json

import pytest

from atm.method import Method


Expand Down Expand Up @@ -34,5 +31,5 @@ def test_enumerate():
assert len(hps) == 12
assert all('a' in list(zip(*hp.categoricals))[0] for hp in hps)
assert all(('f', 0.5) in hp.constants for hp in hps)
assert len([hp for hp in hps if hp.tunables
and 'b' in list(zip(*hp.tunables))[0]]) == 1
assert len([hp for hp in hps if hp.tunables and
'b' in list(zip(*hp.tunables))[0]]) == 1
49 changes: 49 additions & 0 deletions atm/tests/unit_tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import socket

import pytest
from mock import patch

from atm import utilities


@patch('atm.utilities.requests')
def test_public_ip_existing(requests_mock):
# Set-up
utilities.public_ip = '1.2.3.4'

# run
ip = utilities.get_public_ip()

# asserts
assert ip == utilities.public_ip
requests_mock.get.assert_not_called()


def test_public_ip_success():
# Set-up
utilities.public_ip = None

# run
ip = utilities.get_public_ip()

# asserts
assert ip == utilities.public_ip
try:
socket.inet_aton(ip)
except socket.error:
pytest.fail("Invalid IP address")


@patch('atm.utilities.requests')
def test_public_ip_fail(requests_mock):
# Set-up
utilities.public_ip = None
requests_mock.get.side_effect = Exception # Force fail

# run
ip = utilities.get_public_ip()

# asserts
assert ip == utilities.public_ip
assert ip == 'localhost'
requests_mock.get.assert_called_once_with(utilities.PUBLIC_IP_URL)
12 changes: 5 additions & 7 deletions atm/tests/unit_tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
import os
import random

import mock
import numpy as np
import pytest
from btb.selection import BestKVelocity, Selector
from btb.tuning import GP, Tuner
from mock import ANY, Mock, patch

import atm
from atm import PROJECT_ROOT
from atm.config import LogConfig, RunConfig, SQLConfig
from atm.constants import METRICS_BINARY, TIME_FMT
from atm.database import ClassifierStatus, Database, db_session
from atm.database import Database, db_session
from atm.enter_data import enter_data
from atm.model import Model
from atm.utilities import download_data, load_metrics, load_model
from atm.worker import ClassifierError, Worker

from btb.selection import BestKReward, BestKVelocity, Selector
from btb.tuning import GP, GPEi, Tuner

DB_CACHE_PATH = os.path.join(PROJECT_ROOT, 'data/modelhub/test/')
DB_PATH = '/tmp/atm.db'
METRIC_DIR = '/tmp/metrics/'
Expand All @@ -31,6 +28,7 @@
DT_PARAMS = {'criterion': 'gini', 'max_features': 0.5, 'max_depth': 3,
'min_samples_split': 2, 'min_samples_leaf': 1}


# helper class to allow fuzzy arg matching
class StringWith(object):
def __init__(self, match):
Expand Down Expand Up @@ -157,7 +155,7 @@ def test_tune_hyperparameters(worker, hyperpartition):
worker.Tuner = Mock(return_value=mock_tuner)

with patch('atm.worker.vector_to_params') as vtp_mock:
params = worker.tune_hyperparameters(hyperpartition)
worker.tune_hyperparameters(hyperpartition)
vtp_mock.assert_called()

approximate_tunables = [(k, ObjWithAttrs(range=v.range))
Expand Down
Loading

0 comments on commit 799468f

Please sign in to comment.