Skip to content

Commit c2b1ae4

Browse files
authored
Merge branch 'main' into contributors-readme-action-Qcb7xawFeP
2 parents b96d293 + cf1454a commit c2b1ae4

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65+
category_cols = X_encoded.select_dtypes(include="category").columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
6567
for col, mapping in self._mapping.items():
6668
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6769

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
305+
self.train[category_cols] = self.train[category_cols].astype("object")
304306
categorical_cardinality = [
305307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
306308
]
307309
else:
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
308312
categorical_cardinality = [
309313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
310314
]

0 commit comments

Comments
 (0)