@@ -178,9 +178,11 @@ def goodness_of_fit_history(model):
178
178
return infonce_to_goodness_of_fit (infonce , model )
179
179
180
180
181
- def infonce_to_goodness_of_fit (infonce : Union [float , Iterable [float ]],
182
- model : cebra_sklearn_cebra .CEBRA ) -> np .ndarray :
183
- """Given a trained CEBRA model, return goodness of fit metric
181
+ def infonce_to_goodness_of_fit (infonce : Union [float , np .ndarray ],
182
+ model : Optional [cebra_sklearn_cebra .CEBRA ] = None ,
183
+ batch_size : Optional [int ] = None ,
184
+ num_sessions : Optional [int ] = None ) -> Union [float , np .ndarray ]:
185
+ """Given a trained CEBRA model, return goodness of fit metric.
184
186
185
187
The goodness of fit ranges from 0 (lowest meaningful value)
186
188
to a positive number with the unit "bits", the higher the
@@ -199,27 +201,41 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
199
201
200
202
S = \\ log N - \\ text{InfoNCE}
201
203
204
+ To use this function, either provide a trained CEBRA model or the
205
+ batch size and number of sessions.
206
+
202
207
Args:
208
+ infonce: The InfoNCE loss, either a single value or an iterable of values.
203
209
model: The trained CEBRA model
210
+ batch_size: The batch size used to train the model.
211
+ num_sessions: The number of sessions used to train the model.
204
212
205
213
Returns:
206
214
Numpy array containing the goodness of fit values, measured in bits
207
215
208
216
Raises:
209
217
RuntimeError: If the provided model is not fit to data.
218
+ ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
210
219
"""
211
- if not hasattr (model , "state_dict_" ):
212
- raise RuntimeError ("Fit the CEBRA model first." )
213
- if model .batch_size is None :
214
- raise ValueError (
215
- "Computing the goodness of fit is not yet supported for "
216
- "models trained on the full dataset (batchsize = None). "
217
- )
220
+ if model is not None :
221
+ if batch_size is not None or num_sessions is not None :
222
+ raise ValueError ("batch_size and num_sessions should not be provided if model is provided." )
223
+ if not hasattr (model , "state_dict_" ):
224
+ raise RuntimeError ("Fit the CEBRA model first." )
225
+ if model .batch_size is None :
226
+ raise ValueError (
227
+ "Computing the goodness of fit is not yet supported for "
228
+ "models trained on the full dataset (batchsize = None). "
229
+ )
230
+ batch_size = model .batch_size
231
+ num_sessions = model .num_sessions_
232
+ if num_sessions is None :
233
+ num_sessions = 1
234
+ else :
235
+ if batch_size is None or num_sessions is None :
236
+ raise ValueError ("batch_size should be provided if model is not provided." )
218
237
219
238
nats_to_bits = np .log2 (np .e )
220
- num_sessions = model .num_sessions_
221
- if num_sessions is None :
222
- num_sessions = 1
223
239
chance_level = np .log (model .batch_size * num_sessions )
224
240
return (chance_level - infonce ) * nats_to_bits
225
241
0 commit comments