Skip to content

Commit

Permalink
add type casting to ATM hyperparameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Bennett Cyphers committed Jan 17, 2018
1 parent 4891291 commit a50a2e2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
16 changes: 16 additions & 0 deletions atm/method.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from builtins import object, str as newstr

import json
from os.path import join

Expand Down Expand Up @@ -33,6 +35,20 @@ class Categorical(HyperParameter):
def __init__(self, name, type, values):
self.name = name
self.type = type
for i, val in enumerate(values):
if val is None:
# the value None is allowed for every parameter type
continue
if self.type == 'int_cat':
values[i] = int(val)
elif self.type == 'float_cat':
values[i] = float(val)
elif self.type == 'string':
# this is necessary to avoid a bug in sklearn, which won't be
# fixed until 0.20
values[i] = str(newstr(val))
elif self.type == 'bool':
values[i] = bool(val)
self.values = values

@property
Expand Down
2 changes: 1 addition & 1 deletion methods/stochastic_gradient_descent.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"range": [0.0, 1.0]
},
"fit_intercept": {
"type": "int_cat",
"type": "int",
"range": [0, 1]
},
"n_iter": {
Expand Down
2 changes: 1 addition & 1 deletion methods/support_vector_machine.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
"class_weight": {
"type": "string",
"range": ["balanced"]
"values": ["balanced"]
},
"_scale": {
"type": "bool",
Expand Down

0 comments on commit a50a2e2

Please sign in to comment.