Skip to content

Commit

Permalink
minor bug fix on shape output of predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent b385138 commit 2104021
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
1 change: 1 addition & 0 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def run(self, results_root_dir: str | Path, fit_threshold: int | bool = 0) -> fl
predictions_save_path=test_predictions_save_path,
labels=y_test, # used only to save alongside predictions in disk
)
self._y_test_scores = self.llm_clf._get_positive_class_scores(self._y_test_scores)

# If requested, fit the threshold on a small portion of the train set
if fit_threshold:
Expand Down
16 changes: 5 additions & 11 deletions folktexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def fit(self, X, y, *, false_pos_cost=1.0, false_neg_cost=1.0, **kwargs):
"""Uses the provided data sample to fit the prediction threshold."""

# Compute risk estimates for the data
y_pred_scores = self.predict_proba(X, **kwargs)
if len(y_pred_scores.shape) > 1:
y_pred_scores = y_pred_scores[:, -1]
y_pred_scores = self._get_positive_class_scores(
self.predict_proba(X, **kwargs)
)

# Compute the best threshold for the given data
self.threshold = compute_best_threshold(
Expand Down Expand Up @@ -172,7 +172,7 @@ def _make_predictions_multiclass(pos_class_scores: np.ndarray) -> np.ndarray:

def predict(
self,
data: pd.DataFrame | Dataset,
data: pd.DataFrame,
batch_size: int = None,
context_size: int = None,
predictions_save_path: str | Path = None,
Expand All @@ -186,13 +186,7 @@ def predict(
predictions_save_path=predictions_save_path,
labels=labels,
)
if isinstance(risk_scores, dict):
return {
data_type: (self._get_positive_class_scores(data_scores) >= self.threshold).astype(int)
for data_type, data_scores in risk_scores.items()
}
else:
return (self._get_positive_class_scores(risk_scores) >= self.threshold).astype(int)
return (self._get_positive_class_scores(risk_scores) >= self.threshold).astype(int)

def _load_predictions_from_disk(
self,
Expand Down
1 change: 1 addition & 0 deletions folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def evaluate_predictions(
y_pred_binary = (y_pred_scores >= threshold).astype(int)

# Evaluate binary predictions
import ipdb; ipdb.set_trace()
results.update(evaluate_binary_predictions(y_true, y_pred_binary))

# Add loss functions as proxies for calibration
Expand Down
6 changes: 5 additions & 1 deletion folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def create_task_with_feature_subset(self, feature_subset: Iterable[str]):

# Check if features are a subset of the original features
if not set(feature_subset).issubset(self.features):
raise ValueError("`feature_subset` must be a subset of the original features.")
raise ValueError(
f"`feature_subset` must be a subset of the original features; "
f"following features are not in the original set: "
f"{set(self.features) - set(feature_subset)}"
)

# Return new TaskMetadata object
return dataclasses.replace(
Expand Down

0 comments on commit 2104021

Please sign in to comment.