Skip to content

Commit 1d55ead

Browse files
committed
adapt GoF implementation
1 parent f2e6257 commit 1d55ead

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,11 @@ def goodness_of_fit_history(model):
178178
return infonce_to_goodness_of_fit(infonce, model)
179179

180180

181-
def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
182-
model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
183-
"""Given a trained CEBRA model, return goodness of fit metric
181+
def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
182+
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
183+
batch_size: Optional[int] = None,
184+
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
185+
"""Given a trained CEBRA model, return goodness of fit metric.
184186
185187
The goodness of fit ranges from 0 (lowest meaningful value)
186188
to a positive number with the unit "bits", the higher the
@@ -199,27 +201,41 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
199201
200202
S = \\log N - \\text{InfoNCE}
201203
204+
To use this function, either provide a trained CEBRA model or the
205+
batch size and number of sessions.
206+
202207
Args:
208+
infonce: The InfoNCE loss, either a single value or an iterable of values.
203209
model: The trained CEBRA model
210+
batch_size: The batch size used to train the model.
211+
num_sessions: The number of sessions used to train the model.
204212
205213
Returns:
206214
Numpy array containing the goodness of fit values, measured in bits
207215
208216
Raises:
209217
RuntimeError: If the provided model is not fit to data.
218+
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
210219
"""
211-
if not hasattr(model, "state_dict_"):
212-
raise RuntimeError("Fit the CEBRA model first.")
213-
if model.batch_size is None:
214-
raise ValueError(
215-
"Computing the goodness of fit is not yet supported for "
216-
"models trained on the full dataset (batchsize = None). "
217-
)
220+
if model is not None:
221+
if batch_size is not None or num_sessions is not None:
222+
raise ValueError("batch_size and num_sessions should not be provided if model is provided.")
223+
if not hasattr(model, "state_dict_"):
224+
raise RuntimeError("Fit the CEBRA model first.")
225+
if model.batch_size is None:
226+
raise ValueError(
227+
"Computing the goodness of fit is not yet supported for "
228+
"models trained on the full dataset (batchsize = None). "
229+
)
230+
batch_size = model.batch_size
231+
num_sessions = model.num_sessions_
232+
if num_sessions is None:
233+
num_sessions = 1
234+
else:
235+
if batch_size is None or num_sessions is None:
236+
raise ValueError("batch_size should be provided if model is not provided.")
218237

219238
nats_to_bits = np.log2(np.e)
220-
num_sessions = model.num_sessions_
221-
if num_sessions is None:
222-
num_sessions = 1
223239
chance_level = np.log(model.batch_size * num_sessions)
224240
return (chance_level - infonce) * nats_to_bits
225241

0 commit comments

Comments
 (0)