Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

computing feature importance #3

Merged
merged 8 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions folktexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._version import __version__, __version_info__
from .task import TaskMetadata
from .acs import ACSDataset, ACSTaskMetadata
from .benchmark import BenchmarkConfig, CalibrationBenchmark
from .classifier import LLMClassifier
from .acs import ACSDataset, ACSTaskMetadata
from .task import TaskMetadata
2 changes: 1 addition & 1 deletion folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import json
import logging
import operator
from contextlib import contextmanager
from datetime import datetime
from functools import partial, reduce
from pathlib import Path
from contextlib import contextmanager

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from ._utils import parse_pums_code
from .acs_thresholds import (
acs_employment_threshold,
acs_health_insurance_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_health_insurance_threshold,
)

# Path to ACS codebook files
Expand Down
12 changes: 5 additions & 7 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_VAL_SIZE = 0.1
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand Down Expand Up @@ -116,17 +116,15 @@ def task(self, new_task: ACSTaskMetadata):
# Parse data rows for new ACS task
self._data = self._parse_task_data(self._full_acs_data, new_task)

# Re-Make train/test/val split
# Re-make train/test/val split
self._train_indices, self._test_indices, self._val_indices = (
self._make_train_test_val_split(
self._data, self.test_size, self.val_size, self._rng)
)

# Check if task columns are in the data
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={new_task.features}, target={new_task.get_target()}")
# Check if sub-sampling is necessary (it's applied only to train/test/val indices)
if self.subsampling is not None:
self._subsample_train_test_val_indices(self.subsampling)

self._task = new_task

Expand Down
6 changes: 3 additions & 3 deletions folktexts/acs/acs_questions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""A collection of instantiated ACS column objects and ACS tasks."""
from __future__ import annotations

from folktexts.col_to_text import ColumnToText
from folktexts.qa_interface import DirectNumericQA as _DirectNumericQA
from folktexts.qa_interface import MultipleChoiceQA as _MultipleChoiceQA
from folktexts.col_to_text import ColumnToText

from . import acs_columns
from .acs_tasks import _acs_columns_map
from .acs_tasks import acs_columns_map

# Map of numeric ACS questions
acs_numeric_qa_map: dict[str, object] = {
Expand All @@ -25,7 +25,7 @@
# ... include all multiple-choice questions defined in the column descriptions
acs_multiple_choice_qa_map.update({
col_to_text.name: col_to_text.question
for col_to_text in _acs_columns_map.values()
for col_to_text in acs_columns_map.values()
if (
isinstance(col_to_text, ColumnToText)
and col_to_text._question is not None
Expand Down
37 changes: 31 additions & 6 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from . import acs_columns
from .acs_thresholds import (
acs_employment_threshold,
acs_health_insurance_threshold,
acs_income_poverty_ratio_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_income_poverty_ratio_threshold,
acs_health_insurance_threshold,
)

# Map of ACS column names to ColumnToText objects
_acs_columns_map: dict[str, object] = {
acs_columns_map: dict[str, object] = {
col_mapper.name: col_mapper
for col_mapper in acs_columns.__dict__.values()
if isinstance(col_mapper, _ColumnToText)
Expand All @@ -37,13 +37,38 @@ class ACSTaskMetadata(TaskMetadata):
# The ACS task object from the folktables package
folktables_obj: BasicProblem = None

@classmethod
def make_task(
cls,
name: str,
description: str,
features: list[str],
target: str,
target_threshold: Threshold = None,
sensitive_attribute: str = None,
) -> ACSTaskMetadata:
# Validate columns mappings exist
if not all(col in acs_columns_map for col in (features + [target])):
raise ValueError("Not all columns have mappings to text descriptions.")

return cls(
name=name,
description=description,
features=features,
target=target,
cols_to_text=acs_columns_map,
target_threshold=target_threshold,
sensitive_attribute=sensitive_attribute,
folktables_obj=None,
)

@classmethod
def make_folktables_task(
cls,
name: str,
description: str,
target_threshold: Threshold = None,
) -> "ACSTaskMetadata":
) -> ACSTaskMetadata:

# Get the task object from the folktables package
try:
Expand All @@ -56,7 +81,7 @@ def make_folktables_task(
description=description,
features=folktables_task.features,
target=folktables_task.target,
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
sensitive_attribute=folktables_task.group,
target_threshold=target_threshold,
folktables_obj=folktables_task,
Expand Down Expand Up @@ -125,6 +150,6 @@ def __hash__(self) -> int:
*acs_travel_time_task.features,
})),
target="HINS2",
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
target_threshold=acs_health_insurance_threshold,
)
1 change: 1 addition & 0 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def run(self, results_root_dir: str | Path, fit_threshold: int | bool = 0) -> fl
predictions_save_path=test_predictions_save_path,
labels=y_test, # used only to save alongside predictions in disk
)
self._y_test_scores = self.llm_clf._get_positive_class_scores(self._y_test_scores)

# If requested, fit the threshold on a small portion of the train set
if fit_threshold:
Expand Down
Loading
Loading