Skip to content

Commit

Permalink
adapt GoF implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Dec 16, 2024
1 parent f2e6257 commit 1d55ead
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,11 @@ def goodness_of_fit_history(model):
return infonce_to_goodness_of_fit(infonce, model)


def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
"""Given a trained CEBRA model, return goodness of fit metric
def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
batch_size: Optional[int] = None,
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
"""Given a trained CEBRA model, return goodness of fit metric.
The goodness of fit ranges from 0 (lowest meaningful value)
to a positive number with the unit "bits", the higher the
Expand All @@ -199,27 +201,41 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
S = \\log N - \\text{InfoNCE}
To use this function, either provide a trained CEBRA model or the
batch size and number of sessions.
Args:
infonce: The InfoNCE loss, either a single value or an iterable of values.
model: The trained CEBRA model
batch_size: The batch size used to train the model.
num_sessions: The number of sessions used to train the model.
Returns:
Numpy array containing the goodness of fit values, measured in bits
Raises:
RuntimeError: If the provided model is not fit to data.
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
"""
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
if model.batch_size is None:
raise ValueError(
"Computing the goodness of fit is not yet supported for "
"models trained on the full dataset (batchsize = None). "
)
if model is not None:
if batch_size is not None or num_sessions is not None:
raise ValueError("batch_size and num_sessions should not be provided if model is provided.")
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
if model.batch_size is None:
raise ValueError(
"Computing the goodness of fit is not yet supported for "
"models trained on the full dataset (batchsize = None). "
)
batch_size = model.batch_size
num_sessions = model.num_sessions_
if num_sessions is None:
num_sessions = 1
else:
if batch_size is None or num_sessions is None:
raise ValueError("batch_size should be provided if model is not provided.")

nats_to_bits = np.log2(np.e)
num_sessions = model.num_sessions_
if num_sessions is None:
num_sessions = 1
chance_level = np.log(model.batch_size * num_sessions)
return (chance_level - infonce) * nats_to_bits

Expand Down

0 comments on commit 1d55ead

Please sign in to comment.