diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 9a1dd5a6..46e3b8ca 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -178,9 +178,11 @@ def goodness_of_fit_history(model): return infonce_to_goodness_of_fit(infonce, model) -def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]], - model: cebra_sklearn_cebra.CEBRA) -> np.ndarray: - """Given a trained CEBRA model, return goodness of fit metric +def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray], + model: Optional[cebra_sklearn_cebra.CEBRA] = None, + batch_size: Optional[int] = None, + num_sessions: Optional[int] = None) -> Union[float, np.ndarray]: + """Given a trained CEBRA model, return goodness of fit metric. The goodness of fit ranges from 0 (lowest meaningful value) to a positive number with the unit "bits", the higher the @@ -199,27 +201,41 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]], S = \\log N - \\text{InfoNCE} + To use this function, either provide a trained CEBRA model or the + batch size and number of sessions. + Args: + infonce: The InfoNCE loss, either a single value or an iterable of values. model: The trained CEBRA model + batch_size: The batch size used to train the model. + num_sessions: The number of sessions used to train the model. Returns: Numpy array containing the goodness of fit values, measured in bits Raises: RuntimeError: If the provided model is not fit to data. + ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided. """ - if not hasattr(model, "state_dict_"): - raise RuntimeError("Fit the CEBRA model first.") - if model.batch_size is None: - raise ValueError( - "Computing the goodness of fit is not yet supported for " - "models trained on the full dataset (batchsize = None). " - ) + if model is not None: + if batch_size is not None or num_sessions is not None: + raise ValueError("batch_size and num_sessions should not be provided if model is provided.") + if not hasattr(model, "state_dict_"): + raise RuntimeError("Fit the CEBRA model first.") + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). " + ) + batch_size = model.batch_size + num_sessions = model.num_sessions_ + if num_sessions is None: + num_sessions = 1 + else: + if batch_size is None or num_sessions is None: + raise ValueError("batch_size should be provided if model is not provided.") nats_to_bits = np.log2(np.e) - num_sessions = model.num_sessions_ - if num_sessions is None: - num_sessions = 1 chance_level = np.log(model.batch_size * num_sessions) return (chance_level - infonce) * nats_to_bits