diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index cc1e249cd..4433144d9 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -203,6 +203,9 @@ def __init__( self._logger: Optional[PicklableClientLogger] = None self.dataset_name: Optional[str] = None self.cv_models_: Dict = {} + self.precision: Optional[int] = None + self.opt_metric: Optional[str] = None + self.dataset: Optional[BaseDataset] = None self._results_manager = ResultsManager() @@ -420,6 +423,7 @@ def _clean_logger(self) -> None: self.logging_server.join(timeout=5) self.logging_server.terminate() del self.stop_logging_server + self._logger = None def _create_dask_client(self) -> None: """ @@ -435,7 +439,7 @@ def _create_dask_client(self) -> None: dask.distributed.LocalCluster( n_workers=self.n_jobs, processes=True, - threads_per_worker=1, + threads_per_worker=self.n_threads, # We use the temporal directory to save the # dask workers, because deleting workers # more time than deleting backend directories @@ -516,6 +520,23 @@ def _load_models(self) -> bool: return True + def _cleanup(self) -> None: + """ + Closes the different servers created during api search. + Returns: + None + """ + if hasattr(self, '_logger') and self._logger is not None: + self._logger.info("Closing the dask infrastructure") + self._close_dask_client() + self._logger.info("Finished closing the dask infrastructure") + + # Clean up the logger + self._logger.info("Starting to clean up the logger") + self._clean_logger() + else: + self._close_dask_client() + def _load_best_individual_model(self) -> SingleBest: """ In case of failure during ensemble building, @@ -754,6 +775,38 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: save_external=True) return + def run_traditional_ml( + self, + current_task_name: str, + runtime_limit: int, + func_eval_time_limit_secs: int + ) -> None: + """ + This function can be used to run the suite of traditional machine + learning models during the current task (for e.g, ensemble fit, search) + + Args: + current_task_name (str): name of the current task, + runtime_limit (int): time limit for fitting traditional models, + func_eval_time_limit_secs (int): Time limit + for a single call to the machine learning model. + Model fitting will be terminated if the machine + learning algorithm runs over the time limit. + """ + assert self._logger is not None # for mypy compliancy + if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS: + self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...") + else: + traditional_task_name = 'runTraditional' + self._stopwatch.start_task(traditional_task_name) + elapsed_time = self._stopwatch.wall_elapsed(current_task_name) + time_for_traditional = int(runtime_limit - elapsed_time) + self._do_traditional_prediction( + func_eval_time_limit_secs=func_eval_time_limit_secs, + time_left=time_for_traditional, + ) + self._stopwatch.stop_task(traditional_task_name) + def _search( self, optimize_metric: str, @@ -905,8 +958,10 @@ def _search( """ if self.task_type != dataset.task_type: raise ValueError("Incompatible dataset entered for current task," - "expected dataset to have task type :{} got " + "expected dataset to have task type :{} but got " ":{}".format(self.task_type, dataset.task_type)) + if precision not in [16, 32, 64]: + raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision)) # Initialise information needed for the experiment experiment_task_name: str = 'runSearch' @@ -1001,28 +1056,25 @@ def _search( ) # ============> Run dummy predictions - dummy_task_name = 'runDummy' - self._stopwatch.start_task(dummy_task_name) - self._do_dummy_prediction() - self._stopwatch.stop_task(dummy_task_name) + # We only want to run dummy predictions in case we want to build an ensemble + if self.ensemble_size > 0: + dummy_task_name = 'runDummy' + self._stopwatch.start_task(dummy_task_name) + self._do_dummy_prediction() + self._stopwatch.stop_task(dummy_task_name) # ============> Run traditional ml - - if enable_traditional_pipeline: - traditional_task_name = 'runTraditional' - self._stopwatch.start_task(traditional_task_name) - elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) - # We want time for at least 1 Neural network in SMAC - time_for_traditional = int( - self._time_for_task - elapsed_time - func_eval_time_limit_secs - ) - self._do_traditional_prediction( - func_eval_time_limit_secs=func_eval_time_limit_secs, - time_left=time_for_traditional, - ) - self._stopwatch.stop_task(traditional_task_name) + # We only want to run traditional predictions in case we want to build an ensemble + # We want time for at least 1 Neural network in SMAC + if enable_traditional_pipeline and self.ensemble_size > 0: + traditional_runtime_limit = int(self._time_for_task - func_eval_time_limit_secs) + self.run_traditional_ml(current_task_name=self.dataset_name, + runtime_limit=traditional_runtime_limit, + func_eval_time_limit_secs=func_eval_time_limit_secs) # ============> Starting ensemble + self.precision = precision + self.opt_metric = optimize_metric elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time) proc_ensemble = None @@ -1039,27 +1091,12 @@ def _search( self._logger.info("Starting ensemble") ensemble_task_name = 'ensemble' self._stopwatch.start_task(ensemble_task_name) - proc_ensemble = EnsembleBuilderManager( - start_time=time.time(), - time_left_for_ensembles=time_left_for_ensembles, - backend=copy.deepcopy(self._backend), - dataset_name=str(dataset.dataset_name), - output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type], - task_type=STRING_TO_TASK_TYPES[self.task_type], - metrics=[self._metric], - opt_metric=optimize_metric, - ensemble_size=self.ensemble_size, - ensemble_nbest=self.ensemble_nbest, - max_models_on_disc=self.max_models_on_disc, - seed=self.seed, - max_iterations=None, - read_at_most=sys.maxsize, - ensemble_memory_limit=self._memory_limit, - random_state=self.seed, - precision=precision, - logger_port=self._logger_port, - pynisher_context=self._multiprocessing_context, - ) + proc_ensemble = self._init_ensemble_builder(time_left_for_ensembles=time_left_for_ensembles, + ensemble_size=self.ensemble_size, + ensemble_nbest=self.ensemble_nbest, + precision=precision, + optimize_metric=self.opt_metric + ) self._stopwatch.stop_task(ensemble_task_name) # ==> Run SMAC @@ -1145,18 +1182,12 @@ def _search( pd.DataFrame(self.ensemble_performance_history).to_json( os.path.join(self._backend.internals_directory, 'ensemble_history.json')) - self._logger.info("Closing the dask infrastructure") - self._close_dask_client() - self._logger.info("Finished closing the dask infrastructure") - if load_models: self._logger.info("Loading models...") self._load_models() self._logger.info("Finished loading models...") - # Clean up the logger - self._logger.info("Starting to clean up the logger") - self._clean_logger() + self._cleanup() return self @@ -1290,7 +1321,7 @@ def fit(self, exclude=self.exclude_components, search_space_updates=self.search_space_updates) dataset_properties = dataset.get_dataset_properties(dataset_requirements) - self._backend.save_datamanager(dataset) + self._backend.replace_datamanager(dataset) # build pipeline pipeline = self.build_pipeline(dataset_properties) @@ -1309,6 +1340,207 @@ def fit(self, return pipeline + def fit_ensemble( + self, + optimize_metric: Optional[str] = None, + precision: Optional[int] = None, + ensemble_nbest: int = 50, + ensemble_size: int = 50, + load_models: bool = True, + time_for_task: int = 100, + func_eval_time_limit_secs: int = 50, + enable_traditional_pipeline: bool = True, + ) -> 'BaseTask': + """ + Enables post-hoc fitting of the ensemble after the `search()` + method is finished. This method creates an ensemble using all + the models stored on disk during the smbo run. + + Args: + optimize_metric (str): name of the metric that is used to + evaluate a pipeline. if not specified, value passed to search will be used + precision (int), (default=32): Numeric precision used when loading + ensemble data. Can be either 16, 32 or 64. + ensemble_nbest (Optional[int]): + only consider the ensemble_nbest models to build the ensemble. + If None, uses the value stored in class attribute `ensemble_nbest`. + ensemble_size (int) (default=50): + Number of models added to the ensemble built by + Ensemble selection from libraries of models. + Models are drawn with replacement. + enable_traditional_pipeline (bool), (default=True): + We fit traditional machine learning algorithms + (LightGBM, CatBoost, RandomForest, ExtraTrees, KNN, SVM) + prior building PyTorch Neural Networks. You can disable this + feature by turning this flag to False. All machine learning + algorithms that are fitted during search() are considered for + ensemble building. + load_models (bool), (default=True): Whether to load the + models after fitting AutoPyTorch. + time_for_task (int), (default=100): Time limit + in seconds for the search of appropriate models. + By increasing this value, autopytorch has a higher + chance of finding better models. + func_eval_time_limit_secs (int), (default=None): Time limit + for a single call to the machine learning model. + Model fitting will be terminated if the machine + learning algorithm runs over the time limit. Set + this value high enough so that typical machine + learning algorithms can be fit on the training + data. + When set to None, this time will automatically be set to + total_walltime_limit // 2 to allow enough time to fit + at least 2 individual machine learning algorithms. + Set to np.inf in case no time limit is desired. + + Returns: + self + """ + # Make sure that input is valid + if self.dataset is None or self.opt_metric is None: + raise ValueError("fit_ensemble() can only be called after `search()`. " + "Please call the `search()` method of {} prior to " + "fit_ensemble().".format(self.__class__.__name__)) + + if precision not in [16, 32, 64]: + raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision)) + + if self._logger is None: + self._logger = self._get_logger(self.dataset.dataset_name) + + # Create a client if needed + if self._dask_client is None: + self._create_dask_client() + else: + self._is_dask_client_internally_created = False + + ensemble_fit_task_name = 'EnsembleFit' + self._stopwatch.start_task(ensemble_fit_task_name) + if enable_traditional_pipeline: + if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_for_task: + self._logger.warning( + 'Time limit for a single run is higher than total time ' + 'limit. Capping the limit for a single run to the total ' + 'time given to Ensemble fit (%f)' % time_for_task + ) + func_eval_time_limit_secs = time_for_task + + # Make sure that at least 2 models are created for the ensemble process + num_models = time_for_task // func_eval_time_limit_secs + if num_models < 2: + func_eval_time_limit_secs = time_for_task // 2 + self._logger.warning( + "Capping the func_eval_time_limit_secs to {} to have " + "time for at least 2 models to ensemble.".format( + func_eval_time_limit_secs + ) + ) + # ============> Run Dummy predictions + dummy_task_name = 'runDummy' + self._stopwatch.start_task(dummy_task_name) + self._do_dummy_prediction() + self._stopwatch.stop_task(dummy_task_name) + + # ============> Run traditional ml + if enable_traditional_pipeline: + self.run_traditional_ml(current_task_name=ensemble_fit_task_name, + runtime_limit=time_for_task, + func_eval_time_limit_secs=func_eval_time_limit_secs) + + elapsed_time = self._stopwatch.wall_elapsed(ensemble_fit_task_name) + time_left_for_ensemble = int(time_for_task - elapsed_time) + manager = self._init_ensemble_builder( + time_left_for_ensembles=time_left_for_ensemble, + optimize_metric=self.opt_metric if optimize_metric is None else optimize_metric, + precision=self.precision if precision is None else precision, + ensemble_size=ensemble_size, + ensemble_nbest=ensemble_nbest, + ) + + manager.build_ensemble(self._dask_client) + future = manager.futures.pop() + result = future.result() + if result is None: + raise ValueError("Errors occurred while building the ensemble - please" + " check the log file and command line output for error messages.") + self.ensemble_performance_history, _, _, _ = result + + if load_models: + self._load_models() + + self._stopwatch.stop_task(ensemble_fit_task_name) + + self._cleanup() + + return self + + def _init_ensemble_builder( + self, + time_left_for_ensembles: float, + optimize_metric: str, + ensemble_nbest: int, + ensemble_size: int, + precision: int = 32, + ) -> EnsembleBuilderManager: + """ + Initializes an `EnsembleBuilderManager`. + Args: + time_left_for_ensembles (float): + Time (in seconds) allocated to building the ensemble + optimize_metric (str): + Name of the metric to optimize the ensemble. + ensemble_nbest (int): + only consider the ensemble_nbest models to build the ensemble. + ensemble_size (int): + Number of models added to the ensemble built by + Ensemble selection from libraries of models. + Models are drawn with replacement. + precision (int), (default=32): Numeric precision used when loading + ensemble data. Can be either 16, 32 or 64. + + Returns: + EnsembleBuilderManager + """ + if self._logger is None: + raise ValueError("logger should be initialized to fit ensemble") + if self.dataset is None: + raise ValueError("ensemble can only be initialised after or during `search()`. " + "Please call the `search()` method of {}.".format(self.__class__.__name__)) + + self._logger.info("Starting ensemble") + ensemble_task_name = 'ensemble' + self._stopwatch.start_task(ensemble_task_name) + + # Use the current thread to start the ensemble builder process + # The function ensemble_builder_process will internally create a ensemble + # builder in the provide dask client + required_dataset_properties = {'task_type': self.task_type, + 'output_type': self.dataset.output_type} + proc_ensemble = EnsembleBuilderManager( + start_time=time.time(), + time_left_for_ensembles=time_left_for_ensembles, + backend=copy.deepcopy(self._backend), + dataset_name=str(self.dataset.dataset_name), + output_type=STRING_TO_OUTPUT_TYPES[self.dataset.output_type], + task_type=STRING_TO_TASK_TYPES[self.task_type], + metrics=[self._metric] if self._metric is not None else get_metrics( + dataset_properties=required_dataset_properties, names=[optimize_metric]), + opt_metric=optimize_metric, + ensemble_size=ensemble_size, + ensemble_nbest=ensemble_nbest, + max_models_on_disc=self.max_models_on_disc, + seed=self.seed, + max_iterations=None, + read_at_most=sys.maxsize, + ensemble_memory_limit=self._memory_limit, + random_state=self.seed, + precision=precision, + logger_port=self._logger_port, + ) + self._stopwatch.stop_task(ensemble_task_name) + + return proc_ensemble + def predict( self, X_test: np.ndarray, @@ -1359,7 +1591,7 @@ def predict( predictions = self.ensemble_.predict(all_predictions) - self._clean_logger() + self._cleanup() return predictions @@ -1399,10 +1631,7 @@ def __getstate__(self) -> Dict[str, Any]: return self.__dict__ def __del__(self) -> None: - # Clean up the logger - self._clean_logger() - - self._close_dask_client() + self._cleanup() # When a multiprocessing work is done, the # objects are deleted. We don't want to delete run areas diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 635786d83..de36cdb11 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -304,6 +304,8 @@ def search( ) + if self.dataset is None: + raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__)) return self._search( dataset=self.dataset, optimize_metric=optimize_metric, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 8187e7e6a..a82e4b64d 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -66,7 +66,6 @@ class TabularRegressionTask(BaseTask): search space updates that can be used to modify the search space of particular components or choice modules of the pipeline """ - def __init__( self, seed: int = 1, @@ -303,6 +302,8 @@ def search( ) + if self.dataset is None: + raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__)) return self._search( dataset=self.dataset, optimize_metric=optimize_metric, diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 2184a9da5..5726e527b 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional, Union +from typing import List, Optional, Set, Tuple, Union import numpy as np @@ -35,16 +35,14 @@ class BaseFeatureValidator(BaseEstimator): List of the column types found by this estimator during fit. data_type (str): Class name of the data type provided during fit. - column_transformer (Optional[BaseEstimator]) + encoder (Optional[BaseEstimator]) Host a encoder object if the data requires transformation (for example, - if provided a categorical column in a pandas DataFrame) - transformed_columns (List[str]) - List of columns that were encoded. + if provided a categorical column in a pandas DataFrame). """ def __init__( self, logger: Optional[Union[PicklableClientLogger, logging.Logger]] = None, - ): + ) -> None: # Register types to detect unsupported data format changes self.feat_type: Optional[List[str]] = None self.data_type: Optional[type] = None @@ -52,7 +50,6 @@ def __init__( self.column_order: List[str] = [] self.column_transformer: Optional[BaseEstimator] = None - self.transformed_columns: List[str] = [] self.logger: Union[ PicklableClientLogger, logging.Logger @@ -64,6 +61,8 @@ def __init__( self.categorical_columns: List[int] = [] self.numerical_columns: List[int] = [] + self.all_nan_columns: Optional[Set[Union[int, str]]] = None + self._is_fitted = False def fit( @@ -86,7 +85,7 @@ def fit( # If a list was provided, it will be converted to pandas if isinstance(X_train, list): - X_train, X_test = self.list_to_dataframe(X_train, X_test) + X_train, X_test = self.list_to_pandas(X_train, X_test) self._check_data(X_train) @@ -120,6 +119,7 @@ def _fit( self: The fitted base estimator """ + raise NotImplementedError() def _check_data( @@ -129,11 +129,12 @@ def _check_data( """ Feature dimensionality and data type checks - Arguments: + Args: X (SUPPORTED_FEAT_TYPES): A set of features that are going to be validated (type and dimensionality checks) and a encoder fitted in the case the data needs encoding """ + raise NotImplementedError() def transform( @@ -150,4 +151,30 @@ def transform( np.ndarray: The transformed array """ + + raise NotImplementedError() + + def list_to_pandas( + self, + X_train: SUPPORTED_FEAT_TYPES, + X_test: Optional[SUPPORTED_FEAT_TYPES] = None, + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + """ + Converts a list to a pandas DataFrame. In this process, column types are inferred. + + If test data is provided, we proactively match it to train data + + Args: + X_train (SUPPORTED_FEAT_TYPES): + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + X_test (Optional[SUPPORTED_FEAT_TYPES]): + A hold out set of data used for checking + Returns: + pd.DataFrame: + transformed train data from list to pandas DataFrame + pd.DataFrame: + transformed test data from list to pandas DataFrame + """ + raise NotImplementedError() diff --git a/autoPyTorch/data/base_target_validator.py b/autoPyTorch/data/base_target_validator.py index 393f3d85b..03f155f61 100644 --- a/autoPyTorch/data/base_target_validator.py +++ b/autoPyTorch/data/base_target_validator.py @@ -48,7 +48,7 @@ def __init__(self, logging.Logger ] ] = None, - ): + ) -> None: self.is_classification = is_classification self.data_type: Optional[type] = None @@ -98,6 +98,7 @@ def fit( np.shape(y_test) )) if isinstance(y_train, pd.DataFrame): + y_train = cast(pd.DataFrame, y_train) y_test = cast(pd.DataFrame, y_test) if y_train.columns.tolist() != y_test.columns.tolist(): raise ValueError( @@ -143,7 +144,7 @@ def _fit( def transform( self, - y: Union[SUPPORTED_TARGET_TYPES], + y: SUPPORTED_TARGET_TYPES, ) -> np.ndarray: """ Args: diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index f81629217..fee7bc49b 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -9,18 +9,19 @@ import scipy.sparse import sklearn.utils -from sklearn import preprocessing from sklearn.base import BaseEstimator from sklearn.compose import ColumnTransformer from sklearn.exceptions import NotFittedError from sklearn.impute import SimpleImputer from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import OneHotEncoder, StandardScaler from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES def _create_column_transformer( preprocessors: Dict[str, List[BaseEstimator]], + numerical_columns: List[str], categorical_columns: List[str], ) -> ColumnTransformer: """ @@ -31,6 +32,8 @@ def _create_column_transformer( Args: preprocessors (Dict[str, List[BaseEstimator]]): Dictionary containing list of numerical and categorical preprocessors. + numerical_columns (List[str]): + List of names of numerical columns categorical_columns (List[str]): List of names of categorical columns @@ -38,11 +41,17 @@ def _create_column_transformer( ColumnTransformer """ - categorical_pipeline = make_pipeline(*preprocessors['categorical']) + numerical_pipeline = 'drop' + categorical_pipeline = 'drop' + if len(numerical_columns) > 0: + numerical_pipeline = make_pipeline(*preprocessors['numerical']) + if len(categorical_columns) > 0: + categorical_pipeline = make_pipeline(*preprocessors['categorical']) return ColumnTransformer([ - ('categorical_pipeline', categorical_pipeline, categorical_columns)], - remainder='passthrough' + ('categorical_pipeline', categorical_pipeline, categorical_columns), + ('numerical_pipeline', numerical_pipeline, numerical_columns)], + remainder='drop' ) @@ -57,11 +66,15 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]: preprocessors: Dict[str, List[BaseEstimator]] = dict() # Categorical Preprocessors - onehot_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value', - unknown_value=-1) + onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore') categorical_imputer = SimpleImputer(strategy='constant', copy=False) + # Numerical Preprocessors + numerical_imputer = SimpleImputer(strategy='median', copy=False) + standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False) + preprocessors['categorical'] = [categorical_imputer, onehot_encoder] + preprocessors['numerical'] = [numerical_imputer, standard_scaler] return preprocessors @@ -111,7 +124,6 @@ def _comparator(cmp1: str, cmp2: str) -> int: if cmp1 not in choices or cmp2 not in choices: raise ValueError('The comparator for the column order only accepts {}, ' 'but got {} and {}'.format(choices, cmp1, cmp2)) - idx1, idx2 = choices.index(cmp1), choices.index(cmp2) return idx1 - idx2 @@ -137,64 +149,45 @@ def _fit( # The final output of a validator is a numpy array. But pandas # gives us information about the column dtype if isinstance(X, np.ndarray): - X = self.numpy_array_to_pandas(X) + + X = self.numpy_to_pandas(X) + # Replace the data type from the previously saved type. + self.data_type = type(X) + # save all the information about the column order and data types + self._check_data(X) if hasattr(X, "iloc") and not scipy.sparse.issparse(X): + X = cast(pd.DataFrame, X) - # Treat a column with all instances a NaN as numerical - # This will prevent doing encoding to a categorical column made completely - # out of nan values -- which will trigger a fail, as encoding is not supported - # with nan values. - # Columns that are completely made of NaN values are provided to the pipeline - # so that later stages decide how to handle them - - # Clear whatever null column markers we had previously - self.null_columns.clear() - if np.any(pd.isnull(X)): - for column in X.columns: - if X[column].isna().all(): - self.null_columns.add(column) - X[column] = pd.to_numeric(X[column]) - # Also note this change in self.dtypes - if len(self.dtypes) != 0: - self.dtypes[list(X.columns).index(column)] = X[column].dtype - - if not X.select_dtypes(include='object').empty: - X = self.infer_objects(X) - self._check_data(X) - self.enc_columns, self.feat_type = self._get_columns_to_encode(X) + self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()]) - if len(self.transformed_columns) > 0: + categorical_columns, numerical_columns, feat_type = self._get_columns_info(X) - preprocessors = get_tabular_preprocessors() - self.column_transformer = _create_column_transformer( - preprocessors=preprocessors, - categorical_columns=self.transformed_columns, - ) + self.enc_columns = categorical_columns - # Mypy redefinition - assert self.column_transformer is not None - self.column_transformer.fit(X) + preprocessors = get_tabular_preprocessors() + self.column_transformer = _create_column_transformer( + preprocessors=preprocessors, + numerical_columns=numerical_columns, + categorical_columns=categorical_columns, + ) - # The column transformer reorders the feature types - # therefore, we need to change the order of columns as well - # This means categorical columns are shifted to the left - self.feat_type = sorted( - self.feat_type, - key=functools.cmp_to_key(self._comparator) - ) + # Mypy redefinition + assert self.column_transformer is not None + self.column_transformer.fit(X) - encoded_categories = self.column_transformer.\ - named_transformers_['categorical_pipeline'].\ - named_steps['ordinalencoder'].categories_ - self.categories = [ - # We fit an ordinal encoder, where all categorical - # columns are shifted to the left - list(range(len(cat))) - for cat in encoded_categories - ] + # The column transformer reorders the feature types + # therefore, we need to change the order of columns as well + # This means categorical columns are shifted to the left + + self.feat_type = sorted( + feat_type, + key=functools.cmp_to_key(self._comparator) + ) + # differently to categorical_columns and numerical_columns, + # this saves the index of the column. for i, type_ in enumerate(self.feat_type): if 'numerical' in type_: self.numerical_columns.append(i) @@ -203,6 +196,7 @@ def _fit( # Lastly, store the number of features self.num_features = np.shape(X)[1] + return self def transform( @@ -221,49 +215,70 @@ def transform( Return: np.ndarray: The transformed array + + Note: + The default transform performs the folloing: + * simple imputation for both + * scaling for numerical + * one-hot encoding for categorical + For example, here is a simple case + of which all the columns are categorical. + data = [ + {'A': 1, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'B': 3, 'C': np.nan}, + {'A': 2, 'B': np.nan, 'C': np.nan} + ] + and suppose all the columns are categorical, + then + * `A` in {np.nan, 1, 2} + * `B` in {np.nan, 3} + * `C` in {np.nan} <=== it will be dropped. + + So in the column A, + * np.nan ==> [1, 0, 0] (always the index 0) + * 1 ==> [0, 1, 0] + * 2 ==> [0, 0, 1] + in the column B, + * np.nan ==> [1, 0] + * 3 ==> [0, 1] + Therefore, by concatenating, + * {'A': 1, 'B': np.nan, 'C': np.nan} ==> [0, 1, 0, 1, 0] + * {'A': np.nan, 'B': 3, 'C': np.nan} ==> [1, 0, 0, 0, 1] + * {'A': 2, 'B': np.nan, 'C': np.nan} ==> [0, 0, 1, 1, 0] + ==> [ + [0, 1, 0, 1, 0], + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 0] + ] """ if not self._is_fitted: raise NotFittedError("Cannot call transform on a validator that is not fitted") # If a list was provided, it will be converted to pandas if isinstance(X, list): - X, _ = self.list_to_dataframe(X) + X, _ = self.list_to_pandas(X) if isinstance(X, np.ndarray): - X = self.numpy_array_to_pandas(X) + X = self.numpy_to_pandas(X) if hasattr(X, "iloc") and not scipy.sparse.issparse(X): - X = typing.cast(pd.DataFrame, X) - # If we had null columns in our fit call and we made them numeric, then: - # - If the columns are null even in transform, apply the same procedure. - # - Otherwise, substitute the values with np.NaN and then make the columns numeric. - # If the column is null here, but it was not in fit, it does not matter. - for column in self.null_columns: - # The column is not null, make it null since it was null in fit. - if not X[column].isna().all(): - X[column] = np.NaN - X[column] = pd.to_numeric(X[column]) - - # for the test set, if we have columns with only null values - # they will probably have a numeric type. If these columns were not - # with only null values in the train set, they should be converted - # to the type that they had during fitting. - for column in X.columns: - if X[column].isna().all(): - X[column] = X[column].astype(self.dtypes[list(X.columns).index(column)]) - - # Also remove the object dtype for new data - if not X.select_dtypes(include='object').empty: - X = self.infer_objects(X) + X = cast(pd.DataFrame, X) # Check the data here so we catch problems on new test data self._check_data(X) - # We also need to fillna on the transformation - # in case test data is provided - X = self.impute_nan_in_categories(X) - if self.encoder is not None: - X = self.encoder.transform(X) + # in case of test data being all none and train data + # having a value for a categorical column. + # We need to convert the column in test data to + # object otherwise the test column is interpreted as float + if len(self.categorical_columns) > 0: + categorical_columns = self.column_transformer.transformers_[0][-1] + for column in categorical_columns: + if X[column].isna().all(): + X[column] = X[column].astype('object') + + if self.column_transformer is not None: + X = self.column_transformer.transform(X) # Sparse related transformations # Not all sparse format support index sorting @@ -332,40 +347,39 @@ def _check_data( X = cast(pd.DataFrame, X) # Handle objects if possible - if not X.select_dtypes(include='object').empty: + exist_object_columns = has_object_columns(X.dtypes.values) + if exist_object_columns: X = self.infer_objects(X) - # Define the column to be encoded here as the feature validator is fitted once - # per estimator - self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) - column_order = [column for column in X.columns] if len(self.column_order) > 0: if self.column_order != column_order: - raise ValueError("Changing the column order of the features after fit() is " - "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format(self.column_order, - column_order,) - ) + raise ValueError("The column order of the features must not be changed after fit(), but" + " the column order are different between training ({}) and" + " test ({}) datasets.".format(self.column_order, column_order)) else: self.column_order = column_order dtypes = [dtype.name for dtype in X.dtypes] - if len(self.dtypes) > 0: - if self.dtypes != dtypes: - raise ValueError("Changing the dtype of the features after fit() is " - "not supported. Fit() method was called with " - "{} whereas the new features have {} as type".format(self.dtypes, - dtypes, - ) - ) - else: - self.dtypes = dtypes - def _get_columns_to_encode( + dtypes_diff = [s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)] + if len(self.dtypes) == 0: + self.dtypes = dtypes + elif ( + any(dtypes_diff) # the dtypes of some columns are different in train and test dataset + and self.all_nan_columns is not None # Ignore all_nan_columns is None + and len(set(X.columns[dtypes_diff]).difference(self.all_nan_columns)) != 0 + ): + # The dtypes can be different if and only if the column belongs + # to all_nan_columns as these columns would be imputed. + raise ValueError("The dtype of the features must not be changed after fit(), but" + " the dtypes of some columns are different between training ({}) and" + " test ({}) datasets.".format(self.dtypes, dtypes)) + + def _get_columns_info( self, X: pd.DataFrame, - ) -> Tuple[List[str], List[str]]: + ) -> Tuple[List[str], List[str], List[str]]: """ Return the columns to be encoded from a pandas dataframe @@ -375,9 +389,11 @@ def _get_columns_to_encode( checks) and an encoder fitted in the case the data needs encoding Returns: - transformed_columns (List[str]): - Columns to encode, if any - feat_type: + categorical_columns (List[str]) + List of the names of categorical columns. + numerical_columns (List[str]) + List of the names of numerical columns. + feat_type (List[str]) Type of each column numerical/categorical """ @@ -385,60 +401,54 @@ def _get_columns_to_encode( return self.transformed_columns, self.feat_type # Register if a column needs encoding - transformed_columns = [] - + numerical_columns = [] + categorical_columns = [] # Also, register the feature types for the estimator feat_type = [] # Make sure each column is a valid type - for column in X.columns: - if X[column].dtype.name in ['category', 'bool']: - - transformed_columns.append(column) + for i, column in enumerate(X.columns): + if self.all_nan_columns is not None and column in self.all_nan_columns: + continue + column_dtype = self.dtypes[i] + err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \ + "but input column {} has an invalid type `{}`.".format(column, column_dtype) + if column_dtype in ['category', 'bool']: + categorical_columns.append(column) feat_type.append('categorical') # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types - elif not is_numeric_dtype(X[column]): - if X[column].dtype.name == 'object': - raise ValueError( - "Input Column {} has invalid type object. " - "Cast it to a valid dtype before using it in AutoPyTorch. " - "Valid types are numerical, categorical or boolean. " - "You can cast it to a valid dtype using " - "pandas.Series.astype ." - "If working with string objects, the following " - "tutorial illustrates how to work with text data: " - "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( - # noqa: E501 - column, - ) - ) - elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype( - X[column].dtype - ): - raise ValueError( - "AutoPyTorch does not support time and/or date datatype as given " - "in column {}. Please convert the time information to a numerical value " - "first. One example on how to do this can be found on " - "https://stats.stackexchange.com/questions/311494/".format( - column, - ) - ) - else: - raise ValueError( - "Input Column {} has unsupported dtype {}. " - "Supported column types are categorical/bool/numerical dtypes. " - "Make sure your data is formatted in a correct way, " - "before feeding it to AutoPyTorch.".format( - column, - X[column].dtype.name, - ) + elif is_numeric_dtype(column_dtype): + feat_type.append('numerical') + numerical_columns.append(column) + elif column_dtype == 'object': + # TODO verify how would this happen when we always convert the object dtypes to category + raise TypeError( + "{} Cast it to a valid dtype before feeding it to AutoPyTorch. " + "You can cast it to a valid dtype using pandas.Series.astype." + "If you are working with string objects, the following " + "tutorial illustrates how to work with text data: " + "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( + # noqa: E501 + err_msg, ) + ) + elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(column_dtype): + raise TypeError( + "{} Convert the time information to a numerical value" + " before feeding it to AutoPyTorch. " + "One example of the conversion can be found on " + "https://stats.stackexchange.com/questions/311494/".format(err_msg) + ) else: - feat_type.append('numerical') - return transformed_columns, feat_type + raise TypeError( + "{} Make sure your data is formatted in a correct way" + "before feeding it to AutoPyTorch.".format(err_msg) + ) + + return categorical_columns, numerical_columns, feat_type - def list_to_dataframe( + def list_to_pandas( self, X_train: SUPPORTED_FEAT_TYPES, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, @@ -463,7 +473,7 @@ def list_to_dataframe( """ # If a list was provided, it will be converted to pandas - X_train = pd.DataFrame(data=X_train).infer_objects() + X_train = pd.DataFrame(data=X_train).convert_dtypes() self.logger.warning("The provided feature types to AutoPyTorch are of type list." "Features have been interpreted as: {}".format([(col, t) for col, t in zip(X_train.columns, X_train.dtypes)])) @@ -472,11 +482,12 @@ def list_to_dataframe( self.logger.warning("Train features are a list while the provided test data" "is {}. X_test will be casted as DataFrame.".format(type(X_test)) ) - X_test = pd.DataFrame(data=X_test).infer_objects() + X_test = pd.DataFrame(data=X_test).convert_dtypes() + return X_train, X_test - def numpy_array_to_pandas( - self, + @staticmethod + def numpy_to_pandas( X: np.ndarray, ) -> pd.DataFrame: """ @@ -515,10 +526,33 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}') pass else: + # Calling for the first time to infer the categories X = X.infer_objects() - for column in X.columns: - if not is_numeric_dtype(X[column]): + for column, data_type in zip(X.columns, X.dtypes): + if not is_numeric_dtype(data_type): X[column] = X[column].astype('category') - self.object_dtype_mapping = {column: X[column].dtype for column in X.columns} + + # only numerical attributes and categories + self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)} + self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}") + return X + + +def has_object_columns( + feature_types: pd.Series, +) -> bool: + """ + Indicate whether on a Series of dtypes for a Pandas DataFrame + there exists one or more object columns. + + Args: + feature_types (pd.Series): The feature types for a DataFrame. + + Returns: + bool: + True if the DataFrame dtypes contain an object column, False + otherwise. + """ + return np.dtype('O') in feature_types diff --git a/autoPyTorch/data/tabular_target_validator.py b/autoPyTorch/data/tabular_target_validator.py index c37dc81c3..3ac72df0e 100644 --- a/autoPyTorch/data/tabular_target_validator.py +++ b/autoPyTorch/data/tabular_target_validator.py @@ -96,7 +96,7 @@ def _fit( def transform( self, - y: Union[SUPPORTED_TARGET_TYPES], + y: SUPPORTED_TARGET_TYPES, ) -> np.ndarray: """ Validates and fit a categorical encoder (if needed) to the features. diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 617b2e4d0..0c48ac06d 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -333,13 +333,22 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset: to provide training data to fit a pipeline Args: - split (int): The desired subset of the dataset to split and use + split_id (int): which split id to get from the splits + train (bool): whether the dataset is required for training or evaluating. Returns: + Dataset: the reduced dataset to be used for testing """ # Subset creates a dataset. Splits is a (train_indices, test_indices) tuple - return TransformSubset(self, self.splits[split_id][0], train=train) + if split_id >= len(self.splits): # old version: split_id > len(self.splits) + raise IndexError("split_id out of range, got split_id={}" + " (>= num_splits={})".format(split_id, len(self.splits))) + subset = int(not train) + indices = self.splits[split_id][subset] + if indices is None: + raise ValueError("Specified fold (or subset) does not exist") + return TransformSubset(self, indices, train=train) def replace_data(self, X_train: BaseDatasetInputType, X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset': diff --git a/autoPyTorch/ensemble/singlebest_ensemble.py b/autoPyTorch/ensemble/singlebest_ensemble.py index 881ae5fd2..5fe6170af 100644 --- a/autoPyTorch/ensemble/singlebest_ensemble.py +++ b/autoPyTorch/ensemble/singlebest_ensemble.py @@ -3,7 +3,7 @@ import numpy as np -from smac.runhistory.runhistory import RunHistory +from smac.runhistory.runhistory import RunHistory, StatusType from autoPyTorch.automl_common.common.utils.backend import Backend from autoPyTorch.ensemble.abstract_ensemble import AbstractEnsemble @@ -49,6 +49,9 @@ def get_identifiers_from_run_history(self) -> List[Tuple[int, int, float]]: for run_key in self.run_history.data.keys(): run_value = self.run_history.data[run_key] + if run_value.status == StatusType.CRASHED: + continue + score = self.metric._optimum - (self.metric._sign * run_value.cost) if (score > best_model_score and self.metric._sign > 0) \ diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index 4e1fda3d9..fc15d9fed 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -22,7 +22,9 @@ get_match_array ) from autoPyTorch.utils.common import FitRequirement -from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates +from autoPyTorch.utils.hyperparameter_search_space_update import ( + HyperparameterSearchSpaceUpdates +) PipelineStepType = Union[autoPyTorchComponent, autoPyTorchChoice] @@ -405,6 +407,7 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], raise ValueError("Unknown node name. Expected update node name to be in {} " "got {}".format(self.named_steps.keys(), update.node_name)) node = self.named_steps[update.node_name] + node_name = node.__class__.__name__ # if node is a choice module if hasattr(node, 'get_components'): split_hyperparameter = update.hyperparameter.split(':') @@ -432,18 +435,18 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], if choice in exclude[update.node_name]: raise ValueError("Found {} in exclude".format(choice)) if choice not in components.keys(): - raise ValueError("Unknown hyperparameter for choice {}. " + raise ValueError("Unknown component choice for node {}. " "Expected update hyperparameter " - "to be in {} got {}".format(node.__class__.__name__, - components.keys(), choice)) + "to be in {}, but got {}".format(node_name, + components.keys(), choice)) # check if the component whose hyperparameter # needs to be updated is in components of the # choice module elif split_hyperparameter[0] not in components.keys(): - raise ValueError("Unknown hyperparameter for choice {}. " - "Expected update hyperparameter " - "to be in {} got {}".format(node.__class__.__name__, - components.keys(), split_hyperparameter[0])) + raise ValueError("Unknown component choice for node {}. " + "Expected update component " + "to be in {}, but got {}".format(node_name, + components.keys(), split_hyperparameter[0])) else: # check if hyperparameter is in the search space of the component component = components[split_hyperparameter[0]] @@ -456,14 +459,16 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], component.get_hyperparameter_search_space( dataset_properties=self.dataset_properties).get_hyperparameter_names()]): continue - raise ValueError("Unknown hyperparameter for component {}. " - "Expected update hyperparameter " - "to be in {} got {}".format(node.__class__.__name__, - component. - get_hyperparameter_search_space( - dataset_properties=self.dataset_properties). - get_hyperparameter_names(), - split_hyperparameter[1])) + component_hyperparameters = component.get_hyperparameter_search_space( + dataset_properties=self.dataset_properties).get_hyperparameter_names() + raise ValueError("Unknown hyperparameter for component {} of node {}." + " Expected update hyperparameter " + "to be in {}, but got {}.".format(component.__name__, + node_name, + component_hyperparameters, + split_hyperparameter[1] + ) + ) else: if update.hyperparameter not in node.get_hyperparameter_search_space( dataset_properties=self.dataset_properties): @@ -471,13 +476,13 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], node.get_hyperparameter_search_space( dataset_properties=self.dataset_properties).get_hyperparameter_names()]): continue - raise ValueError("Unknown hyperparameter for component {}. " + node_hyperparameters = node.get_hyperparameter_search_space( + dataset_properties=self.dataset_properties).get_hyperparameter_names() + raise ValueError("Unknown hyperparameter for node {}. " "Expected update hyperparameter " - "to be in {} got {}".format(node.__class__.__name__, - node. - get_hyperparameter_search_space( - dataset_properties=self.dataset_properties). - get_hyperparameter_names(), update.hyperparameter)) + "to be in {}, but got {}".format(node_name, + node_hyperparameters, + update.hyperparameter)) def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] ) -> List[Tuple[str, PipelineStepType]]: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index ea47e33b9..e8f95ab57 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -3,14 +3,14 @@ import numpy as np from sklearn.compose import ColumnTransformer -from sklearn.pipeline import make_pipeline +# from sklearn.pipeline import make_pipeline import torch from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import ( autoPyTorchTabularPreprocessingComponent ) -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers +# from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers from autoPyTorch.utils.common import FitRequirement, subsampler @@ -47,15 +47,16 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": Returns: "TabularColumnTransformer": an instance of self """ + self.check_requirements(X, y) - numerical_pipeline = 'drop' - categorical_pipeline = 'drop' + numerical_pipeline = 'passthrough' + categorical_pipeline = 'passthrough' - preprocessors = get_tabular_preprocessers(X) - if len(X['dataset_properties']['numerical_columns']): - numerical_pipeline = make_pipeline(*preprocessors['numerical']) - if len(X['dataset_properties']['categorical_columns']): - categorical_pipeline = make_pipeline(*preprocessors['categorical']) + # preprocessors = get_tabular_preprocessers(X) + # if len(X['dataset_properties']['numerical_columns']): + # numerical_pipeline = make_pipeline(*preprocessors['numerical']) + # if len(X['dataset_properties']['categorical_columns']): + # categorical_pipeline = make_pipeline(*preprocessors['categorical']) self.preprocessor = ColumnTransformer([ ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']), @@ -71,6 +72,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": X_train = X['backend'].load_datamanager().train_tensors[0] self.preprocessor.fit(X_train) + return self def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py index eadc0a188..9829cadcd 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - X.update({'encoder': self.preprocessor}) + # X.update({'encoder': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py index b65f3c229..ac0648481 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py @@ -29,5 +29,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - X.update({'imputer': self.preprocessor}) + # X.update({'imputer': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py index 39834dd2b..270fac246 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - X.update({'scaler': self.preprocessor}) + # X.update({'scaler': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index 8b75ab66a..02782e7a2 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -134,13 +134,15 @@ def _predict(self, network: torch.nn.Module, loader: torch.utils.data.DataLoader # Batch prediction Y_batch_preds = list() - for i, (X_batch, Y_batch) in enumerate(loader): - # Predict on batch - X_batch = X_batch.float().to(self.device) - Y_batch_pred = network(X_batch) - if self.final_activation is not None: - Y_batch_pred = self.final_activation(Y_batch_pred) - Y_batch_preds.append(Y_batch_pred.detach().cpu()) + # `torch.no_grad` reduces memory usage even after `model.eval()` + with torch.no_grad(): + for i, (X_batch, Y_batch) in enumerate(loader): + # Predict on batch + X_batch = X_batch.float().to(self.device) + Y_batch_pred = network(X_batch) + if self.final_activation is not None: + Y_batch_pred = self.final_activation(Y_batch_pred) + Y_batch_preds.append(Y_batch_pred.detach().cpu()) return torch.cat(Y_batch_preds, 0) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 5388bfcc4..5f71825be 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -140,6 +140,14 @@ def get_hyperparameter_search_space( value_range=(True, False), default_value=True, ), + shake_shake_update_func: HyperparameterSearchSpace = HyperparameterSearchSpace( + hyperparameter="shake_shake_update_func", + value_range=('shake-shake', + 'shake-even', + 'even-even', + 'M3'), + default_value='shake-shake', + ), use_shake_drop: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="use_shake_drop", value_range=(True, False), default_value=True, @@ -181,16 +189,25 @@ def get_hyperparameter_search_space( if skip_connection_flag: + shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter) cs.add_hyperparameter(mb_choice) cs.add_condition(CS.EqualsCondition(mb_choice, use_sc, True)) + shake_shake_update_func_conditional: List[str] = list() if shake_drop_prob_flag: shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter) cs.add_hyperparameter(shake_drop_prob) cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop")) + shake_shake_update_func_conditional.append('shake-drop') + if shake_shake_flag: + shake_shake_update_func_conditional.append('shake-shake') + if len(shake_shake_update_func_conditional) > 0: + method = get_hyperparameter(shake_shake_update_func, CategoricalHyperparameter) + cs.add_hyperparameter(method) + cs.add_condition(CS.InCondition(method, mb_choice, shake_shake_update_func_conditional)) # It is the upper bound of the nr of groups, # since the configuration will actually be sampled. @@ -259,7 +276,7 @@ def __init__( # if in != out the shortcut needs a linear layer to match the result dimensions # if the shortcut needs a layer we apply batchnorm and activation to the shortcut # as well (start_norm) - if in_features != out_features: + if in_features != out_features and self.config["use_skip_connection"]: self.shortcut = nn.Linear(in_features, out_features) initial_normalization = list() if self.config['use_batch_norm']: @@ -289,13 +306,6 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module: if self.config['use_batch_norm']: layers.append(nn.BatchNorm1d(in_features)) layers.append(self.activation()) - elif not self.config['use_skip_connection']: - # if start norm is not None and skip connection is False - # we will never apply the start_norm for the first layer in the block, - # which is why we should account for this case. - if self.config['use_batch_norm']: - layers.append(nn.BatchNorm1d(in_features)) - layers.append(self.activation()) layers.append(nn.Linear(in_features, out_features)) @@ -311,9 +321,6 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - if self.config["use_skip_connection"]: - residual = x - # if shortcut is not none we need a layer such that x matches the output dimension if self.shortcut is not None and self.start_norm is not None: # in this case self.start_norm is also != none @@ -321,38 +328,42 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # in front of shortcut and layers. Note that in this case layers # does not start with batchnorm+activation but with the first linear layer # (see _build_block). As a result if in_features == out_features - # -> result = x + W(~D(A(BN(W(A(BN(x)))))) + # -> result = x + W_2(~D(A(BN(W_1(A(BN(x)))))) # if in_features != out_features # -> result = W_shortcut(A(BN(x))) + W_2(~D(A(BN(W_1(A(BN(x)))))) x = self.start_norm(x) - if self.config["use_skip_connection"]: - residual = self.shortcut(x) - - # TODO make the below code better - if self.config["use_skip_connection"]: - if self.config["multi_branch_choice"] == 'shake-shake': - x1 = self.layers(x) - x2 = self.shake_shake_layers(x) - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) - x = shake_shake(x1, x2, alpha, beta) - else: - x = self.layers(x) + residual = self.shortcut(x) + elif self.config["use_skip_connection"]: + # We use a skip connection but we do not need to match dimensions + residual = x + else: # Early-return because no need of skip connection + return self.layers(x) + + if self.config["multi_branch_choice"] == 'shake-shake': + x1 = self.layers(x) + x2 = self.shake_shake_layers(x) + alpha, beta = shake_get_alpha_beta( + is_training=self.training, + is_cuda=x.is_cuda, + method=self.config['shake_shake_update_func'], + ) + x = shake_shake(x1, x2, alpha, beta) + elif self.config["multi_branch_choice"] == 'shake-drop': + x = self.layers(x) + alpha, beta = shake_get_alpha_beta( + is_training=self.training, + is_cuda=x.is_cuda, + method=self.config['shake_shake_update_func'], + ) + bl = shake_drop_get_bl( + self.block_index, + 1 - self.config["max_shake_drop_probability"], + self.num_blocks, + self.training, + x.is_cuda, + ) + x = shake_drop(x, alpha, beta, bl) else: x = self.layers(x) - if self.config["use_skip_connection"]: - if self.config["multi_branch_choice"] == 'shake-drop': - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) - bl = shake_drop_get_bl( - self.block_index, - 1 - self.config["max_shake_drop_probability"], - self.num_blocks, - self.training, - x.is_cuda, - ) - x = shake_drop(x, alpha, beta, bl) - - if self.config["use_skip_connection"]: - x = x + residual - - return x + return x + residual diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index a9e1f011e..2e4fa53c5 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -73,6 +73,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential: ) if self.config['use_batch_norm']: layers.append(torch.nn.BatchNorm1d(self.config["num_units_%i" % self.config['num_groups']])) + layers.append(_activations[self.config["activation"]]()) backbone = torch.nn.Sequential(*layers) return backbone @@ -145,6 +146,14 @@ def get_hyperparameter_search_space( # type: ignore[override] 'stairs'), default_value='funnel', ), + shake_shake_update_func: HyperparameterSearchSpace = HyperparameterSearchSpace( + hyperparameter="shake_shake_update_func", + value_range=('shake-shake', + 'shake-even', + 'even-even', + 'M3'), + default_value='shake-shake', + ), max_shake_drop_probability: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="max_shake_drop_probability", value_range=(0, 1), @@ -188,17 +197,24 @@ def get_hyperparameter_search_space( # type: ignore[override] if skip_connection_flag: - shake_drop_prob_flag = False - if 'shake-drop' in multi_branch_choice.value_range: - shake_drop_prob_flag = True + shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range + shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter) cs.add_hyperparameter(mb_choice) cs.add_condition(CS.EqualsCondition(mb_choice, use_sc, True)) + shake_shake_update_func_conditional: List[str] = list() if shake_drop_prob_flag: shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter) cs.add_hyperparameter(shake_drop_prob) cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop")) + shake_shake_update_func_conditional.append('shake-drop') + if shake_shake_flag: + shake_shake_update_func_conditional.append('shake-shake') + if len(shake_shake_update_func_conditional) > 0: + method = get_hyperparameter(shake_shake_update_func, CategoricalHyperparameter) + cs.add_hyperparameter(method) + cs.add_condition(CS.InCondition(method, mb_choice, shake_shake_update_func_conditional)) return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py index bb1a93ac1..57304a6e4 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py @@ -28,6 +28,7 @@ def get_output_shape(network: torch.nn.Module, input_shape: Tuple[int, ...] placeholder = torch.randn((2, *input_shape), dtype=torch.float) with torch.no_grad(): output = network(placeholder) + return tuple(output.shape[1:]) @@ -82,7 +83,11 @@ class ShakeDropFunction(Function): Github URL: https://github.com/owruby/shake-drop_pytorch/blob/master/models/shakedrop.py """ @staticmethod +<<<<<<< HEAD def forward(ctx: Any, +======= + def forward(ctx: typing.Any, +>>>>>>> Bug fixes (#249) x: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, @@ -109,15 +114,31 @@ def backward(ctx: Any, shake_drop = ShakeDropFunction.apply +<<<<<<< HEAD def shake_get_alpha_beta(is_training: bool, is_cuda: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ The methods used in this function have been introduced in 'ShakeShake Regularisation' Currently, this function supports `shake-shake`. +======= +def shake_get_alpha_beta( + is_training: bool, + is_cuda: bool, + method: str +) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + The methods used in this function have been introduced in 'ShakeShake Regularisation' + Each method name is available in the referred paper. + Currently, this function supports `even-even`, `shake-even`, `shake-shake` and `M3`. +>>>>>>> Bug fixes (#249) Args: is_training (bool): Whether the computation for the training is_cuda (bool): Whether the tensor is on CUDA +<<<<<<< HEAD +======= + method (str): The shake method either `even-even`, `shake-even`, `shake-shake` or `M3` +>>>>>>> Bug fixes (#249) Returns: alpha, beta (Tuple[float, float]): @@ -129,17 +150,34 @@ def shake_get_alpha_beta(is_training: bool, is_cuda: bool Author: Xavier Gastaldi URL: https://arxiv.org/abs/1705.07485 +<<<<<<< HEAD Note: The names have been taken from the paper as well. Currently, this function supports `shake-shake`. +======= + The names have been taken from the paper as well. + Currently, this function supports `even-even`, `shake-even`, `shake-shake` and `M3`. +>>>>>>> Bug fixes (#249) """ if not is_training: result = (torch.FloatTensor([0.5]), torch.FloatTensor([0.5])) return result if not is_cuda else (result[0].cuda(), result[1].cuda()) # TODO implement other update methods - alpha = torch.rand(1) - beta = torch.rand(1) + # alpha is the weight ratio for the forward pass and beta is that for the backward pass + alpha = torch.FloatTensor([0.5]) if method.startswith('even') else torch.rand(1) + if method.endswith('even'): + beta = torch.FloatTensor([0.5]) + elif method.endswith('shake'): + beta = torch.rand(1) + elif method == 'M3': + # Table 4 in the paper `Shake-Shake regularization` + rnd = torch.rand(1) + beta = torch.FloatTensor( + [rnd * (0.5 - alpha) + alpha if alpha < 0.5 else rnd * (alpha - 0.5) + 0.5] + ) + else: + raise ValueError(f"Unknown method `{method}` for ShakeShakeRegularisation in NetworkBackbone") if is_cuda: alpha = alpha.cuda() @@ -149,16 +187,27 @@ def shake_get_alpha_beta(is_training: bool, is_cuda: bool def shake_drop_get_bl( +<<<<<<< HEAD block_index: int, min_prob_no_shake: float, num_blocks: int, is_training: bool, is_cuda: bool +======= + block_index: int, + min_prob_no_shake: float, + num_blocks: int, + is_training: bool, + is_cuda: bool +>>>>>>> Bug fixes (#249) ) -> torch.Tensor: """ The sampling of Bernoulli random variable based on Eq. (4) in the paper +<<<<<<< HEAD +======= +>>>>>>> Bug fixes (#249) Args: block_index (int): The index of the block from the input layer min_prob_no_shake (float): The initial shake probability @@ -168,18 +217,28 @@ def shake_drop_get_bl( Returns: bl (torch.Tensor): a Bernoulli random variable in {0, 1} +<<<<<<< HEAD +======= +>>>>>>> Bug fixes (#249) Reference: ShakeDrop Regularization for Deep Residual Learning Yoshihiro Yamada et. al. (2020) paper: https://arxiv.org/pdf/1802.02375.pdf implementation: https://github.com/imenurok/ShakeDrop """ +<<<<<<< HEAD pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake) if is_training: # Move to torch.rand(1) for reproducibility +======= + pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake) + + if is_training: + # Move to torch.randn(1) for reproducibility +>>>>>>> Bug fixes (#249) bl = torch.as_tensor(1.0) if torch.rand(1) <= pl else torch.as_tensor(0.0) else: bl = torch.as_tensor(pl) diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py index 5ae2880ed..6feac0fba 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py @@ -1,4 +1,4 @@ -import copy +# import copy from typing import Any, Dict, Optional, Tuple import numpy as np @@ -31,21 +31,22 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module: raise NotImplementedError - def _get_args(self, X: Dict[str, Any]) -> Tuple[int, np.ndarray]: + def _get_args(self, X: Dict[str, Any]) -> Tuple[None, None]: # Tuple[int, np.ndarray]: # Feature preprocessors can alter numerical columns - if len(X['dataset_properties']['numerical_columns']) == 0: - num_numerical_columns = 0 - else: - X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) - - numerical_column_transformer = X['tabular_transformer'].preprocessor. \ - named_transformers_['numerical_pipeline'] - num_numerical_columns = numerical_column_transformer.transform( - X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] - num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), - dtype=int) - categories = X['dataset_properties']['categories'] - - for i, category in enumerate(categories): - num_input_features[num_numerical_columns + i, ] = len(category) - return num_numerical_columns, num_input_features + # if len(X['dataset_properties']['numerical_columns']) == 0: + # num_numerical_columns = 0 + # else: + # X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) + # + # numerical_column_transformer = X['tabular_transformer'].preprocessor. \ + # named_transformers_['numerical_pipeline'] + # num_numerical_columns = numerical_column_transformer.transform( + # X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] + # num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), + # dtype=int) + # categories = X['dataset_properties']['categories'] + # + # for i, category in enumerate(categories): + # num_input_features[num_numerical_columns + i, ] = len(category) + # return num_numerical_columns, num_input_features + return None, None diff --git a/autoPyTorch/pipeline/components/setup/network_head/no_head.py b/autoPyTorch/pipeline/components/setup/network_head/no_head.py index 870f680fb..0e711f06c 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/no_head.py +++ b/autoPyTorch/pipeline/components/setup/network_head/no_head.py @@ -23,7 +23,6 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] layers = [] in_features = np.prod(input_shape).item() out_features = np.prod(output_shape).item() - layers.append(_activations[self.config["activation"]]()) layers.append(nn.Linear(in_features=in_features, out_features=out_features)) return nn.Sequential(*layers) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py index 3ae84a9e0..348fb4925 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py @@ -97,9 +97,9 @@ def get_hyperparameter_search_space( default_value=True, ), weight_decay: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="weight_decay", - value_range=(1E-7, 0.1), + value_range=(1E-5, 0.1), default_value=1E-4, - log=True), + log=False), ) -> ConfigurationSpace: cs = ConfigurationSpace() diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index e4d334d01..3509515b2 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -115,7 +115,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader: shuffle=True, num_workers=X.get('num_workers', 0), pin_memory=X.get('pin_memory', True), - drop_last=X.get('drop_last', True), + drop_last=X.get('drop_last', False), collate_fn=custom_collate_fn, ) @@ -154,6 +154,8 @@ def get_loader(self, X: np.ndarray, y: Optional[np.ndarray] = None, batch_size: train_tensors=(X, y), seed=self.random_state.get_state()[1][0], # This dataset is used for loading test data in a batched format + seed=self.random_state.get_state()[1][0], + shuffle=False, train_transforms=self.test_transform, val_transforms=self.test_transform, ) diff --git a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py index 157924ed0..7f5385382 100644 --- a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py @@ -31,7 +31,7 @@ class AdversarialTrainer(BaseTrainerComponent): def __init__( self, epsilon: float, - weighted_loss: bool = False, + weighted_loss: int = 0, random_state: Optional[np.random.RandomState] = None, use_stochastic_weight_averaging: bool = False, use_snapshot_ensemble: bool = False, @@ -157,7 +157,7 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None 'shortname': 'AdversarialTrainer', 'name': 'AdversarialTrainer', 'handles_tabular': True, - 'handles_image': False, + 'handles_image': True, 'handles_time_series': False, } @@ -166,8 +166,8 @@ def get_hyperparameter_search_space( dataset_properties: Optional[Dict] = None, weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="weighted_loss", - value_range=(True, False), - default_value=True), + value_range=(1, ), + default_value=1), la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="la_steps", value_range=(5, 10), @@ -192,16 +192,21 @@ def get_hyperparameter_search_space( default_value=True), se_lastk: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="se_lastk", - value_range=(3,), + value_range=(3, ), default_value=3), epsilon: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="epsilon", - value_range=(0.05, 0.2), - default_value=0.2), + value_range=(0.001, 0.15), + default_value=0.007, + log=True), ) -> ConfigurationSpace: cs = ConfigurationSpace() + epsilon = HyperparameterSearchSpace(hyperparameter="epsilon", + value_range=(0.007, 0.007), + default_value=0.007) add_hyperparameter(cs, epsilon, UniformFloatHyperparameter) + add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter) snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range) @@ -229,9 +234,17 @@ def get_hyperparameter_search_space( parent_hyperparameter=parent_hyperparameter ) + """ # TODO, decouple the weighted loss from the trainer if dataset_properties is not None: if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter) + """ + # TODO, decouple the weighted loss from the trainer. Uncomment the code above and + # remove the code below. Also update the method signature, so the weighted loss + # is not a constant. + if dataset_properties is not None: + if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: + add_hyperparameter(cs, weighted_loss, Constant) return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py index e36faf121..3f7866f3c 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py @@ -1,4 +1,4 @@ -import typing +from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -11,7 +11,7 @@ class RowCutMixTrainer(MixUp, BaseTrainerComponent): def data_preparation(self, X: np.ndarray, y: np.ndarray, - ) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]: + ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: """ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is, in standard training we provide the data to the @@ -26,39 +26,38 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, np.ndarray: that processes data typing.Dict[str, np.ndarray]: arguments to the criterion function """ - alpha, beta = 1.0, 1.0 - lam = self.random_state.beta(alpha, beta) - batch_size = X.shape[0] - device = torch.device('cuda' if X.is_cuda else 'cpu') - permed_indices = torch.randperm(batch_size).to(device) + beta = 1.0 + lam = self.random_state.beta(beta, beta) + batch_size, n_columns = np.shape(X) + # shuffled_indices: Shuffled version of torch.arange(batch_size) + shuffled_indices = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size) r = self.random_state.rand(1) if beta <= 0 or r > self.alpha: - return X, {'y_a': y, 'y_b': y[permed_indices], 'lam': 1} + return X, {'y_a': y, 'y_b': y[shuffled_indices], 'lam': 1} - # batch_size (permutation of rows), col_size = X.shape - col_size = X.shape[1] - col_indices = torch.tensor( + cut_column_indices = torch.as_tensor( self.random_state.choice( - range(col_size), - max(1, int(col_size * lam)), - replace=False - ) + range(n_columns), + max(1, np.int32(n_columns * lam)), + replace=False, + ), ) - # Replace selected columns with columns from another data point - X[:, col_indices] = X[permed_indices, :][:, col_indices] + # Replace the values in `cut_indices` columns with + # the values from `permed_indices` + X[:, cut_column_indices] = X[shuffled_indices, :][:, cut_column_indices] - # Adjust lam - lam = 1 - len(col_indices) / X.shape[1] + # Since we cannot cut exactly `lam x 100 %` of rows, we need to adjust the `lam` + lam = 1 - (len(cut_column_indices) / n_columns) - y_a, y_b = y, y[permed_indices] + y_a, y_b = y, y[shuffled_indices] return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam} @staticmethod - def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.Any]] = None - ) -> typing.Dict[str, typing.Union[str, bool]]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None + ) -> Dict[str, Union[str, bool]]: return { 'shortname': 'RowCutMixTrainer', 'name': 'MixUp Regularized with Cutoff Tabular Trainer', diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py index 5e1b9a1a3..4578082cb 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py @@ -1,9 +1,7 @@ -import typing +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -import torch - from autoPyTorch.pipeline.components.training.trainer.base_trainer import BaseTrainerComponent from autoPyTorch.pipeline.components.training.trainer.cutout_utils import CutOut @@ -17,13 +15,8 @@ class RowCutOutTrainer(CutOut, BaseTrainerComponent): Github URL: https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py#L36-L68 """ - # 0 is non-informative in image data - NUMERICAL_VALUE = 0 - # -1 is the conceptually equivalent to 0 in a image, i.e. 0-pad - CATEGORICAL_VALUE = -1 - def data_preparation(self, X: np.ndarray, y: np.ndarray, - ) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]: + ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: """ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is, in standard training we provide the data to the @@ -36,9 +29,8 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, Returns: np.ndarray: that processes data - typing.Dict[str, np.ndarray]: arguments to the criterion function + Dict[str, np.ndarray]: arguments to the criterion function """ - r = self.random_state.rand(1) if r > self.cutout_prob: y_a = y @@ -46,30 +38,23 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, lam = 1 return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam} - # (batch_size (permutation of rows), col_size) = X.shape - col_size = X.shape[1] - col_indices = self.random_state.choice(range(col_size), max(1, int(col_size * self.patch_ratio)), - replace=False) - - if not isinstance(self.numerical_columns, typing.Iterable): - raise ValueError("numerical_columns in {} must be iterable, " - "but got {}.".format(self.__class__.__name__, - self.numerical_columns)) - - numerical_indices = torch.tensor(self.numerical_columns) - categorical_indices = torch.tensor([idx for idx in col_indices if idx not in self.numerical_columns]) - - X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE - X[:, numerical_indices.long()] = self.NUMERICAL_VALUE + size: int = np.shape(X)[1] + cut_column_indices = self.random_state.choice( + range(size), + max(1, np.int32(size * self.patch_ratio)), + replace=False, + ) + # Mask the selected features as 0 + X[:, cut_column_indices] = 0 lam = 1 y_a = y y_b = y return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam} @staticmethod - def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.Any]] = None - ) -> typing.Dict[str, typing.Union[str, bool]]: + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None + ) -> Dict[str, Union[str, bool]]: return { 'shortname': 'RowCutOutTrainer', 'name': 'RowCutOutTrainer', diff --git a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py index 9e44399fd..fc5cc3e3a 100644 --- a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py @@ -14,7 +14,7 @@ class StandardTrainer(BaseTrainerComponent): def __init__(self, - weighted_loss: bool = False, + weighted_loss: int = 0, use_stochastic_weight_averaging: bool = False, use_snapshot_ensemble: bool = False, se_lastk: int = 3, @@ -25,7 +25,7 @@ def __init__(self, This class handles the training of a network for a single given epoch. Args: - weighted_loss (bool): whether to use weighted loss + weighted_loss (int): whether to use weighted loss """ super().__init__(random_state=random_state, diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index cca87a18b..ce35518a9 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -287,7 +287,6 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice) if self.choice.use_snapshot_ensemble: X['network_snapshots'].extend(self.choice.model_snapshots) - return self.choice def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoice': @@ -412,14 +411,15 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic torch.cuda.empty_cache() if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated: + # update batch norm statistics swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double()) # change model update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict()) if self.choice.use_snapshot_ensemble: - for model in self.choice.model_snapshots: - swa_utils.update_bn(loader=X['train_data_loader'], model=model.double()) + # we update only the last network which pertains to the stochastic weight averaging model + swa_utils.update_bn(X['train_data_loader'], self.choice.model_snapshots[-1].double()) # wrap up -- add score if not evaluating every epoch if not self.eval_valid_each_epoch(X): diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index e6982be09..b17db9d13 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -196,7 +196,7 @@ class BaseTrainerComponent(autoPyTorchTrainingComponent): """ Base class for training Args: - weighted_loss (bool, default=False): In case for classification, whether to weight + weighted_loss (int, default=0): In case for classification, whether to weight the loss function according to the distribution of classes in the target use_stochastic_weight_averaging (bool, default=True): whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of @@ -211,7 +211,7 @@ class BaseTrainerComponent(autoPyTorchTrainingComponent): random_state: **lookahead_config: """ - def __init__(self, weighted_loss: bool = False, + def __init__(self, weighted_loss: int = 0, use_stochastic_weight_averaging: bool = True, use_snapshot_ensemble: bool = True, se_lastk: int = 3, @@ -342,8 +342,12 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: if self.use_snapshot_ensemble: assert self.model_snapshots is not None, "model snapshots container can't be " \ "none when snapshot ensembling is enabled" - model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \ - else deepcopy(self.model) + is_last_epoch = (epoch == self.budget_tracker.max_epochs) + if is_last_epoch and self.use_stochastic_weight_averaging: + model_copy = deepcopy(self.swa_model) + else: + model_copy = deepcopy(self.model) + assert model_copy is not None model_copy.cpu() self.model_snapshots.append(model_copy) @@ -579,8 +583,8 @@ def get_hyperparameter_search_space( dataset_properties: Optional[Dict] = None, weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="weighted_loss", - value_range=(True, False), - default_value=True), + value_range=(1, ), + default_value=1), la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="la_steps", value_range=(5, 10), @@ -605,7 +609,7 @@ def get_hyperparameter_search_space( default_value=True), se_lastk: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="se_lastk", - value_range=(3,), + value_range=(3, ), default_value=3), ) -> ConfigurationSpace: cs = ConfigurationSpace() @@ -637,9 +641,17 @@ def get_hyperparameter_search_space( parent_hyperparameter=parent_hyperparameter ) + """ # TODO, decouple the weighted loss from the trainer if dataset_properties is not None: if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter) + """ + # TODO, decouple the weighted loss from the trainer. Uncomment the code above and + # remove the code below. Also update the method signature, so the weighted loss + # is not a constant. + if dataset_properties is not None: + if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: + add_hyperparameter(cs, weighted_loss, Constant) return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py b/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py index 1b987d599..4feedf5cb 100644 --- a/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py +++ b/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py @@ -20,7 +20,7 @@ class CutOut: def __init__(self, patch_ratio: float, cutout_prob: float, - weighted_loss: bool = False, + weighted_loss: int = 0, random_state: Optional[np.random.RandomState] = None, use_stochastic_weight_averaging: bool = False, use_snapshot_ensemble: bool = False, @@ -63,8 +63,8 @@ def get_hyperparameter_search_space( dataset_properties: Optional[Dict] = None, weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="weighted_loss", - value_range=(True, False), - default_value=True), + value_range=(1, ), + default_value=1), la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="la_steps", value_range=(5, 10), @@ -136,9 +136,17 @@ def get_hyperparameter_search_space( parent_hyperparameter=parent_hyperparameter ) + """ # TODO, decouple the weighted loss from the trainer if dataset_properties is not None: if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter) + """ + # TODO, decouple the weighted loss from the trainer. Uncomment the code above and + # remove the code below. Also update the method signature, so the weighted loss + # is not a constant. + if dataset_properties is not None: + if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: + add_hyperparameter(cs, weighted_loss, Constant) return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py b/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py index e33011bf5..e2ea25148 100644 --- a/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py +++ b/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py @@ -26,7 +26,7 @@ class MixUp: Github URL: https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py#L119-L138 """ def __init__(self, alpha: float, - weighted_loss: bool = False, + weighted_loss: int = 0, random_state: Optional[np.random.RandomState] = None, use_stochastic_weight_averaging: bool = False, use_snapshot_ensemble: bool = False, @@ -68,8 +68,8 @@ def get_hyperparameter_search_space( dataset_properties: Optional[Dict] = None, weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="weighted_loss", - value_range=(True, False), - default_value=True), + value_range=(1, ), + default_value=1), la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="la_steps", value_range=(5, 10), @@ -94,7 +94,7 @@ def get_hyperparameter_search_space( default_value=True), se_lastk: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="se_lastk", - value_range=(3,), + value_range=(3, ), default_value=3), alpha: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="alpha", @@ -134,9 +134,18 @@ def get_hyperparameter_search_space( la_config_space, parent_hyperparameter=parent_hyperparameter ) + + """ # TODO, decouple the weighted loss from the trainer if dataset_properties is not None: if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter) + """ + # TODO, decouple the weighted loss from the trainer. Uncomment the code above and + # remove the code below. Also update the method signature, so the weighted loss + # is not a constant. + if dataset_properties is not None: + if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS: + add_hyperparameter(cs, weighted_loss, Constant) return cs diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index b95de512e..db2708dcb 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -137,10 +137,23 @@ def _predict_proba(self, X: np.ndarray) -> np.ndarray: # The final layer is always softmax now (`pred` already gives pseudo proba) return pred else: +<<<<<<< HEAD raise ValueError("Expected output_shape to be integer, got {}," "Tabular Classification only supports 'binary' and 'multiclass' outputs" "got {}".format(type(self.dataset_properties['output_shape']), self.dataset_properties['output_type'])) +======= + all_proba = [] + + for k in range(self.dataset_properties['output_shape']): + proba_k = pred[:, k, :self.dataset_properties['output_shape'][k]] + normalizer = proba_k.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba_k /= normalizer + all_proba.append(proba_k) + + return np.array(all_proba) +>>>>>>> Bug fixes (#249) def predict_proba(self, X: np.ndarray, batch_size: Optional[int] = None) -> np.ndarray: """predict_proba. diff --git a/autoPyTorch/utils/backend.py b/autoPyTorch/utils/backend.py new file mode 100644 index 000000000..5348bd11c --- /dev/null +++ b/autoPyTorch/utils/backend.py @@ -0,0 +1,575 @@ +import glob +import os +import pickle +import re +import shutil +import tempfile +import time +import uuid +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import lockfile + +import numpy as np + +from autoPyTorch.datasets.base_dataset import BaseDataset +from autoPyTorch.ensemble.abstract_ensemble import AbstractEnsemble +from autoPyTorch.pipeline.base_pipeline import BasePipeline +from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger + +__all__ = [ + 'Backend' +] + + +def create( + temporary_directory: Optional[str], + output_directory: Optional[str], + delete_tmp_folder_after_terminate: bool = True, + delete_output_folder_after_terminate: bool = True, +) -> 'Backend': + """ + Creates a backend object that manages disk related transactions + + Args: + temporary_directory (str): where all temporal data is to be dumped + output_directory (str): where all predictions are to be output + delete_tmp_folder_after_terminate (bool): whether to delete the + temporal directory when then run completes + delete_output_folder_after_terminate (bool): whether to delete + the output directory when the run completes + + Returns: + Backend object + """ + context = BackendContext(temporary_directory, output_directory, + delete_tmp_folder_after_terminate, + delete_output_folder_after_terminate, + ) + backend = Backend(context) + + return backend + + +def get_randomized_directory_name(temporary_directory: Optional[str] = None) -> str: + uuid_str = str(uuid.uuid1(clock_seq=os.getpid())) + + temporary_directory = ( + temporary_directory + if temporary_directory + else os.path.join( + tempfile.gettempdir(), + "autoPyTorch_tmp_{}".format( + uuid_str, + ), + ) + ) + + return temporary_directory + + +class BackendContext(object): + + def __init__(self, + temporary_directory: Optional[str], + output_directory: Optional[str], + delete_tmp_folder_after_terminate: bool, + delete_output_folder_after_terminate: bool, + ): + + # Check that the names of tmp_dir and output_dir is not the same. + if temporary_directory == output_directory and temporary_directory is not None: + raise ValueError("The temporary and the output directory " + "must be different.") + + self.delete_tmp_folder_after_terminate = delete_tmp_folder_after_terminate + self.delete_output_folder_after_terminate = delete_output_folder_after_terminate + # attributes to check that directories were created by autoPyTorch + self._tmp_dir_created = False + self._output_dir_created = False + + self._temporary_directory = ( + get_randomized_directory_name( + temporary_directory=temporary_directory, + ) + ) + self._output_directory = output_directory + self.create_directories() + self._logger = None # type: Optional[PicklableClientLogger] + + @property + def output_directory(self) -> Optional[str]: + if self._output_directory is not None: + # make sure that tilde does not appear on the path. + return os.path.expanduser(os.path.expandvars(self._output_directory)) + else: + return None + + @property + def temporary_directory(self) -> str: + # make sure that tilde does not appear on the path. + return os.path.expanduser(os.path.expandvars(self._temporary_directory)) + + def create_directories(self) -> None: + # Exception is raised if self.temporary_directory already exists. + os.makedirs(self.temporary_directory) + self._tmp_dir_created = True + + # Exception is raised if self.output_directory already exists. + if self.output_directory is not None: + os.makedirs(self.output_directory) + self._output_dir_created = True + + def delete_directories(self, force: bool = True) -> None: + if self.output_directory and (self.delete_output_folder_after_terminate or force): + if self._output_dir_created is False: + raise ValueError("Failed to delete output dir: %s because autoPyTorch did not " + "create it. Please make sure that the specified output dir does " + "not exist when instantiating autoPyTorch." + % self.output_directory) + try: + shutil.rmtree(self.output_directory) + except Exception: + try: + if self._logger is not None: + self._logger.warning("Could not delete output dir: %s" % + self.output_directory) + else: + warnings.warn("Could not delete output dir: %s" % self.output_directory) + except Exception: + warnings.warn("Could not delete output dir: %s" % self.output_directory) + + if self.delete_tmp_folder_after_terminate or force: + if self._tmp_dir_created is False: + raise ValueError("Failed to delete tmp dir: % s because autoPyTorch did not " + "create it. Please make sure that the specified tmp dir does not " + "exist when instantiating autoPyTorch." + % self.temporary_directory) + try: + shutil.rmtree(self.temporary_directory) + except Exception: + try: + if self._logger is not None: + self._logger.warning( + "Could not delete tmp dir: %s" % self.temporary_directory) + else: + warnings.warn("Could not delete tmp dir: %s" % self.temporary_directory) + except Exception: + warnings.warn("Could not delete tmp dir: %s" % self.temporary_directory) + + +class Backend(object): + """Utility class to load and save all objects to be persisted. + These are: + * start time of auto-pytorch + * true targets of the ensemble + """ + + def __init__(self, context: BackendContext): + self._logger = None # type: Optional[PicklableClientLogger] + self.context = context + + # Track the number of configurations launched + # num_run == 1 means a dummy estimator run + self.active_num_run = 1 + + # Create the temporary directory if it does not yet exist + try: + os.makedirs(self.temporary_directory) + except Exception: + pass + # This does not have to exist or be specified + if self.output_directory is not None: + if not os.path.exists(self.output_directory): + raise ValueError("Output directory %s does not exist." % self.output_directory) + + self.internals_directory = os.path.join(self.temporary_directory, ".autoPyTorch") + self._make_internals_directory() + + def setup_logger(self, name: str, port: int) -> None: + self._logger = get_named_client_logger( + name=name, + port=port, + ) + self.context._logger = self._logger + return + + @property + def output_directory(self) -> Optional[str]: + return self.context.output_directory + + @property + def temporary_directory(self) -> str: + return self.context.temporary_directory + + def _make_internals_directory(self) -> None: + # TODO: make exist_ok a function argument + try: + os.makedirs(self.internals_directory, exist_ok=True) + except Exception as e: + if self._logger is not None: + self._logger.debug("_make_internals_directory: %s" % e) + try: + os.makedirs(self.get_runs_directory(), exist_ok=True) + except Exception as e: + if self._logger is not None: + self._logger.debug("_make_internals_directory: %s" % e) + + def _get_start_time_filename(self, seed: Union[str, int]) -> str: + if isinstance(seed, str): + seed = int(seed) + return os.path.join(self.internals_directory, "start_time_%d" % seed) + + def save_start_time(self, seed: str) -> str: + self._make_internals_directory() + start_time = time.time() + + filepath = self._get_start_time_filename(seed) + + if not isinstance(start_time, float): + raise ValueError("Start time must be a float, but is %s." % type(start_time)) + + if os.path.exists(filepath): + raise ValueError( + "{filepath} already exist. Different seeds should be provided for different jobs." + ) + + with tempfile.NamedTemporaryFile('w', dir=os.path.dirname(filepath), delete=False) as fh: + fh.write(str(start_time)) + tempname = fh.name + os.rename(tempname, filepath) + + return filepath + + def load_start_time(self, seed: int) -> float: + with open(self._get_start_time_filename(seed), 'r') as fh: + start_time = float(fh.read()) + return start_time + + def get_smac_output_directory(self) -> str: + return os.path.join(self.temporary_directory, 'smac3-output') + + def get_smac_output_directory_for_run(self, seed: int) -> str: + return os.path.join( + self.temporary_directory, + 'smac3-output', + 'run_%d' % seed + ) + + def _get_targets_ensemble_filename(self) -> str: + return os.path.join(self.internals_directory, + "true_targets_ensemble.npy") + + def save_targets_ensemble(self, targets: np.ndarray) -> str: + self._make_internals_directory() + if not isinstance(targets, np.ndarray): + raise ValueError('Targets must be of type np.ndarray, but is %s' % + type(targets)) + + filepath = self._get_targets_ensemble_filename() + + # Try to open the file without locking it, this will reduce the + # number of times where we erroneously keep a lock on the ensemble + # targets file although the process already was killed + try: + existing_targets = np.load(filepath, allow_pickle=True) + if existing_targets.shape[0] > targets.shape[0] or ( + existing_targets.shape == targets.shape and np.allclose(existing_targets, targets)): + return filepath + except Exception: + pass + + with lockfile.LockFile(filepath): + if os.path.exists(filepath): + with open(filepath, 'rb') as fh: + existing_targets = np.load(fh, allow_pickle=True) + if existing_targets.shape[0] > targets.shape[0] or ( + existing_targets.shape == targets.shape and np.allclose(existing_targets, targets)): + return filepath + + with tempfile.NamedTemporaryFile('wb', dir=os.path.dirname( + filepath), delete=False) as fh_w: + np.save(fh_w, targets.astype(np.float32)) + tempname = fh_w.name + + os.rename(tempname, filepath) + + return filepath + + def load_targets_ensemble(self) -> np.ndarray: + filepath = self._get_targets_ensemble_filename() + + with lockfile.LockFile(filepath): + with open(filepath, 'rb') as fh: + targets = np.load(fh, allow_pickle=True) + + return targets + + def _get_datamanager_pickle_filename(self) -> str: + return os.path.join(self.internals_directory, 'datamanager.pkl') + + def save_datamanager(self, datamanager: BaseDataset) -> str: + self._make_internals_directory() + filepath = self._get_datamanager_pickle_filename() + + with lockfile.LockFile(filepath): + if not os.path.exists(filepath): + with tempfile.NamedTemporaryFile('wb', dir=os.path.dirname( + filepath), delete=False) as fh: + pickle.dump(datamanager, fh, -1) + tempname = fh.name + os.rename(tempname, filepath) + + return filepath + + def load_datamanager(self) -> BaseDataset: + filepath = self._get_datamanager_pickle_filename() + with lockfile.LockFile(filepath): + with open(filepath, 'rb') as fh: + return pickle.load(fh) + + def replace_datamanager(self, datamanager: BaseDataset) -> None: + """ + This function is called to replace the old datamanager with a datamanager + in case it is required. + + Args: + datamanager (BaseDataset): the new datamanager to replace the old. + """ + warnings.warn("Original dataset will be overwritten with the provided dataset") + datamanager_pickle_file = self._get_datamanager_pickle_filename() + if os.path.exists(datamanager_pickle_file): + os.remove(datamanager_pickle_file) + self.save_datamanager(datamanager=datamanager) + + def get_runs_directory(self) -> str: + return os.path.join(self.internals_directory, 'runs') + + def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str: + return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget)) + + def get_next_num_run(self, peek: bool = False) -> int: + """ + Every pipeline that is fitted by the estimator is stored with an + identifier called num_run. A dummy classifier will always have a num_run + equal to 1, and all other new configurations that are explored will + have a sequentially increasing identifier. + + This method returns the next num_run a configuration should take. + + Parameters + ---------- + peek: bool + By default, the next num_rum will be returned, i.e. self.active_num_run + 1 + Yet, if this bool parameter is equal to True, the value of the current + num_run is provided, i.e, self.active_num_run. + In other words, peek allows to get the current maximum identifier + of a configuration. + + Returns + ------- + num_run: int + An unique identifier for a configuration + """ + + # If there are other num_runs, their name would be runs/__ + other_num_runs = [int(os.path.basename(run_dir).split('_')[1]) + for run_dir in glob.glob(os.path.join(self.internals_directory, 'runs', '*')) + if re.match(r"\d+_\d+_\d+", os.path.basename(run_dir))] + if len(other_num_runs) > 0: + # We track the number of runs from two forefronts: + # The physically available num_runs (which might be deleted or a crash could happen) + # From a internally kept attribute. The later should be sufficient, but we + # want to be robust against multiple backend copies on different workers + self.active_num_run = max([self.active_num_run] + other_num_runs) + + # We are interested in the next run id + if not peek: + self.active_num_run += 1 + return self.active_num_run + + def get_model_filename(self, seed: int, idx: int, budget: float) -> str: + return '%s.%s.%s.model' % (seed, idx, budget) + + def get_cv_model_filename(self, seed: int, idx: int, budget: float) -> str: + return '%s.%s.%s.cv_model' % (seed, idx, budget) + + def list_all_models(self, seed: int) -> List[str]: + runs_directory = self.get_runs_directory() + model_files = glob.glob( + os.path.join(glob.escape(runs_directory), '%d_*' % seed, '%s.*.*.model' % seed) + ) + return model_files + + def load_models_by_identifiers(self, identifiers: List[Tuple[int, int, float]] + ) -> Dict: + models = dict() + + for identifier in identifiers: + seed, idx, budget = identifier + models[identifier] = self.load_model_by_seed_and_id_and_budget( + seed, idx, budget) + + return models + + def load_model_by_seed_and_id_and_budget(self, seed: int, + idx: int, + budget: float + ) -> BasePipeline: + model_directory = self.get_numrun_directory(seed, idx, budget) + + model_file_name = '%s.%s.%s.model' % (seed, idx, budget) + model_file_path = os.path.join(model_directory, model_file_name) + with open(model_file_path, 'rb') as fh: + return pickle.load(fh) + + def load_cv_models_by_identifiers(self, identifiers: List[Tuple[int, int, float]] + ) -> Dict: + models = dict() + + for identifier in identifiers: + seed, idx, budget = identifier + models[identifier] = self.load_cv_model_by_seed_and_id_and_budget( + seed, idx, budget) + + return models + + def load_cv_model_by_seed_and_id_and_budget(self, + seed: int, + idx: int, + budget: float + ) -> BasePipeline: + model_directory = self.get_numrun_directory(seed, idx, budget) + + model_file_name = '%s.%s.%s.cv_model' % (seed, idx, budget) + model_file_path = os.path.join(model_directory, model_file_name) + with open(model_file_path, 'rb') as fh: + return pickle.load(fh) + + def save_numrun_to_dir( + self, seed: int, idx: int, budget: float, model: Optional[BasePipeline], + cv_model: Optional[BasePipeline], ensemble_predictions: Optional[np.ndarray], + valid_predictions: Optional[np.ndarray], test_predictions: Optional[np.ndarray], + ) -> None: + assert self._logger is not None + runs_directory = self.get_runs_directory() + tmpdir = tempfile.mkdtemp(dir=runs_directory) + if model is not None: + file_path = os.path.join(tmpdir, self.get_model_filename(seed, idx, budget)) + with open(file_path, 'wb') as fh: + pickle.dump(model, fh, -1) + + if cv_model is not None: + file_path = os.path.join(tmpdir, self.get_cv_model_filename(seed, idx, budget)) + with open(file_path, 'wb') as fh: + pickle.dump(cv_model, fh, -1) + + for preds, subset in ( + (ensemble_predictions, 'ensemble'), + (valid_predictions, 'valid'), + (test_predictions, 'test') + ): + if preds is not None: + file_path = os.path.join( + tmpdir, + self.get_prediction_filename(subset, seed, idx, budget) + ) + with open(file_path, 'wb') as fh: + pickle.dump(preds.astype(np.float32), fh, -1) + try: + self._logger.debug("Renaming {} to {}".format(tmpdir, + self.get_numrun_directory(seed, idx, budget))) + os.rename(tmpdir, self.get_numrun_directory(seed, idx, budget)) + except OSError: + if os.path.exists(self.get_numrun_directory(seed, idx, budget)): + os.rename(self.get_numrun_directory(seed, idx, budget), + os.path.join(runs_directory, tmpdir + '.old')) + os.rename(tmpdir, self.get_numrun_directory(seed, idx, budget)) + shutil.rmtree(os.path.join(runs_directory, tmpdir + '.old')) + + def get_ensemble_dir(self) -> str: + return os.path.join(self.internals_directory, 'ensembles') + + def load_ensemble(self, seed: int) -> Optional[AbstractEnsemble]: + ensemble_dir = self.get_ensemble_dir() + + if not os.path.exists(ensemble_dir): + if self._logger is not None: + self._logger.warning('Directory %s does not exist' % ensemble_dir) + else: + warnings.warn('Directory %s does not exist' % ensemble_dir) + return None + + if seed >= 0: + indices_files = glob.glob( + os.path.join(glob.escape(ensemble_dir), '%s.*.ensemble' % seed) + ) + indices_files.sort() + else: + indices_files = os.listdir(ensemble_dir) + indices_files = [os.path.join(ensemble_dir, f) for f in indices_files] + indices_files.sort(key=lambda f: time.ctime(os.path.getmtime(f))) + + with open(indices_files[-1], 'rb') as fh: + ensemble_members_run_numbers = pickle.load(fh) + + return ensemble_members_run_numbers + + def save_ensemble(self, ensemble: AbstractEnsemble, idx: int, seed: int) -> None: + try: + os.makedirs(self.get_ensemble_dir()) + except Exception: + pass + + filepath = os.path.join( + self.get_ensemble_dir(), + '%s.%s.ensemble' % (str(seed), str(idx).zfill(10)) + ) + with tempfile.NamedTemporaryFile('wb', dir=os.path.dirname( + filepath), delete=False) as fh: + pickle.dump(ensemble, fh) + tempname = fh.name + os.rename(tempname, filepath) + + def get_prediction_filename(self, subset: str, + automl_seed: Union[str, int], + idx: int, + budget: float + ) -> str: + return 'predictions_%s_%s_%s_%s.npy' % (subset, automl_seed, idx, budget) + + def save_predictions_as_txt(self, + predictions: np.ndarray, + subset: str, + idx: int, precision: int, + prefix: Optional[str] = None) -> None: + if not self.output_directory: + return + # Write prediction scores in prescribed format + filepath = os.path.join( + self.output_directory, + ('%s_' % prefix if prefix else '') + '%s_%s.predict' % (subset, str(idx)), + ) + + format_string = '{:.%dg} ' % precision + with tempfile.NamedTemporaryFile('w', dir=os.path.dirname( + filepath), delete=False) as output_file: + for row in predictions: + if not isinstance(row, np.ndarray) and not isinstance(row, list): + row = [row] + for val in row: + output_file.write(format_string.format(float(val))) + output_file.write('\n') + tempname = output_file.name + os.rename(tempname, filepath) + + def write_txt_file(self, filepath: str, data: str, name: str) -> None: + with lockfile.LockFile(filepath): + with tempfile.NamedTemporaryFile('w', dir=os.path.dirname( + filepath), delete=False) as fh: + fh.write(data) + tempname = fh.name + os.rename(tempname, filepath) + if self._logger is not None: + self._logger.debug('Created %s file %s' % (name, filepath)) diff --git a/autoPyTorch/utils/common.py b/autoPyTorch/utils/common.py index 76fbb5bb8..227f96e29 100644 --- a/autoPyTorch/utils/common.py +++ b/autoPyTorch/utils/common.py @@ -75,6 +75,29 @@ def __str__(self) -> str: self.hyperparameter, self.value_range, self.default_value, self.log) +<<<<<<< HEAD +======= +def replace_prefix_in_config_dict(config: Dict[str, Any], prefix: str, replace: str = "") -> Dict[str, Any]: + """ + Replace the prefix in all keys with the specified replacement string (the empty string by + default to remove the prefix from the key). The functions makes sure that the prefix is a proper config + prefix by checking if it ends with ":", if not it appends ":" to the prefix. + + :param config: config dictionary where the prefixed of the keys should be replaced + :param prefix: prefix to be replaced in each key + :param replace: the string to replace the prefix with + :return: updated config dictionary + """ + # make sure that prefix ends with the config separator ":" + if not prefix.endswith(":"): + prefix = prefix + ":" + # only replace first occurrence of the prefix + return {k.replace(prefix, replace, 1): v + for k, v in config.items() if + k.startswith(prefix)} + + +>>>>>>> Bug fixes (#249) def custom_collate_fn(batch: List) -> List[Optional[torch.Tensor]]: """ In the case of not providing a y tensor, in a diff --git a/examples/40_advanced/40_advanced/example_custom_configuration_space.py b/examples/40_advanced/40_advanced/example_custom_configuration_space.py new file mode 100644 index 000000000..b95ceeaa5 --- /dev/null +++ b/examples/40_advanced/40_advanced/example_custom_configuration_space.py @@ -0,0 +1,141 @@ +""" +====================== +Tabular Classification with Custom Configuration Space +====================== + +The following example shows how adjust the configuration space of +the search. Currently, there are two changes that can be made to the space:- +1. Adjust individual hyperparameters in the pipeline +2. Include or exclude components: + a) include: Dictionary containing components to include. Key is the node + name and Value is an Iterable of the names of the components + to include. Only these components will be present in the + search space. + b) exclude: Dictionary containing components to exclude. Key is the node + name and Value is an Iterable of the names of the components + to exclude. All except these components will be present in + the search space. +""" +import os +import tempfile as tmp +import warnings + +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + +import sklearn.datasets +import sklearn.model_selection + +from autoPyTorch.api.tabular_classification import TabularClassificationTask +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates + + +def get_search_space_updates(): + """ + Search space updates to the task can be added using HyperparameterSearchSpaceUpdates + Returns: + HyperparameterSearchSpaceUpdates + """ + updates = HyperparameterSearchSpaceUpdates() + updates.append(node_name="data_loader", + hyperparameter="batch_size", + value_range=[16, 512], + default_value=32) + updates.append(node_name="lr_scheduler", + hyperparameter="CosineAnnealingLR:T_max", + value_range=[50, 60], + default_value=55) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:dropout', + value_range=[0, 0.5], + default_value=0.2) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:multi_branch_choice', + value_range=['shake-shake'], + default_value='shake-shake') + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:shake_shake_method', + value_range=['M3'], + default_value='M3' + ) + return updates + + +if __name__ == '__main__': + + ############################################################################ + # Data Loading + # ============ + X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, + y, + random_state=1, + ) + + ############################################################################ + # Build and fit a classifier with include components + # ================================================== + api = TabularClassificationTask( + search_space_updates=get_search_space_updates(), + include_components={'network_backbone': ['ResNetBackbone'], + 'encoder': ['OneHotEncoder']} + ) + + ############################################################################ + # Search for an ensemble of machine learning algorithms + # ===================================================== + api.search( + X_train=X_train.copy(), + y_train=y_train.copy(), + X_test=X_test.copy(), + y_test=y_test.copy(), + optimize_metric='accuracy', + total_walltime_limit=300, + func_eval_time_limit_secs=50 + ) + + ############################################################################ + # Print the final ensemble performance + # ==================================== + print(api.run_history, api.trajectory) + y_pred = api.predict(X_test) + score = api.score(y_pred, y_test) + print(score) + print(api.show_models()) + + ############################################################################ + # Build and fit a classifier with exclude components + # ================================================== + api = TabularClassificationTask( + search_space_updates=get_search_space_updates(), + exclude_components={'network_backbone': ['MLPBackbone'], + 'encoder': ['OneHotEncoder']} + ) + + ############################################################################ + # Search for an ensemble of machine learning algorithms + # ===================================================== + api.search( + X_train=X_train, + y_train=y_train, + X_test=X_test.copy(), + y_test=y_test.copy(), + optimize_metric='accuracy', + total_walltime_limit=300, + func_eval_time_limit_secs=50 + ) + + ############################################################################ + # Print the final ensemble performance + # ==================================== + print(api.run_history, api.trajectory) + y_pred = api.predict(X_test) + score = api.score(y_pred, y_test) + print(score) + print(api.show_models()) diff --git a/examples/40_advanced/40_advanced/example_posthoc_ensemble_fit.py b/examples/40_advanced/40_advanced/example_posthoc_ensemble_fit.py new file mode 100644 index 000000000..b9383b2a6 --- /dev/null +++ b/examples/40_advanced/40_advanced/example_posthoc_ensemble_fit.py @@ -0,0 +1,81 @@ +""" +===================================================== +Tabular Classification with Post-Hoc Ensemble Fitting +===================================================== + +The following example shows how to fit a sample classification model +and create an ensemble post-hoc with AutoPyTorch +""" +import os +import tempfile as tmp +import warnings + +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + +import sklearn.datasets +import sklearn.model_selection + +from autoPyTorch.api.tabular_classification import TabularClassificationTask + + +if __name__ == '__main__': + + ############################################################################ + # Data Loading + # ============ + X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, + y, + random_state=42, + ) + + ############################################################################ + # Build and fit a classifier + # ========================== + api = TabularClassificationTask( + ensemble_size=0, + seed=42, + ) + + ############################################################################ + # Search for the best neural network + # ================================== + api.search( + X_train=X_train, + y_train=y_train, + X_test=X_test.copy(), + y_test=y_test.copy(), + optimize_metric='accuracy', + total_walltime_limit=250, + func_eval_time_limit_secs=50 + ) + + ############################################################################ + # Print the final performance of the incumbent neural network + # =========================================================== + print(api.run_history, api.trajectory) + y_pred = api.predict(X_test) + score = api.score(y_pred, y_test) + print(score) + + ############################################################################ + # Fit an ensemble with the neural networks fitted during the search + # ================================================================= + + api.fit_ensemble(ensemble_size=5, + # Set the enable_traditional_pipeline=True + # to also include traditional models + # in the ensemble + enable_traditional_pipeline=False) + # Print the final ensemble built by AutoPyTorch + y_pred = api.predict(X_test) + score = api.score(y_pred, y_test) + print(score) + print(api.show_models()) diff --git a/requirements.txt b/requirements.txt index 6f81bfcb7..4cf518ad0 100755 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,11 @@ imgaug>=0.4.0 ConfigSpace>=0.4.14,<0.5 pynisher>=0.6.3 pyrfr>=0.7,<0.9 +<<<<<<< HEAD smac==0.14.0 +======= +smac +>>>>>>> Bug fixes (#249) dask distributed>=2.2.0 catboost diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 7d4f47b09..4447d09f2 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -286,9 +286,9 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest): if isinstance(input_data_featuretest, pd.DataFrame): pytest.skip("Column order change in pandas is not supported") elif isinstance(input_data_featuretest, np.ndarray): - complementary_type = validator.numpy_array_to_pandas(input_data_featuretest) + complementary_type = validator.numpy_to_pandas(input_data_featuretest) elif isinstance(input_data_featuretest, list): - complementary_type, _ = validator.list_to_dataframe(input_data_featuretest) + complementary_type, _ = validator.list_to_pandas(input_data_featuretest) elif sparse.issparse(input_data_featuretest): complementary_type = sparse.csr_matrix(input_data_featuretest.todense()) else: @@ -314,10 +314,128 @@ def test_featurevalidator_get_columns_to_encode(): for col in df.columns: df[col] = df[col].astype(col) +<<<<<<< HEAD transformed_columns, feature_types = validator._get_columns_to_encode(df) assert transformed_columns == ['category', 'bool'] assert feature_types == ['numerical', 'numerical', 'categorical', 'categorical'] +======= + validator.fit(df) + + categorical_columns, numerical_columns, feat_type = validator._get_columns_info(df) + + assert numerical_columns == ['int', 'float'] + assert categorical_columns == ['category', 'bool'] + assert feat_type == ['numerical', 'numerical', 'categorical', 'categorical'] + + +def feature_validator_remove_nan_catcolumns(df_train: pd.DataFrame, df_test: pd.DataFrame, + ans_train: np.ndarray, ans_test: np.ndarray) -> None: + validator = TabularFeatureValidator() + validator.fit(df_train) + transformed_df_train = validator.transform(df_train) + transformed_df_test = validator.transform(df_test) + + assert np.array_equal(transformed_df_train, ans_train) + assert np.array_equal(transformed_df_test, ans_test) + + +def test_feature_validator_remove_nan_catcolumns(): + """ + Make sure categorical columns that have only nan values are removed. + Transform performs the folloing: + * simple imputation for both + * scaling for numerical + * one-hot encoding for categorical + For example, + data = [ + {'A': 1, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'B': 3, 'C': np.nan}, + {'A': 2, 'B': np.nan, 'C': np.nan} + ] + and suppose all the columns are categorical, + then + * `A` in {np.nan, 1, 2} + * `B` in {np.nan, 3} + * `C` in {np.nan} <=== it will be dropped. + + So in the column A, + * np.nan ==> [1, 0, 0] + * 1 ==> [0, 1, 0] + * 2 ==> [0, 0, 1] + in the column B, + * np.nan ==> [1, 0] + * 3 ==> [0, 1] + Therefore, by concatenating, + * {'A': 1, 'B': np.nan, 'C': np.nan} ==> [0, 1, 0, 1, 0] + * {'A': np.nan, 'B': 3, 'C': np.nan} ==> [1, 0, 0, 0, 1] + * {'A': 2, 'B': np.nan, 'C': np.nan} ==> [0, 0, 1, 1, 0] + """ + # First case, there exist null columns (B and C) in the train set + # and a same column (C) are not all null for the test set. + + df_train = pd.DataFrame( + [ + {'A': 1, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'C': np.nan}, + {'A': 1} + ], + dtype='category', + ) + ans_train = np.array([[0, 1], [1, 0], [0, 1]], dtype=np.float64) + df_test = pd.DataFrame( + [ + {'A': np.nan, 'B': np.nan, 'C': 5}, + {'A': np.nan, 'C': np.nan}, + {'A': 1} + ], + dtype='category', + ) + ans_test = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float64) + feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) + + # Second case, there exist null columns (B and C) in the training set and + # the same columns (B and C) are null in the test set. + df_train = pd.DataFrame( + [ + {'A': 1, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'C': np.nan}, + {'A': 1} + ], + dtype='category', + ) + ans_train = np.array([[0, 1], [1, 0], [0, 1]], dtype=np.float64) + df_test = pd.DataFrame( + [ + {'A': np.nan, 'B': np.nan, 'C': np.nan}, + {'A': np.nan, 'C': np.nan}, + {'A': 1} + ], + dtype='category', + ) + ans_test = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float64) + feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) + + # Third case, there exist no null columns in the training set and + # null columns exist in the test set. + df_train = pd.DataFrame( + [ + {'A': 1, 'B': 1}, + {'A': 2, 'B': 2} + ], + dtype='category', + ) + ans_train = np.array([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=np.float64) + df_test = pd.DataFrame( + [ + {'A': np.nan, 'B': np.nan}, + {'A': np.nan, 'B': np.nan} + ], + dtype='category', + ) + ans_test = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float64) + feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) +>>>>>>> Bug fixes (#249) def test_features_unsupported_calls_are_raised(): @@ -582,15 +700,16 @@ def test_feature_validator_imbalanced_data(): validator.fit(X_train) train_feature_types = copy.deepcopy(validator.feat_type) - assert train_feature_types == ['numerical', 'numerical', 'numerical', 'numerical'] + assert train_feature_types == ['numerical'] # validator will throw an error if the column types are not the same transformed_X_test = validator.transform(X_test) transformed_X_test = pd.DataFrame(transformed_X_test) - null_columns = [] - for column in transformed_X_test.columns: - if transformed_X_test[column].isna().all(): - null_columns.append(column) - assert null_columns == [0, 2, 3] + assert sorted(validator.all_nan_columns) == sorted(['A', 'C', 'D']) + # as there are no categorical columns, we can make such an + # assertion. We only expect to drop the all nan columns + total_all_nan_columns = len(validator.all_nan_columns) + total_columns = len(validator.column_order) + assert total_columns - total_all_nan_columns == len(transformed_X_test.columns) # Columns with not all null values in the train split and # completely null on the test split. @@ -609,14 +728,33 @@ def test_feature_validator_imbalanced_data(): X_test = pd.DataFrame.from_dict(test_features) validator = TabularFeatureValidator() validator.fit(X_train) + train_feature_types = copy.deepcopy(validator.feat_type) assert train_feature_types == ['categorical', 'numerical', 'numerical'] transformed_X_test = validator.transform(X_test) transformed_X_test = pd.DataFrame(transformed_X_test) - null_columns = [] - for column in transformed_X_test.columns: - if transformed_X_test[column].isna().all(): - null_columns.append(column) + assert not len(validator.all_nan_columns) + + +def test_comparator(): + numerical = 'numerical' + categorical = 'categorical' + + validator = TabularFeatureValidator + + feat_type = [numerical, categorical] * 10 + ans = [categorical] * 10 + [numerical] * 10 + feat_type = sorted( + feat_type, + key=functools.cmp_to_key(validator._comparator) + ) + assert ans == feat_type - assert null_columns == [1] + feat_type = [numerical] * 10 + [categorical] * 10 + ans = [categorical] * 10 + [numerical] * 10 + feat_type = sorted( + feat_type, + key=functools.cmp_to_key(validator._comparator) + ) + assert ans == feat_type