Skip to content

Commit

Permalink
[test] [fix] Rebase, modify tests, and increase the coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Feb 23, 2022
1 parent 2eea80f commit e48fd14
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 466 deletions.
244 changes: 107 additions & 137 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,42 +167,82 @@ class FixedPipelineParams(NamedTuple):
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
An object used to fine tune the hyperparameter search space of the pipeline
"""
def __init__(self, backend: Backend,
queue: Queue,
metric: autoPyTorchMetric,
budget: float,
configuration: Union[int, str, Configuration],
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
seed: int = 1,
output_y_hat_optimization: bool = True,
num_run: Optional[int] = None,
include: Optional[Dict[str, Any]] = None,
exclude: Optional[Dict[str, Any]] = None,
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
all_supported_metrics: bool = True,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> None:

self.starttime = time.time()

self.configuration = configuration
self.backend: Backend = backend
self.queue = queue

self.include = include
self.exclude = exclude
self.search_space_updates = search_space_updates

self.metric = metric


self._init_datamanager_info()

# Flag to save target for ensemble
self.output_y_hat_optimization = output_y_hat_optimization
backend: Backend
seed: int
metric: autoPyTorchMetric
budget_type: str # Literal['epochs', 'runtime']
pipeline_config: Dict[str, Any]
save_y_opt: bool = True
include: Optional[Dict[str, Any]] = None
exclude: Optional[Dict[str, Any]] = None
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None
logger_port: Optional[int] = None
all_supported_metrics: bool = True
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None

@classmethod
def with_default_pipeline_config(
cls,
pipeline_config: Optional[Dict[str, Any]] = None,
choice: str = 'default',
**kwargs: Any
) -> 'FixedPipelineParams':

if 'budget_type' in kwargs:
raise TypeError(
f'{cls.__name__}.with_default_pipeline_config() got multiple values for argument `budget_type`'
)

budget_type_choices = ('epochs', 'runtime')
if pipeline_config is None:
pipeline_config = get_default_pipeline_config(choice=choice)
if 'budget_type' not in pipeline_config:
raise ValueError('pipeline_config must have `budget_type`')

budget_type = pipeline_config['budget_type']
if pipeline_config['budget_type'] not in budget_type_choices:
raise ValueError(f"budget_type must be in {budget_type_choices}, but got {budget_type}")

kwargs.update(pipeline_config=pipeline_config, budget_type=budget_type)
return cls(**kwargs)


class EvaluatorParams(NamedTuple):
"""
Attributes:
configuration (Union[int, str, Configuration]):
Determines the pipeline to be constructed. A dummy estimator is created for
integer configurations, a traditional machine learning pipeline is created
for string based configuration, and NAS is performed when a configuration
object is passed.
num_run (Optional[int]):
An identifier of the current configuration being fit. This number is unique per
configuration.
init_params (Optional[Dict[str, Any]]):
Optional argument that is passed to each pipeline step. It is the equivalent of
kwargs for the pipeline steps.
"""
budget: float
configuration: Union[int, str, Configuration]
num_run: Optional[int] = None
init_params: Optional[Dict[str, Any]] = None

@classmethod
def with_default_budget(
cls,
budget: float = 0,
choice: str = 'default',
**kwargs: Any
) -> 'EvaluatorParams':
budget = get_default_budget(choice=choice) if budget == 0 else budget
kwargs.update(budget=budget)
return cls(**kwargs)


class AbstractEvaluator(object):
"""
This method defines the interface that pipeline evaluators should follow, when
interacting with SMAC through TargetAlgorithmQuery.
An evaluator is an object that:
+ constructs a pipeline (i.e. a classification or regression estimator) for a given
Expand Down Expand Up @@ -244,21 +284,33 @@ def _init_miscellaneous(self) -> None:
DisableFileOutputParameters.check_compatibility(disable_file_output)
self.disable_file_output = disable_file_output
else:
if isinstance(self.configuration, int):
self.pipeline_class = DummyClassificationPipeline
elif isinstance(self.configuration, str):
if self.task_type in TABULAR_TASKS:
self.pipeline_class = MyTraditionalTabularClassificationPipeline
else:
raise ValueError("Only tabular tasks are currently supported with traditional methods")
elif isinstance(self.configuration, Configuration):
if self.task_type in TABULAR_TASKS:
self.pipeline_class = autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline
elif self.task_type in IMAGE_TASKS:
self.pipeline_class = autoPyTorch.pipeline.image_classification.ImageClassificationPipeline
else:
raise ValueError('task {} not available'.format(self.task_type))
self.predict_function = self._predict_proba
self.disable_file_output = []

if self.num_folds == 1: # not save cv model when we perform holdout
self.disable_file_output.append('cv_model')

def _init_dataset_properties(self) -> None:
datamanager: BaseDataset = self.fixed_pipeline_params.backend.load_datamanager()
if datamanager.task_type is None:
raise ValueError(f"Expected dataset {datamanager.__class__.__name__} to have task_type got None")
if datamanager.splits is None:
raise ValueError(f"cannot fit pipeline {self.__class__.__name__} with datamanager.splits None")

self.splits = datamanager.splits
self.num_folds: int = len(self.splits)
# Since cv might not finish in time, we take self.pipelines as None by default
self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
self.num_classes = getattr(datamanager, 'num_classes', 1)
self.output_type = datamanager.output_type

search_space_updates = self.fixed_pipeline_params.search_space_updates
self.dataset_properties = datamanager.get_dataset_properties(
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
include=self.fixed_pipeline_params.include,
exclude=self.fixed_pipeline_params.exclude,
search_space_updates=search_space_updates
))

self.X_train, self.y_train = datamanager.train_tensors
self.unique_train_labels = [
Expand All @@ -271,6 +323,8 @@ def _init_miscellaneous(self) -> None:
if datamanager.test_tensors is not None:
self.X_test, self.y_test = datamanager.test_tensors

del datamanager # Delete datamanager to release the memory

def _init_additional_metrics(self) -> None:
all_supported_metrics = self.fixed_pipeline_params.all_supported_metrics
metric = self.fixed_pipeline_params.metric
Expand All @@ -282,59 +336,7 @@ def _init_additional_metrics(self) -> None:
all_supported_metrics=all_supported_metrics)
self.metrics_dict = {'additional_metrics': [m.name for m in [metric] + self.additional_metrics]}

def _init_datamanager_info(
self,
) -> None:
"""
Initialises instance attributes that come from the datamanager.
For example,
X_train, y_train, etc.
"""

datamanager: BaseDataset = self.backend.load_datamanager()

assert datamanager.task_type is not None, \
"Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__)
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type]
self.issparse = datamanager.issparse

self.X_train, self.y_train = datamanager.train_tensors

if datamanager.val_tensors is not None:
self.X_valid, self.y_valid = datamanager.val_tensors
else:
self.X_valid, self.y_valid = None, None

if datamanager.test_tensors is not None:
self.X_test, self.y_test = datamanager.test_tensors
else:
self.X_test, self.y_test = None, None

self.resampling_strategy = datamanager.resampling_strategy

self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None)

self.dataset_properties = datamanager.get_dataset_properties(
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
include=self.include,
exclude=self.exclude,
search_space_updates=self.search_space_updates
))
self.splits = datamanager.splits
if self.splits is None:
raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called "
f"before the instantiation of {self.__class__.__name__}")

# delete datamanager from memory
del datamanager

def _init_fit_dictionary(
self,
logger_port: int,
pipeline_config: Dict[str, Any],
metrics_dict: Optional[Dict[str, List[str]]] = None,
) -> None:
def _init_fit_dictionary(self) -> None:
"""
Initialises the fit dictionary
Expand Down Expand Up @@ -617,36 +619,4 @@ def _is_output_possible(
if y is not None and not np.all(np.isfinite(y)):
return False # Model predictions contains NaNs

Args:
prediction (np.ndarray):
The un-formatted predictions of a pipeline
Y_train (np.ndarray):
The labels from the dataset to give an intuition of the expected
predictions dimensionality
Returns:
(np.ndarray):
The formatted prediction
"""
assert self.num_classes is not None, "Called function on wrong task"

if self.output_type == MULTICLASS and \
prediction.shape[1] < self.num_classes:
if Y_train is None:
raise ValueError('Y_train must not be None!')
classes = list(np.unique(Y_train))

mapping = dict()
for class_number in range(self.num_classes):
if class_number in classes:
index = classes.index(class_number)
mapping[index] = class_number
new_predictions = np.zeros((prediction.shape[0], self.num_classes),
dtype=np.float32)

for index in mapping:
class_index = mapping[index]
new_predictions[:, class_index] = prediction[:, index]

return new_predictions

return prediction
return True
15 changes: 8 additions & 7 deletions autoPyTorch/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@ def __init__(self, queue: Queue, fixed_pipeline_params: FixedPipelineParams, eva
self.train = not isinstance(resampling_strategy, NoResamplingStrategyTypes)
self.cross_validation = isinstance(resampling_strategy, CrossValTypes)

if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
raise ValueError(
f'resampling_strategy for TrainEvaluator must be in '
f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}'
)
if not self.train and fixed_pipeline_params.save_y_opt:
# TODO: Add the test to cover here
# No resampling can not be used for building ensembles. save_y_opt=False ensures it
fixed_pipeline_params = fixed_pipeline_params._replace(save_y_opt=False)

super().__init__(queue=queue, fixed_pipeline_params=fixed_pipeline_params, evaluator_params=evaluator_params)

self.num_folds: int = len(self.splits)
self.logger.debug("Search space updates :{}".format(self.search_space_updates))
if self.train:
self.logger.debug("Search space updates :{}".format(self.fixed_pipeline_params.search_space_updates))

def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
"""
Expand Down
Loading

0 comments on commit e48fd14

Please sign in to comment.