Skip to content

Commit

Permalink
creating arbitrary acs tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent ed4ff03 commit b385138
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
4 changes: 2 additions & 2 deletions folktexts/acs/acs_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from folktexts.qa_interface import MultipleChoiceQA as _MultipleChoiceQA

from . import acs_columns
from .acs_tasks import _acs_columns_map
from .acs_tasks import acs_columns_map

# Map of numeric ACS questions
acs_numeric_qa_map: dict[str, object] = {
Expand All @@ -25,7 +25,7 @@
# ... include all multiple-choice questions defined in the column descriptions
acs_multiple_choice_qa_map.update({
col_to_text.name: col_to_text.question
for col_to_text in _acs_columns_map.values()
for col_to_text in acs_columns_map.values()
if (
isinstance(col_to_text, ColumnToText)
and col_to_text._question is not None
Expand Down
34 changes: 30 additions & 4 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

# Map of ACS column names to ColumnToText objects
_acs_columns_map: dict[str, object] = {
acs_columns_map: dict[str, object] = {
col_mapper.name: col_mapper
for col_mapper in acs_columns.__dict__.values()
if isinstance(col_mapper, _ColumnToText)
Expand All @@ -37,13 +37,39 @@ class ACSTaskMetadata(TaskMetadata):
# The ACS task object from the folktables package
folktables_obj: BasicProblem = None

@classmethod
def make_task(
cls,
name: str,
description: str,
features: list[str],
target: str,
target_threshold: Threshold = None,
sensitive_attribute: str = None,
) -> ACSTaskMetadata:
# Validate columns mappings exist
if not all(col in acs_columns_map for col in (features + [target])):
raise ValueError("Not all columns have mappings to text descriptions.")

# TODO: CHECK IF THIS WORKS!
return cls(
name=name,
description=description,
features=features,
target=target,
cols_to_text=acs_columns_map,
target_threshold=target_threshold,
sensitive_attribute=sensitive_attribute,
folktables_obj=None,
)

@classmethod
def make_folktables_task(
cls,
name: str,
description: str,
target_threshold: Threshold = None,
) -> "ACSTaskMetadata":
) -> ACSTaskMetadata:

# Get the task object from the folktables package
try:
Expand All @@ -56,7 +82,7 @@ def make_folktables_task(
description=description,
features=folktables_task.features,
target=folktables_task.target,
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
sensitive_attribute=folktables_task.group,
target_threshold=target_threshold,
folktables_obj=folktables_task,
Expand Down Expand Up @@ -125,6 +151,6 @@ def __hash__(self) -> int:
*acs_travel_time_task.features,
})),
target="HINS2",
cols_to_text=_acs_columns_map,
cols_to_text=acs_columns_map,
target_threshold=acs_health_insurance_threshold,
)
2 changes: 1 addition & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
folktables~=0.0.12
scikit-learn>=1.3
numpy
pandas
tqdm
scikit-learn
accelerate
transformers
torch
Expand Down

0 comments on commit b385138

Please sign in to comment.