Skip to content

Commit

Permalink
fixes after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Feb 28, 2022
1 parent 9c28f3a commit 484ead4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 34 deletions.
13 changes: 4 additions & 9 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,8 @@ def search(
y_test=y_test,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
<<<<<<< HEAD
dataset_name=dataset_name,
dataset_compression=self._dataset_compression)
=======
dataset_name=dataset_name
)
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)

return self._search(
dataset=self.dataset,
Expand Down Expand Up @@ -465,23 +460,23 @@ def predict(
raise ValueError("predict() is only supported after calling search. Kindly call first "
"the estimator search() method.")

X_test = self.InputValidator.feature_validator.transform(X_test)
X_test = self.input_validator.feature_validator.transform(X_test)
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
n_jobs=n_jobs)

if self.InputValidator.target_validator.is_single_column_target():
if self.input_validator.target_validator.is_single_column_target():
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
else:
predicted_indexes = (predicted_probabilities > 0.5).astype(int)

# Allow to predict in the original domain -- that is, the user is not interested
# in our encoded values
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
return self.input_validator.target_validator.inverse_transform(predicted_indexes)

def predict_proba(self,
X_test: Union[np.ndarray, pd.DataFrame, List],
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
if self.InputValidator is None or not self.InputValidator._is_fitted:
if self.input_validator is None or not self.input_validator._is_fitted:
raise ValueError("predict() is only supported after calling search. Kindly call first "
"the estimator search() method.")
X_test = self.input_validator.feature_validator.transform(X_test)
Expand Down
11 changes: 3 additions & 8 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,8 @@ def search(
y_test=y_test,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
<<<<<<< HEAD
dataset_name=dataset_name,
dataset_compression=self._dataset_compression)
=======
dataset_name=dataset_name
)
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)

return self._search(
dataset=self.dataset,
Expand All @@ -452,14 +447,14 @@ def predict(
batch_size: Optional[int] = None,
n_jobs: int = 1
) -> np.ndarray:
if self.InputValidator is None or not self.InputValidator._is_fitted:
if self.input_validator is None or not self.input_validator._is_fitted:
raise ValueError("predict() is only supported after calling search. Kindly call first "
"the estimator search() method.")

X_test = self.InputValidator.feature_validator.transform(X_test)
X_test = self.input_validator.feature_validator.transform(X_test)
predicted_values = super().predict(X_test, batch_size=batch_size,
n_jobs=n_jobs)

# Allow to predict in the original domain -- that is, the user is not interested
# in our encoded values
return self.InputValidator.target_validator.inverse_transform(predicted_values)
return self.input_validator.target_validator.inverse_transform(predicted_values)
12 changes: 6 additions & 6 deletions autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def _fit(

def _check_data(
self,
X: SUPPORTED_FEAT_TYPES,
X: SupportedFeatTypes,
) -> None:
"""
Feature dimensionality and data type checks
Args:
X (SUPPORTED_FEAT_TYPES):
X (SupportedFeatTypes):
A set of features that are going to be validated (type and dimensionality
checks) and a encoder fitted in the case the data needs encoding
"""
Expand All @@ -144,19 +144,19 @@ def transform(

def list_to_pandas(
self,
X_train: SUPPORTED_FEAT_TYPES,
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
X_train: SupportedFeatTypes,
X_test: Optional[SupportedFeatTypes] = None,
) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
"""
Converts a list to a pandas DataFrame. In this process, column types are inferred.
If test data is provided, we proactively match it to train data
Args:
X_train (SUPPORTED_FEAT_TYPES):
X_train (SupportedFeatTypes):
A set of features that are going to be validated (type and dimensionality
checks) and a encoder fitted in the case the data needs encoding
X_test (Optional[SUPPORTED_FEAT_TYPES]):
X_test (Optional[SupportedFeatTypes]):
A hold out set of data used for checking
Returns:
pd.DataFrame:
Expand Down
11 changes: 1 addition & 10 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import functools
<<<<<<< HEAD
from logging import Logger
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast
=======
from typing import Dict, List, Optional, Tuple, Type, Union, cast
>>>>>>> [FIX] Tests after rebase of `reg_cocktails` (#359)
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union, cast

import numpy as np

Expand Down Expand Up @@ -283,13 +279,8 @@ def transform(
if isinstance(X, np.ndarray):
X = self.numpy_to_pandas(X)

<<<<<<< HEAD
if hasattr(X, "iloc") and not issparse(X):
X = cast(pd.DataFrame, X)
=======
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
X = cast(Type[pd.DataFrame], X)
>>>>>>> [FIX] Tests after rebase of `reg_cocktails` (#359)

if self.all_nan_columns is None:
raise ValueError('_fit must be called before calling transform')
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/data/tabular_target_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, cast
from typing import List, Optional, Union, cast

import numpy as np

Expand Down

0 comments on commit 484ead4

Please sign in to comment.