diff --git a/utils/plot_helpers.py b/utils/plot_helpers.py index c20d604..fd5faec 100644 --- a/utils/plot_helpers.py +++ b/utils/plot_helpers.py @@ -1,28 +1,47 @@ import matplotlib.pyplot as plt import seaborn as sns -from sklearn.metrics import confusion_matrix, roc_curve, auc +from sklearn.metrics import confusion_matrix -def plot_regression_line(X, y, model): - plt.figure() - plt.scatter(X, y, color="blue", label="Data") - y_pred = model.predict(X) - plt.plot(X, y_pred, color="red", label="Prediction") - plt.legend() - return plt +def plot_confusion_matrix(y_true, y_pred, labels=None, cmap="Blues", annotate=True, normalize=False, model_name="Model"): + """ + Plots a confusion matrix with optional annotations and color customization. -def plot_confusion_matrix(y_true, y_pred, labels): - cm = confusion_matrix(y_true, y_pred) - plt.figure() - sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels) - plt.xlabel("Predicted") - plt.ylabel("Actual") - return plt + Parameters: + y_true (array-like): True labels + y_pred (array-like): Predicted labels + labels (list): List of class labels for axes + cmap (str): Color map for the plot (default: 'Blues') + annotate (bool): Whether to show cell values + normalize (bool): Normalize values to show percentages + model_name (str): Name of the model for the plot title -def plot_roc_curve(y_true, y_scores): - fpr, tpr, _ = roc_curve(y_true, y_scores) - roc_auc = auc(fpr, tpr) - plt.figure() - plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}") - plt.plot([0, 1], [0, 1], linestyle="--") - plt.legend() - return plt + Returns: + fig (matplotlib.figure.Figure): Confusion matrix figure for Streamlit display + """ + # Compute confusion matrix + cm = confusion_matrix(y_true, y_pred, labels=labels) + if normalize: + cm = cm.astype("float") / cm.sum(axis=1, keepdims=True) + + # Set up figure + fig, ax = plt.subplots(figsize=(6, 5)) + + sns.heatmap( + cm, + annot=annotate, + fmt=".2f" if normalize else "d", + cmap=cmap, + xticklabels=labels if labels is not None else sorted(set(y_true)), + yticklabels=labels if labels is not None else sorted(set(y_true)), + cbar=True, + ax=ax, + linewidths=0.5, + linecolor="gray" + ) + + ax.set_xlabel("Predicted Labels", fontsize=11) + ax.set_ylabel("True Labels", fontsize=11) + ax.set_title(f"Confusion Matrix - {model_name}", fontsize=13) + plt.tight_layout() + + return fig