Skip to content

Commit

Permalink
changing ACSDataset api
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 24, 2024
1 parent 217a045 commit 3967801
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 56 deletions.
4 changes: 2 additions & 2 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
value_map=partial(
parse_pums_code,
file=ACS_OCCP_FILE,
postprocess=lambda x: x[4:].lower().strip(),
postprocess=lambda x: x[4:].lower().capitalize().strip(),
),
)

Expand All @@ -108,7 +108,7 @@
value_map=partial(
parse_pums_code,
file=ACS_POBP_FILE,
postprocess=lambda x: x[: x.find("/")].strip(),
postprocess=lambda x: (x[: x.find("/")] if "/" in x else x).strip(),
),
)

Expand Down
98 changes: 78 additions & 20 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import logging
from pathlib import Path

import pandas as pd
from folktables import ACSDataSource
from folktables.load_acs import state_list

from ..dataset import Dataset
from .acs_tasks import ACSTaskMetadata

DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand All @@ -24,6 +27,27 @@ class ACSDataset(Dataset):

def __init__(
self,
data: pd.DataFrame,
full_acs_data: pd.DataFrame,
task: ACSTaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
):
self._full_acs_data = full_acs_data
super().__init__(
data=data,
task=task,
test_size=test_size,
val_size=val_size,
subsampling=subsampling,
seed=seed,
)

@classmethod
def make_acs_dataset(
cls,
task: str | ACSTaskMetadata,
cache_dir: str | Path = None,
survey_year: str = DEFAULT_SURVEY_YEAR,
Expand All @@ -32,7 +56,7 @@ def __init__(
seed: int = DEFAULT_SEED,
**kwargs,
):
"""Construct an ACSDataset object.
"""Construct an ACSDataset object using ACS survey parameters.
Parameters
----------
Expand All @@ -49,42 +73,76 @@ def __init__(
The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT.
seed : int, optional
The random seed, by default DEFAULT_SEED.
**kwargs
Extra key-word arguments to be passed to the Dataset constructor.
"""
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables"
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)

# Parse task if given a string
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task

# Load ACS data source
print("Loading ACS data...")
data_source = ACSDataSource(
survey_year=survey_year, horizon=horizon, survey=survey,
root_dir=cache_dir.as_posix(),
)

# Get ACS data in a pandas DF
data = data_source.get_data(
states=state_list, download=True, random_seed=seed,
# Get full ACS dataset
full_acs_data = data_source.get_data(
states=state_list, download=True, random_seed=seed)

# Parse data for this task
parsed_data = cls._parse_task_data(full_acs_data, task_obj)

return cls(
data=parsed_data,
full_acs_data=full_acs_data,
task=task,
seed=seed,
**kwargs,
)

# Get information on this ACS/folktables task
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task
@property
def task(self) -> ACSTaskMetadata:
return self._task

# Keep only rows used in this task
if isinstance(task_obj, ACSTaskMetadata) and task_obj.folktables_obj is not None:
data = task_obj.folktables_obj._preprocess(data)
@task.setter
def task(self, new_task: ACSTaskMetadata):
import ipdb; ipdb.set_trace()
self._data = self._parse_task_data(self._full_data, new_task)
super().task = new_task

@classmethod
def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.DataFrame:
"""Parse a DataFrame for compatibility with the given task object.
Parameters
----------
full_df : pd.DataFrame
Full DataFrame. Some rows and/or columns may be discarded for each
task.
task : ACSTaskMetadata
The task object used to parse the given data.
Returns
-------
parsed_df : pd.DataFrame
Parsed DataFrame in accordance with the given task.
"""
if not isinstance(task, ACSTaskMetadata):
logging.error(f"Expected task of type `ACSTaskMetadata` for {type(task)}")
return full_df

# Parse data
parsed_df = task.folktables_obj._preprocess(full_df)

# Threshold the target column if necessary
# > use standardized ACS naming convention
if task_obj.target_threshold is not None:
thresholded_target = task_obj.get_target()
if thresholded_target not in data.columns:
data[thresholded_target] = task_obj.target_threshold.apply_to_column_data(data[task_obj.target])
if task.target_threshold is not None and task.get_target() not in parsed_df.columns:
parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target])

super().__init__(
data=data,
task=task_obj,
seed=seed,
**kwargs,
)
return parsed_df
2 changes: 1 addition & 1 deletion folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def make_acs_benchmark(

# Fetch ACS task and dataset
acs_task = ACSTaskMetadata.get_task(task_name)
acs_dataset = ACSDataset(
acs_dataset = ACSDataset.make_acs_dataset(
task_obj=acs_task,
cache_dir=data_dir,
**acs_dataset_configs)
Expand Down
47 changes: 14 additions & 33 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
"""
from __future__ import annotations

import copy
import logging
import warnings
from abc import ABC

import numpy as np
Expand All @@ -22,17 +20,18 @@

DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42


class Dataset(ABC):
def __init__(
self,
data: pd.DataFrame,
task: TaskMetadata, # TODO: remove this from the Dataset
task: TaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
seed: int = DEFAULT_SEED,
):
"""Construct a Dataset object.
Expand Down Expand Up @@ -92,15 +91,15 @@ def task(self) -> TaskMetadata:
return self._task

@task.setter
def task(self, task: TaskMetadata):
logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.")
def task(self, new_task: TaskMetadata):
logging.info(f"Updating dataset's task from '{self.task.name}' to '{new_task.name}'.")
# Check if task columns are in the data
if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={task.features}, target={task.get_target()}")
f"features={new_task.features}, target={new_task.get_target()}")

self._task = task
self._task = new_task

@property
def train_size(self) -> float:
Expand Down Expand Up @@ -130,22 +129,6 @@ def name(self) -> str:
hash_str = f"hash-{hash(self)}"
return f"{self.task.name}_{subsampling_str}_{seed_str}_{hash_str}"

def __copy__(self) -> "Dataset":
dataset = Dataset(
data=self.data,
task=self.task,
test_size=self.test_size,
val_size=self.val_size,
subsampling=self.subsampling,
seed=self.seed,
)
dataset._train_indices = self._train_indices.copy()
dataset._test_indices = self._test_indices.copy()
dataset._val_indices = self._val_indices.copy() if self._val_indices is not None else None
dataset._rng = copy.deepcopy(self._rng)

return dataset

def _subsample_inplace(self, subsampling: float) -> "Dataset":
"""Subsample the dataset in-place."""

Expand Down Expand Up @@ -177,11 +160,9 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset":

return self

def subsample(self, subsampling: float) -> "Dataset":
"""Create a new dataset whose samples are a fraction of this dataset."""
with suppress_logging(logging.WARNING):
self_copy = copy.copy(self)
return self_copy._subsample_inplace(subsampling)
def subsample(self, subsampling: float):
"""Subsamples this dataset in-place."""
return self._subsample_inplace(subsampling)

def _filter_inplace(
self,
Expand Down Expand Up @@ -216,9 +197,9 @@ def _filter_inplace(

return self

def filter(self, population_feature_values: dict) -> "Dataset":
"""Create a new dataset whose samples are a subset of this dataset."""
return copy.copy(self)._filter_inplace(population_feature_values)
def filter(self, population_feature_values: dict):
"""Filter dataset rows in-place."""
self._filter_inplace(population_feature_values)

def get_features_data(self) -> pd.DataFrame:
return self.data[self.task.features]
Expand Down

0 comments on commit 3967801

Please sign in to comment.