Skip to content

Commit

Permalink
updated readme example code
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 12, 2024
1 parent cd4d44a commit 7afbddd
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 47 deletions.
35 changes: 19 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Package documentation can be found [here](https://socialfoundations.github.io/fo
**Table of contents:**
- [Installing](#installing)
- [Basic setup](#basic-setup)
- [Usage](#usage)
- [Example usage](#example-usage)
- [Benchmark options](#benchmark-options)
- [License and terms of use](#license-and-terms-of-use)

Expand Down Expand Up @@ -63,43 +63,43 @@ mkdir data

4. Download transformers model and tokenizer
```
python -m folktexts.cli.download_models --model "google/gemma-2b" --save-dir models
download_models --model "google/gemma-2b" --save-dir models
```

5. Run benchmark on a given task

```
python -m folktexts.cli.run_acs_benchmark --results-dir results --data-dir data --task-name "ACSIncome" --model models/google--gemma-2b
run_acs_benchmark --results-dir results --data-dir data --task-name "ACSIncome" --model models/google--gemma-2b
```

Run `python -m folktexts.cli.run_acs_benchmark --help` to get a list of all
available benchmark flags.
Run `run_acs_benchmark --help` to get a list of all available benchmark flags.


## Usage

To use one of the pre-defined survey prediction tasks, simply use the following
code snippet:
## Example usage

```py
from folktexts.acs import ACSDataset, ACSTaskMetadata
from folktexts.llm_utils import load_model_tokenizer
model, tokenizer = load_model_tokenizer("gpt2") # using tiny model as an example

from folktexts.acs import ACSDataset
acs_task_name = "ACSIncome"

# Create an object that classifies data using an LLM
from folktexts import LLMClassifier
clf = LLMClassifier(
model=model,
tokenizer=tokenizer,
task=ACSTaskMetadata.get_task(acs_task_name),
task=acs_task_name,
)

# Use a dataset or feed in your own data
dataset = ACSDataset(acs_task_name)
dataset = ACSDataset(acs_task_name) # use `.subsample(0.01)` to get faster approximate results

# Get risk score predictions out of the model
y_scores = clf.predict_proba(dataset)

# Optionally, you can fit the threshold based on a small portion of the data
clf.fit(dataset[0:100])
clf.fit(*dataset[0:100])

# ...in order to get more accurate binary predictions
clf.predict(dataset)
Expand All @@ -109,12 +109,15 @@ from folktexts.benchmark import CalibrationBenchmark
benchmark_results = CalibrationBenchmark(clf, dataset).run(results_root_dir=".")
```

<!-- TODO: add code to show-case example functionalities, including the
LLMClassifier (maybe the above code is fine for this), the benchmark, and
creating a custom ACS prediction task -->

## Benchmark options

```
usage: run_acs_benchmark.py [-h] --model MODEL --task-name TASK_NAME --results-dir RESULTS_DIR --data-dir DATA_DIR [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD]
[--subsampling SUBSAMPLING] [--seed SEED] [--dont-correct-order-bias] [--chat-prompt] [--direct-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset [USE_FEATURE_SUBSET ...]]
[--use-population-filter [USE_POPULATION_FILTER ...]] [--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}]
usage: run_acs_benchmark [-h] --model MODEL --task-name TASK_NAME --results-dir RESULTS_DIR --data-dir DATA_DIR [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] [--subsampling SUBSAMPLING] [--seed SEED] [--dont-correct-order-bias] [--chat-prompt] [--direct-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset [USE_FEATURE_SUBSET ...]]
[--use-population-filter [USE_POPULATION_FILTER ...]] [--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}]
Run an LLM as a classifier experiment.
Expand Down
6 changes: 6 additions & 0 deletions folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,9 @@ def hash_function(func, length: int = 8) -> str:
def standardize_path(path: str | Path) -> str:
"""Represents a posix path as a standardized string."""
return Path(path).expanduser().resolve().as_posix()


def get_thresholded_column_name(column_name: str, threshold: float | int) -> str:
"""Standardizes naming of thresholded columns."""
threshold_str = f"{threshold:.2f}".replace(".", "_") if isinstance(threshold, float) else str(threshold)
return f"{column_name}_binary_{threshold_str}"
6 changes: 0 additions & 6 deletions folktexts/acs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@
from typing import Callable


def get_thresholded_column_name(column_name: str, threshold: float | int) -> str:
"""Standardizes naming of thresholded columns."""
threshold_str = f"{threshold:.2f}".replace(".", "_") if isinstance(threshold, float) else str(threshold)
return f"{column_name}_binary_{threshold_str}"


def parse_pums_code(
value: int,
file: str | Path,
Expand Down
3 changes: 2 additions & 1 deletion folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ..col_to_text import ColumnToText
from ..qa_interface import Choice, DirectNumericQA, MultipleChoiceQA
from ._utils import get_thresholded_column_name, parse_pums_code
from .._utils import get_thresholded_column_name
from ._utils import parse_pums_code

# Path to ACS codebook files
ACS_OCCP_FILE = Path(__file__).parent / "data" / "OCCP-codes-acs.txt"
Expand Down
9 changes: 5 additions & 4 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from folktables.load_acs import state_list

from ..dataset import Dataset
from ._utils import get_thresholded_column_name
from .._utils import get_thresholded_column_name
from .acs_tasks import ACSTaskMetadata # noqa # load ACS tasks

DEFAULT_ACS_DATA_DIR = Path("~/data").expanduser().resolve()
Expand Down Expand Up @@ -60,9 +60,10 @@ def __init__(
# Threshold the target column if necessary
# > use standardized ACS naming convention
if task.target_threshold is not None:
thresholded_target = get_thresholded_column_name(task.target, task.target_threshold)
data[thresholded_target] = (data[task.target] >= task.target_threshold).astype(int)
task.target = thresholded_target
# TODO: the target should be thresholded in the task definition not here!
thresholded_target = task.get_target()
if thresholded_target not in data.columns:
data[thresholded_target] = (data[task.target] >= task.target_threshold).astype(int)

super().__init__(
data=data,
Expand Down
8 changes: 4 additions & 4 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self,
llm_clf: LLMClassifier,
dataset: Dataset | str,
config: BenchmarkConfig,
config: BenchmarkConfig = BenchmarkConfig.default_config(),
):
self.llm_clf = llm_clf
self.dataset = dataset
Expand Down Expand Up @@ -457,12 +457,12 @@ def make_benchmark(
# Load the QA interface to be used for risk-score prompting
if config.direct_risk_prompting:
logging.warning(f"Untested feature: direct_risk_prompting={config.direct_risk_prompting}") # TODO!
question = acs_numeric_qa_map[task.target]
question = acs_numeric_qa_map[task.get_target()]
else:
question = acs_multiple_choice_qa_map[task.target]
question = acs_multiple_choice_qa_map[task.get_target()]

# Set the task's target question
task.cols_to_text[task.target]._question = question
task.cols_to_text[task.get_target()]._question = question

# Construct the LLMClassifier object
llm_inference_kwargs = {"correct_order_bias": config.correct_order_bias}
Expand Down
1 change: 0 additions & 1 deletion folktexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def _compute_risk_estimates_for_dataset(
results = {
data_type: self._compute_risk_estimates_for_dataframe(
df=df,
dataset=dataset,
batch_size=batch_size,
context_size=context_size,
)
Expand Down
6 changes: 5 additions & 1 deletion folktexts/cli/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def is_bf16_compatible() -> bool:
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()


if __name__ == "__main__":
def main():
# Parse command-line arguments
parser = setup_arg_parser()
args = parser.parse_args()
Expand Down Expand Up @@ -114,3 +114,7 @@ def is_bf16_compatible() -> bool:
# Empty VRAM if GPU is available
if torch.cuda.is_available():
torch.cuda.empty_cache()


if __name__ == "__main__":
main()
16 changes: 8 additions & 8 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def task(self) -> TaskMetadata:
def task(self, task: TaskMetadata):
logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.")
# Check if task columns are in the data
if not all(col in self.data.columns for col in (task.features + [task.target])):
if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={task.features}, target={task.target}")
f"features={task.features}, target={task.get_target()}")

self._task = task

Expand Down Expand Up @@ -221,7 +221,7 @@ def get_features_data(self) -> pd.DataFrame:
return self.data[self.task.features]

def get_target_data(self) -> pd.Series:
return self.data[self.task.target]
return self.data[self.task.get_target()]

def get_sensitive_attribute_data(self) -> pd.Series:
if self.task.sensitive_attribute is not None:
Expand All @@ -240,7 +240,7 @@ def get_data_split(self, split: str) -> tuple[pd.DataFrame, pd.Series]:

def get_train(self):
train_data = self.data.iloc[self._train_indices]
return train_data[self.task.features], train_data[self.task.target]
return train_data[self.task.features], train_data[self.task.get_target()]

def sample_n_train_examples(
self,
Expand Down Expand Up @@ -270,24 +270,24 @@ def sample_n_train_examples(

return (
self.data.iloc[example_indices][self.task.features],
self.data.iloc[example_indices][self.task.target],
self.data.iloc[example_indices][self.task.get_target()],
)

def get_test(self):
test_data = self.data.iloc[self._test_indices]
return test_data[self.task.features], test_data[self.task.target]
return test_data[self.task.features], test_data[self.task.get_target()]

def get_val(self):
if self._val_indices is None:
return None
val_data = self.data.iloc[self._val_indices]
return val_data[self.task.features], val_data[self.task.target]
return val_data[self.task.features], val_data[self.task.get_target()]

def __getitem__(self, i) -> tuple[pd.DataFrame, pd.Series]:
"""Returns the i-th training sample."""
curr_indices = self._train_indices[i]
curr_data = self.data.iloc[curr_indices]
return curr_data[self.task.features], curr_data[self.task.target]
return curr_data[self.task.features], curr_data[self.task.get_target()]

def __iter__(self):
"""Iterates over the training data."""
Expand Down
2 changes: 1 addition & 1 deletion folktexts/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def show_or_save(fig, fig_name: str):

# If the group is too small of a fraction, skip (curve will be too erratic)
if len(group_indices) / len(sensitive_attribute) < group_size_threshold:
logging.warning(f"Skipping group {s_value} plot as it's too small.")
logging.warning(f"Skipping group {group_value_map(s_value)} plot as it's too small.")
continue

# Plot global calibration curve
Expand Down
15 changes: 11 additions & 4 deletions folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pandas as pd

from ._utils import hash_dict
from ._utils import hash_dict, get_thresholded_column_name
from .col_to_text import ColumnToText


Expand Down Expand Up @@ -60,6 +60,13 @@ def __hash__(self) -> int:
hashable_params["question_hash"] = hash(self.question)
return int(hash_dict(hashable_params), 16)

def get_target(self) -> str:
"""Resolves the name of the target column depending on `self.target_threshold`."""
if self.target_threshold is None:
return self.target
else:
return get_thresholded_column_name(self.target, self.target_threshold)

@classmethod
def get_task(cls, name: str) -> "TaskMetadata":
if name not in cls._tasks:
Expand All @@ -68,9 +75,9 @@ def get_task(cls, name: str) -> "TaskMetadata":

@property
def question(self):
if self.cols_to_text[self.target]._question is None:
raise ValueError(f"No question provided for the target column '{self.target}'.")
return self.cols_to_text[self.target].question
if self.cols_to_text[self.get_target()]._question is None:
raise ValueError(f"No question provided for the target column '{self.get_target()}'.")
return self.cols_to_text[self.get_target()].question

def get_row_description(self, row: pd.Series) -> str:
"""Encode a description of a given data row in textual form."""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

version = "0.0.10"
version = "0.0.11"
requires-python = ">=3.8"
dynamic = [
"readme",
Expand Down Expand Up @@ -63,6 +63,7 @@ documentation = "https://socialfoundations.github.io/folktexts/"

[project.scripts]
run_acs_benchmark = "folktexts.cli.run_acs_benchmark:main"
download_models = "folktexts.cli.download_models:main"

# flake8
[tool.flake8]
Expand Down

0 comments on commit 7afbddd

Please sign in to comment.