Skip to content

Commit

Permalink
created Threshold class
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 21, 2024
1 parent b0bca6d commit 8597ef1
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 49 deletions.
6 changes: 0 additions & 6 deletions folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,3 @@ def hash_function(func, length: int = 8) -> str:
def standardize_path(path: str | Path) -> str:
"""Represents a posix path as a standardized string."""
return Path(path).expanduser().resolve().as_posix()


def get_thresholded_column_name(column_name: str, threshold: float | int, *, op: str = ">") -> str:
"""Standardizes naming of thresholded columns."""
threshold_str = f"{threshold:.2f}".replace(".", "_") if isinstance(threshold, float) else str(threshold)
return f"{column_name}_{op}_{threshold_str}"
75 changes: 43 additions & 32 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@

from ..col_to_text import ColumnToText
from ..qa_interface import Choice, DirectNumericQA, MultipleChoiceQA
from .._utils import get_thresholded_column_name
from ._utils import parse_pums_code
from .acs_thresholds import (
acs_income_threshold,
acs_publiccoverage_threshold,
acs_mobility_threshold,
acs_employment_threshold,
acs_traveltime_threshold,
)

# Path to ACS codebook files
ACS_OCCP_FILE = Path(__file__).parent / "data" / "OCCP-codes-acs.txt"
Expand Down Expand Up @@ -173,8 +179,9 @@
value_map=lambda x: f"${int(x):,}",
)

# PINCP: Yearly Income (Thresholded)
acs_income_qa = MultipleChoiceQA(
column=get_thresholded_column_name("PINCP", 50_000),
column=acs_income_threshold.apply_to_column_name("PINCP"),
text="What is this person's estimated yearly income?",
choices=[
Choice("Below $50,000", 0),
Expand All @@ -183,7 +190,7 @@
)

acs_income_numeric_qa = DirectNumericQA(
column=get_thresholded_column_name("PINCP", 50_000),
column=acs_income_threshold.apply_to_column_name("PINCP"),
text=(
"What is the probability that this person's estimated yearly income is "
"above $50,000 ?"
Expand All @@ -193,47 +200,48 @@
)

acs_income_target_col = ColumnToText(
name=get_thresholded_column_name("PINCP", 50_000),
name=acs_income_threshold.apply_to_column_name("PINCP"),
short_description="yearly income",
missing_value_fill="N/A (less than 15 years old)",
question=acs_income_qa,
)

"""
acs_income_brackets = ColumnToText(
"PINCP_brackets",
short_description="yearly income",
missing_value_fill="N/A (less than 15 years old)",
question=MultipleChoiceQA(
column="PINCP_brackets",
text="What is this person's estimated yearly income?",
choices=[
Choice("Less than $25,000", data_value="(0.0, 25000.0]", numeric_value=12_500),
Choice("Between $25,000 and $50,000", data_value="(25000.0, 50000.0]", numeric_value=37_500),
Choice("Between $50,000 and $100,000", data_value="(50000.0, 100000.0]", numeric_value=75_000),
Choice("Above $100,000", data_value="(100000.0, inf]", numeric_value=150_000),
],
),
# PUBCOV: Public Health Coverage (Original)
acs_pubcov_og_qa = MultipleChoiceQA(
column="PUBCOV",
text="Does this person have public health insurance coverage?",
choices=[
Choice("Yes, person is covered by public health insurance", 1),
Choice("No, person is not covered by public health insurance", 2), # NOTE: value=2 for no public coverage!
],
)
"""

# PUBCOV: Public Health Coverage
# NOTE: in folktables the negative choice has value `0` instead of `2`
acs_pubcov_og_target_col = ColumnToText(
"PUBCOV",
short_description="public health coverage status",
value_map={
1: "Covered by public health insurance",
2: "Not covered by public health insurance",
},
question=acs_pubcov_og_qa,
)

# PUBCOV: Public Health Coverage (Thresholded)
acs_pubcov_qa = MultipleChoiceQA(
column="PUBCOV",
column=acs_publiccoverage_threshold.apply_to_column_name("PUBCOV"),
text="Does this person have public health insurance coverage?",
choices=[
Choice("Yes, person is covered by public health insurance", 1),
Choice("No, person is not covered by public health insurance", 2),
Choice("No, person is not covered by public health insurance", 0), # NOTE: value=0 for no public coverage!
],
)

acs_pubcov_target_col = ColumnToText(
"PUBCOV",
name=acs_publiccoverage_threshold.apply_to_column_name("PUBCOV"),
short_description="public health coverage status",
value_map={
1: "Covered by public health insurance",
2: "Not covered by public health insurance",
0: "Not covered by public health insurance",
},
question=acs_pubcov_qa,
)
Expand Down Expand Up @@ -289,8 +297,9 @@
},
)

# MIG: Mobility Status (Thresholded)
acs_mobility_qa = MultipleChoiceQA(
column=get_thresholded_column_name("MIG", 1, op="=="),
column=acs_mobility_threshold.apply_to_column_name("MIG"),
text="Has this person moved in the last year?",
choices=[
Choice("No, person has lived in the same house for the last year", 1),
Expand All @@ -299,7 +308,7 @@
)

acs_mobility_target_col = ColumnToText(
name=get_thresholded_column_name("MIG", 1, op="=="),
name=acs_mobility_threshold.apply_to_column_name("MIG"),
short_description="mobility status over the last year",
question=acs_mobility_qa,
use_value_map_only=True,
Expand Down Expand Up @@ -386,8 +395,9 @@
missing_value_fill="N/A (less than 16 years old)",
)

# ESR: Employment Status (Thresholded)
acs_employment_qa = MultipleChoiceQA(
column=get_thresholded_column_name("ESR", 1, op="=="),
column=acs_employment_threshold.apply_to_column_name("ESR"),
text="What is this person's employment status?",
choices=[
Choice("Employed civilian", 1),
Expand All @@ -396,7 +406,7 @@
)

acs_employment_target_col = ColumnToText(
name=get_thresholded_column_name("ESR", 1, op="=="),
name=acs_employment_threshold.apply_to_column_name("ESR"),
short_description="employment status",
question=acs_employment_qa,
use_value_map_only=True,
Expand Down Expand Up @@ -433,8 +443,9 @@
missing_value_fill="N/A (not a worker, or worker who worked at home)",
)

# JWMNP: Commute Time (Thresholded)
acs_commute_time_qa = MultipleChoiceQA(
column=get_thresholded_column_name("JWMNP", 20),
column=acs_traveltime_threshold.apply_to_column_name("JWMNP"),
text="What is this person's commute time?",
choices=[
Choice("Longer than 20 minutes", 1),
Expand All @@ -443,7 +454,7 @@
)

acs_travel_time_target_col = ColumnToText(
name=get_thresholded_column_name("JWMNP", 20),
name=acs_traveltime_threshold.apply_to_column_name("JWMNP"),
short_description="commute time",
question=acs_commute_time_qa,
use_value_map_only=True,
Expand Down
17 changes: 14 additions & 3 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from .._utils import hash_dict
from ..col_to_text import ColumnToText as _ColumnToText
from ..task import TaskMetadata
from ..threshold import Threshold
from . import acs_columns
from .acs_thresholds import (
acs_employment_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_publiccoverage_threshold,
acs_traveltime_threshold,
)

# Map of ACS column names to ColumnToText objects
_acs_columns_map: dict[str, object] = {
Expand All @@ -32,7 +40,7 @@ def make_folktables_task(
cls,
name: str,
description: str,
target_threshold: float | int = None,
target_threshold: Threshold = None,
) -> "ACSTaskMetadata":

# Get the task object from the folktables package
Expand Down Expand Up @@ -65,26 +73,29 @@ def __hash__(self) -> int:
acs_income_task = ACSTaskMetadata.make_folktables_task(
name="ACSIncome",
description="predict whether an individual's income is above $50,000",
target_threshold=50000,
target_threshold=acs_income_threshold,
)

acs_public_coverage_task = ACSTaskMetadata.make_folktables_task(
name="ACSPublicCoverage",
description="predict whether an individual is covered by public health insurance",
target_threshold=acs_publiccoverage_threshold,
)

acs_mobility_task = ACSTaskMetadata.make_folktables_task(
name="ACSMobility",
description="predict whether an individual had the same residential address one year ago",
target_threshold=acs_mobility_threshold,
)

acs_employment_task = ACSTaskMetadata.make_folktables_task(
name="ACSEmployment",
description="predict whether an individual is employed",
target_threshold=acs_employment_threshold,
)

acs_travel_time_task = ACSTaskMetadata.make_folktables_task(
name="ACSTravelTime",
description="predict whether an individual has a commute to work that is longer than 20 minutes",
target_threshold=20,
target_threshold=acs_traveltime_threshold,
)
19 changes: 19 additions & 0 deletions folktexts/acs/acs_thresholds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Threshold instances for ACS / folktables tasks.
"""
from folktexts.threshold import Threshold


# ACSIncome task
acs_income_threshold = Threshold(50_000, ">")

# ACSPublicCoverage task
acs_publiccoverage_threshold = Threshold(1, "==")

# ACSMobility task
acs_mobility_threshold = Threshold(1, "==")

# ACSEmployment task
acs_employment_threshold = Threshold(1, "==")

# ACSTravelTime task
acs_traveltime_threshold = Threshold(20, ">")
3 changes: 1 addition & 2 deletions folktexts/cli/launch_acs_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
"""
Python script to launch condor jobs for all experiments.
"""Helper script to launch htcondor jobs for all experiments.
"""
import argparse
import math
Expand Down
3 changes: 1 addition & 2 deletions folktexts/cli/rerun_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
"""
Python script to re-run a single experiment locally.
"""Helper script to re-run a single experiment locally.
"""
from argparse import ArgumentParser
from subprocess import call
Expand Down
9 changes: 5 additions & 4 deletions folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import pandas as pd

from ._utils import hash_dict, get_thresholded_column_name
from ._utils import hash_dict
from .col_to_text import ColumnToText
from .threshold import Threshold


@dataclass
Expand All @@ -32,7 +33,7 @@ class TaskMetadata:
A mapping between column names and their textual descriptions.
sensitive_attribute : str, optional
The name of the column used as the sensitive attribute data (if provided).
target_threshold : float, optional
target_threshold : Threshold, optional
The threshold used to binarize the target column (if provided).
"""
name: str
Expand All @@ -41,7 +42,7 @@ class TaskMetadata:
target: str
cols_to_text: dict[str, ColumnToText]
sensitive_attribute: str = None
target_threshold: float | int = None
target_threshold: Threshold = None

# Class-level task storage
_tasks: ClassVar[dict[str, "TaskMetadata"]] = field(default={}, init=False, repr=False)
Expand All @@ -65,7 +66,7 @@ def get_target(self) -> str:
if self.target_threshold is None:
return self.target
else:
return get_thresholded_column_name(self.target, self.target_threshold)
return self.target_threshold.apply_to_column_name(self.target)

@classmethod
def get_task(cls, name: str) -> "TaskMetadata":
Expand Down
43 changes: 43 additions & 0 deletions folktexts/threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Helper function for defining binarization thresholds.
"""
from __future__ import annotations
import operator
import dataclasses
from typing import ClassVar

import pandas as pd


@dataclasses.dataclass(frozen=True)
class Threshold:
"""A class to represent a threshold value and its comparison operator."""
value: float | int
op: str

valid_ops: ClassVar[dict] = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
}

def __post_init__(self):
if self.op not in self.valid_ops.keys():
raise ValueError(f"Invalid comparison operator '{self.op}'.")

def __str__(self):
return f"{self.op}{self.value}"

def apply_to_column_data(self, data: float | int | pd.Series) -> bool:
"""Applies the threshold operation to a pandas Series or scalar value."""
if isinstance(data, pd.Series):
return self.valid_ops[self.op](data, self.value).astype(int)
elif isinstance(data, (float, int)):
return int(self.valid_ops[self.op](data, self.value))
else:
raise TypeError(f"Invalid data type '{type(data)}'.")

def apply_to_column_name(self, column_name: str) -> str:
"""Standardizes naming of thresholded columns."""
return column_name + str(self)

0 comments on commit 8597ef1

Please sign in to comment.