Skip to content

Commit

Permalink
added cli script for evaluating feature importance
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 26, 2024
1 parent d8db5fd commit ade90a5
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 92 deletions.
137 changes: 45 additions & 92 deletions folktexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(
self._inference_kwargs.setdefault("context_size", DEFAULT_CONTEXT_SIZE)
self._inference_kwargs.setdefault("batch_size", DEFAULT_BATCH_SIZE)

# Fixed sklearn parameters
self.classes_ = np.array([0, 1])
self._is_fitted = False

@property
def model(self) -> AutoModelForCausalLM:
return self._model
Expand Down Expand Up @@ -161,6 +165,11 @@ def _get_positive_class_scores(risk_scores: np.ndarray) -> np.ndarray:
else:
return risk_scores

@staticmethod
def _make_predictions_multiclass(pos_class_scores: np.ndarray) -> np.ndarray:
"""Converts positive class scores to multiclass scores."""
return np.column_stack([1 - pos_class_scores, pos_class_scores])

def predict(
self,
data: pd.DataFrame | Dataset,
Expand Down Expand Up @@ -188,63 +197,35 @@ def predict(
def _load_predictions_from_disk(
self,
predictions_save_path: str | Path,
data: pd.DataFrame | Dataset,
) -> np.ndarray | dict[str, np.ndarray] | None:
data: pd.DataFrame,
) -> np.ndarray | None:
"""Attempts to load pre-computed predictions from disk."""
predictions_save_path = Path(predictions_save_path)

# If DF, try to load predictions as a CSV file
if isinstance(data, pd.DataFrame):
predictions_save_path = predictions_save_path.with_suffix(".csv")
predictions_df = pd.read_csv(predictions_save_path, index_col=0)

# Check if index matches our current dataframe
if predictions_df.index.equals(data.index):
return predictions_df[SCORE_COL_NAME].values
else:
logging.error("Saved predictions do not match the current dataframe.")

# If Dataset, try to load predictions as a pickled dict
elif isinstance(data, Dataset):
from ._io import load_pickle
predictions_save_path = predictions_save_path.with_suffix(".pkl")
predictions_dict = load_pickle(predictions_save_path)
if not isinstance(predictions_dict, dict):
logging.error("Loaded predictions are not in the expected dictionary format.")
return None

# Check if the predictions' indices match the current dataset
if all(
preds_array.index.equals(data.get_data_split(data_type)[0].index)
for data_type, preds_array in predictions_dict.items()
):
return {
data_type: predictions_dict[data_type][SCORE_COL_NAME].values
for data_type in predictions_dict.keys()
}
else:
logging.error("Saved predictions do not match the current dataset splits.")
# Load predictions from disk
predictions_save_path = Path(predictions_save_path).with_suffix(".csv")
predictions_df = pd.read_csv(predictions_save_path, index_col=0)

# Check if index matches our current dataframe
if predictions_df.index.equals(data.index):
return predictions_df[SCORE_COL_NAME].values
else:
logging.error(f"Cannot load predictions from disk for data type {type(data)}.")
logging.error("Saved predictions do not match the current dataframe.")
return None

def predict_proba(
self,
data: pd.DataFrame | Dataset,
data: pd.DataFrame,
batch_size: int = None,
context_size: int = None,
predictions_save_path: str | Path = None,
labels: pd.Series | np.ndarray = None,
) -> np.ndarray | dict[str, np.ndarray]:
) -> np.ndarray:
"""Returns probability estimates for the given data.
Parameters
----------
data : pd.DataFrame | Dataset
The data to compute risk estimates for. Can be a pandas DataFrame or
a Dataset object. If a Dataset object is provided, will compute risk
scores for all available data splits.
data : pd.DataFrame
The DataFrame to compute risk estimates for.
batch_size : int, optional
The batch size to use when running inference.
context_size : int, optional
Expand All @@ -260,9 +241,8 @@ def predict_proba(
Returns
-------
risk_scores : np.ndarray | dict[str, np.ndarray]
The risk scores for the given data, or a dictionary of data split
name to risk scores if a Dataset object was provided.
risk_scores : np.ndarray
The risk scores for the given data.
"""
# Validate arguments
if labels is not None and predictions_save_path is None:
Expand All @@ -276,62 +256,35 @@ def predict_proba(
result = self._load_predictions_from_disk(predictions_save_path, data=data)
if result is not None:
logging.info(f"Loaded predictions from {predictions_save_path}.")
return result
return self._make_predictions_multiclass(result)
else:
logging.error(
f"Failed to load predictions from {predictions_save_path}. "
f"Re-computing predictions and overwriting local file..."
)

# Compute risk estimates
# (if local save path was not provided or does not match current data)
if isinstance(data, pd.DataFrame):
risk_scores = self._compute_risk_estimates_for_dataframe(
df=data,
batch_size=batch_size,
context_size=context_size,
)

# Save to disk if `predictions_save_path` is provided
if predictions_save_path is not None:
predictions_save_path = Path(predictions_save_path).with_suffix(".csv")

predictions_df = pd.DataFrame(risk_scores, index=data.index, columns=[SCORE_COL_NAME])
predictions_df[LABEL_COL_NAME] = labels
predictions_df.to_csv(predictions_save_path, index=True, mode="w")

return risk_scores

elif isinstance(data, Dataset):
# TODO: save predictions in a standardized file format when given a Dataset
# > we have the dataset name, splits, and seed, so we can safely save predictions for future use
scores_dict = self._compute_risk_estimates_for_dataset(
dataset=data,
batch_size=batch_size,
context_size=context_size,
)
if not isinstance(data, pd.DataFrame):
raise ValueError(
f"`data` must be a pd.DataFrame, received {type(data)} instead.")

# Save to disk if `predictions_save_path` is provided
if predictions_save_path is not None:
predictions_save_path = Path(predictions_save_path).with_suffix(".pkl")
# Compute risk estimates
risk_scores = self.compute_risk_estimates_for_dataframe(
df=data,
batch_size=batch_size,
context_size=context_size,
)

from ._io import save_pickle
logging.warning( # TODO
f"Saving dataset predictions to {predictions_save_path.as_posix()}. "
f"TODO: remove pickling functionality and save everything as csv files "
f"to re-use file-handling code from `_compute_risk_estimates_for_dataframe`."
)
save_pickle(scores_dict, predictions_save_path)
# Save to disk if `predictions_save_path` is provided
if predictions_save_path is not None:
predictions_save_path = Path(predictions_save_path).with_suffix(".csv")

return scores_dict
predictions_df = pd.DataFrame(risk_scores, index=data.index, columns=[SCORE_COL_NAME])
predictions_df[LABEL_COL_NAME] = labels
predictions_df.to_csv(predictions_save_path, index=True, mode="w")

else:
raise ValueError(
f"`data` must be a pandas DataFrame or a Dataset object; "
f"received {type(data)} instead."
)
return self._make_predictions_multiclass(risk_scores)

def _compute_risk_estimates_for_dataframe(
def compute_risk_estimates_for_dataframe(
self,
df: pd.DataFrame,
batch_size: int = None,
Expand Down Expand Up @@ -415,12 +368,12 @@ def _compute_risk_estimates_for_dataframe(
assert not np.isclose(risk_scores, fill_value).any()
return risk_scores

def _compute_risk_estimates_for_dataset(
def compute_risk_estimates_for_dataset(
self,
dataset: Dataset,
batch_size: int = None,
context_size: int = None,
):
) -> dict[str, np.ndarray]:
"""Computes risk estimates for each row in the dataset.
Parameters
Expand All @@ -446,7 +399,7 @@ def _compute_risk_estimates_for_dataset(
data_types["val"] = dataset.get_val()[0]

results = {
data_type: self._compute_risk_estimates_for_dataframe(
data_type: self.compute_risk_estimates_for_dataframe(
df=df,
batch_size=batch_size,
context_size=context_size,
Expand Down
148 changes: 148 additions & 0 deletions folktexts/cli/evaluate_llm_feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python
import logging
from pathlib import Path
from argparse import ArgumentParser

from lightgbm import LGBMClassifier
from sklearn.inspection import permutation_importance

from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
from folktexts._io import save_pickle


# Local paths
ROOT_DIR = Path("/fast/groups/sf") # CLUSTER dir
# ROOT_DIR = Path("~").expanduser().resolve() # LOCAL dir

MODELS_DIR = ROOT_DIR / "huggingface-models"
DATA_DIR = ROOT_DIR / "data"
RESULTS_ROOT_DIR = ROOT_DIR / "folktexts-results"

# MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_NAME = "google/gemma-2b" # NOTE: this is among the smallest models

TASK_NAME = "ACSIncome"

DEFAULT_CONTEXT_SIZE = 500
DEFAULT_BATCH_SIZE = 30
DEFAULT_SEED = 42


def setup_arg_parser() -> ArgumentParser:

# Init parser
parser = ArgumentParser(description="Evaluate LLM feature importance on a given ACS task.")

# List of command-line arguments, with type and helper string
cli_args = [
("--model", str, "[str] Model name or path to model saved on disk"),
("--task-name", str, "[str] Name of the ACS task to run the experiment on"),
("--results-dir", str, "[str] Directory under which this experiment's results will be saved"),
("--data-dir", str, "[str] Root folder to find datasets on"),
("--batch-size", int, "[int] The batch size to use for inference", False, DEFAULT_BATCH_SIZE),
("--context-size", int, "[int] The maximum context size when prompting the LLM", False, DEFAULT_CONTEXT_SIZE),
("--subsampling", float, "[float] Which fraction of the dataset to use (if omitted will use all data)", False),
("--seed", int, "[int] Random seed -- to set for reproducibility", False, DEFAULT_SEED),
]

for arg in cli_args:
parser.add_argument(
arg[0],
type=arg[1],
help=arg[2],
required=(arg[3] if len(arg) > 3 else True), # NOTE: required by default
default=(arg[4] if len(arg) > 4 else None), # default value if provided
)
return parser


def compute_feature_importance(llm_clf, dataset):

# # Optionally, fit the LLM classifier's threshold on a few data samples.
# llm_clf.fit(*dataset[:1000])

# Get train and test data
X_train, y_train = dataset.get_train()
X_test, y_test = dataset.get_test()

permutation_kwargs = dict(
X=X_test, y=y_test,
scoring="roc_auc",
n_repeats=5,
random_state=SEED,
)

# Baseline: GBM feature importance
gbm_clf = LGBMClassifier()
gbm_clf.fit(X_train, y_train)

r = permutation_importance(gbm_clf, **permutation_kwargs)

save_pickle(obj=r, path=f"permutation-importance.{TASK_NAME}.GBM.pkl")

# Print results:
print("GBM feature importance:")
for i in r.importances_mean.argsort()[::-1]:
# if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
print(
f"{X_test.columns[i]:<8}"
f"{r.importances_mean[i]:.3f}"
f" +/- {r.importances_std[i]:.3f}")

# LLM feature importance
r = permutation_importance(llm_clf, **permutation_kwargs)
save_pickle(obj=r, path=f"permutation-importance.{TASK_NAME}.{llm_clf.model_name}.pkl")

print("LLM feature importance:")
for i in r.importances_mean.argsort()[::-1]:
# if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
print(
f"{X_test.columns[i]:<8}"
f"{r.importances_mean[i]:.3f}"
f" +/- {r.importances_std[i]:.3f}")

print(X_test.columns.tolist())


def main():
# Parse arguments from command line
args = setup_arg_parser().parse_args() # TODO: use args to set up the experiment

# Set logging level
logging.getLogger().setLevel(logging.INFO)

# Load model and tokenizer
model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)
model, tokenizer = load_model_tokenizer(model_folder_path)

results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name
results_dir.mkdir(exist_ok=True, parents=True)
results_dir

# Load Task and Dataset
from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(TASK_NAME)

from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)

# Optionally, subsample dataset # TODO: use command line argument
# dataset.subsample(0.1)
# print(f"{dataset.subsampling=}")

# Construct LLM Classifier
from folktexts.classifier import LLMClassifier
llm_clf = LLMClassifier(
model=model,
tokenizer=tokenizer,
task=task,
batch_size=32,
)

# Compute feature importance
compute_feature_importance(llm_clf, tokenizer, dataset)


if __name__ == "__main__":
main()

0 comments on commit ade90a5

Please sign in to comment.