Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reg cocktails pytorch embedding roc #490

Open
wants to merge 23 commits into
base: reg_cocktails
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0bd7d5f
update eta for experiments
ravinkohli Mar 10, 2022
76dae54
add check if True is in value range
ravinkohli Mar 11, 2022
d26c611
Reg cocktails common paper modifications 2 (#417)
ravinkohli Mar 14, 2022
bc60e31
have working embedding from pytroch
ravinkohli Mar 23, 2022
13fad76
divide columns to encode and embed based on threshold
ravinkohli Mar 31, 2022
cf4fd98
cleanup unwanted changes
ravinkohli Mar 31, 2022
af41dd7
use shape after preprocessing in base network backbone
ravinkohli Mar 31, 2022
9706875
remove redundant call to load datamanager
ravinkohli Apr 5, 2022
def144c
add init file for column splitting
ravinkohli Apr 11, 2022
926a757
fix tests
ravinkohli Jun 14, 2022
7567d26
fix precommit and add test changes
ravinkohli Jun 14, 2022
09fdc0d
[ADD] Calculate memory of dataset after one hot encoding (pytorch emb…
ravinkohli Jul 16, 2022
3aef02e
suggestions from review
ravinkohli Jul 18, 2022
8e3dbef
add preprocessed_dtype to determine double or float
ravinkohli Aug 9, 2022
52427bc
test fix in progress
ravinkohli Aug 16, 2022
90512ee
TODO: fix errors after rebase
ravinkohli Aug 17, 2022
895b904
Reg cocktails apt1.0+reg cocktails pytorch embedding reduced (#454)
ravinkohli Aug 17, 2022
033bca7
fix embeddings after rebase
ravinkohli Aug 17, 2022
d4cd8b4
fix error with pytorch embeddings
ravinkohli Aug 18, 2022
a5807cb
fix redundant code
ravinkohli Aug 18, 2022
960e1ef
change userdefined to False
ravinkohli Aug 18, 2022
1be80d5
remove using categorical columns
ravinkohli Aug 19, 2022
a616ecb
Add fix for ROC
ravinkohli Feb 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
self.input_validator: Optional[BaseInputValidator] = None

self.search_space_updates = search_space_updates

if search_space_updates is not None:
if not isinstance(self.search_space_updates,
HyperparameterSearchSpaceUpdates):
Expand Down
2 changes: 2 additions & 0 deletions autoPyTorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@

# To avoid that we get a sequence that is too long to be fed to a network
MAX_WINDOW_SIZE_BASE = 500

MIN_CATEGORIES_FOR_EMBEDDING_MAX = 7
2 changes: 1 addition & 1 deletion autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __init__(

# Required for dataset properties
self.num_features: Optional[int] = None
self.categories: List[List[int]] = []
self.categorical_columns: List[int] = []
self.numerical_columns: List[int] = []
self.encode_columns: List[str] = []

self.num_categories_per_col: Optional[List[int]] = []
self.all_nan_columns: Optional[Set[Union[int, str]]] = None

self._is_fitted = False
Expand Down
14 changes: 5 additions & 9 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ class TabularFeatureValidator(BaseFeatureValidator):
transformer.

Attributes:
categories (List[List[str]]):
List for which an element at each index is a
list containing the categories for the respective
categorical column.
num_categories_per_col (List[int]):
List for which an element at each index is the number
of categories for the respective categorical column.
transformed_columns (List[str])
List of columns that were transformed.
column_transformer (Optional[BaseEstimator])
Expand Down Expand Up @@ -202,10 +201,8 @@ def _fit(
encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
list(range(len(cat)))
for cat in encoded_categories
]

self.num_categories_per_col = [len(cat) for cat in encoded_categories]

# differently to categorical_columns and numerical_columns,
# this saves the index of the column.
Expand Down Expand Up @@ -283,7 +280,6 @@ def transform(
X = self.numpy_to_pandas(X)

if ispandas(X) and not issparse(X):

if self.all_nan_columns is None:
raise ValueError('_fit must be called before calling transform')

Expand Down
2 changes: 2 additions & 0 deletions autoPyTorch/data/tabular_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _compress_dataset(
y=y,
is_classification=self.is_classification,
random_state=self.seed,
categorical_columns=self.feature_validator.categorical_columns,
n_categories_per_cat_column=self.feature_validator.num_categories_per_col,
**self.dataset_compression # type: ignore [arg-type]
)
self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype
Expand Down
53 changes: 46 additions & 7 deletions autoPyTorch/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.utils import _approximate_mode, check_random_state
from sklearn.utils.validation import _num_samples, check_array

from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX
from autoPyTorch.data.base_target_validator import SupportedTargetTypes
from autoPyTorch.utils.common import ispandas

Expand Down Expand Up @@ -459,8 +460,8 @@ def _subsample_by_indices(
return X, y


def megabytes(arr: DatasetCompressionInputType) -> float:

def get_raw_memory_usage(arr: DatasetCompressionInputType) -> float:
memory_in_bytes: float
if isinstance(arr, np.ndarray):
memory_in_bytes = arr.nbytes
elif issparse(arr):
Expand All @@ -470,19 +471,57 @@ def megabytes(arr: DatasetCompressionInputType) -> float:
else:
raise ValueError(f"Unrecognised data type of X, expected data type to "
f"be in (np.ndarray, spmatrix, pd.DataFrame) but got :{type(arr)}")
return memory_in_bytes


def get_approximate_mem_usage_in_mb(
arr: DatasetCompressionInputType,
categorical_columns: List,
n_categories_per_cat_column: Optional[List[int]] = None
) -> float:

err_msg = "Value number of categories per categorical is required when the data has categorical columns"
if ispandas(arr):
arr_dtypes = arr.dtypes.to_dict()
multipliers = [dtype.itemsize for col, dtype in arr_dtypes.items() if col not in categorical_columns]
if len(categorical_columns) > 0:
if n_categories_per_cat_column is None:
raise ValueError(err_msg)
for col, num_cat in zip(categorical_columns, n_categories_per_cat_column):
if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX:
multipliers.append(num_cat * arr_dtypes[col].itemsize)
else:
multipliers.append(arr_dtypes[col].itemsize)
size_one_row = sum(multipliers)

elif isinstance(arr, (np.ndarray, spmatrix)):
n_cols = arr.shape[-1] - len(categorical_columns)
multiplier = arr.dtype.itemsize
if len(categorical_columns) > 0:
if n_categories_per_cat_column is None:
raise ValueError(err_msg)
# multiply num categories with the size of the column to capture memory after one hot encoding
n_cols += sum(num_cat if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX else 1 for num_cat in n_categories_per_cat_column)
size_one_row = n_cols * multiplier
else:
raise ValueError(f"Unrecognised data type of X, expected data type to "
f"be in (np.ndarray, spmatrix, pd.DataFrame), but got :{type(arr)}")

return float(memory_in_bytes / (2**20))
return float(arr.shape[0] * size_one_row / (2**20))


def reduce_dataset_size_if_too_large(
X: DatasetCompressionInputType,
memory_allocation: Union[int, float],
is_classification: bool,
random_state: Union[int, np.random.RandomState],
categorical_columns: List,
n_categories_per_cat_column: Optional[List[int]] = None,
y: Optional[SupportedTargetTypes] = None,
methods: List[str] = ['precision', 'subsample'],
) -> DatasetCompressionInputType:
f""" Reduces the size of the dataset if it's too close to the memory limit.
f"""
Reduces the size of the dataset if it's too close to the memory limit.

Follows the order of the operations passed in and retains the type of its
input.
Expand Down Expand Up @@ -513,7 +552,6 @@ def reduce_dataset_size_if_too_large(
Reduce the amount of samples of the dataset such that it fits into the allocated
memory. Ensures stratification and that unique labels are present


memory_allocation (Union[int, float]):
The amount of memory to allocate to the dataset. It should specify an
absolute amount.
Expand All @@ -524,7 +562,7 @@ def reduce_dataset_size_if_too_large(
"""

for method in methods:
if megabytes(X) <= memory_allocation:
if get_approximate_mem_usage_in_mb(X, categorical_columns, n_categories_per_cat_column) <= memory_allocation:
break

if method == 'precision':
Expand All @@ -540,7 +578,8 @@ def reduce_dataset_size_if_too_large(
# into the allocated memory, we subsample it so that it does

n_samples_before = X.shape[0]
sample_percentage = memory_allocation / megabytes(X)
sample_percentage = memory_allocation / get_approximate_mem_usage_in_mb(
X, categorical_columns, n_categories_per_cat_column)

# NOTE: type ignore
#
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self,
self.categorical_columns = validator.feature_validator.categorical_columns
self.numerical_columns = validator.feature_validator.numerical_columns
self.num_features = validator.feature_validator.num_features
self.categories = validator.feature_validator.categories
self.num_categories_per_col = validator.feature_validator.num_categories_per_col

super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle,
resampling_strategy=resampling_strategy,
Expand Down
4 changes: 2 additions & 2 deletions autoPyTorch/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(self,
self.num_features: int = self.validator.feature_validator.num_features # type: ignore[assignment]
self.num_targets: int = self.validator.target_validator.out_dimensionality # type: ignore[assignment]

self.categories = self.validator.feature_validator.categories
self.num_categories_per_col = self.validator.feature_validator.num_categories_per_col

self.feature_shapes = self.validator.feature_shapes
self.feature_names = tuple(self.validator.feature_names)
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def get_required_dataset_info(self) -> Dict[str, Any]:
'categorical_features': self.categorical_features,
'numerical_columns': self.numerical_columns,
'categorical_columns': self.categorical_columns,
'categories': self.categories,
'num_categories_per_col': self.num_categories_per_col,
})
return info

Expand Down
27 changes: 27 additions & 0 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
from multiprocessing.queues import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -20,7 +22,9 @@
fit_and_suppress_warnings
)
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.utils.common import dict_repr, subsampler
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates

Expand Down Expand Up @@ -193,6 +197,8 @@ def fit_predict_and_loss(self) -> None:
additional_run_info = pipeline.get_additional_run_info() if hasattr(
pipeline, 'get_additional_run_info') else {}

# self._write_run_summary(pipeline)

status = StatusType.SUCCESS

self.logger.debug("In train evaluator.fit_predict_and_loss, num_run: {} loss:{},"
Expand Down Expand Up @@ -348,6 +354,27 @@ def fit_predict_and_loss(self) -> None:
status=status,
)

def _write_run_summary(self, pipeline: BasePipeline) -> None:
# add learning curve of configurations to additional_run_info
if isinstance(pipeline, TabularClassificationPipeline):
assert isinstance(self.configuration, Configuration)
if hasattr(pipeline.named_steps['trainer'], 'run_summary'):
run_summary = pipeline.named_steps['trainer'].run_summary
split_types = ['train', 'val', 'test']
run_summary_dict = dict(
run_summary={},
budget=self.budget,
seed=self.seed,
config_id=self.configuration.config_id,
num_run=self.num_run)
for split_type in split_types:
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(
f'{split_type}_loss', None)
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(
f'{split_type}_metrics', None)
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file:
file.write(f"{json.dumps(run_summary_dict)}\n")

def _fit_and_predict(self, pipeline: BaseEstimator, fold: int, train_indices: Union[np.ndarray, List],
test_indices: Union[np.ndarray, List],
add_pipeline_to_self: bool
Expand Down
53 changes: 14 additions & 39 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def _get_hyperparameter_search_space(self,
def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpace:
"""
Add forbidden conditions to ensure valid configurations.
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder
and CyclicLR is disabled when using stochastic weight averaging and snapshot
Currently, CyclicLR is disabled when using stochastic weight averaging and snapshot
ensembling.

Args:
Expand All @@ -314,33 +313,6 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac

"""

# Learned Entity Embedding is only valid when encoder is one hot encoder
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
if 'LearnedEntityEmbedding' in embeddings:
encoders = cs.get_hyperparameter('encoder:__choice__').choices
possible_default_embeddings = copy(list(embeddings))
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]

for encoder in encoders:
if encoder == 'OneHotEncoder':
continue
while True:
try:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
))
break
except ValueError:
# change the default and try again
try:
default = possible_default_embeddings.pop()
except IndexError:
raise ValueError("Cannot find a legal default configuration")
cs.get_hyperparameter('network_embedding:__choice__').default_value = default

# Disable CyclicLR until todo is completed.
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
trainers = cs.get_hyperparameter('trainer:__choice__').choices
Expand All @@ -350,16 +322,19 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac
cyclic_lr_name = 'CyclicLR'
if cyclic_lr_name in available_schedulers:
# disable snapshot ensembles and stochastic weight averaging
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
f'trainer:{trainer}:use_snapshot_ensemble'), True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(cs.get_hyperparameter(
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble')
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and \
True in snapshot_ensemble_hyperparameter.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
swa_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_stochastic_weight_averaging')
if hasattr(swa_hyperparameter, 'choices') and True in swa_hyperparameter.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(swa_hyperparameter, True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
return cs

def __repr__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
self.preprocessor: Optional[ColumnTransformer] = None
self.add_fit_requirements([
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)])
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('encode_columns', (List,), user_defined=False, dataset_property=False),
FitRequirement('embed_columns', (List,), user_defined=False, dataset_property=False)])


def get_column_transformer(self) -> ColumnTransformer:
"""
Expand Down Expand Up @@ -52,17 +55,31 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
self.check_requirements(X, y)

preprocessors = get_tabular_preprocessers(X)

column_transformers: List[Tuple[str, BaseEstimator, List[int]]] = []

numerical_pipeline = 'passthrough'
encode_pipeline = 'passthrough'

if len(preprocessors['numerical']) > 0:
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
column_transformers.append(
('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns'])
)
if len(preprocessors['categorical']) > 0:
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
column_transformers.append(
('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns'])
)

column_transformers.append(
('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns'])
)

if len(preprocessors['encode']) > 0:
encode_pipeline = make_pipeline(*preprocessors['encode'])

column_transformers.append(
('encode_pipeline', encode_pipeline, X['encode_columns'])
)

# if len(preprocessors['categorical']) > 0:
# categorical_pipeline = make_pipeline(*preprocessors['categorical'])
# column_transformers.append(
# ('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns'])
# )

# in case the preprocessing steps are disabled
# i.e, NoEncoder for categorical, we want to
Expand Down
Loading