@@ -138,7 +138,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
138
138
>>> import cebra
139
139
>>> import numpy as np
140
140
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
141
- >>> cebra_model = cebra.CEBRA(max_iterations=10)
141
+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512 )
142
142
>>> cebra_model.fit(neural_data)
143
143
CEBRA(max_iterations=10)
144
144
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
@@ -169,7 +169,7 @@ def goodness_of_fit_history(model):
169
169
>>> import cebra
170
170
>>> import numpy as np
171
171
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
172
- >>> cebra_model = cebra.CEBRA(max_iterations=10)
172
+ >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512 )
173
173
>>> cebra_model.fit(neural_data)
174
174
CEBRA(max_iterations=10)
175
175
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
@@ -210,6 +210,11 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
210
210
"""
211
211
if not hasattr (model , "state_dict_" ):
212
212
raise RuntimeError ("Fit the CEBRA model first." )
213
+ if model .batch_size is None :
214
+ raise ValueError (
215
+ "Computing the goodness of fit is not yet supported for "
216
+ "models trained on the full dataset (batchsize = None). "
217
+ )
213
218
214
219
nats_to_bits = np .log2 (np .e )
215
220
num_sessions = model .num_sessions_
0 commit comments