Skip to content

Commit

Permalink
minor bug fix to subsampling after task changed
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent 6978b7f commit 30cc3d7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
12 changes: 5 additions & 7 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_VAL_SIZE = 0.1
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand Down Expand Up @@ -116,17 +116,15 @@ def task(self, new_task: ACSTaskMetadata):
# Parse data rows for new ACS task
self._data = self._parse_task_data(self._full_acs_data, new_task)

# Re-Make train/test/val split
# Re-make train/test/val split
self._train_indices, self._test_indices, self._val_indices = (
self._make_train_test_val_split(
self._data, self.test_size, self.val_size, self._rng)
)

# Check if task columns are in the data
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={new_task.features}, target={new_task.get_target()}")
# Check if sub-sampling is necessary (it's applied only to train/test/val indices)
if self.subsampling is not None:
self._subsample_train_test_val_indices(self.subsampling)

self._task = new_task

Expand Down
33 changes: 23 additions & 10 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .task import TaskMetadata

DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_VAL_SIZE = 0.1
DEFAULT_SEED = 42


Expand Down Expand Up @@ -55,11 +55,18 @@ def __init__(
"""
self._data = data
self._task = task

# Validate task
if not isinstance(self._task, TaskMetadata):
raise ValueError(
f"Invalid `task` type: {type(self._task)}. "
f"Expected `TaskMetadata`.")

if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={task.features}, target={task.get_target()}")

self._test_size = test_size
self._val_size = val_size or 0
self._train_size = 1 - self._test_size - self._val_size
Expand All @@ -77,7 +84,7 @@ def __init__(
# Subsample the train/test/val data (if requested)
self._subsampling = None
if subsampling is not None:
self._subsample_inplace(subsampling)
self.subsample(subsampling)

@property
def data(self) -> pd.DataFrame:
Expand Down Expand Up @@ -155,7 +162,7 @@ def _make_train_test_val_split(
val_indices,
)

def _subsample_inplace(self, subsampling: float) -> Dataset:
def _subsample_train_test_val_indices(self, subsampling: float) -> Dataset:
"""Subsample the dataset in-place."""

# Check argument is valid
Expand All @@ -168,27 +175,33 @@ def _subsample_inplace(self, subsampling: float) -> Dataset:

self._train_indices = self._train_indices[: new_train_size]
self._test_indices = self._test_indices[: new_test_size]

if self._val_indices is not None:
new_val_size = int(len(self._val_indices) * subsampling + 0.5)
self._val_indices = self._val_indices[: new_val_size]

# Update subsampling factor
self._subsampling = (getattr(self, "_subsampling", None) or 1) * subsampling

# Log new dataset size
msg = (
f"Subsampled dataset to {self.subsampling:.1%} of the original size. "
f"Train size: {len(self._train_indices)}, "
f"Test size: {len(self._test_indices)}, "
f"Val size: {len(self._val_indices) if self._val_indices is not None else 0};"
f"Test size: {len(self._test_indices)}, "
f"Val size: {len(self._val_indices) if self._val_indices is not None else 0};"
)
logging.info(msg)

return self

def subsample(self, subsampling: float):
"""Subsamples this dataset in-place."""
return self._subsample_inplace(subsampling)
if subsampling is None:
logging.warning(f"Got `subsampling={subsampling}`, skipping...")
return self

# Update train/test/val indices
self._subsample_train_test_val_indices(subsampling)

# Update subsampling factor
self._subsampling = (self._subsampling or 1) * subsampling
return self

def _filter_inplace(
self,
Expand Down

0 comments on commit 30cc3d7

Please sign in to comment.