Skip to content

Commit

Permalink
cli script for evaluating feat imp
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent ade90a5 commit d804223
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions folktexts/cli/evaluate_llm_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@
from lightgbm import LGBMClassifier
from sklearn.inspection import permutation_importance

from folktexts.classifier import LLMClassifier
from folktexts.dataset import Dataset
from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
from folktexts._io import save_pickle


# Local paths
ROOT_DIR = Path("/fast/groups/sf") # CLUSTER dir
# ROOT_DIR = Path("~").expanduser().resolve() # LOCAL dir
# DEFAULT_ROOT_DIR = Path("/fast/groups/sf") # CLUSTER dir
DEFAULT_ROOT_DIR = Path("~").expanduser().resolve() # LOCAL dir

MODELS_DIR = ROOT_DIR / "huggingface-models"
DATA_DIR = ROOT_DIR / "data"
RESULTS_ROOT_DIR = ROOT_DIR / "folktexts-results"
DEFAULT_MODELS_DIR = DEFAULT_ROOT_DIR / "huggingface-models"
DEFAULT_DATA_DIR = DEFAULT_ROOT_DIR / "data"
DEFAULT_RESULTS_DIR = Path("folktexts-results")

# MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_NAME = "google/gemma-2b" # NOTE: this is among the smallest models

TASK_NAME = "ACSIncome"
DEFAULT_TASK_NAME = "ACSIncome"

DEFAULT_CONTEXT_SIZE = 500
DEFAULT_BATCH_SIZE = 30
DEFAULT_SEED = 42

DEFAULT_PERMUTATION_REPEATS = 5


def setup_arg_parser() -> ArgumentParser:

Expand All @@ -37,12 +37,15 @@ def setup_arg_parser() -> ArgumentParser:
# List of command-line arguments, with type and helper string
cli_args = [
("--model", str, "[str] Model name or path to model saved on disk"),
("--task-name", str, "[str] Name of the ACS task to run the experiment on"),
("--results-dir", str, "[str] Directory under which this experiment's results will be saved"),
("--data-dir", str, "[str] Root folder to find datasets on"),
("--task-name", str, "[str] Name of the ACS task to run the experiment on", False, DEFAULT_TASK_NAME),
("--results-dir", str, "[str] Directory under which this experiment's results will be saved", False, DEFAULT_RESULTS_DIR),
("--data-dir", str, "[str] Root folder to find datasets on", False, DEFAULT_DATA_DIR),
("--models-dir", str, "[str] Root folder to find huggingface models on", False, DEFAULT_MODELS_DIR),
("--scorer", str, "[str] Name of the scorer to use for evaluation", False, "roc_auc"),
("--batch-size", int, "[int] The batch size to use for inference", False, DEFAULT_BATCH_SIZE),
("--context-size", int, "[int] The maximum context size when prompting the LLM", False, DEFAULT_CONTEXT_SIZE),
("--subsampling", float, "[float] Which fraction of the dataset to use (if omitted will use all data)", False),
("--fit-threshold", int, "[int] Whether to fit the prediction threshold, and on how many samples", False),
("--seed", int, "[int] Random seed -- to set for reproducibility", False, DEFAULT_SEED),
]

Expand All @@ -57,46 +60,57 @@ def setup_arg_parser() -> ArgumentParser:
return parser


def compute_feature_importance(llm_clf, dataset):
def compute_feature_importance(
llm_clf: LLMClassifier,
dataset: Dataset,
scorer: str,
results_dir: Path,
fit_threshold=None,
seed=DEFAULT_SEED,
) -> dict:

# # Optionally, fit the LLM classifier's threshold on a few data samples.
# llm_clf.fit(*dataset[:1000])
# Optionally, fit the LLM classifier's threshold on a few data samples.
if fit_threshold and isinstance(fit_threshold, int):
llm_clf.fit(*dataset[:fit_threshold])

# Get train and test data
X_train, y_train = dataset.get_train()
X_test, y_test = dataset.get_test()

permutation_kwargs = dict(
X=X_test, y=y_test,
scoring="roc_auc",
n_repeats=5,
random_state=SEED,
scoring=scorer,
n_repeats=DEFAULT_PERMUTATION_REPEATS,
random_state=seed,
)

# Baseline: GBM feature importance
gbm_clf = LGBMClassifier()
gbm_clf.fit(X_train, y_train)

r = permutation_importance(gbm_clf, **permutation_kwargs)

save_pickle(obj=r, path=f"permutation-importance.{TASK_NAME}.GBM.pkl")
save_pickle(
obj=r,
path=results_dir / f"permutation-importance.{llm_clf.task.name}.GBM.pkl",
)

# Print results:
print("GBM feature importance:")
for i in r.importances_mean.argsort()[::-1]:
# if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
print(
f"{X_test.columns[i]:<8}"
f"{r.importances_mean[i]:.3f}"
f" +/- {r.importances_std[i]:.3f}")

# LLM feature importance
r = permutation_importance(llm_clf, **permutation_kwargs)
save_pickle(obj=r, path=f"permutation-importance.{TASK_NAME}.{llm_clf.model_name}.pkl")
save_pickle(
obj=r,
path=results_dir / f"permutation-importance.{llm_clf.task.name}.{llm_clf.model_name}.pkl",
)

print("LLM feature importance:")
for i in r.importances_mean.argsort()[::-1]:
# if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
print(
f"{X_test.columns[i]:<8}"
f"{r.importances_mean[i]:.3f}"
Expand All @@ -107,41 +121,51 @@ def compute_feature_importance(llm_clf, dataset):

def main():
# Parse arguments from command line
args = setup_arg_parser().parse_args() # TODO: use args to set up the experiment
args = setup_arg_parser().parse_args()

# Set logging level
logging.getLogger().setLevel(logging.INFO)

# Load model and tokenizer
model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)
model_folder_path = get_model_folder_path(model_name=args.model, root_dir=args.models_dir)
model, tokenizer = load_model_tokenizer(model_folder_path)

results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name
# Set-up results directory
results_dir = Path(args.results_dir) / Path(model_folder_path).name
results_dir.mkdir(exist_ok=True, parents=True)
results_dir
logging.info(f"Saving results to {results_dir.as_posix()}")

# Load Task and Dataset
from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(TASK_NAME)
task = ACSTaskMetadata.get_task(args.task_name)

from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)
dataset = ACSDataset.make_from_task(task=task, cache_dir=args.data_dir)

# Optionally, subsample dataset # TODO: use command line argument
# dataset.subsample(0.1)
# print(f"{dataset.subsampling=}")
# Optionally, subsample dataset
if args.subsampling:
dataset.subsample(args.subsampling) # subsample in-place
logging.info(f"{dataset.subsampling=}")

# Construct LLM Classifier
from folktexts.classifier import LLMClassifier
llm_clf = LLMClassifier(
model=model,
tokenizer=tokenizer,
task=task,
batch_size=32,
batch_size=args.batch_size,
context_size=args.context_size,
)

# Compute feature importance
compute_feature_importance(llm_clf, tokenizer, dataset)
compute_feature_importance(
llm_clf,
dataset=dataset,
scorer=args.scorer,
results_dir=results_dir,
fit_threshold=args.fit_threshold,
seed=args.seed,
)


if __name__ == "__main__":
Expand Down

0 comments on commit d804223

Please sign in to comment.