diff --git a/README.md b/README.md index c226f61..f5e640e 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Package documentation can be found [here](https://socialfoundations.github.io/fo **Table of contents:** - [Installing](#installing) - [Basic setup](#basic-setup) -- [Usage](#usage) +- [Example usage](#example-usage) - [Benchmark options](#benchmark-options) - [License and terms of use](#license-and-terms-of-use) @@ -63,43 +63,43 @@ mkdir data 4. Download transformers model and tokenizer ``` -python -m folktexts.cli.download_models --model "google/gemma-2b" --save-dir models +download_models --model "google/gemma-2b" --save-dir models ``` 5. Run benchmark on a given task ``` -python -m folktexts.cli.run_acs_benchmark --results-dir results --data-dir data --task-name "ACSIncome" --model models/google--gemma-2b +run_acs_benchmark --results-dir results --data-dir data --task-name "ACSIncome" --model models/google--gemma-2b ``` -Run `python -m folktexts.cli.run_acs_benchmark --help` to get a list of all -available benchmark flags. +Run `run_acs_benchmark --help` to get a list of all available benchmark flags. -## Usage - -To use one of the pre-defined survey prediction tasks, simply use the following -code snippet: +## Example usage ```py -from folktexts.acs import ACSDataset, ACSTaskMetadata +from folktexts.llm_utils import load_model_tokenizer +model, tokenizer = load_model_tokenizer("gpt2") # using tiny model as an example + +from folktexts.acs import ACSDataset acs_task_name = "ACSIncome" # Create an object that classifies data using an LLM +from folktexts import LLMClassifier clf = LLMClassifier( model=model, tokenizer=tokenizer, - task=ACSTaskMetadata.get_task(acs_task_name), + task=acs_task_name, ) # Use a dataset or feed in your own data -dataset = ACSDataset(acs_task_name) +dataset = ACSDataset(acs_task_name) # use `.subsample(0.01)` to get faster approximate results # Get risk score predictions out of the model y_scores = clf.predict_proba(dataset) # Optionally, you can fit the threshold based on a small portion of the data -clf.fit(dataset[0:100]) +clf.fit(*dataset[0:100]) # ...in order to get more accurate binary predictions clf.predict(dataset) @@ -109,12 +109,15 @@ from folktexts.benchmark import CalibrationBenchmark benchmark_results = CalibrationBenchmark(clf, dataset).run(results_root_dir=".") ``` + + ## Benchmark options ``` -usage: run_acs_benchmark.py [-h] --model MODEL --task-name TASK_NAME --results-dir RESULTS_DIR --data-dir DATA_DIR [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] - [--subsampling SUBSAMPLING] [--seed SEED] [--dont-correct-order-bias] [--chat-prompt] [--direct-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset [USE_FEATURE_SUBSET ...]] - [--use-population-filter [USE_POPULATION_FILTER ...]] [--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] +usage: run_acs_benchmark [-h] --model MODEL --task-name TASK_NAME --results-dir RESULTS_DIR --data-dir DATA_DIR [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] [--subsampling SUBSAMPLING] [--seed SEED] [--dont-correct-order-bias] [--chat-prompt] [--direct-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset [USE_FEATURE_SUBSET ...]] + [--use-population-filter [USE_POPULATION_FILTER ...]] [--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] Run an LLM as a classifier experiment. diff --git a/folktexts/_utils.py b/folktexts/_utils.py index 39d324e..50a41bb 100644 --- a/folktexts/_utils.py +++ b/folktexts/_utils.py @@ -89,3 +89,9 @@ def hash_function(func, length: int = 8) -> str: def standardize_path(path: str | Path) -> str: """Represents a posix path as a standardized string.""" return Path(path).expanduser().resolve().as_posix() + + +def get_thresholded_column_name(column_name: str, threshold: float | int) -> str: + """Standardizes naming of thresholded columns.""" + threshold_str = f"{threshold:.2f}".replace(".", "_") if isinstance(threshold, float) else str(threshold) + return f"{column_name}_binary_{threshold_str}" diff --git a/folktexts/acs/_utils.py b/folktexts/acs/_utils.py index de11227..ad0e9d3 100644 --- a/folktexts/acs/_utils.py +++ b/folktexts/acs/_utils.py @@ -7,12 +7,6 @@ from typing import Callable -def get_thresholded_column_name(column_name: str, threshold: float | int) -> str: - """Standardizes naming of thresholded columns.""" - threshold_str = f"{threshold:.2f}".replace(".", "_") if isinstance(threshold, float) else str(threshold) - return f"{column_name}_binary_{threshold_str}" - - def parse_pums_code( value: int, file: str | Path, diff --git a/folktexts/acs/acs_columns.py b/folktexts/acs/acs_columns.py index f534d6e..2290fc7 100644 --- a/folktexts/acs/acs_columns.py +++ b/folktexts/acs/acs_columns.py @@ -5,7 +5,8 @@ from ..col_to_text import ColumnToText from ..qa_interface import Choice, DirectNumericQA, MultipleChoiceQA -from ._utils import get_thresholded_column_name, parse_pums_code +from .._utils import get_thresholded_column_name +from ._utils import parse_pums_code # Path to ACS codebook files ACS_OCCP_FILE = Path(__file__).parent / "data" / "OCCP-codes-acs.txt" diff --git a/folktexts/acs/acs_dataset.py b/folktexts/acs/acs_dataset.py index 69931b6..4015614 100644 --- a/folktexts/acs/acs_dataset.py +++ b/folktexts/acs/acs_dataset.py @@ -9,7 +9,7 @@ from folktables.load_acs import state_list from ..dataset import Dataset -from ._utils import get_thresholded_column_name +from .._utils import get_thresholded_column_name from .acs_tasks import ACSTaskMetadata # noqa # load ACS tasks DEFAULT_ACS_DATA_DIR = Path("~/data").expanduser().resolve() @@ -60,9 +60,10 @@ def __init__( # Threshold the target column if necessary # > use standardized ACS naming convention if task.target_threshold is not None: - thresholded_target = get_thresholded_column_name(task.target, task.target_threshold) - data[thresholded_target] = (data[task.target] >= task.target_threshold).astype(int) - task.target = thresholded_target + # TODO: the target should be thresholded in the task definition not here! + thresholded_target = task.get_target() + if thresholded_target not in data.columns: + data[thresholded_target] = (data[task.target] >= task.target_threshold).astype(int) super().__init__( data=data, diff --git a/folktexts/benchmark.py b/folktexts/benchmark.py index bad06a0..17402a5 100755 --- a/folktexts/benchmark.py +++ b/folktexts/benchmark.py @@ -93,7 +93,7 @@ def __init__( self, llm_clf: LLMClassifier, dataset: Dataset | str, - config: BenchmarkConfig, + config: BenchmarkConfig = BenchmarkConfig.default_config(), ): self.llm_clf = llm_clf self.dataset = dataset @@ -457,12 +457,12 @@ def make_benchmark( # Load the QA interface to be used for risk-score prompting if config.direct_risk_prompting: logging.warning(f"Untested feature: direct_risk_prompting={config.direct_risk_prompting}") # TODO! - question = acs_numeric_qa_map[task.target] + question = acs_numeric_qa_map[task.get_target()] else: - question = acs_multiple_choice_qa_map[task.target] + question = acs_multiple_choice_qa_map[task.get_target()] # Set the task's target question - task.cols_to_text[task.target]._question = question + task.cols_to_text[task.get_target()]._question = question # Construct the LLMClassifier object llm_inference_kwargs = {"correct_order_bias": config.correct_order_bias} diff --git a/folktexts/classifier.py b/folktexts/classifier.py index 61c41b9..a964426 100755 --- a/folktexts/classifier.py +++ b/folktexts/classifier.py @@ -445,7 +445,6 @@ def _compute_risk_estimates_for_dataset( results = { data_type: self._compute_risk_estimates_for_dataframe( df=df, - dataset=dataset, batch_size=batch_size, context_size=context_size, ) diff --git a/folktexts/cli/download_models.py b/folktexts/cli/download_models.py index 95d367b..6f1476e 100755 --- a/folktexts/cli/download_models.py +++ b/folktexts/cli/download_models.py @@ -74,7 +74,7 @@ def is_bf16_compatible() -> bool: return torch.cuda.is_available() and torch.cuda.is_bf16_supported() -if __name__ == "__main__": +def main(): # Parse command-line arguments parser = setup_arg_parser() args = parser.parse_args() @@ -114,3 +114,7 @@ def is_bf16_compatible() -> bool: # Empty VRAM if GPU is available if torch.cuda.is_available(): torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/folktexts/dataset.py b/folktexts/dataset.py index 57228f9..ac4f731 100755 --- a/folktexts/dataset.py +++ b/folktexts/dataset.py @@ -94,10 +94,10 @@ def task(self) -> TaskMetadata: def task(self, task: TaskMetadata): logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.") # Check if task columns are in the data - if not all(col in self.data.columns for col in (task.features + [task.target])): + if not all(col in self.data.columns for col in (task.features + [task.get_target()])): raise ValueError( f"Task columns not found in dataset: " - f"features={task.features}, target={task.target}") + f"features={task.features}, target={task.get_target()}") self._task = task @@ -221,7 +221,7 @@ def get_features_data(self) -> pd.DataFrame: return self.data[self.task.features] def get_target_data(self) -> pd.Series: - return self.data[self.task.target] + return self.data[self.task.get_target()] def get_sensitive_attribute_data(self) -> pd.Series: if self.task.sensitive_attribute is not None: @@ -240,7 +240,7 @@ def get_data_split(self, split: str) -> tuple[pd.DataFrame, pd.Series]: def get_train(self): train_data = self.data.iloc[self._train_indices] - return train_data[self.task.features], train_data[self.task.target] + return train_data[self.task.features], train_data[self.task.get_target()] def sample_n_train_examples( self, @@ -270,24 +270,24 @@ def sample_n_train_examples( return ( self.data.iloc[example_indices][self.task.features], - self.data.iloc[example_indices][self.task.target], + self.data.iloc[example_indices][self.task.get_target()], ) def get_test(self): test_data = self.data.iloc[self._test_indices] - return test_data[self.task.features], test_data[self.task.target] + return test_data[self.task.features], test_data[self.task.get_target()] def get_val(self): if self._val_indices is None: return None val_data = self.data.iloc[self._val_indices] - return val_data[self.task.features], val_data[self.task.target] + return val_data[self.task.features], val_data[self.task.get_target()] def __getitem__(self, i) -> tuple[pd.DataFrame, pd.Series]: """Returns the i-th training sample.""" curr_indices = self._train_indices[i] curr_data = self.data.iloc[curr_indices] - return curr_data[self.task.features], curr_data[self.task.target] + return curr_data[self.task.features], curr_data[self.task.get_target()] def __iter__(self): """Iterates over the training data.""" diff --git a/folktexts/plotting.py b/folktexts/plotting.py index 20bb5d6..35f85bb 100755 --- a/folktexts/plotting.py +++ b/folktexts/plotting.py @@ -238,7 +238,7 @@ def show_or_save(fig, fig_name: str): # If the group is too small of a fraction, skip (curve will be too erratic) if len(group_indices) / len(sensitive_attribute) < group_size_threshold: - logging.warning(f"Skipping group {s_value} plot as it's too small.") + logging.warning(f"Skipping group {group_value_map(s_value)} plot as it's too small.") continue # Plot global calibration curve diff --git a/folktexts/task.py b/folktexts/task.py index 38bab56..d7fbd73 100755 --- a/folktexts/task.py +++ b/folktexts/task.py @@ -9,7 +9,7 @@ import pandas as pd -from ._utils import hash_dict +from ._utils import hash_dict, get_thresholded_column_name from .col_to_text import ColumnToText @@ -60,6 +60,13 @@ def __hash__(self) -> int: hashable_params["question_hash"] = hash(self.question) return int(hash_dict(hashable_params), 16) + def get_target(self) -> str: + """Resolves the name of the target column depending on `self.target_threshold`.""" + if self.target_threshold is None: + return self.target + else: + return get_thresholded_column_name(self.target, self.target_threshold) + @classmethod def get_task(cls, name: str) -> "TaskMetadata": if name not in cls._tasks: @@ -68,9 +75,9 @@ def get_task(cls, name: str) -> "TaskMetadata": @property def question(self): - if self.cols_to_text[self.target]._question is None: - raise ValueError(f"No question provided for the target column '{self.target}'.") - return self.cols_to_text[self.target].question + if self.cols_to_text[self.get_target()]._question is None: + raise ValueError(f"No question provided for the target column '{self.get_target()}'.") + return self.cols_to_text[self.get_target()].question def get_row_description(self, row: pd.Series) -> str: """Encode a description of a given data row in textual form.""" diff --git a/pyproject.toml b/pyproject.toml index b25da87..a3dc372 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] -version = "0.0.10" +version = "0.0.11" requires-python = ">=3.8" dynamic = [ "readme", @@ -63,6 +63,7 @@ documentation = "https://socialfoundations.github.io/folktexts/" [project.scripts] run_acs_benchmark = "folktexts.cli.run_acs_benchmark:main" +download_models = "folktexts.cli.download_models:main" # flake8 [tool.flake8]