Skip to content

Commit

Permalink
added healthinsurance task
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 21, 2024
1 parent 62a7b15 commit 0142332
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 19 deletions.
43 changes: 37 additions & 6 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
acs_employment_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_publiccoverage_threshold,
acs_traveltime_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_health_insurance_threshold,
)

# Path to ACS codebook files
Expand Down Expand Up @@ -228,7 +229,7 @@

# PUBCOV: Public Health Coverage (Thresholded)
acs_pubcov_qa = MultipleChoiceQA(
column=acs_publiccoverage_threshold.apply_to_column_name("PUBCOV"),
column=acs_public_coverage_threshold.apply_to_column_name("PUBCOV"),
text="Does this person have public health insurance coverage?",
choices=[
Choice("Yes, person is covered by public health insurance", 1),
Expand All @@ -237,7 +238,7 @@
)

acs_pubcov_target_col = ColumnToText(
name=acs_publiccoverage_threshold.apply_to_column_name("PUBCOV"),
name=acs_public_coverage_threshold.apply_to_column_name("PUBCOV"),
short_description="public health coverage status",
value_map={
1: "Covered by public health insurance",
Expand Down Expand Up @@ -445,7 +446,7 @@

# JWMNP: Commute Time (Thresholded)
acs_commute_time_qa = MultipleChoiceQA(
column=acs_traveltime_threshold.apply_to_column_name("JWMNP"),
column=acs_travel_time_threshold.apply_to_column_name("JWMNP"),
text="What is this person's commute time?",
choices=[
Choice("Longer than 20 minutes", 1),
Expand All @@ -454,7 +455,7 @@
)

acs_travel_time_target_col = ColumnToText(
name=acs_traveltime_threshold.apply_to_column_name("JWMNP"),
name=acs_travel_time_threshold.apply_to_column_name("JWMNP"),
short_description="commute time",
question=acs_commute_time_qa,
use_value_map_only=True,
Expand Down Expand Up @@ -486,3 +487,33 @@
short_description="income-to-poverty ratio",
value_map=lambda x: f"{x:.2f}",
)

# HINS2: Health Insurance Coverage through Private Company
acs_health_ins_2_col = ColumnToText(
"HINS2",
short_description="acquired health insurance directly from an insurance company",
use_value_map_only=True,
value_map={
1: "Person has purchased insurance directly from an insurance company",
2: (
"Person has not purchased insurance directly from an insurance "
"company (is either uninsured or insured through another source)",
)
},
)

# HINS2: Health Insurance Coverage through Private Company (Thresholded)
acs_health_ins_2_qa = MultipleChoiceQA(
text="Has this person purchased health insurance directly from an insurance company?",
choices=[
Choice("Yes, this person has health insurance through a private company", 1),
Choice("No, this person either has insurance through other means or is uninsured", 0),
],
)

acs_health_ins_2_target_col = ColumnToText(
name=acs_health_insurance_threshold.apply_to_column_name("HINS2"),
short_description="acquired health insurance directly from an insurance company",
question=acs_health_ins_2_qa,
use_value_map_only=True,
)
24 changes: 21 additions & 3 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from folktables.load_acs import state_list

from ..dataset import Dataset
from .acs_tasks import ACSTaskMetadata # noqa # load ACS tasks
from .acs_tasks import ACSTaskMetadata

DEFAULT_ACS_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand All @@ -32,8 +32,26 @@ def __init__(
seed: int = DEFAULT_SEED,
**kwargs,
):
"""Construct an ACSDataset object.
Parameters
----------
task : str | ACSTaskMetadata
The name of the ACS task or the task object itself.
cache_dir : str | Path, optional
The directory where ACS data is (or will be) saved to, by default
uses DEFAULT_DATA_DIR.
survey_year : str, optional
The year from which to load survey data, by default DEFAULT_SURVEY_YEAR.
horizon : str, optional
The time horizon of survey data to load, by default DEFAULT_SURVEY_HORIZON.
survey : str, optional
The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT.
seed : int, optional
The random seed, by default DEFAULT_SEED.
"""
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_ACS_DATA_DIR).expanduser().resolve() / "folktables"
cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables"
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)
Expand Down
37 changes: 33 additions & 4 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import asdict, dataclass
from functools import reduce

import folktables
from folktables import BasicProblem
Expand All @@ -16,8 +17,10 @@
acs_employment_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_publiccoverage_threshold,
acs_traveltime_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_income_poverty_ratio_threshold,
acs_health_insurance_threshold,
)

# Map of ACS column names to ColumnToText objects
Expand Down Expand Up @@ -79,7 +82,7 @@ def __hash__(self) -> int:
acs_public_coverage_task = ACSTaskMetadata.make_folktables_task(
name="ACSPublicCoverage",
description="predict whether an individual is covered by public health insurance",
target_threshold=acs_publiccoverage_threshold,
target_threshold=acs_public_coverage_threshold,
)

acs_mobility_task = ACSTaskMetadata.make_folktables_task(
Expand All @@ -97,5 +100,31 @@ def __hash__(self) -> int:
acs_travel_time_task = ACSTaskMetadata.make_folktables_task(
name="ACSTravelTime",
description="predict whether an individual has a commute to work that is longer than 20 minutes",
target_threshold=acs_traveltime_threshold,
target_threshold=acs_travel_time_threshold,
)

acs_income_poverty_ratio_task = ACSTaskMetadata.make_folktables_task(
name="ACSIncomePovertyRatio",
description="predict whether an individual's income-to-poverty ratio is below 2.5",
target_threshold=acs_income_poverty_ratio_threshold,
)


# Dummy/test ACS task to predict health insurance coverage using all other available features
acs_full_task = TaskMetadata(
name="ACSHealthInsurance-full",
description=(
"predict whether an individual has purchased health insurance directly "
"from an insurance company (as opposed to being insured through an "
"employer, Medicare, Medicaid, or any other source)"
),
features=list(
reduce(
lambda t1, t2: set(t1.features) | set(t2.features),
[acs_income_task, acs_public_coverage_task, acs_mobility_task, acs_employment_task, acs_travel_time_task],
)
),
target="HINS2",
cols_to_text=_acs_columns_map,
target_threshold=acs_health_insurance_threshold,
)
10 changes: 8 additions & 2 deletions folktexts/acs/acs_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
acs_income_threshold = Threshold(50_000, ">")

# ACSPublicCoverage task
acs_publiccoverage_threshold = Threshold(1, "==")
acs_public_coverage_threshold = Threshold(1, "==")

# ACSMobility task
acs_mobility_threshold = Threshold(1, "==")
Expand All @@ -15,4 +15,10 @@
acs_employment_threshold = Threshold(1, "==")

# ACSTravelTime task
acs_traveltime_threshold = Threshold(20, ">")
acs_travel_time_threshold = Threshold(20, ">")

# ACSIncomePovertyRatio task
acs_income_poverty_ratio_threshold = Threshold(250, "<")

# ACSHealthInsurance task
acs_health_insurance_threshold = Threshold(1, "==")
4 changes: 2 additions & 2 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,9 @@ def make_benchmark(

# Construct the LLMClassifier object
llm_inference_kwargs = {"correct_order_bias": config.correct_order_bias}
if config.batch_size:
if config.batch_size is not None:
llm_inference_kwargs["batch_size"] = config.batch_size
if config.context_size:
if config.context_size is not None:
llm_inference_kwargs["context_size"] = config.context_size

llm_clf = LLMClassifier(
Expand Down
2 changes: 1 addition & 1 deletion folktexts/qa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class MultipleChoiceQA(QAInterface):

num_forward_passes: int = 1 # NOTE: overrides superclass default
choices: list[Choice] = dataclasses.field(default_factory=list)
_answer_keys_source: list[str] = _ALPHABET
_answer_keys_source: list[str] = list(_ALPHABET)

def __post_init__(self):
if not self.choices:
Expand Down
10 changes: 9 additions & 1 deletion folktexts/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@

@dataclasses.dataclass(frozen=True)
class Threshold:
"""A class to represent a threshold value and its comparison operator."""
"""A class to represent a threshold value and its comparison operator.
Attributes
----------
value : float | int
The threshold value to compare against.
op : str
The comparison operator to use. One of '>', '<', '>=', '<=', '=='.
"""
value: float | int
op: str

Expand Down

0 comments on commit 0142332

Please sign in to comment.