Skip to content

Commit

Permalink
Working simplified ML module
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonWilde committed Oct 25, 2024
1 parent b13ee52 commit a5776a4
Show file tree
Hide file tree
Showing 5 changed files with 781 additions and 655 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-toml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
rev: v0.6.9
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
193 changes: 193 additions & 0 deletions daindex/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import warnings
from typing import Callable, Protocol, runtime_checkable

import numpy as np
import pandas as pd

from daindex.core import deterioration_index


@runtime_checkable
class ProbabilisticModel(Protocol):
def predict_proba(self, X: np.ndarray) -> np.ndarray: ...


def obtain_det_alo_index(
df,
cohort_name,
scores,
score_bounds: tuple[float, float],
det_feature: str,
det_threshold: float,
det_label: str,
det_feature_func: Callable,
det_list_lengths: list[int],
min_det_v: float,
max_det_v: float,
is_discrete: bool,
reverse: bool,
optimise_bandwidth: bool,
) -> tuple[int, float]:
"""
Calculates the Deterioration Allocation Index (DAI) for a given cohort.
Args:
df: The DataFrame containing the data.
cohort_name: The name of the cohort.
scores: The list of scores corresponding to the DataFrame rows.
score_bounds: The lower and upper bounds for the scores to be considered.
det_feature: The feature to be used for DAI calculation.
det_threshold: The threshold value for the deterioration index.
det_label: The label for the deterioration index.
det_feature_func: A function to apply to each row to extract the feature value.
det_list_lengths: A list of acceptable lengths for the det_list, in descending order.
min_det_v: The minimum value for the deterioration index.
max_det_v: The maximum value for the deterioration index.
is_discrete: Whether the feature is discrete.
reverse: Whether to reverse the order of the feature values.
optimise_bandwidth: Whether to search for the optimal bandwidth.
Returns:
A tuple containing:
- int: The length of the det_list.
- float: The k-step value from the deterioration index calculation.
Raises:
UserWarning: If the number of samples is sub-optimal or insufficient for DAI calculation.
"""
# sort det_list_lengths into descending order
det_list_lengths.sort(reverse=True)
lb, ub = score_bounds

det_list = []
i = 0
for idx, r in df.iterrows():
p = scores[i]
if lb <= p <= ub:
if det_feature_func is not None:
det_list.append(det_feature_func(r))
else:
det_list.append(r[det_feature])
i += 1
for det_list_length in det_list_lengths:
if len(det_list) >= det_list_length:
if det_list_length != det_list_lengths[0]:
warnings.warn(
f"Sub-optimal number of samples for DAI calculation, {len(det_list)} is acceptable but {det_list_lengths[0]} is preferred."
)
break
else:
warnings.warn(f"Insufficient number of samples for DAI calculation, {len(det_list)} < {det_list_length}.")
return len(det_list), 0

X = np.array(det_list)
di_ret = deterioration_index(
X[~np.isnan(X)].reshape(-1, 1),
min_det_v,
max_det_v,
threshold=det_threshold,
plot_title=f"{cohort_name} | {det_label}",
is_discrete=is_discrete,
reverse=reverse,
optimise_bandwidth=optimise_bandwidth,
do_plot=False,
)
return len(det_list), di_ret["k-step"]


def get_scores(models: list[ProbabilisticModel], df: pd.DataFrame, feature_list: list[str]) -> np.ndarray:
"""
Computes the mean predicted probabilities for a list of models.
Args:
models: A list of trained model objects that have a `predict_proba` method.
df: A DataFrame containing the data to be used for predictions.
feature_list: A list of column names in `df` to be used as features for prediction.
Returns:
An array of mean predicted probabilities for the positive class.
"""
predicted_probs = np.array([m.predict_proba(df[feature_list].to_numpy()) for m in models])
return predicted_probs[:, :, 1].mean(axis=0)


def get_da_values_on_predictions(
df: pd.DataFrame,
cohort_name: str,
model: list[ProbabilisticModel] | ProbabilisticModel,
feature_list: list[str],
det_feature: str,
det_threshold: float,
det_label: str = None,
det_feature_func: Callable = None,
det_list_lengths: list[int] = [20, 10, 5],
steps: int = 50,
is_discrete: bool = False,
reverse: bool = False,
optimise_bandwidth: bool = True,
score_margin_multiplier: float = 5.0,
) -> list[tuple[int, float]]:
"""
Compute the deterioration index on the predictions of the model.
Args:
df: The DataFrame containing the data.
cohort_name: The name of the cohort being analyzed.
model: The predictive model or list of models to be used.
feature_list: List of feature names to be used for prediction.
det_feature: The feature name representing the deterioration metric.
det_threshold: The threshold value for determining deterioration.
det_label: The label for the deterioration index. Defaults to None to be replaced by `det_feature`.
det_feature_func: The function to be used to extract the deterioration feature. Defaults to None as we assume it is in the `df`.
det_list_lengths: The list of acceptable lengths to be used for the deterioration index calculation. The first element is the preferred minimum length. Defaults to [20, 10, 5].
steps: The number of steps for the deterioration index calculation. Defaults to 50.
is_discrete: Flag indicating if the deterioration feature is discrete. Defaults to False.
reverse: Flag indicating if the deterioration index should be reversed. Defaults to False.
optimise_bandwidth: Flag indicating if the bandwidth of the kde should be searched. Defaults to True.
score_margin_multiplier: The multiplier to be used for the score margin, 1.0 results in no overlap, some overlap works well, essentially this smooths. Defaults to 1.5.
Returns:
A list of deterioration index values computed at each step,
the length of the list of scores within the score range (there is a small upper and lower bound applied)
and finally the da index value itself.
"""

det_label = det_label if det_label is not None else det_feature
models = model if isinstance(model, list) else [model]
ret = []

min_det_v = min(np.min(df[det_feature]), np.min(df[det_feature]))
max_det_v = max(np.max(df[det_feature]), np.max(df[det_feature]))

scores = get_scores(models, df, feature_list)
step_scores = np.linspace(0, 1, steps + 1)[1:] - 1 / (2 * steps)

for s in step_scores:
step_score_bounds = (
max(0.0, s - (1 / (2 * steps)) * score_margin_multiplier),
min(1.0, s + (1 / (2 * steps)) * score_margin_multiplier),
)
ret.append(
(
s,
*obtain_det_alo_index(
df,
cohort_name,
scores,
step_score_bounds,
det_feature,
det_threshold,
det_label,
det_feature_func,
det_list_lengths,
min_det_v,
max_det_v,
is_discrete,
reverse,
optimise_bandwidth,
),
)
)

return ret
21 changes: 12 additions & 9 deletions daindex/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def vis_DA_indices(data: np.ndarray, label: str) -> tuple[float, float, np.ndarr
def calc_ratios(a1: float, a2: float, da1: float, da2: float) -> dict[str, float | str]:
ratio = (a2 - a1) / a1
decision_ratio = ((da2 - da1) / da1) if da1 != 0 else "N/A"
return {"AUC Ratio": ratio, "Decision AUC Ratio": decision_ratio}
return {"Full": ratio, "Decision": decision_ratio}


def get_DA_curve(
def get_da_curve(
d1: np.ndarray,
d2: np.ndarray,
g1_label: str,
Expand All @@ -139,7 +139,7 @@ def get_DA_curve(
config: dict = {},
decision_boundary: float = 0.5,
verbose: bool = True,
) -> dict[str, float | str]:
) -> tuple[dict[str, float | str], plt.Figure]:
"""
do DA curve visualisation
"""
Expand Down Expand Up @@ -171,8 +171,8 @@ def get_DA_curve(
ratios = calc_ratios(a1, a2, da1, da2)

if verbose:
print("AUC\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(a1, a2, ratios["AUC"]))
print("Decision AUC\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(da1, da2, ratios["Decision AUC"]))
print("Ratio\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(a1, a2, ratios["Full"]))
print("Decision Ratio\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(da1, da2, ratios["Decision"]))

# figure finishing up
plt.xlabel(allocation_label, fontsize=font_size)
Expand All @@ -184,15 +184,18 @@ def get_DA_curve(

plt.legend(fontsize=font_size, loc="best")

return ratios
fig = plt.gcf()
plt.close()

return ratios, fig


def calc_area(data: np.ndarray) -> tuple[float, float]:
wd = data[np.where(data[:, 1] > 0)][:, [0, 2, 1]]
return area_under_curve(wd)


def get_DA_ratios(
def get_da_ratios(
d1: np.ndarray, d2: np.ndarray, verbose: bool = False, return_areas: bool = False
) -> dict[str, float | str]:
"""
Expand All @@ -215,7 +218,7 @@ def get_DA_ratios(
ratios.update({"AUC1": a1, "AUC2": a2, "Decision AUC1": da1, "Decision AUC2": da2})

if verbose:
print("AUC\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(a1, a2, ratios["AUC"]))
print("Decision AUC\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(da1, da2, ratios["Decision AUC"]))
print("Ratio\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(a1, a2, ratios["Full"]))
print("Decision Ratio\t{0:.6f}\t{1:.6f}\t{2:.2%}".format(da1, da2, ratios["Decision"]))

return ratios
Loading

0 comments on commit a5776a4

Please sign in to comment.