Skip to content

Commit

Permalink
Handle batch size = None for goodness of fit computation
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Dec 16, 2024
1 parent ed79cac commit 087ac37
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
Expand Down Expand Up @@ -169,7 +169,7 @@ def goodness_of_fit_history(model):
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
Expand Down Expand Up @@ -210,6 +210,11 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
"""
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
if model.batch_size is None:
raise ValueError(
"Computing the goodness of fit is not yet supported for "
"models trained on the full dataset (batchsize = None). "
)

nats_to_bits = np.log2(np.e)
num_sessions = model.num_sessions_
Expand Down

0 comments on commit 087ac37

Please sign in to comment.