Skip to content

Commit

Permalink
Merge branch 'main' into icarosadero-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
stes authored Dec 17, 2024
2 parents aaf912f + 5f46c32 commit 74fc232
Show file tree
Hide file tree
Showing 67 changed files with 521 additions and 361 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ jobs:
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.2.2", "2.4.0"]
sklearn-version: ["latest"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "latest"
- os: ubuntu-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"

runs-on: ${{ matrix.os }}

Expand All @@ -32,7 +38,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}

- name: Checkout code
uses: actions/checkout@v2
Expand All @@ -48,6 +54,11 @@ jobs:
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'
- name: Check sklearn legacy version
if: matrix.sklearn-version == 'legacy'
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
- name: Run the formatter
run: |
make format
Expand Down
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ repos:
- id: isort
additional_dependencies:
- pyproject.toml
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.280
hooks:
- id: ruff
10 changes: 5 additions & 5 deletions cebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from cebra.integrations.sklearn.decoder import L1LinearRegressor

is_sklearn_available = True
except ImportError as e:
except ImportError:
# silently fail for now
pass

Expand All @@ -42,7 +42,7 @@
from cebra.integrations.matplotlib import *

is_matplotlib_available = True
except ImportError as e:
except ImportError:
# silently fail for now
pass

Expand All @@ -51,7 +51,7 @@
from cebra.integrations.plotly import *

is_plotly_available = True
except ImportError as e:
except ImportError:
# silently fail for now
pass

Expand Down Expand Up @@ -92,11 +92,11 @@ def __getattr__(key):

return CEBRA
elif key == "KNNDecoder":
from cebra.integrations.sklearn.decoder import KNNDecoder
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811

return KNNDecoder
elif key == "L1LinearRegressor":
from cebra.integrations.sklearn.decoder import L1LinearRegressor
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811

return L1LinearRegressor
elif not key.startswith("_"):
Expand Down
4 changes: 0 additions & 4 deletions cebra/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
import argparse
import sys

import numpy as np
import torch

import cebra
import cebra.distributions as cebra_distr


def train(parser, kwargs):
Expand Down
1 change: 0 additions & 1 deletion cebra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#
import argparse
import json
from dataclasses import MISSING
from typing import Literal, Optional

import literate_dataclasses as dataclasses
Expand Down
3 changes: 0 additions & 3 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
"""Base classes for datasets and loaders."""

import abc
import collections
from typing import List

import literate_dataclasses as dataclasses
import numpy as np
import torch

import cebra.data.assets as cebra_data_assets
Expand Down
51 changes: 36 additions & 15 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,16 @@
#
"""Pre-defined datasets."""

import abc
import collections
import types
from typing import List, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import literate_dataclasses as dataclasses
import numpy as np
import numpy.typing as npt
import torch
from numpy.typing import NDArray

import cebra.data as cebra_data
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex
import cebra.helper as cebra_helper
from cebra.data.datatypes import Offset


class TensorDataset(cebra_data.SingleSessionDataset):
Expand Down Expand Up @@ -71,26 +66,52 @@ def __init__(self,
neural: Union[torch.Tensor, npt.NDArray],
continuous: Union[torch.Tensor, npt.NDArray] = None,
discrete: Union[torch.Tensor, npt.NDArray] = None,
offset: int = 1,
offset: Offset = Offset(0, 1),
device: str = "cpu"):
super().__init__(device=device)
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
self.discrete = self._to_tensor(discrete, torch.LongTensor)
self.neural = self._to_tensor(neural, check_dtype="float").float()
self.continuous = self._to_tensor(continuous, check_dtype="float")
self.discrete = self._to_tensor(discrete, check_dtype="int")
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
)
self.offset = offset

def _to_tensor(self, array, check_dtype=None):
def _to_tensor(
self,
array: Union[torch.Tensor, npt.NDArray],
check_dtype: Optional[Literal["int",
"float"]] = None) -> torch.Tensor:
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
Args:
array: Array to check.
check_dtype: If not `None`, list of dtypes to which the values in `array`
must belong to. Defaults to None.
Returns:
The `array` as a :py:class:`torch.Tensor`.
"""
if array is None:
return None
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if not isinstance(array, check_dtype):
raise TypeError(f"{type(array)} instead of {check_dtype}.")
if check_dtype not in ["int", "float"]:
raise ValueError(
f"check_dtype must be 'int' or 'float', got {check_dtype}")
if (check_dtype == "int" and not cebra_helper._is_integer(array)
) or (check_dtype == "float" and
not cebra_helper._is_floating(array)):
raise TypeError(
f"Array has type {array.dtype} instead of {check_dtype}.")
if cebra_helper._is_floating(array):
array = array.float()
if cebra_helper._is_integer(array):
# NOTE(stes): Required for standardizing number format on
# windows machines.
array = array.long()
return array

@property
Expand Down
3 changes: 0 additions & 3 deletions cebra/data/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
# limitations under the License.
#
import collections
from typing import Tuple

import torch

__all__ = ["Batch", "BatchIndex", "Offset"]

Expand Down
23 changes: 14 additions & 9 deletions cebra/data/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
For each dataset, the data and labels to align the data on is provided.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
the labels of the reference dataset (``ref_label``) are selected and used to sample
from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
to the ``ref_data``.
Note:
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number
Expand Down Expand Up @@ -181,14 +186,14 @@ def fit(
elif ref_data.shape[0] == data.shape[0] and (ref_label is None or
label is None):
raise ValueError(
f"Missing labels: the data to align are the same shape but you provided only "
f"one of the sets of labels. Either provide both the reference and alignment "
f"labels or none.")
"Missing labels: the data to align are the same shape but you provided only "
"one of the sets of labels. Either provide both the reference and alignment "
"labels or none.")
else:
if ref_label is None or label is None:
raise ValueError(
f"Missing labels: the data to align are not the same shape, "
f"provide labels to align the data and reference data.")
"Missing labels: the data to align are not the same shape, "
"provide labels to align the data and reference data.")

if len(ref_label.shape) == 1:
ref_label = np.expand_dims(ref_label, axis=1)
Expand Down
3 changes: 2 additions & 1 deletion cebra/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ def load(
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
- if a key is provided, it needs to correspond to an existing item of the collection;
- if a key is provided, the data value accessed needs to be a data structure;
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
- the function loads data for only one data structure, even if the file contains more. The function can be
called again with the corresponding key to get the other ones.
Args:
file: The path to the given file to load, in a supported format.
Expand Down
2 changes: 0 additions & 2 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
"""Datasets and loaders for multi-session training."""

import abc
import collections
from typing import List

import literate_dataclasses as dataclasses
import numpy as np
import torch

import cebra.data as cebra_data
Expand Down
11 changes: 3 additions & 8 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
"""

import abc
import collections
import warnings
from typing import List

import literate_dataclasses as dataclasses
import numpy as np
import torch

import cebra.data as cebra_data
Expand Down Expand Up @@ -353,18 +350,16 @@ def __post_init__(self):
# here might be sub-optimal. The final behavior should be determined after
# e.g. integrating the FAISS dataloader back in.
super().__post_init__()
index = self.index.to(self.device)

if self.conditional != "time_delta":
raise NotImplementedError(
f"Hybrid training is currently only implemented using the ``time_delta`` "
f"continual distribution.")
"Hybrid training is currently only implemented using the ``time_delta`` "
"continual distribution.")

self.time_distribution = cebra.distributions.TimeContrastive(
time_offset=self.time_offset,
num_samples=len(self.dataset.neural),
device=self.device,
)
device=self.device)
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
self.dataset.continuous_index, self.time_offset, device=self.device)

Expand Down
2 changes: 0 additions & 2 deletions cebra/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str:
from cebra.datasets.monkey_reaching import *
from cebra.datasets.synthetic_data import *
except ModuleNotFoundError as e:
import warnings

warnings.warn(f"Could not initialize one or more datasets: {e}. "
f"For using the datasets, consider installing the "
f"[datasets] extension via pip.")
Expand Down
18 changes: 9 additions & 9 deletions cebra/datasets/allen/ca_movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@
"""Allen pseudomouse Ca dataset.
References:
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
*https://github.com/zivlab/visual_drift
*http://observatory.brain-map.org/visualcoding
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
"Representational drift in the mouse visual cortex."
Current biology 31.19 (2021): 4327-4339.
* de Vries, Saskia EJ, et al.
"A large-scale standardized physiological survey reveals functional
organization of the mouse visual cortex."
Nature neuroscience 23.1 (2020): 138-151.
* https://github.com/zivlab/visual_drift
* http://observatory.brain-map.org/visualcoding
"""

import glob
import hashlib
import pathlib

import h5py
import joblib
import numpy as np
import pandas as pd
Expand All @@ -46,7 +47,6 @@
import cebra.data
from cebra.datasets import get_datapath
from cebra.datasets import parametrize
from cebra.datasets import register
from cebra.datasets.allen import NUM_NEURONS
from cebra.datasets.allen import SEEDS

Expand Down
Loading

0 comments on commit 74fc232

Please sign in to comment.