Skip to content

Commit

Permalink
Add predictions / models functions for da index calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonWilde committed Nov 14, 2024
1 parent 19ecab1 commit 0eb6a40
Show file tree
Hide file tree
Showing 3 changed files with 374 additions and 317 deletions.
92 changes: 85 additions & 7 deletions daindex/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ProbabilisticModel(Protocol):
def predict_proba(self, X: np.ndarray) -> np.ndarray: ...


def obtain_det_alo_index(
def obtain_da_index(
df,
cohort_name,
scores,
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_scores(models: list[ProbabilisticModel], df: pd.DataFrame, feature_list:
return predicted_probs[:, :, 1].mean(axis=0)


def get_da_values_on_predictions(
def get_da_values_on_models(
df: pd.DataFrame,
cohort_name: str,
model: list[ProbabilisticModel] | ProbabilisticModel,
Expand Down Expand Up @@ -145,20 +145,20 @@ def get_da_values_on_predictions(
is_discrete: Flag indicating if the deterioration feature is discrete. Defaults to False.
reverse: Flag indicating if the deterioration index should be reversed. Defaults to False.
optimise_bandwidth: Flag indicating if the bandwidth of the kde should be searched. Defaults to True.
score_margin_multiplier: The multiplier to be used for the score margin, 1.0 results in no overlap, some overlap works well, essentially this smooths. Defaults to 1.5.
score_margin_multiplier: The multiplier to be used for the score margin, 1.0 results in no overlap, some overlap works well, essentially this smooths. Defaults to 5.0.
Returns:
A list of deterioration index values computed at each step,
the length of the list of scores within the score range (there is a small upper and lower bound applied)
and finally the da index value itself.
"""

det_label = det_label if det_label is not None else det_feature
det_label = det_label or det_feature
models = model if isinstance(model, list) else [model]
ret = []

min_det_v = min(np.min(df[det_feature]), np.min(df[det_feature]))
max_det_v = max(np.max(df[det_feature]), np.max(df[det_feature]))
min_det_v = np.min(df[det_feature])
max_det_v = np.max(df[det_feature])

scores = get_scores(models, df, feature_list)
step_scores = np.linspace(0, 1, steps + 1)[1:] - 1 / (2 * steps)
Expand All @@ -171,7 +171,85 @@ def get_da_values_on_predictions(
ret.append(
(
s,
*obtain_det_alo_index(
*obtain_da_index(
df,
cohort_name,
scores,
step_score_bounds,
det_feature,
det_threshold,
det_label,
det_feature_func,
det_list_lengths,
min_det_v,
max_det_v,
is_discrete,
reverse,
optimise_bandwidth,
),
)
)

return ret


def get_da_values_on_predictions(
df: pd.DataFrame,
cohort_name: str,
preds_col: str,
det_feature: str,
det_threshold: float,
det_label: str = None,
det_feature_func: Callable = None,
det_list_lengths: list[int] = [20, 10, 5],
steps: int = 50,
is_discrete: bool = False,
reverse: bool = False,
optimise_bandwidth: bool = True,
score_margin_multiplier: float = 5.0,
) -> list[tuple[int, float]]:
"""
Compute the deterioration index on the predictions of the model.
Args:
df: The DataFrame containing the data.
cohort_name: The name of the cohort being analyzed.
preds_col: The column name containing the predictions.
det_feature: The feature name representing the deterioration metric.
det_threshold: The threshold value for determining deterioration.
det_label: The label for the deterioration index. Defaults to None to be replaced by `det_feature`.
det_feature_func: The function to be used to extract the deterioration feature. Defaults to None as we assume it is in the `df`.
det_list_lengths: The list of acceptable lengths to be used for the deterioration index calculation. The first element is the preferred minimum length. Defaults to [20, 10, 5].
steps: The number of steps for the deterioration index calculation. Defaults to 50.
is_discrete: Flag indicating if the deterioration feature is discrete. Defaults to False.
reverse: Flag indicating if the deterioration index should be reversed. Defaults to False.
optimise_bandwidth: Flag indicating if the bandwidth of the kde should be searched. Defaults to True.
score_margin_multiplier: The multiplier to be used for the score margin, 1.0 results in no overlap, some overlap works well, essentially this smooths. Defaults to 1.5.
Returns:
A list of deterioration index values computed at each step,
the length of the list of scores within the score range (there is a small upper and lower bound applied)
and finally the da index value itself.
"""

det_label = det_label or det_feature
ret = []

min_det_v = np.min(df[det_feature])
max_det_v = np.max(df[det_feature])

scores = df[preds_col].values
step_scores = np.linspace(0, 1, steps + 1)[1:] - 1 / (2 * steps)

for s in step_scores:
step_score_bounds = (
max(0.0, s - (1 / (2 * steps)) * score_margin_multiplier),
min(1.0, s + (1 / (2 * steps)) * score_margin_multiplier),
)
ret.append(
(
s,
*obtain_da_index(
df,
cohort_name,
scores,
Expand Down
Loading

0 comments on commit 0eb6a40

Please sign in to comment.