-
Notifications
You must be signed in to change notification settings - Fork 27
/
tabzilla_preprocessor_utils.py
75 lines (60 loc) · 2.67 KB
/
tabzilla_preprocessor_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import functools
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold
from tabzilla_datasets import TabularDataset
cv_n_folds = 10 # Number of folds to use for splitting
def dataset_preprocessor(
preprocessor_dict,
dataset_name,
target_encode=None,
cat_feature_encode=True,
generate_split=True,
):
"""
Adds the function to the dictionary of pre-processors, which can then be called as preprocessor_dict[dataset_name]()
Args:
dataset_name: Name of the dataset
"""
def dataset_preprocessor_decorator(func):
@functools.wraps(func)
def wrapper_preprocessor(*args, **kwargs):
dataset_kwargs = func(*args, **kwargs)
if generate_split:
dataset_kwargs["split_indeces"] = split_dataset(dataset_kwargs)
dataset_kwargs["split_source"] = "random_init"
dataset = TabularDataset(dataset_name, **dataset_kwargs)
# Infer target_encode based on target type
is_regression = dataset.target_type == "regression"
if (target_encode is None and not is_regression) or target_encode:
dataset.target_encode()
if cat_feature_encode:
dataset.cat_feature_encode()
return dataset
if dataset_name in preprocessor_dict:
raise RuntimeError(f"Duplicate dataset names not allowed: {dataset_name}")
preprocessor_dict[dataset_name] = wrapper_preprocessor
return wrapper_preprocessor
return dataset_preprocessor_decorator
def split_dataset(dataset_kwargs, num_splits=cv_n_folds, shuffle=True, seed=0):
target_type = dataset_kwargs["target_type"]
if target_type == "regression":
kf = KFold(n_splits=num_splits, shuffle=shuffle, random_state=seed)
elif target_type == "classification" or target_type == "binary":
kf = StratifiedKFold(n_splits=num_splits, shuffle=shuffle, random_state=seed)
else:
raise NotImplementedError("Objective" + target_type + "is not yet implemented.")
splits = kf.split(dataset_kwargs["X"], dataset_kwargs["y"])
split_indeces = []
for train_indices, test_indices in splits:
split_indeces.append({"train": train_indices, "test": test_indices, "val": []})
# Build validation set by using n+1 th test set.
for split_idx in range(cv_n_folds):
split_indeces[split_idx]["val"] = split_indeces[(split_idx + 1) % cv_n_folds][
"test"
].copy()
split_indeces[split_idx]["train"] = np.setdiff1d(
split_indeces[split_idx]["train"],
split_indeces[split_idx]["val"],
assume_unique=True,
)
return split_indeces