-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add improved goodness of fit implementation #190
base: main
Are you sure you want to change the base?
Changes from all commits
2372c8b
f7a7042
923675b
ed79cac
087ac37
ad8ae60
5cee743
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -108,6 +108,138 @@ def infonce_loss( | |||||||||||
return avg_loss | ||||||||||||
|
||||||||||||
|
||||||||||||
def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, | ||||||||||||
X: Union[npt.NDArray, torch.Tensor], | ||||||||||||
*y, | ||||||||||||
session_id: Optional[int] = None, | ||||||||||||
num_batches: int = 500) -> float: | ||||||||||||
"""Compute the InfoNCE loss on a *single session* dataset on the model. | ||||||||||||
|
||||||||||||
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
cebra_model: The model to use to compute the InfoNCE loss on the samples. | ||||||||||||
X: A 2D data matrix, corresponding to a *single session* recording. | ||||||||||||
y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one | ||||||||||||
discrete index passed as a 1D array. Each index has to match the length of ``X``. | ||||||||||||
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions` | ||||||||||||
for multisession, set to ``None`` for single session. | ||||||||||||
num_batches: The number of iterations to consider to evaluate the model on the new data. | ||||||||||||
Higher values will give a more accurate estimate. Set it to at least 500 iterations. | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
The average GoF score estimated over ``num_batches`` batches from the data distribution. | ||||||||||||
|
||||||||||||
Related: | ||||||||||||
:func:`infonce_to_goodness_of_fit` | ||||||||||||
|
||||||||||||
Example: | ||||||||||||
|
||||||||||||
>>> import cebra | ||||||||||||
>>> import numpy as np | ||||||||||||
>>> neural_data = np.random.uniform(0, 1, (1000, 20)) | ||||||||||||
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) | ||||||||||||
>>> cebra_model.fit(neural_data) | ||||||||||||
CEBRA(batch_size=512, max_iterations=10) | ||||||||||||
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) | ||||||||||||
""" | ||||||||||||
loss = infonce_loss(cebra_model, | ||||||||||||
X, | ||||||||||||
*y, | ||||||||||||
session_id=session_id, | ||||||||||||
num_batches=num_batches, | ||||||||||||
correct_by_batchsize=False) | ||||||||||||
return infonce_to_goodness_of_fit(loss, cebra_model) | ||||||||||||
|
||||||||||||
|
||||||||||||
def goodness_of_fit_history(model): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
"""Return the history of the goodness of fit score. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
model: A trained CEBRA model. | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
A numpy array containing the goodness of fit values, measured in bits. | ||||||||||||
|
||||||||||||
Related: | ||||||||||||
:func:`infonce_to_goodness_of_fit` | ||||||||||||
|
||||||||||||
Example: | ||||||||||||
|
||||||||||||
>>> import cebra | ||||||||||||
>>> import numpy as np | ||||||||||||
>>> neural_data = np.random.uniform(0, 1, (1000, 20)) | ||||||||||||
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) | ||||||||||||
>>> cebra_model.fit(neural_data) | ||||||||||||
CEBRA(batch_size=512, max_iterations=10) | ||||||||||||
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) | ||||||||||||
""" | ||||||||||||
infonce = np.array(model.state_dict_["log"]["total"]) | ||||||||||||
return infonce_to_goodness_of_fit(infonce, model) | ||||||||||||
|
||||||||||||
|
||||||||||||
def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray], | ||||||||||||
model: Optional[cebra_sklearn_cebra.CEBRA] = None, | ||||||||||||
batch_size: Optional[int] = None, | ||||||||||||
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]: | ||||||||||||
"""Given a trained CEBRA model, return goodness of fit metric. | ||||||||||||
|
||||||||||||
The goodness of fit ranges from 0 (lowest meaningful value) | ||||||||||||
to a positive number with the unit "bits", the higher the | ||||||||||||
better. | ||||||||||||
|
||||||||||||
Values lower than 0 bits are possible, but these only occur | ||||||||||||
due to numerical effects. A perfectly collapsed embedding | ||||||||||||
(e.g., because the data cannot be fit with the provided | ||||||||||||
auxiliary variables) will have a goodness of fit of 0. | ||||||||||||
|
||||||||||||
The conversion between the generalized InfoNCE metric that | ||||||||||||
CEBRA is trained with and the goodness of fit computed with this | ||||||||||||
function is | ||||||||||||
|
||||||||||||
.. math:: | ||||||||||||
|
||||||||||||
S = \\log N - \\text{InfoNCE} | ||||||||||||
|
||||||||||||
To use this function, either provide a trained CEBRA model or the | ||||||||||||
batch size and number of sessions. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
infonce: The InfoNCE loss, either a single value or an iterable of values. | ||||||||||||
model: The trained CEBRA model | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
batch_size: The batch size used to train the model. | ||||||||||||
num_sessions: The number of sessions used to train the model. | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
Numpy array containing the goodness of fit values, measured in bits | ||||||||||||
|
||||||||||||
Raises: | ||||||||||||
RuntimeError: If the provided model is not fit to data. | ||||||||||||
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided. | ||||||||||||
""" | ||||||||||||
if model is not None: | ||||||||||||
if batch_size is not None or num_sessions is not None: | ||||||||||||
raise ValueError("batch_size and num_sessions should not be provided if model is provided.") | ||||||||||||
if not hasattr(model, "state_dict_"): | ||||||||||||
raise RuntimeError("Fit the CEBRA model first.") | ||||||||||||
if model.batch_size is None: | ||||||||||||
raise ValueError( | ||||||||||||
"Computing the goodness of fit is not yet supported for " | ||||||||||||
"models trained on the full dataset (batchsize = None). " | ||||||||||||
) | ||||||||||||
batch_size = model.batch_size | ||||||||||||
num_sessions = model.num_sessions_ | ||||||||||||
if num_sessions is None: | ||||||||||||
num_sessions = 1 | ||||||||||||
else: | ||||||||||||
if batch_size is None or num_sessions is None: | ||||||||||||
raise ValueError("batch_size should be provided if model is not provided.") | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
nats_to_bits = np.log2(np.e) | ||||||||||||
chance_level = np.log(model.batch_size * num_sessions) | ||||||||||||
return (chance_level - infonce) * nats_to_bits | ||||||||||||
|
||||||||||||
|
||||||||||||
def _consistency_scores( | ||||||||||||
embeddings: List[Union[npt.NDArray, torch.Tensor]], | ||||||||||||
datasets: List[Union[int, str]], | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -383,3 +383,67 @@ def test_sklearn_runs_consistency(): | |
with pytest.raises(ValueError, match="Invalid.*embeddings"): | ||
_, _, _ = cebra_sklearn_metrics.consistency_score( | ||
invalid_embeddings_runs, between="runs") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, some tests to test |
||
|
||
@pytest.mark.parametrize("seed", [42, 24, 10]) | ||
def test_goodness_of_fit_score(seed): | ||
""" | ||
Ensure that the GoF score is close to 0 for a model fit on random data. | ||
""" | ||
cebra_model = cebra_sklearn_cebra.CEBRA( | ||
model_architecture="offset1-model", | ||
max_iterations=5, | ||
batch_size=512, | ||
) | ||
X = torch.tensor(np.random.uniform(0, 1, (5000, 50))) | ||
y = torch.tensor(np.random.uniform(0, 1, (5000, 5))) | ||
cebra_model.fit(X, y) | ||
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model, | ||
X, | ||
y, | ||
session_id=0, | ||
num_batches=500) | ||
assert isinstance(score, float) | ||
assert np.isclose(score, 0, atol=0.01) | ||
|
||
|
||
@pytest.mark.parametrize("seed", [42, 24, 10]) | ||
def test_goodness_of_fit_history(seed): | ||
""" | ||
Ensure that the GoF score is higher for a model fit on data with underlying | ||
structure than for a model fit on random data. | ||
""" | ||
|
||
# Generate data | ||
generator = torch.Generator().manual_seed(seed) | ||
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) | ||
y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator) | ||
linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator) | ||
y_linear = X @ linear_map | ||
|
||
def _fit_and_get_history(X, y): | ||
cebra_model = cebra_sklearn_cebra.CEBRA( | ||
model_architecture="offset1-model", | ||
max_iterations=150, | ||
batch_size=512, | ||
device="cpu") | ||
cebra_model.fit(X, y) | ||
history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model) | ||
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values | ||
# due to numerical issues. | ||
return history[5:] | ||
|
||
history_random = _fit_and_get_history(X, y_random) | ||
history_linear = _fit_and_get_history(X, y_linear) | ||
|
||
assert isinstance(history_random, np.ndarray) | ||
assert history_random.shape[0] > 0 | ||
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values | ||
# due to numerical issues. | ||
history_random_non_negative = history_random[history_random >= 0] | ||
np.testing.assert_allclose(history_random_non_negative, 0, atol=0.05) | ||
|
||
assert isinstance(history_linear, np.ndarray) | ||
assert history_linear.shape[0] > 0 | ||
|
||
assert np.all(history_linear[-20:] > history_random[-20:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it computes the goodness of fit score from the infonce loss no?