Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 21, 2024
1 parent 0142332 commit 35a4151
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 791 deletions.
13 changes: 13 additions & 0 deletions folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime
from functools import partial, reduce
from pathlib import Path
from contextlib import contextmanager

import numpy as np

Expand Down Expand Up @@ -89,3 +90,15 @@ def hash_function(func, length: int = 8) -> str:
def standardize_path(path: str | Path) -> str:
"""Represents a posix path as a standardized string."""
return Path(path).expanduser().resolve().as_posix()


@contextmanager
def suppress_logging(new_level):
"""Suppresses all logs of a given level within a context block."""
logger = logging.getLogger()
previous_level = logger.level
logger.setLevel(new_level)
try:
yield
finally:
logger.setLevel(previous_level)
1 change: 1 addition & 0 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@

# HINS2: Health Insurance Coverage through Private Company (Thresholded)
acs_health_ins_2_qa = MultipleChoiceQA(
column=acs_health_insurance_threshold.apply_to_column_name("HINS2"),
text="Has this person purchased health insurance directly from an insurance company?",
choices=[
Choice("Yes, this person has health insurance through a private company", 1),
Expand Down
5 changes: 2 additions & 3 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,14 @@ def __init__(
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task

# Keep only rows used in this task
data = task_obj.folktables_obj._preprocess(data)
if isinstance(task_obj, ACSTaskMetadata) and task_obj.folktables_obj is not None:
data = task_obj.folktables_obj._preprocess(data)

# Threshold the target column if necessary
# > use standardized ACS naming convention
if task_obj.target_threshold is not None:
thresholded_target = task_obj.get_target()
if thresholded_target not in data.columns:
# data[thresholded_target] = (data[task_obj.target] >= task_obj.target_threshold).astype(int)
import ipdb; ipdb.set_trace() # TODO: check this works!
data[thresholded_target] = task_obj.target_threshold.apply_to_column_data(data[task_obj.target])

super().__init__(
Expand Down
15 changes: 8 additions & 7 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@ def __hash__(self) -> int:

# Dummy/test ACS task to predict health insurance coverage using all other available features
acs_full_task = TaskMetadata(
name="ACSHealthInsurance-full",
name="ACSHealthInsurance-test",
description=(
"predict whether an individual has purchased health insurance directly "
"from an insurance company (as opposed to being insured through an "
"employer, Medicare, Medicaid, or any other source)"
),
features=list(
reduce(
lambda t1, t2: set(t1.features) | set(t2.features),
[acs_income_task, acs_public_coverage_task, acs_mobility_task, acs_employment_task, acs_travel_time_task],
)
),
features=list({
*acs_income_task.features,
*acs_public_coverage_task.features,
*acs_mobility_task.features,
*acs_employment_task.features,
*acs_travel_time_task.features,
}),
target="HINS2",
cols_to_text=_acs_columns_map,
target_threshold=acs_health_insurance_threshold,
Expand Down
7 changes: 5 additions & 2 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

import copy
import logging
import warnings
from abc import ABC

import numpy as np
import pandas as pd

from ._utils import hash_dict, is_valid_number
from ._utils import hash_dict, is_valid_number, suppress_logging
from .task import TaskMetadata

DEFAULT_TEST_SIZE = 0.1
Expand Down Expand Up @@ -178,7 +179,9 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset":

def subsample(self, subsampling: float) -> "Dataset":
"""Create a new dataset whose samples are a fraction of this dataset."""
return copy.copy(self)._subsample_inplace(subsampling)
with suppress_logging(logging.WARNING):
self_copy = copy.copy(self)
return self_copy._subsample_inplace(subsampling)

def _filter_inplace(
self,
Expand Down
2 changes: 1 addition & 1 deletion folktexts/qa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class MultipleChoiceQA(QAInterface):

num_forward_passes: int = 1 # NOTE: overrides superclass default
choices: list[Choice] = dataclasses.field(default_factory=list)
_answer_keys_source: list[str] = list(_ALPHABET)
_answer_keys_source: list[str] = dataclasses.field(default_factory=lambda: list(_ALPHABET))

def __post_init__(self):
if not self.choices:
Expand Down
Loading

0 comments on commit 35a4151

Please sign in to comment.