Skip to content

Commit

Permalink
fixed ACSDataset assignment of new task
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 24, 2024
1 parent ba22888 commit 9ffbc7b
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 64 deletions.
3 changes: 2 additions & 1 deletion folktexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._version import __version__, __version_info__
from .acs import ACSDataset, ACSTaskMetadata
from .task import TaskMetadata
from .benchmark import BenchmarkConfig, CalibrationBenchmark
from .classifier import LLMClassifier
from .acs import ACSDataset, ACSTaskMetadata
8 changes: 4 additions & 4 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
value_map=partial(
parse_pums_code,
file=ACS_OCCP_FILE,
postprocess=lambda x: x[4:].lower().strip(),
postprocess=lambda x: x[4:].lower().capitalize().strip(),
),
)

Expand All @@ -108,7 +108,7 @@
value_map=partial(
parse_pums_code,
file=ACS_POBP_FILE,
postprocess=lambda x: x[: x.find("/")].strip(),
postprocess=lambda x: (x[: x.find("/")] if "/" in x else x).strip(),
),
)

Expand Down Expand Up @@ -524,7 +524,7 @@
"PUMA",
short_description="Public Use Microdata Area (PUMA) code",
use_value_map_only=True,
value_map=lambda x: f"PUMA code: {int(x)}",
value_map=lambda x: f"PUMA code: {int(x)}.",
# missing_value_fill="N/A (less than 16 years old)",
)

Expand All @@ -533,7 +533,7 @@
"POWPUMA",
short_description="place of work PUMA",
use_value_map_only=True,
value_map=lambda x: f"Place of work PUMA code: {int(x)}",
value_map=lambda x: f"Place of work PUMA code: {int(x)}.",
# missing_value_fill="N/A (not a worker, or worker who worked at home)",
)

Expand Down
105 changes: 85 additions & 20 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import logging
from pathlib import Path

import pandas as pd
from folktables import ACSDataSource
from folktables.load_acs import state_list

from ..dataset import Dataset
from .acs_tasks import ACSTaskMetadata

DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand All @@ -24,6 +27,27 @@ class ACSDataset(Dataset):

def __init__(
self,
data: pd.DataFrame,
full_acs_data: pd.DataFrame,
task: ACSTaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
):
self._full_acs_data = full_acs_data
super().__init__(
data=data,
task=task,
test_size=test_size,
val_size=val_size,
subsampling=subsampling,
seed=seed,
)

@classmethod
def make_from_task(
cls,
task: str | ACSTaskMetadata,
cache_dir: str | Path = None,
survey_year: str = DEFAULT_SURVEY_YEAR,
Expand All @@ -32,7 +56,7 @@ def __init__(
seed: int = DEFAULT_SEED,
**kwargs,
):
"""Construct an ACSDataset object.
"""Construct an ACSDataset object using ACS survey parameters.
Parameters
----------
Expand All @@ -49,42 +73,83 @@ def __init__(
The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT.
seed : int, optional
The random seed, by default DEFAULT_SEED.
**kwargs
Extra key-word arguments to be passed to the Dataset constructor.
"""
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables"
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)

# Parse task if given a string
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task

# Load ACS data source
print("Loading ACS data...")
data_source = ACSDataSource(
survey_year=survey_year, horizon=horizon, survey=survey,
root_dir=cache_dir.as_posix(),
)

# Get ACS data in a pandas DF
data = data_source.get_data(
states=state_list, download=True, random_seed=seed,
# Get full ACS dataset
full_acs_data = data_source.get_data(
states=state_list, download=True, random_seed=seed)

# Parse data for this task
parsed_data = cls._parse_task_data(full_acs_data, task_obj)

return cls(
data=parsed_data,
full_acs_data=full_acs_data,
task=task,
seed=seed,
**kwargs,
)

# Get information on this ACS/folktables task
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task
@property
def task(self) -> ACSTaskMetadata:
return self._task

@task.setter
def task(self, new_task: ACSTaskMetadata):
# Parse data rows for new ACS task
self._data = self._parse_task_data(self._full_acs_data, new_task)

# Keep only rows used in this task
if isinstance(task_obj, ACSTaskMetadata) and task_obj.folktables_obj is not None:
data = task_obj.folktables_obj._preprocess(data)
# Check if task columns are in the data
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={new_task.features}, target={new_task.get_target()}")

self._task = new_task

@classmethod
def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.DataFrame:
"""Parse a DataFrame for compatibility with the given task object.
Parameters
----------
full_df : pd.DataFrame
Full DataFrame. Some rows and/or columns may be discarded for each
task.
task : ACSTaskMetadata
The task object used to parse the given data.
Returns
-------
parsed_df : pd.DataFrame
Parsed DataFrame in accordance with the given task.
"""
if not isinstance(task, ACSTaskMetadata):
logging.error(f"Expected task of type `ACSTaskMetadata` for {type(task)}")
return full_df

# Parse data
parsed_df = task.folktables_obj._preprocess(full_df)

# Threshold the target column if necessary
# > use standardized ACS naming convention
if task_obj.target_threshold is not None:
thresholded_target = task_obj.get_target()
if thresholded_target not in data.columns:
data[thresholded_target] = task_obj.target_threshold.apply_to_column_data(data[task_obj.target])
if task.target_threshold is not None and task.get_target() not in parsed_df.columns:
parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target])

super().__init__(
data=data,
task=task_obj,
seed=seed,
**kwargs,
)
return parsed_df
2 changes: 1 addition & 1 deletion folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def make_acs_benchmark(

# Fetch ACS task and dataset
acs_task = ACSTaskMetadata.get_task(task_name)
acs_dataset = ACSDataset(
acs_dataset = ACSDataset.make_from_task(
task_obj=acs_task,
cache_dir=data_dir,
**acs_dataset_configs)
Expand Down
48 changes: 14 additions & 34 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,29 @@
"""
from __future__ import annotations

import copy
import logging
import warnings
from abc import ABC

import numpy as np
import pandas as pd

from ._utils import hash_dict, is_valid_number, suppress_logging
from ._utils import hash_dict, is_valid_number
from .task import TaskMetadata

DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42


class Dataset(ABC):
def __init__(
self,
data: pd.DataFrame,
task: TaskMetadata, # TODO: remove this from the Dataset
task: TaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
seed: int = DEFAULT_SEED,
):
"""Construct a Dataset object.
Expand Down Expand Up @@ -92,15 +91,14 @@ def task(self) -> TaskMetadata:
return self._task

@task.setter
def task(self, task: TaskMetadata):
logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.")
def task(self, new_task: TaskMetadata):
# Check if task columns are in the data
if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={task.features}, target={task.get_target()}")
f"features={new_task.features}, target={new_task.get_target()}")

self._task = task
self._task = new_task

@property
def train_size(self) -> float:
Expand Down Expand Up @@ -130,22 +128,6 @@ def name(self) -> str:
hash_str = f"hash-{hash(self)}"
return f"{self.task.name}_{subsampling_str}_{seed_str}_{hash_str}"

def __copy__(self) -> "Dataset":
dataset = Dataset(
data=self.data,
task=self.task,
test_size=self.test_size,
val_size=self.val_size,
subsampling=self.subsampling,
seed=self.seed,
)
dataset._train_indices = self._train_indices.copy()
dataset._test_indices = self._test_indices.copy()
dataset._val_indices = self._val_indices.copy() if self._val_indices is not None else None
dataset._rng = copy.deepcopy(self._rng)

return dataset

def _subsample_inplace(self, subsampling: float) -> "Dataset":
"""Subsample the dataset in-place."""

Expand Down Expand Up @@ -177,11 +159,9 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset":

return self

def subsample(self, subsampling: float) -> "Dataset":
"""Create a new dataset whose samples are a fraction of this dataset."""
with suppress_logging(logging.WARNING):
self_copy = copy.copy(self)
return self_copy._subsample_inplace(subsampling)
def subsample(self, subsampling: float):
"""Subsamples this dataset in-place."""
return self._subsample_inplace(subsampling)

def _filter_inplace(
self,
Expand Down Expand Up @@ -216,9 +196,9 @@ def _filter_inplace(

return self

def filter(self, population_feature_values: dict) -> "Dataset":
"""Create a new dataset whose samples are a subset of this dataset."""
return copy.copy(self)._filter_inplace(population_feature_values)
def filter(self, population_feature_values: dict):
"""Filter dataset rows in-place."""
self._filter_inplace(population_feature_values)

def get_features_data(self) -> pd.DataFrame:
return self.data[self.task.features]
Expand Down
6 changes: 3 additions & 3 deletions folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def evaluate_binary_predictions_fairness(
def group_metric_name(metric_name, group_name):
return f"{metric_name}_group={group_name}"

assert (
len(unique_groups) > 1
), f"Found a single unique sensitive attribute: {unique_groups}"
if len(unique_groups) <= 1:
logging.error(f"Found a single unique sensitive attribute: {unique_groups}")
return {}

for s_value in unique_groups:
# Indices of samples that belong to the current group
Expand Down
2 changes: 1 addition & 1 deletion folktexts/qa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def create_answer_keys_permutations(cls, question: "MultipleChoiceQA") -> Iterat
yield dataclasses.replace(question, choices=perm)

@property
def answer_keys(self) -> list[str]:
def answer_keys(self) -> tuple[str]:
return self._answer_keys_source[:len(self.choices)]

@property
Expand Down

0 comments on commit 9ffbc7b

Please sign in to comment.