Skip to content

Commit

Permalink
creating arbitrary acs tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent db2c337 commit 2029a8c
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 27 deletions.
4 changes: 2 additions & 2 deletions folktexts/acs/acs_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from folktexts.qa_interface import MultipleChoiceQA as _MultipleChoiceQA

from . import acs_columns
from .acs_tasks import _acs_columns_map
from .acs_tasks import acs_columns_map

# Map of numeric ACS questions
acs_numeric_qa_map: dict[str, object] = {
Expand All @@ -25,7 +25,7 @@
# ... include all multiple-choice questions defined in the column descriptions
acs_multiple_choice_qa_map.update({
col_to_text.name: col_to_text.question
for col_to_text in _acs_columns_map.values()
for col_to_text in acs_columns_map.values()
if (
isinstance(col_to_text, ColumnToText)
and col_to_text._question is not None
Expand Down
34 changes: 30 additions & 4 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

# Map of ACS column names to ColumnToText objects
_acs_columns_map: dict[str, object] = {
acs_columns_map: dict[str, object] = {
col_mapper.name: col_mapper
for col_mapper in acs_columns.__dict__.values()
if isinstance(col_mapper, _ColumnToText)
Expand All @@ -37,13 +37,39 @@ class ACSTaskMetadata(TaskMetadata):
# The ACS task object from the folktables package
folktables_obj: BasicProblem = None

@classmethod
def make_task(
cls,
name: str,
description: str,
features: list[str],
target: str,
target_threshold: Threshold = None,
sensitive_attribute: str = None,
) -> ACSTaskMetadata:
# Validate columns mappings exist
if not all(col in acs_columns_map for col in (features + [target])):
raise ValueError("Not all columns have mappings to text descriptions.")

# TODO: CHECK IF THIS WORKS!
return cls(
name=name,
description=description,
features=features,
target=target,
cols_to_text=acs_columns_map,
target_threshold=target_threshold,
sensitive_attribute=sensitive_attribute,
folktables_obj=None,
)

@classmethod
def make_folktables_task(
cls,
name: str,
description: str,
target_threshold: Threshold = None,
) -> "ACSTaskMetadata":
) -> ACSTaskMetadata:

# Get the task object from the folktables package
try:
Expand All @@ -56,7 +82,7 @@ def make_folktables_task(
description=description,
features=folktables_task.features,
target=folktables_task.target,
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
sensitive_attribute=folktables_task.group,
target_threshold=target_threshold,
folktables_obj=folktables_task,
Expand Down Expand Up @@ -125,6 +151,6 @@ def __hash__(self) -> int:
*acs_travel_time_task.features,
})),
target="HINS2",
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
target_threshold=acs_health_insurance_threshold,
)
1 change: 1 addition & 0 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def run(self, results_root_dir: str | Path, fit_threshold: int | bool = 0) -> fl
predictions_save_path=test_predictions_save_path,
labels=y_test, # used only to save alongside predictions in disk
)
self._y_test_scores = self.llm_clf._get_positive_class_scores(self._y_test_scores)

# If requested, fit the threshold on a small portion of the train set
if fit_threshold:
Expand Down
16 changes: 5 additions & 11 deletions folktexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def fit(self, X, y, *, false_pos_cost=1.0, false_neg_cost=1.0, **kwargs):
"""Uses the provided data sample to fit the prediction threshold."""

# Compute risk estimates for the data
y_pred_scores = self.predict_proba(X, **kwargs)
if len(y_pred_scores.shape) > 1:
y_pred_scores = y_pred_scores[:, -1]
y_pred_scores = self._get_positive_class_scores(
self.predict_proba(X, **kwargs)
)

# Compute the best threshold for the given data
self.threshold = compute_best_threshold(
Expand Down Expand Up @@ -172,7 +172,7 @@ def _make_predictions_multiclass(pos_class_scores: np.ndarray) -> np.ndarray:

def predict(
self,
data: pd.DataFrame | Dataset,
data: pd.DataFrame,
batch_size: int = None,
context_size: int = None,
predictions_save_path: str | Path = None,
Expand All @@ -186,13 +186,7 @@ def predict(
predictions_save_path=predictions_save_path,
labels=labels,
)
if isinstance(risk_scores, dict):
return {
data_type: (self._get_positive_class_scores(data_scores) >= self.threshold).astype(int)
for data_type, data_scores in risk_scores.items()
}
else:
return (self._get_positive_class_scores(risk_scores) >= self.threshold).astype(int)
return (self._get_positive_class_scores(risk_scores) >= self.threshold).astype(int)

def _load_predictions_from_disk(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

# Path to the executable script to run
# EXECUTABLE_PATH = Path(__file__).parent.resolve() / "run_acs_benchmark.py"
EXECUTABLE_PATH = Path(__file__).parent.resolve() / "eval_feature_importance.py"
# TODO ^ pass executable path as cmd line arg
logging.warning(f"Using executable path: {EXECUTABLE_PATH}")
# EXECUTABLE_PATH = Path(__file__).parent.resolve() / "eval_feature_importance.py"

##################
# Global configs #
Expand Down Expand Up @@ -76,7 +74,8 @@


# Function that defines common settings among all LLM-as-clf experiments
def make_llm_as_clf_experiment(
def make_llm_clf_experiment(
executable_path: str,
model_name: str,
task: str,
results_root_dir: str,
Expand Down Expand Up @@ -111,7 +110,7 @@ def make_llm_as_clf_experiment(

# Define experiment
exp = Experiment(
executable_path=EXECUTABLE_PATH.as_posix(),
executable_path=executable_path,
kwargs=dict(
model=model_path,
task=task,
Expand Down Expand Up @@ -146,6 +145,13 @@ def setup_arg_parser() -> argparse.ArgumentParser:
# Init parser
parser = argparse.ArgumentParser(description="Launch experiments to evaluate LLMs as classifiers.")

parser.add_argument(
"--executable-path",
type=str,
help="[string] Path to the executable script to run.",
required=True,
)

parser.add_argument(
"--results-root-dir",
type=str,
Expand Down Expand Up @@ -194,11 +200,13 @@ def main():
# Parse extra kwargs
from ._utils import cmd_line_args_to_kwargs
extra_kwargs = cmd_line_args_to_kwargs(extra_kwargs)
# TODO: use the run_acs_benchmark.py parser to parse extra kwargs
# with `setup_arg_parser().convert_arg_line_to_args(extra_kwargs)` !!!

# Prepare command-line arguments
models = args.model or LLM_MODELS
tasks = args.task or ACS_TASKS
executable_path = Path(args.executable_path).resolve()
if not executable_path.exists() or not executable_path.is_file():
raise FileNotFoundError(f"Executable script not found at '{executable_path}'.")

# Load experiment from JSON file if provided
if args.experiment_json:
Expand All @@ -209,7 +217,8 @@ def main():
# Otherwise, run all experiments planned
else:
all_experiments = [
make_llm_as_clf_experiment(
make_llm_clf_experiment(
executable_path=executable_path.as_posix(),
model_name=model,
task=task,
results_root_dir=args.results_root_dir,
Expand Down
6 changes: 5 additions & 1 deletion folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def create_task_with_feature_subset(self, feature_subset: Iterable[str]):

# Check if features are a subset of the original features
if not set(feature_subset).issubset(self.features):
raise ValueError("`feature_subset` must be a subset of the original features.")
raise ValueError(
f"`feature_subset` must be a subset of the original features; "
f"following features are not in the original set: "
f"{set(self.features) - set(feature_subset)}"
)

# Return new TaskMetadata object
return dataclasses.replace(
Expand Down
2 changes: 1 addition & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
folktables~=0.0.12
scikit-learn>=1.3
numpy
pandas
tqdm
scikit-learn
accelerate
transformers
torch
Expand Down

0 comments on commit 2029a8c

Please sign in to comment.