Skip to content

Commit

Permalink
[FIX] Passing checks (#298)
Browse files Browse the repository at this point in the history
* Initial fix for all tests passing locally py=3.8

* fix bug in tests

* fix bug in test for data

* debugging error in dummy forward pass

* debug try -2

* catch runtime error in ci

* catch runtime error in ci

* add better debug test setup

* debug some more

* run this test only

* remove sum backward

* remove inplace in inception block

* undo silly change

* Enable all tests

* fix flake

* fix bug in test setup

* remove anamoly detection

* minor changes to comments

* Apply suggestions from code review

Co-authored-by: nabenabe0928 <[email protected]>

* Address comments from Shuhei

* revert change leading to bug

* fix flake

* change comment position in feature validator

* Add documentation for _is_datasets_consistent

* address comments from arlind

* case when all nans in test

Co-authored-by: nabenabe0928 <[email protected]>
  • Loading branch information
ravinkohli and nabenabe0928 committed Dec 8, 2021
1 parent a16cbbb commit 87ee242
Show file tree
Hide file tree
Showing 18 changed files with 119 additions and 120 deletions.
5 changes: 3 additions & 2 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def fit_ensemble(
Args:
optimize_metric (str): name of the metric that is used to
evaluate a pipeline. if not specified, value passed to search will be used
precision (int), (default=32): Numeric precision used when loading
precision (Optional[int]): Numeric precision used when loading
ensemble data. Can be either 16, 32 or 64.
ensemble_nbest (Optional[int]):
only consider the ensemble_nbest models to build the ensemble.
Expand Down Expand Up @@ -1402,6 +1402,7 @@ def fit_ensemble(
"Please call the `search()` method of {} prior to "
"fit_ensemble().".format(self.__class__.__name__))

precision = precision if precision is not None else self.precision
if precision not in [16, 32, 64]:
raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision))

Expand Down Expand Up @@ -1452,7 +1453,7 @@ def fit_ensemble(
manager = self._init_ensemble_builder(
time_left_for_ensembles=time_left_for_ensemble,
optimize_metric=self.opt_metric if optimize_metric is None else optimize_metric,
precision=self.precision if precision is None else precision,
precision=precision,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest,
)
Expand Down
51 changes: 36 additions & 15 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, Union, cast

import numpy as np

Expand Down Expand Up @@ -124,6 +124,7 @@ def _comparator(cmp1: str, cmp2: str) -> int:
if cmp1 not in choices or cmp2 not in choices:
raise ValueError('The comparator for the column order only accepts {}, '
'but got {} and {}'.format(choices, cmp1, cmp2))

idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
return idx1 - idx2

Expand Down Expand Up @@ -271,13 +272,12 @@ def transform(
# having a value for a categorical column.
# We need to convert the column in test data to
# object otherwise the test column is interpreted as float
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')

if self.column_transformer is not None:
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')
X = self.column_transformer.transform(X)

# Sparse related transformations
Expand Down Expand Up @@ -362,16 +362,10 @@ def _check_data(

dtypes = [dtype.name for dtype in X.dtypes]

dtypes_diff = [s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
if len(self.dtypes) == 0:
self.dtypes = dtypes
elif (
any(dtypes_diff) # the dtypes of some columns are different in train and test dataset
and self.all_nan_columns is not None # Ignore all_nan_columns is None
and len(set(X.columns[dtypes_diff]).difference(self.all_nan_columns)) != 0
):
# The dtypes can be different if and only if the column belongs
# to all_nan_columns as these columns would be imputed.
elif not self._is_datasets_consistent(diff_cols, X):
raise ValueError("The dtype of the features must not be changed after fit(), but"
" the dtypes of some columns are different between training ({}) and"
" test ({}) datasets.".format(self.dtypes, dtypes))
Expand Down Expand Up @@ -539,6 +533,33 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:

return X

def _is_datasets_consistent(self, diff_cols: List[Union[int, str]], X: pd.DataFrame) -> bool:
"""
Check the consistency of dtypes between training and test datasets.
The dtypes can be different if the column belongs to `self.all_nan_columns`
(list of column names with all nans in training data) or if the column is
all nan as these columns would be imputed.
Args:
diff_cols (List[bool]):
The column labels that have different dtypes.
X (pd.DataFrame):
A validation or test dataset to be compared with the training dataset
Returns:
_ (bool): Whether the training and test datasets are consistent.
"""
if self.all_nan_columns is None:
if len(diff_cols) == 0:
return True
else:
return all(X[diff_cols].isna().all())

# dtype is different ==> the column in at least either of train or test datasets must be all NaN
# inconsistent <==> dtype is different and the col in both train and test is not all NaN
inconsistent_cols = list(set(diff_cols) - self.all_nan_columns)

return len(inconsistent_cols) == 0 or all(X[inconsistent_cols].isna().all())


def has_object_columns(
feature_types: pd.Series,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
X.update({'encoder': self.preprocessor})
# X.update({'encoder': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
np.ndarray: Transformed features
"""
X.update({'scaler': self.preprocessor})
# X.update({'scaler': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.embedding = self.build_embedding(
num_input_features=num_input_features,
num_numerical_features=num_numerical_columns)
num_numerical_features=num_numerical_columns) # type: ignore[arg-type]
return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
loss.backward()
self.optimizer.step()
if self.scheduler:
if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__:
self.scheduler.step(loss)
else:
self.scheduler.step()

# only passing the original outputs since we do not care about
# the adversarial performance.
return loss.item(), original_outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
y=y,
**kwargs
)

# Add snapshots to base network to enable
# predicting with snapshot ensemble
self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_search_space_updates():
value_range=['shake-shake'],
default_value='shake-shake')
updates.append(node_name='network_backbone',
hyperparameter='ResNetBackbone:shake_shake_method',
hyperparameter='ResNetBackbone:shake_shake_update_func',
value_range=['M3'],
default_value='M3'
)
Expand Down
74 changes: 25 additions & 49 deletions test/test_data/test_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_featurevalidator_supported_types(input_data_featuretest):
assert sparse.issparse(transformed_X)
else:
assert isinstance(transformed_X, np.ndarray)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted

Expand Down Expand Up @@ -237,11 +236,10 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
validator.fit(input_data_featuretest)
transformed_X = validator.transform(input_data_featuretest)
assert any(pd.isna(input_data_featuretest))
categories_ = validator.column_transformer.named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
categories_ = validator.column_transformer.\
named_transformers_['categorical_pipeline'].named_steps['onehotencoder'].categories_
assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in
categories_)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted
assert isinstance(transformed_X, np.ndarray)
Expand Down Expand Up @@ -294,7 +292,6 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
else:
raise ValueError(type(input_data_featuretest))
transformed_X = validator.transform(complementary_type)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
assert validator._is_fitted

Expand All @@ -314,12 +311,6 @@ def test_featurevalidator_get_columns_to_encode():
for col in df.columns:
df[col] = df[col].astype(col)

<<<<<<< HEAD
transformed_columns, feature_types = validator._get_columns_to_encode(df)

assert transformed_columns == ['category', 'bool']
assert feature_types == ['numerical', 'numerical', 'categorical', 'categorical']
=======
validator.fit(df)

categorical_columns, numerical_columns, feat_type = validator._get_columns_info(df)
Expand Down Expand Up @@ -435,7 +426,6 @@ def test_feature_validator_remove_nan_catcolumns():
)
ans_test = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float64)
feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test)
>>>>>>> Bug fixes (#249)


def test_features_unsupported_calls_are_raised():
Expand All @@ -445,36 +435,29 @@ def test_features_unsupported_calls_are_raised():
expected
"""
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"):
with pytest.raises(TypeError, match=r".*?Convert the time information to a numerical value"):
validator.fit(
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
)
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
validator.fit({'input1': 1, 'input2': 2})
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
validator = TabularFeatureValidator()
with pytest.raises(TypeError, match=r".*?but input column A has an invalid type `string`.*"):
validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string'))
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"):
validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]),
X_test=np.array([[1, 2, 3, 4], [4, 5, 6, 7]]),
)
validator = TabularFeatureValidator()
with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"):
validator.transform(np.array([[1, 2, 3], [4, 5, 6]]))


@pytest.mark.parametrize(
'input_data_featuretest',
(
'numpy_numericalonly_nonan',
'numpy_numericalonly_nan',
'pandas_numericalonly_nonan',
'pandas_numericalonly_nan',
'list_numericalonly_nonan',
'list_numericalonly_nan',
# Category in numpy is handled via feat_type
'numpy_categoricalonly_nonan',
'numpy_mixed_nonan',
'numpy_categoricalonly_nan',
'numpy_mixed_nan',
'sparse_bsr_nonan',
'sparse_bsr_nan',
'sparse_coo_nonan',
Expand Down Expand Up @@ -512,7 +495,7 @@ def test_no_column_transformer_created(input_data_featuretest):
)
def test_column_transformer_created(input_data_featuretest):
"""
This test ensures an encoder is created if categorical data is provided
This test ensures an column transformer is created if categorical data is provided
"""
validator = TabularFeatureValidator()
validator.fit(input_data_featuretest)
Expand All @@ -521,7 +504,7 @@ def test_column_transformer_created(input_data_featuretest):

# Make sure that the encoded features are actually encoded. Categorical columns are at
# the start after transformation. In our fixtures, this is also honored prior encode
transformed_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest)
cat_columns, _, feature_types = validator._get_columns_info(input_data_featuretest)

# At least one categorical
assert 'categorical' in validator.feat_type
Expand All @@ -530,20 +513,13 @@ def test_column_transformer_created(input_data_featuretest):
if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col]
) for col in input_data_featuretest.columns]):
assert 'numerical' in validator.feat_type
for i, feat_type in enumerate(feature_types):
if 'numerical' in feat_type:
np.testing.assert_array_equal(
transformed_X[:, i],
input_data_featuretest[input_data_featuretest.columns[i]].to_numpy()
)
elif 'categorical' in feat_type:
np.testing.assert_array_equal(
transformed_X[:, i],
# Expect always 0, 1... because we use a ordinal encoder
np.array([0, 1])
)
else:
raise ValueError(feat_type)
# we expect this input to be the fixture 'pandas_mixed_nan'
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., -1.], [0., 1., 1.]]))
else:
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]]))

if not all([feat_type in ['numerical', 'categorical'] for feat_type in feature_types]):
raise ValueError("Expected only numerical and categorical feature types")


def test_no_new_category_after_fit():
Expand Down Expand Up @@ -575,13 +551,12 @@ def test_unknown_encode_value():
x['c'].cat.add_categories(['NA'], inplace=True)
x.loc[0, 'c'] = 'NA' # unknown value
x_t = validator.transform(x)
# The first row should have a -1 as we added a new categorical there
expected_row = [-1, -41, -3, -987.2]
# The first row should have a 0, 0 as we added a
# new categorical there and one hot encoder marks
# it as all zeros for the transformed column
expected_row = [0.0, 0.0, -0.5584294383572701, 0.5000000000000004, -1.5136598016833485]
assert expected_row == x_t[0].tolist()

# Notice how there is only one column 'c' to encode
assert validator.categories == [list(range(2)) for i in range(1)]


# Actual checks for the features
@pytest.mark.parametrize(
Expand Down Expand Up @@ -633,19 +608,20 @@ def test_feature_validator_new_data_after_fit(
assert sparse.issparse(transformed_X)
else:
assert isinstance(transformed_X, np.ndarray)
assert np.shape(X_test) == np.shape(transformed_X)

# And then check proper error messages
if train_data_type == 'pandas':
old_dtypes = copy.deepcopy(validator.dtypes)
validator.dtypes = ['dummy' for dtype in X_train.dtypes]
with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"):
with pytest.raises(ValueError,
match=r"The dtype of the features must not be changed after fit"):
transformed_X = validator.transform(X_test)
validator.dtypes = old_dtypes
if test_data_type == 'pandas':
columns = X_test.columns.tolist()
X_test = X_test[reversed(columns)]
with pytest.raises(ValueError, match=r"Changing the column order of the features"):
with pytest.raises(ValueError,
match=r"The column order of the features must not be changed after fit"):
transformed_X = validator.transform(X_test)


Expand Down
20 changes: 0 additions & 20 deletions test/test_data/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

import pandas as pd

import pytest

from scipy import sparse
Expand Down Expand Up @@ -32,14 +30,6 @@ def test_data_validation_for_classification(openmlid, as_frame):

validator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
X_train_t, y_train_t = validator.transform(X_train, y_train)
assert np.shape(X_train) == np.shape(X_train_t)

# Leave columns that are complete NaN
# The sklearn pipeline will handle that
if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)):
assert np.any(pd.isnull(X_train_t).values.all(axis=0))
elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)):
assert np.any(pd.isnull(X_train_t).all(axis=0))

# make sure everything was encoded to number
assert np.issubdtype(X_train_t.dtype, np.number)
Expand Down Expand Up @@ -74,14 +64,6 @@ def test_data_validation_for_regression(openmlid, as_frame):
validator.fit(X_train=X_train, y_train=y_train)

X_train_t, y_train_t = validator.transform(X_train, y_train)
assert np.shape(X_train) == np.shape(X_train_t)

# Leave columns that are complete NaN
# The sklearn pipeline will handle that
if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)):
assert np.any(pd.isnull(X_train_t).values.all(axis=0))
elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)):
assert np.any(pd.isnull(X_train_t).all(axis=0))

# make sure everything was encoded to number
assert np.issubdtype(X_train_t.dtype, np.number)
Expand All @@ -103,8 +85,6 @@ def test_sparse_data_validation_for_regression():
validator.fit(X_train=X_sp, y_train=y)

X_t, y_t = validator.transform(X, y)
assert np.shape(X) == np.shape(X_t)

# make sure everything was encoded to number
assert np.issubdtype(X_t.dtype, np.number)
assert np.issubdtype(y_t.dtype, np.number)
Expand Down
Loading

0 comments on commit 87ee242

Please sign in to comment.