Skip to content

Commit

Permalink
hash is now deterministic :)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 21, 2024
1 parent 35a4151 commit 005088d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
40 changes: 26 additions & 14 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@
acs_income_qa = MultipleChoiceQA(
column=acs_income_threshold.apply_to_column_name("PINCP"),
text="What is this person's estimated yearly income?",
choices=[
choices=(
Choice("Below $50,000", 0),
Choice("Above $50,000", 1),
],
),
)

acs_income_numeric_qa = DirectNumericQA(
Expand All @@ -211,10 +211,10 @@
acs_pubcov_og_qa = MultipleChoiceQA(
column="PUBCOV",
text="Does this person have public health insurance coverage?",
choices=[
choices=(
Choice("Yes, person is covered by public health insurance", 1),
Choice("No, person is not covered by public health insurance", 2), # NOTE: value=2 for no public coverage!
],
),
)

acs_pubcov_og_target_col = ColumnToText(
Expand All @@ -231,10 +231,10 @@
acs_pubcov_qa = MultipleChoiceQA(
column=acs_public_coverage_threshold.apply_to_column_name("PUBCOV"),
text="Does this person have public health insurance coverage?",
choices=[
choices=(
Choice("Yes, person is covered by public health insurance", 1),
Choice("No, person is not covered by public health insurance", 0), # NOTE: value=0 for no public coverage!
],
),
)

acs_pubcov_target_col = ColumnToText(
Expand Down Expand Up @@ -302,10 +302,10 @@
acs_mobility_qa = MultipleChoiceQA(
column=acs_mobility_threshold.apply_to_column_name("MIG"),
text="Has this person moved in the last year?",
choices=[
choices=(
Choice("No, person has lived in the same house for the last year", 1),
Choice("Yes, person has moved in the last year", 0),
],
),
)

acs_mobility_target_col = ColumnToText(
Expand Down Expand Up @@ -400,10 +400,10 @@
acs_employment_qa = MultipleChoiceQA(
column=acs_employment_threshold.apply_to_column_name("ESR"),
text="What is this person's employment status?",
choices=[
choices=(
Choice("Employed civilian", 1),
Choice("Unemployed or in the military", 0),
],
),
)

acs_employment_target_col = ColumnToText(
Expand Down Expand Up @@ -448,10 +448,10 @@
acs_commute_time_qa = MultipleChoiceQA(
column=acs_travel_time_threshold.apply_to_column_name("JWMNP"),
text="What is this person's commute time?",
choices=[
choices=(
Choice("Longer than 20 minutes", 1),
Choice("Less than 20 minutes", 0),
],
),
)

acs_travel_time_target_col = ColumnToText(
Expand Down Expand Up @@ -502,14 +502,26 @@
},
)

# GCL: Grandparent Living with Grandchildren
acs_gcl_col = ColumnToText(
"GCL",
short_description="grandparent living with grandchildren",
use_value_map_only=True,
value_map={
1: "Household includes grandparent living with grandchildren",
2: "Household does not include grandparents living with grandchildren",
},
missing_value_fill="N/A (less than 30 years old, or living in institutional group quarters)",
)

# HINS2: Health Insurance Coverage through Private Company (Thresholded)
acs_health_ins_2_qa = MultipleChoiceQA(
column=acs_health_insurance_threshold.apply_to_column_name("HINS2"),
text="Has this person purchased health insurance directly from an insurance company?",
choices=[
choices=(
Choice("Yes, this person has health insurance through a private company", 1),
Choice("No, this person either has insurance through other means or is uninsured", 0),
],
),
)

acs_health_ins_2_target_col = ColumnToText(
Expand Down
2 changes: 1 addition & 1 deletion folktexts/col_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(

# Else, warn if both were provided (as they may use inconsistent value maps)
elif self._value_map is not None and self._question is not None:
logging.warning(
logging.info(
f"Got both `value_map` and `question` for column '{self.name}'. "
f"Please make sure value mappings are consistent.")

Expand Down
18 changes: 13 additions & 5 deletions folktexts/qa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import torch
from transformers import AutoTokenizer

from ._utils import hash_dict

# Minimum probability density assigned to all valid answers
# > small models will be worse at using valid answers...
ANSWER_PROB_THRESHOLD = 0.1
Expand All @@ -24,7 +26,7 @@
_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


@dataclass(frozen=True, eq=True)
@dataclass(frozen=True)
class QAInterface(ABC):
"""An interface for a question-answering system."""

Expand Down Expand Up @@ -57,8 +59,11 @@ def get_answer_from_model_output(
"""
raise NotImplementedError

def __hash__(self) -> int:
return int(hash_dict(dataclasses.asdict(self)), 16)

@dataclass(frozen=True, eq=True)

@dataclass(frozen=True)
class DirectNumericQA(QAInterface):
"""Represents a direct numeric question.
Expand Down Expand Up @@ -188,15 +193,18 @@ class MultipleChoiceQA(QAInterface):
"""Represents a multiple-choice question and its answer keys."""

num_forward_passes: int = 1 # NOTE: overrides superclass default
choices: list[Choice] = dataclasses.field(default_factory=list)
_answer_keys_source: list[str] = dataclasses.field(default_factory=lambda: list(_ALPHABET))
choices: tuple[Choice] = dataclasses.field(default_factory=tuple)
_answer_keys_source: tuple[str] = dataclasses.field(default_factory=lambda: tuple(_ALPHABET))

def __post_init__(self):
if not self.choices:
raise ValueError("Choices must be provided.")
if len(self.choices) > len(self._answer_keys_source):
raise ValueError("Number of choices must be less than or equal to the number of answer keys.")

def __hash__(self) -> int:
return int(hash_dict(dataclasses.asdict(self)), 16)

@classmethod
def create_question_from_value_map(
cls,
Expand All @@ -206,7 +214,7 @@ def create_question_from_value_map(
**kwargs,
) -> "MultipleChoiceQA":
"""Constructs a question from a value map."""
choices = [Choice(text, str(value)) for value, text in value_map.items()]
choices = tuple(Choice(text, str(value)) for value, text in value_map.items())

# Set default question text
kwargs.setdefault("text", f"What is this person's {attribute}?")
Expand Down

0 comments on commit 005088d

Please sign in to comment.