Skip to content

Commit f2e6257

Browse files
authored
Handle batch size = None for goodness of fit computation
1 parent f43971f commit f2e6257

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
138138
>>> import cebra
139139
>>> import numpy as np
140140
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
141-
>>> cebra_model = cebra.CEBRA(max_iterations=10)
141+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
142142
>>> cebra_model.fit(neural_data)
143143
CEBRA(max_iterations=10)
144144
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
@@ -169,7 +169,7 @@ def goodness_of_fit_history(model):
169169
>>> import cebra
170170
>>> import numpy as np
171171
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
172-
>>> cebra_model = cebra.CEBRA(max_iterations=10)
172+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
173173
>>> cebra_model.fit(neural_data)
174174
CEBRA(max_iterations=10)
175175
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
@@ -210,6 +210,11 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
210210
"""
211211
if not hasattr(model, "state_dict_"):
212212
raise RuntimeError("Fit the CEBRA model first.")
213+
if model.batch_size is None:
214+
raise ValueError(
215+
"Computing the goodness of fit is not yet supported for "
216+
"models trained on the full dataset (batchsize = None). "
217+
)
213218

214219
nats_to_bits = np.log2(np.e)
215220
num_sessions = model.num_sessions_

0 commit comments

Comments
 (0)