Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent c73d4e3 commit d2d5caa
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 22 deletions.
4 changes: 2 additions & 2 deletions folktexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._version import __version__, __version_info__
from .task import TaskMetadata
from .acs import ACSDataset, ACSTaskMetadata
from .benchmark import BenchmarkConfig, CalibrationBenchmark
from .classifier import LLMClassifier
from .acs import ACSDataset, ACSTaskMetadata
from .task import TaskMetadata
2 changes: 1 addition & 1 deletion folktexts/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import json
import logging
import operator
from contextlib import contextmanager
from datetime import datetime
from functools import partial, reduce
from pathlib import Path
from contextlib import contextmanager

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from ._utils import parse_pums_code
from .acs_thresholds import (
acs_employment_threshold,
acs_health_insurance_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_health_insurance_threshold,
)

# Path to ACS codebook files
Expand Down
2 changes: 1 addition & 1 deletion folktexts/acs/acs_questions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""A collection of instantiated ACS column objects and ACS tasks."""
from __future__ import annotations

from folktexts.col_to_text import ColumnToText
from folktexts.qa_interface import DirectNumericQA as _DirectNumericQA
from folktexts.qa_interface import MultipleChoiceQA as _MultipleChoiceQA
from folktexts.col_to_text import ColumnToText

from . import acs_columns
from .acs_tasks import _acs_columns_map
Expand Down
4 changes: 2 additions & 2 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from . import acs_columns
from .acs_thresholds import (
acs_employment_threshold,
acs_health_insurance_threshold,
acs_income_poverty_ratio_threshold,
acs_income_threshold,
acs_mobility_threshold,
acs_public_coverage_threshold,
acs_travel_time_threshold,
acs_income_poverty_ratio_threshold,
acs_health_insurance_threshold,
)

# Map of ACS column names to ColumnToText objects
Expand Down
16 changes: 9 additions & 7 deletions folktexts/cli/eval_feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#!/usr/bin/env python
import logging
from pathlib import Path
from argparse import ArgumentParser
from pathlib import Path

from lightgbm import LGBMClassifier
from sklearn.inspection import permutation_importance

from folktexts._io import save_pickle
from folktexts.classifier import LLMClassifier
from folktexts.dataset import Dataset
from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
from folktexts._io import save_pickle

from folktexts.llm_utils import get_model_folder_path, load_model_tokenizer

# Local paths
# DEFAULT_ROOT_DIR = Path("/fast/groups/sf") # CLUSTER dir
Expand Down Expand Up @@ -82,10 +80,14 @@ def compute_feature_importance(
)

# Baseline: GBM feature importance
gbm_clf = LGBMClassifier()
print("Running baseline GBM feature importance...")
# TODO! lightgbm seems to be failing on M1 Macs - check!
from xgboost import XGBClassifier
gbm_clf = XGBClassifier()
# from lightgbm import LGBMClassifier
# gbm_clf = LGBMClassifier()
gbm_clf.fit(X_train, y_train)

print("Running baseline GBM feature importance...")
r = permutation_importance(gbm_clf, **permutation_kwargs)
save_pickle(
obj=r,
Expand Down
16 changes: 8 additions & 8 deletions folktexts/cli/launch_acs_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python3
"""Helper script to launch htcondor jobs for all experiments.
"""Launch htcondor jobs for all ACS benchmark experiments.
"""
import argparse
import logging
import math
from pathlib import Path
from pprint import pprint
Expand Down Expand Up @@ -31,19 +32,18 @@

# Models save directory
MODELS_DIR = ROOT_DIR / "huggingface-models"
# MODELS_DIR = ROOT_DIR / "data" / "huggingface-models" # on local machine

# Path to the executable script to run
EXECUTABLE_PATH = Path(__file__).parent.resolve() / "run_acs_benchmark.py"
# EXECUTABLE_PATH = Path(__file__).parent.resolve() / "run_acs_benchmark.py"
EXECUTABLE_PATH = Path(__file__).parent.resolve() / "eval_feature_importance.py"
# TODO ^ pass executable path as cmd line arg
logging.warning(f"Using executable path: {EXECUTABLE_PATH}")

##################
# Global configs #
##################
BATCH_SIZE = 15
CONTEXT_SIZE = 700
CORRECT_ORDER_BIAS = True

VERBOSE = True
BATCH_SIZE = 20
CONTEXT_SIZE = 600

JOB_CPUS = 4
JOB_MEMORY_GB = 60
Expand Down
1 change: 1 addition & 0 deletions requirements/cluster.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
htcondor
xgboost

0 comments on commit d2d5caa

Please sign in to comment.