diff --git a/folktexts/__init__.py b/folktexts/__init__.py index 434a445..681d5a3 100644 --- a/folktexts/__init__.py +++ b/folktexts/__init__.py @@ -1,4 +1,5 @@ from ._version import __version__, __version_info__ -from .acs import ACSDataset, ACSTaskMetadata +from .task import TaskMetadata from .benchmark import BenchmarkConfig, CalibrationBenchmark from .classifier import LLMClassifier +from .acs import ACSDataset, ACSTaskMetadata diff --git a/folktexts/acs/acs_columns.py b/folktexts/acs/acs_columns.py index c57411b..c6f83d7 100644 --- a/folktexts/acs/acs_columns.py +++ b/folktexts/acs/acs_columns.py @@ -97,7 +97,7 @@ value_map=partial( parse_pums_code, file=ACS_OCCP_FILE, - postprocess=lambda x: x[4:].lower().strip(), + postprocess=lambda x: x[4:].lower().capitalize().strip(), ), ) @@ -108,7 +108,7 @@ value_map=partial( parse_pums_code, file=ACS_POBP_FILE, - postprocess=lambda x: x[: x.find("/")].strip(), + postprocess=lambda x: (x[: x.find("/")] if "/" in x else x).strip(), ), ) @@ -524,7 +524,7 @@ "PUMA", short_description="Public Use Microdata Area (PUMA) code", use_value_map_only=True, - value_map=lambda x: f"PUMA code: {int(x)}", + value_map=lambda x: f"PUMA code: {int(x)}.", # missing_value_fill="N/A (less than 16 years old)", ) @@ -533,7 +533,7 @@ "POWPUMA", short_description="place of work PUMA", use_value_map_only=True, - value_map=lambda x: f"Place of work PUMA code: {int(x)}", + value_map=lambda x: f"Place of work PUMA code: {int(x)}.", # missing_value_fill="N/A (not a worker, or worker who worked at home)", ) diff --git a/folktexts/acs/acs_dataset.py b/folktexts/acs/acs_dataset.py index dfd3bc9..0215555 100644 --- a/folktexts/acs/acs_dataset.py +++ b/folktexts/acs/acs_dataset.py @@ -5,6 +5,7 @@ import logging from pathlib import Path +import pandas as pd from folktables import ACSDataSource from folktables.load_acs import state_list @@ -12,6 +13,8 @@ from .acs_tasks import ACSTaskMetadata DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve() +DEFAULT_TEST_SIZE = 0.1 +DEFAULT_VAL_SIZE = None DEFAULT_SEED = 42 DEFAULT_SURVEY_YEAR = "2018" @@ -24,6 +27,27 @@ class ACSDataset(Dataset): def __init__( self, + data: pd.DataFrame, + full_acs_data: pd.DataFrame, + task: ACSTaskMetadata, + test_size: float = DEFAULT_TEST_SIZE, + val_size: float = DEFAULT_VAL_SIZE, + subsampling: float = None, + seed: int = 42, + ): + self._full_acs_data = full_acs_data + super().__init__( + data=data, + task=task, + test_size=test_size, + val_size=val_size, + subsampling=subsampling, + seed=seed, + ) + + @classmethod + def make_from_task( + cls, task: str | ACSTaskMetadata, cache_dir: str | Path = None, survey_year: str = DEFAULT_SURVEY_YEAR, @@ -32,7 +56,7 @@ def __init__( seed: int = DEFAULT_SEED, **kwargs, ): - """Construct an ACSDataset object. + """Construct an ACSDataset object using ACS survey parameters. Parameters ---------- @@ -49,6 +73,8 @@ def __init__( The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT. seed : int, optional The random seed, by default DEFAULT_SEED. + **kwargs + Extra key-word arguments to be passed to the Dataset constructor. """ # Create "folktables" sub-folder under the given cache dir cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables" @@ -56,6 +82,9 @@ def __init__( logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.") cache_dir.mkdir(exist_ok=True, parents=False) + # Parse task if given a string + task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task + # Load ACS data source print("Loading ACS data...") data_source = ACSDataSource( @@ -63,28 +92,64 @@ def __init__( root_dir=cache_dir.as_posix(), ) - # Get ACS data in a pandas DF - data = data_source.get_data( - states=state_list, download=True, random_seed=seed, + # Get full ACS dataset + full_acs_data = data_source.get_data( + states=state_list, download=True, random_seed=seed) + + # Parse data for this task + parsed_data = cls._parse_task_data(full_acs_data, task_obj) + + return cls( + data=parsed_data, + full_acs_data=full_acs_data, + task=task, + seed=seed, + **kwargs, ) - # Get information on this ACS/folktables task - task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task + @property + def task(self) -> ACSTaskMetadata: + return self._task + + @task.setter + def task(self, new_task: ACSTaskMetadata): + # Parse data rows for new ACS task + self._data = self._parse_task_data(self._full_acs_data, new_task) - # Keep only rows used in this task - if isinstance(task_obj, ACSTaskMetadata) and task_obj.folktables_obj is not None: - data = task_obj.folktables_obj._preprocess(data) + # Check if task columns are in the data + if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])): + raise ValueError( + f"Task columns not found in dataset: " + f"features={new_task.features}, target={new_task.get_target()}") + + self._task = new_task + + @classmethod + def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.DataFrame: + """Parse a DataFrame for compatibility with the given task object. + + Parameters + ---------- + full_df : pd.DataFrame + Full DataFrame. Some rows and/or columns may be discarded for each + task. + task : ACSTaskMetadata + The task object used to parse the given data. + + Returns + ------- + parsed_df : pd.DataFrame + Parsed DataFrame in accordance with the given task. + """ + if not isinstance(task, ACSTaskMetadata): + logging.error(f"Expected task of type `ACSTaskMetadata` for {type(task)}") + return full_df + + # Parse data + parsed_df = task.folktables_obj._preprocess(full_df) # Threshold the target column if necessary - # > use standardized ACS naming convention - if task_obj.target_threshold is not None: - thresholded_target = task_obj.get_target() - if thresholded_target not in data.columns: - data[thresholded_target] = task_obj.target_threshold.apply_to_column_data(data[task_obj.target]) + if task.target_threshold is not None and task.get_target() not in parsed_df.columns: + parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target]) - super().__init__( - data=data, - task=task_obj, - seed=seed, - **kwargs, - ) + return parsed_df diff --git a/folktexts/benchmark.py b/folktexts/benchmark.py index 4ea1c6d..dcdde0f 100755 --- a/folktexts/benchmark.py +++ b/folktexts/benchmark.py @@ -383,7 +383,7 @@ def make_acs_benchmark( # Fetch ACS task and dataset acs_task = ACSTaskMetadata.get_task(task_name) - acs_dataset = ACSDataset( + acs_dataset = ACSDataset.make_from_task( task_obj=acs_task, cache_dir=data_dir, **acs_dataset_configs) diff --git a/folktexts/dataset.py b/folktexts/dataset.py index 81014ba..a8c4c9d 100755 --- a/folktexts/dataset.py +++ b/folktexts/dataset.py @@ -9,30 +9,29 @@ """ from __future__ import annotations -import copy import logging -import warnings from abc import ABC import numpy as np import pandas as pd -from ._utils import hash_dict, is_valid_number, suppress_logging +from ._utils import hash_dict, is_valid_number from .task import TaskMetadata DEFAULT_TEST_SIZE = 0.1 DEFAULT_VAL_SIZE = None +DEFAULT_SEED = 42 class Dataset(ABC): def __init__( self, data: pd.DataFrame, - task: TaskMetadata, # TODO: remove this from the Dataset + task: TaskMetadata, test_size: float = DEFAULT_TEST_SIZE, val_size: float = DEFAULT_VAL_SIZE, subsampling: float = None, - seed: int = 42, + seed: int = DEFAULT_SEED, ): """Construct a Dataset object. @@ -92,15 +91,14 @@ def task(self) -> TaskMetadata: return self._task @task.setter - def task(self, task: TaskMetadata): - logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.") + def task(self, new_task: TaskMetadata): # Check if task columns are in the data - if not all(col in self.data.columns for col in (task.features + [task.get_target()])): + if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])): raise ValueError( f"Task columns not found in dataset: " - f"features={task.features}, target={task.get_target()}") + f"features={new_task.features}, target={new_task.get_target()}") - self._task = task + self._task = new_task @property def train_size(self) -> float: @@ -130,22 +128,6 @@ def name(self) -> str: hash_str = f"hash-{hash(self)}" return f"{self.task.name}_{subsampling_str}_{seed_str}_{hash_str}" - def __copy__(self) -> "Dataset": - dataset = Dataset( - data=self.data, - task=self.task, - test_size=self.test_size, - val_size=self.val_size, - subsampling=self.subsampling, - seed=self.seed, - ) - dataset._train_indices = self._train_indices.copy() - dataset._test_indices = self._test_indices.copy() - dataset._val_indices = self._val_indices.copy() if self._val_indices is not None else None - dataset._rng = copy.deepcopy(self._rng) - - return dataset - def _subsample_inplace(self, subsampling: float) -> "Dataset": """Subsample the dataset in-place.""" @@ -177,11 +159,9 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset": return self - def subsample(self, subsampling: float) -> "Dataset": - """Create a new dataset whose samples are a fraction of this dataset.""" - with suppress_logging(logging.WARNING): - self_copy = copy.copy(self) - return self_copy._subsample_inplace(subsampling) + def subsample(self, subsampling: float): + """Subsamples this dataset in-place.""" + return self._subsample_inplace(subsampling) def _filter_inplace( self, @@ -216,9 +196,9 @@ def _filter_inplace( return self - def filter(self, population_feature_values: dict) -> "Dataset": - """Create a new dataset whose samples are a subset of this dataset.""" - return copy.copy(self)._filter_inplace(population_feature_values) + def filter(self, population_feature_values: dict): + """Filter dataset rows in-place.""" + self._filter_inplace(population_feature_values) def get_features_data(self) -> pd.DataFrame: return self.data[self.task.features] diff --git a/folktexts/evaluation.py b/folktexts/evaluation.py index 0d716e2..67bb2dc 100644 --- a/folktexts/evaluation.py +++ b/folktexts/evaluation.py @@ -115,9 +115,9 @@ def evaluate_binary_predictions_fairness( def group_metric_name(metric_name, group_name): return f"{metric_name}_group={group_name}" - assert ( - len(unique_groups) > 1 - ), f"Found a single unique sensitive attribute: {unique_groups}" + if len(unique_groups) <= 1: + logging.error(f"Found a single unique sensitive attribute: {unique_groups}") + return {} for s_value in unique_groups: # Indices of samples that belong to the current group diff --git a/folktexts/qa_interface.py b/folktexts/qa_interface.py index adf47f7..46bf942 100644 --- a/folktexts/qa_interface.py +++ b/folktexts/qa_interface.py @@ -243,7 +243,7 @@ def create_answer_keys_permutations(cls, question: "MultipleChoiceQA") -> Iterat yield dataclasses.replace(question, choices=perm) @property - def answer_keys(self) -> list[str]: + def answer_keys(self) -> tuple[str]: return self._answer_keys_source[:len(self.choices)] @property