Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 12, 2024
1 parent 2d108af commit d550aae
Show file tree
Hide file tree
Showing 14 changed files with 550 additions and 8,459 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests with tox
name: Tests

on:
push:
Expand Down
60 changes: 53 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Package documentation can be found [here](https://socialfoundations.github.io/fo
- [Installing](#installing)
- [Basic setup](#basic-setup)
- [Usage](#usage)
- [Benchmark options](#benchmark-options)
- [License and terms of use](#license-and-terms-of-use)


Expand All @@ -37,8 +38,9 @@ pip install folktexts
```

## Basic setup
> You'll need to go through these steps to run the benchmark tasks.
1. Create condo environment
1. Create conda environment

```
conda create -n folktexts python=3.11
Expand All @@ -56,19 +58,18 @@ pip install folktexts
```
mkdir results
mkdir models
mkdir datasets
mkdir data
```

3. Download transformers model and tokenizer into models folder

4. Download transformers model and tokenizer
```
python -m folktexts.cli.download_models --model "google/gemma-2b" --save-dir models
```

4. Run benchmark
5. Run benchmark on a given task

```
python -m folktexts.cli.run_acs_benchmark --results-dir results --data-dir datasets --task-name "ACSIncome" --model models/google--gemma-2b
python -m folktexts.cli.run_acs_benchmark --results-dir results --data-dir data --task-name "ACSIncome" --model models/google--gemma-2b
```

Run `python -m folktexts.cli.run_acs_benchmark --help` to get a list of all
Expand All @@ -77,6 +78,9 @@ available benchmark flags.

## Usage

To use one of the pre-defined survey prediction tasks, simply use the following
code snippet:

```py
from folktexts.acs import ACSDataset, ACSTaskMetadata
acs_task_name = "ACSIncome"
Expand All @@ -94,7 +98,7 @@ dataset = ACSDataset(acs_task_name)
# Get risk score predictions out of the model
y_scores = clf.predict_proba(dataset)

# Optionally, can fit the threshold based on a small portion of the data
# Optionally, you can fit the threshold based on a small portion of the data
clf.fit(dataset[0:100])

# ...in order to get more accurate binary predictions
Expand All @@ -105,6 +109,48 @@ from folktexts.benchmark import CalibrationBenchmark
benchmark_results = CalibrationBenchmark(clf, dataset, results_dir="results").run()
```

## Benchmark options

```
usage: run_acs_benchmark.py [-h] --model MODEL --task-name TASK_NAME --results-dir RESULTS_DIR --data-dir DATA_DIR [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD]
[--subsampling SUBSAMPLING] [--seed SEED] [--dont-correct-order-bias] [--chat-prompt] [--direct-risk-prompting] [--reuse-few-shot-examples] [--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}]
[--use-feature-subset [USE_FEATURE_SUBSET ...]] [--use-population-filter [USE_POPULATION_FILTER ...]]
Run an LLM as a classifier experiment.
options:
-h, --help show this help message and exit
--model MODEL [str] Model name or path to model saved on disk
--task-name TASK_NAME
[str] Name of the ACS task to run the experiment on
--results-dir RESULTS_DIR
[str] Directory under which this experiment's results will be saved
--data-dir DATA_DIR [str] Root folder to find datasets on
--few-shot FEW_SHOT [int] Use few-shot prompting with the given number of shots
--batch-size BATCH_SIZE
[int] The batch size to use for inference
--context-size CONTEXT_SIZE
[int] The maximum context size when prompting the LLM
--fit-threshold FIT_THRESHOLD
[int] Whether to fit the prediction threshold, and on how many samples
--subsampling SUBSAMPLING
[float] Which fraction of the dataset to use (if omitted will use all data)
--seed SEED [int] Random seed -- to set for reproducibility
--dont-correct-order-bias
[bool] Whether to avoid correcting ordering bias, by default will correct it
--chat-prompt [bool] Whether to use chat-based prompting (for instruct models)
--direct-risk-prompting
[bool] Whether to directly prompt for risk-estimates instead of multiple-choice Q&A
--reuse-few-shot-examples
[bool] Whether to reuse the same samples for few-shot prompting (or sample new ones every time)
--logger-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
[str] The logging level to use for the experiment
--use-feature-subset [USE_FEATURE_SUBSET ...]
[str] Optional subset of features to use for prediction
--use-population-filter [USE_POPULATION_FILTER ...]
[str] Optional population filter for this benchmark; must follow the format 'column_name=value' to filter the dataset by a specific value.
```


## License and terms of use

Expand Down
2 changes: 2 additions & 0 deletions folktexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from ._version import __version__, __version_info__
from .acs import ACSDataset, ACSTaskMetadata
from .benchmark import BenchmarkConfig, CalibrationBenchmark
from .classifier import LLMClassifier
2 changes: 2 additions & 0 deletions folktexts/_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import logging
import pickle
Expand Down
4 changes: 3 additions & 1 deletion folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def is_valid_number(num) -> bool:
def safe_division(a: float, b: float, *, worst_result: float):
"""Try to divide the given arguments and return `worst_result` if unsuccessful."""
if b == 0 or not is_valid_number(a) or not is_valid_number(b):
logging.warning(f"Error in the following division: {a} / {b}")
logging.info(
f"Using `worst_result={worst_result}` in place of the following "
f"division: {a} / {b}")
return worst_result
else:
return a / b
Expand Down
5 changes: 4 additions & 1 deletion folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
from __future__ import annotations

import logging
from pathlib import Path

from folktables import ACSDataSource
Expand Down Expand Up @@ -34,7 +35,9 @@ def __init__(
):
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_ACS_DATA_DIR).expanduser().resolve() / "folktables"
cache_dir.mkdir(exist_ok=True, parents=False)
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)

# Load ACS data source
print("Loading ACS data...")
Expand Down
18 changes: 10 additions & 8 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .task import TaskMetadata

DEFAULT_SEED = 42
DEFAULT_FIT_THRESHOLD_N = 100


@dataclasses.dataclass(frozen=True, eq=True)
Expand All @@ -41,8 +42,9 @@ class BenchmarkConfig:
seed: int = DEFAULT_SEED

@classmethod
def default_config(cls):
return cls()
def default_config(cls, **changes):
"""Returns the default configuration with optional changes."""
return cls(**changes)

@classmethod
def load_from_disk(cls, path: str | Path):
Expand All @@ -54,6 +56,7 @@ def save_to_disk(self, path: str | Path):
save_json(dataclasses.asdict(self), path)

def __hash__(self) -> int:
"""Generates a unique hash for the configuration."""
cfg = dataclasses.asdict(self)
cfg["feature_subset"] = tuple(cfg["feature_subset"]) if cfg["feature_subset"] else None
cfg["population_filter_hash"] = (
Expand Down Expand Up @@ -159,7 +162,7 @@ def _get_predictions_save_path(self, data_split: str) -> Path:
assert data_split in ("train", "val", "test")
return self.results_dir / f"{self.dataset.get_name()}.{data_split}_predictions.csv"

def run(self, fit_threshold: int | False = False) -> float:
def run(self, fit_threshold: int | bool = False) -> float:
"""Run the calibration benchmark experiment."""

# Get test data
Expand All @@ -181,15 +184,14 @@ def run(self, fit_threshold: int | False = False) -> float:

# If requested, fit the threshold on a small portion of the train set
if fit_threshold:
if not is_valid_number(fit_threshold):
if fit_threshold is True:
fit_threshold = DEFAULT_FIT_THRESHOLD_N
elif not is_valid_number(fit_threshold) or fit_threshold <= 0:
raise ValueError(f"Invalid fit_threshold={fit_threshold}")

logging.info(f"Fitting threshold on {fit_threshold} train samples")
X_train, y_train = self.dataset.sample_n_train_examples(fit_threshold)
self.llm_clf.fit(X_train, y_train)
logging.info(
f"Fitted threshold on {len(y_train)} train examples; "
f"threshold={self.llm_clf.threshold:.3f};"
)

# Evaluate test risk scores
self._results = evaluate_predictions(
Expand Down
6 changes: 3 additions & 3 deletions folktexts/cli/launch_acs_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ACS_DATA_DIR = ROOT_DIR / "data"

# Directory to save results in (make sure it exists)
RESULTS_DIR = ROOT_DIR / "folktexts-results"
RESULTS_DIR = ROOT_DIR / "folktexts-results" / "acs-benchmarks"
RESULTS_DIR.mkdir(exist_ok=True, parents=False)

# Models save directory
Expand Down Expand Up @@ -75,8 +75,8 @@
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
# "mistralai/Mixtral-8x22B-v0.1",
# "mistralai/Mixtral-8x22B-Instruct-v0.1",
"mistralai/Mixtral-8x22B-v0.1",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
]


Expand Down
2 changes: 1 addition & 1 deletion folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __copy__(self) -> "Dataset":
)
dataset._train_indices = self._train_indices.copy()
dataset._test_indices = self._test_indices.copy()
dataset._val_indices = self._val_indices.copy()
dataset._val_indices = self._val_indices.copy() if self._val_indices is not None else None
dataset._rng = copy.deepcopy(self._rng)

return dataset
Expand Down
Loading

0 comments on commit d550aae

Please sign in to comment.