Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions utils/plot_helpers.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.metrics import confusion_matrix

def plot_regression_line(X, y, model):
plt.figure()
plt.scatter(X, y, color="blue", label="Data")
y_pred = model.predict(X)
plt.plot(X, y_pred, color="red", label="Prediction")
plt.legend()
return plt
def plot_confusion_matrix(y_true, y_pred, labels=None, cmap="Blues", annotate=True, normalize=False, model_name="Model"):
"""
Plots a confusion matrix with optional annotations and color customization.

def plot_confusion_matrix(y_true, y_pred, labels):
cm = confusion_matrix(y_true, y_pred)
plt.figure()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("Actual")
return plt
Parameters:
y_true (array-like): True labels
y_pred (array-like): Predicted labels
labels (list): List of class labels for axes
cmap (str): Color map for the plot (default: 'Blues')
annotate (bool): Whether to show cell values
normalize (bool): Normalize values to show percentages
model_name (str): Name of the model for the plot title

def plot_roc_curve(y_true, y_scores):
fpr, tpr, _ = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.plot([0, 1], [0, 1], linestyle="--")
plt.legend()
return plt
Returns:
fig (matplotlib.figure.Figure): Confusion matrix figure for Streamlit display
"""
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=labels)
if normalize:
cm = cm.astype("float") / cm.sum(axis=1, keepdims=True)

# Set up figure
fig, ax = plt.subplots(figsize=(6, 5))

sns.heatmap(
cm,
annot=annotate,
fmt=".2f" if normalize else "d",
cmap=cmap,
xticklabels=labels if labels is not None else sorted(set(y_true)),
yticklabels=labels if labels is not None else sorted(set(y_true)),
cbar=True,
ax=ax,
linewidths=0.5,
linecolor="gray"
)

ax.set_xlabel("Predicted Labels", fontsize=11)
ax.set_ylabel("True Labels", fontsize=11)
ax.set_title(f"Confusion Matrix - {model_name}", fontsize=13)
plt.tight_layout()

return fig