Skip to content

Commit

Permalink
add tests and improve implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Oct 27, 2024
1 parent 5df5ca2 commit db9df82
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 18 deletions.
70 changes: 52 additions & 18 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,15 @@ def infonce_loss(
return avg_loss


def goodness_of_fit_score(
cebra_model: cebra_sklearn_cebra.CEBRA,
X: Union[npt.NDArray, torch.Tensor],
*y,
session_id: Optional[int] = None,
num_batches: int = 500,
correct_by_batchsize: bool = False,
) -> float:
def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
X: Union[npt.NDArray, torch.Tensor],
*y,
session_id: Optional[int] = None,
num_batches: int = 500) -> float:
"""Compute the InfoNCE loss on a *single session* dataset on the model.
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss.
Args:
cebra_model: The model to use to compute the InfoNCE loss on the samples.
X: A 2D data matrix, corresponding to a *single session* recording.
Expand All @@ -127,23 +126,60 @@ def goodness_of_fit_score(
for multisession, set to ``None`` for single session.
num_batches: The number of iterations to consider to evaluate the model on the new data.
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
Returns:
The average GoF score estimated over ``num_batches`` batches from the data distribution.
Related:
:func:`infonce_to_goodness_of_fit`
Example:
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data)
"""
loss = infonce_loss(cebra_model=cebra_model,
X=X,
loss = infonce_loss(cebra_model,
X,
*y,
session_id=session_id,
num_batches=500,
num_batches=num_batches,
correct_by_batchsize=False)
return infonce_to_goodness_of_fit(loss, cebra_model)


def goodness_of_fit_score(model):
def goodness_of_fit_history(model):
"""Return the history of the goodness of fit score.
Args:
model: A trained CEBRA model.
Returns:
A numpy array containing the goodness of fit values, measured in bits.
Related:
:func:`infonce_to_goodness_of_fit`
Example:
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof_history = cebra.goodness_of_fit_history(cebra_model)
"""
infonce = np.array(model.state_dict_["log"]["total"])
return infonce_to_goodness_of_fit(infonce, model)


def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
model: cebra.CEBRA) -> np.ndarray:
model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
"""Given a trained CEBRA model, return goodness of fit metric
The goodness of fit ranges from 0 (lowest meaningful value)
Expand All @@ -161,18 +197,16 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
.. math::
S = \log N - \text{InfoNCE}
S = \\log N - \\text{InfoNCE}
Args:
model: The trained CEBRA model
Returns:
Numpy array containing the goodness of fit
values, measured in bits
Numpy array containing the goodness of fit values, measured in bits
Raises:
``RuntimeError``, if provided model is not
fit to data.
``RuntimeError``, if provided model is not fit to data.
"""
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
Expand Down
64 changes: 64 additions & 0 deletions tests/test_sklearn_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,67 @@ def test_sklearn_runs_consistency():
with pytest.raises(ValueError, match="Invalid.*embeddings"):
_, _, _ = cebra_sklearn_metrics.consistency_score(
invalid_embeddings_runs, between="runs")


@pytest.mark.parametrize("seed", [42, 24, 10])
def test_goodness_of_fit_score(seed):
"""
Ensure that the GoF score is close to 0 for a model fit on random data.
"""
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=5,
batch_size=512,
)
X = torch.tensor(np.random.uniform(0, 1, (5000, 50)))
y = torch.tensor(np.random.uniform(0, 1, (5000, 5)))
cebra_model.fit(X, y)
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
X,
y,
session_id=0,
num_batches=500)
assert isinstance(score, float)
assert np.isclose(score, 0, atol=0.01)


@pytest.mark.parametrize("seed", [42, 24, 10])
def test_goodness_of_fit_history(seed):
"""
Ensure that the GoF score is higher for a model fit on data with underlying
structure than for a model fit on random data.
"""

# Generate data
generator = torch.Generator().manual_seed(seed)
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator)
linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator)
y_linear = X @ linear_map

def _fit_and_get_history(X, y):
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=150,
batch_size=512,
device="cpu")
cebra_model.fit(X, y)
history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model)
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
# due to numerical issues.
return history[5:]

history_random = _fit_and_get_history(X, y_random)
history_linear = _fit_and_get_history(X, y_linear)

assert isinstance(history_random, np.ndarray)
assert history_random.shape[0] > 0
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
# due to numerical issues.
history_random_non_negative = history_random[history_random >= 0]
np.testing.assert_allclose(history_random_non_negative, 0, atol=0.05)

assert isinstance(history_linear, np.ndarray)
assert history_linear.shape[0] > 0

assert np.all(history_linear[-20:] > history_random[-20:])

0 comments on commit db9df82

Please sign in to comment.