Skip to content

Commit 136953e

Browse files
authored
Merge branch 'main' into categorical_bug_fix
2 parents 0610bd3 + caa3ea1 commit 136953e

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100+
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101+
99102
"""
100103

101104
target: Optional[List[str]] = field(
@@ -176,6 +179,11 @@ class DataConfig:
176179
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
177180
)
178181

182+
dataloader_kwargs: Dict[str, Any] = field(
183+
default_factory=dict,
184+
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
185+
)
186+
179187
def __post_init__(self):
180188
assert (
181189
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
809809
num_workers=self.config.num_workers,
810810
sampler=self.train_sampler,
811811
pin_memory=self.config.pin_memory,
812+
**self.config.dataloader_kwargs,
812813
)
813814

814815
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -827,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
827828
shuffle=False,
828829
num_workers=self.config.num_workers,
829830
pin_memory=self.config.pin_memory,
831+
**self.config.dataloader_kwargs,
830832
)
831833

832834
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -869,6 +871,7 @@ def prepare_inference_dataloader(
869871
batch_size or self.batch_size,
870872
shuffle=False,
871873
num_workers=self.config.num_workers,
874+
**self.config.dataloader_kwargs,
872875
)
873876

874877
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)