-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add improved goodness of fit implementation #190
base: main
Are you sure you want to change the base?
Conversation
5e21cdc
to
c826b68
Compare
c826b68
to
f43971f
Compare
@stes about what I implemented in #202 that I do see here. I think it would be good to have a really basic function where you provide the loss and the batch size, so that it is easily usable in the pytorch implementation as well. Also, it would be nice to test for the default |
The build issue is fixed, and once #205 is merged tests should pass here as well. |
1d55ead
to
ad8ae60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @stes! This looks nice!! Some minor suggestions on the docstrings and maybe add some tests for the different corner cases based on the arguments provided in infonce_to_goodness_of_fit
.
"""Compute the InfoNCE loss on a *single session* dataset on the model. | ||
|
||
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it computes the goodness of fit score from the infonce loss no?
"""Compute the InfoNCE loss on a *single session* dataset on the model. | |
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss. | |
"""Compute the goodness of fit score on a *single session* dataset on the model. | |
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss | |
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function | |
to derive the goodness of fit from the InfoNCE loss. |
return infonce_to_goodness_of_fit(loss, cebra_model) | ||
|
||
|
||
def goodness_of_fit_history(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def goodness_of_fit_history(model): | |
def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray: |
|
||
Args: | ||
infonce: The InfoNCE loss, either a single value or an iterable of values. | ||
model: The trained CEBRA model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model: The trained CEBRA model | |
model: The trained CEBRA model. |
num_sessions = 1 | ||
else: | ||
if batch_size is None or num_sessions is None: | ||
raise ValueError("batch_size should be provided if model is not provided.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError("batch_size should be provided if model is not provided.") | |
raise ValueError( | |
f"batch_size ({batch_size}) and num_sessions ({num_sessions})" | |
f"should be provided if model is not provided." | |
) |
@@ -383,3 +383,67 @@ def test_sklearn_runs_consistency(): | |||
with pytest.raises(ValueError, match="Invalid.*embeddings"): | |||
_, _, _ = cebra_sklearn_metrics.consistency_score( | |||
invalid_embeddings_runs, between="runs") | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, some tests to test infonce_to_goodness_of_fit
errors raising based on model / batch size (=None) / num_sessions (= None, num_sessions > than the actual number of sessions, etc) are necessary here.
This adds a better goodness of fit measure. Instead of the old variant which simply matched the InfoNCE and depends on the batch size, the proposed measure
The conversion is quite simply done via
This measure is also used in DeWolf et al., 2024, Eq. (43)
Application example (GoF improves from 0 to a larger value during training):
Close https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/669