Skip to content

Commit

Permalink
minor bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent d8dee1b commit c73d4e3
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions folktexts/cli/eval_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,10 @@ def compute_feature_importance(
seed=DEFAULT_SEED,
) -> dict:

# Optionally, fit the LLM classifier's threshold on a few data samples.
if fit_threshold and isinstance(fit_threshold, int):
llm_clf.fit(*dataset[:fit_threshold])

# Get train and test data
X_train, y_train = dataset.get_train()
X_test, y_test = dataset.get_test()
logging.info(f"{X_test.shape=}")

permutation_kwargs = dict(
X=X_test, y=y_test,
Expand All @@ -88,10 +85,11 @@ def compute_feature_importance(
gbm_clf = LGBMClassifier()
gbm_clf.fit(X_train, y_train)

print("Running baseline GBM feature importance...")
r = permutation_importance(gbm_clf, **permutation_kwargs)
save_pickle(
obj=r,
path=results_dir / f"permutation-importance.{llm_clf.task.name}.GBM.pkl",
path=results_dir / f"feature-importance.{llm_clf.task.name}.GBM.pkl",
)

# Print results:
Expand All @@ -102,11 +100,16 @@ def compute_feature_importance(
f"{r.importances_mean[i]:.3f}"
f" +/- {r.importances_std[i]:.3f}")

# Optionally, fit the LLM classifier's threshold on a few data samples.
if fit_threshold and isinstance(fit_threshold, int):
X_train_sample, y_train_sample = dataset.sample_n_train_examples(n=fit_threshold)
llm_clf.fit(X_train_sample, y_train_sample)

# LLM feature importance
r = permutation_importance(llm_clf, **permutation_kwargs)
save_pickle(
obj=r,
path=results_dir / f"permutation-importance.{llm_clf.task.name}.{llm_clf.model_name}.pkl",
path=results_dir / f"feature-importance.{llm_clf.task.name}.{llm_clf.model_name}.pkl",
)

print("LLM feature importance:")
Expand Down Expand Up @@ -140,7 +143,11 @@ def main():
task = ACSTaskMetadata.get_task(args.task_name)

from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(task=task, cache_dir=args.data_dir)
dataset = ACSDataset.make_from_task(
task=task,
cache_dir=args.data_dir,
seed=args.seed,
)

# Optionally, subsample dataset
if args.subsampling:
Expand Down

0 comments on commit c73d4e3

Please sign in to comment.