Skip to content

Commit

Permalink
saving json next to each feature importance pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 29, 2024
1 parent 437feb5 commit f6fba04
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# Import package version programmatically
import sys
from pathlib import Path

sys.path.insert(0, Path(__file__).parent.as_posix())
from folktexts._version import __version__

release = __version__
version = __version__

Expand Down
2 changes: 1 addition & 1 deletion folktexts/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utils for the folktexts cmd-line interface.
"""
from __future__ import annotations
import logging

import logging
from pathlib import Path


Expand Down
31 changes: 22 additions & 9 deletions folktexts/cli/eval_feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/env python
import logging
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path

from sklearn.inspection import permutation_importance

from folktexts._io import save_pickle
from folktexts._io import save_json, save_pickle
from folktexts.classifier import LLMClassifier
from folktexts.dataset import Dataset
from folktexts.llm_utils import get_model_folder_path, load_model_tokenizer
Expand Down Expand Up @@ -59,6 +60,16 @@ def setup_arg_parser() -> ArgumentParser:
return parser


def parse_feature_importance(results: dict, columns: list[str]) -> dict:
"""Parse the results dictionary of sklearn's permutation_importance."""
parsed_r = defaultdict(dict)
for idx, col in enumerate(columns):
parsed_r[col]["imp_mean"] = results.importances_mean[idx]
parsed_r[col]["imp_std"] = results.importances_std[idx]

return parsed_r


def compute_feature_importance(
llm_clf: LLMClassifier,
dataset: Dataset,
Expand Down Expand Up @@ -87,10 +98,11 @@ def compute_feature_importance(
gbm_clf.fit(X_train, y_train)

r = permutation_importance(gbm_clf, **permutation_kwargs)
save_pickle(
obj=r,
path=results_dir / f"feature-importance.{llm_clf.task.name}.GBM.pkl",
)
gbm_imp_file_path = results_dir / f"feature-importance.{llm_clf.task.name}.GBM.pkl"
save_pickle(obj=r, path=gbm_imp_file_path.with_suffix(".pkl"))
save_json(
parse_feature_importance(results=r, columns=X_test.columns),
path=gbm_imp_file_path.with_suffix(".json"))

# Print results:
print("GBM feature importance:")
Expand All @@ -107,10 +119,11 @@ def compute_feature_importance(

# LLM feature importance
r = permutation_importance(llm_clf, **permutation_kwargs)
save_pickle(
obj=r,
path=results_dir / f"feature-importance.{llm_clf.task.name}.{llm_clf.model_name}.pkl",
)
llm_imp_file_path = results_dir / f"feature-importance.{llm_clf.task.name}.{llm_clf.model_name}.pkl"
save_pickle(obj=r, path=llm_imp_file_path)
save_json(
parse_feature_importance(results=r, columns=X_test.columns),
path=llm_imp_file_path.with_suffix(".json"))

print("LLM feature importance:")
for i in r.importances_mean.argsort()[::-1]:
Expand Down
4 changes: 3 additions & 1 deletion folktexts/cli/run_acs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from argparse import ArgumentParser
from pathlib import Path

DEFAULT_ACS_TASK = "ACSIncome"

DEFAULT_BATCH_SIZE = 30
DEFAULT_CONTEXT_SIZE = 500
DEFAULT_SEED = 42
Expand All @@ -24,9 +26,9 @@ def list_of_strings(arg):
# List of command-line arguments, with type and helper string
cli_args = [
("--model", str, "[str] Model name or path to model saved on disk"),
("--task", str, "[str] Name of the ACS task to run the experiment on"),
("--results-dir", str, "[str] Directory under which this experiment's results will be saved"),
("--data-dir", str, "[str] Root folder to find datasets on"),
("--task", str, "[str] Name of the ACS task to run the experiment on", DEFAULT_ACS_TASK),
("--few-shot", int, "[int] Use few-shot prompting with the given number of shots", False),
("--batch-size", int, "[int] The batch size to use for inference", False, DEFAULT_BATCH_SIZE),
("--context-size", int, "[int] The maximum context size when prompting the LLM", False, DEFAULT_CONTEXT_SIZE),
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
"""
from __future__ import annotations

import pytest
import numpy as np

import pytest

TEST_CAUSAL_LMS = [
"hf-internal-testing/tiny-random-gpt2",
Expand Down
1 change: 1 addition & 0 deletions tests/test_load_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

from transformers import PreTrainedModel, PreTrainedTokenizerBase

from folktexts.llm_utils import load_model_tokenizer


Expand Down

0 comments on commit f6fba04

Please sign in to comment.