From 283de064072ce74f9f60929ce85c4aca32cbb542 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 21 Jun 2023 11:43:39 +0200 Subject: [PATCH 01/45] first proposal for batching in tranform method --- cebra/solver/base.py | 56 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index c350ba35..91588637 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -285,7 +285,40 @@ def decoding(self, train_loader, valid_loader): return decode_metric @torch.no_grad() - def transform(self, inputs: torch.Tensor) -> torch.Tensor: + def _transform(self, inputs, session_id): + output = self.model(inputs) + return output + + + @torch.no_grad() + def _batched_transform(self, inputs, session_id, batch_size): + num_samples = inputs.shape[0] + num_batches = (num_samples + batch_size - 1) // batch_size + output = [] + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, num_samples) + batched_data = inputs[start_idx:end_idx] + output_batch = self.model(batched_data) + output.append(output_batch) + + output = torch.cat(output) + return output + + + # OPTION 2 + #num_samples = inputs.shape[0] + #num_batches = (num_samples + batch_size - 1) // batch_size + #output = [self.model(inputs[i * batch_size : min((i + 1) * batch_size, num_samples)]) for i in range(num_batches)] + #output = torch.cat(output) + #return output + + @torch.no_grad() + def transform(self, + inputs: torch.Tensor, + session_id: Optional[int] = None, + batch_size: Optional[int] = None) -> torch.Tensor: """Compute the embedding. This function by default only applies the ``forward`` function @@ -293,17 +326,26 @@ def transform(self, inputs: torch.Tensor) -> torch.Tensor: Args: inputs: The input signal - + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + Returns: The output embedding. - - TODO: - * Remove eval mode """ - self.model.eval() - return self.model(inputs) + + + if batch_size is not None: + #TODO: padding properly with convolutions!! + output = self._batched_transform(inputs, session_id, batch_size) + + else: + output = self._transform(inputs, session_id) + return output + + @abc.abstractmethod def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, return the model outputs. From 202e379bc5423bc9e1358aa104cadc94e20e5331 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 22 Jun 2023 16:02:30 +0200 Subject: [PATCH 02/45] first running version of padding with batched inference --- cebra/solver/base.py | 74 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 91588637..52bef6b9 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -44,6 +44,7 @@ import cebra.models from cebra.solver.util import Meter from cebra.solver.util import ProgressBar +import numpy as np @dataclasses.dataclass @@ -57,6 +58,11 @@ class Solver(abc.ABC, cebra.io.HasDevice): criterion: The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own. optimizer: A PyTorch optimizer for updating model and criterion parameters. + pad_before_transform: If ``False``, no padding is applied to the input sequence. + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would + only be ``n-m+1`` steps long. history: Deprecated since 0.0.2. Use :py:attr:`log`. decode_history: Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of :py:meth:`fit`. @@ -69,6 +75,7 @@ class Solver(abc.ABC, cebra.io.HasDevice): model: torch.nn.Module criterion: torch.nn.Module optimizer: torch.optim.Optimizer + pad_before_transform: bool = True history: List = dataclasses.field(default_factory=list) decode_history: List = dataclasses.field(default_factory=list) log: Dict = dataclasses.field(default_factory=lambda: ({ @@ -95,6 +102,7 @@ def state_dict(self) -> dict: return { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), + "pad_before_transform": self.pad_before_transform, "loss": torch.tensor(self.history), "decode": self.decode_history, "criterion": self.criterion.state_dict(), @@ -130,6 +138,8 @@ def _get(key): self.criterion.load_state_dict(_get("criterion")) if _contains("optimizer"): self.optimizer.load_state_dict(_get("optimizer")) + if _contains("pad_before_transform"): + self.pad_before_transform = _get("pad_before_transform") # TODO(stes): This will be deprecated at some point; the "log" attribute # holds the same information. if _contains("loss"): @@ -286,12 +296,55 @@ def decoding(self, train_loader, valid_loader): @torch.no_grad() def _transform(self, inputs, session_id): + + #model = self.select_model(n_inputs_features=inputs.shape[1], + # session_id=session_id) + #model.to(inputs.device) + #offset = model.get_offset() +# + #model.eval() +# + #if self.pad_before_transform: + # device = inputs.device + # inputs = np.pad(inputs.cpu().numpy(), + # ((offset.left, offset.right - 1), (0, 0)), + # mode="edge") + # inputs = torch.from_numpy(inputs).float().to(device) +# + #if isinstance(model, cebra.models.ConvolutionalModelMixin): + # # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + # inputs = inputs.transpose(1, 0).unsqueeze(0) + # output = model(inputs).squeeze(0).transpose(1, 0) + #else: + # # Standard evaluation, (T, C, dt) + # output = model(inputs) + output = self.model(inputs) return output + + def _get_batched_data_with_padding(self, inputs, offset, start_idx, end_idx, batch_id, num_batches): + + if batch_id == 0: + batched_data = inputs[start_idx:(end_idx+offset.right)] + batched_data = np.pad(batched_data.cpu().numpy(), + ((offset.left, 0), (0, 0)), + mode="edge") + + elif batch_id == num_batches - 1: #Last batch + batched_data = inputs[(start_idx-offset.left):end_idx] + batched_data = np.pad(batched_data.cpu().numpy(), + ((0, offset.right-1), (0, 0)), + mode="edge") + + else: # Middle batches + batched_data = inputs[(start_idx-offset.left):(end_idx+offset.right-1)] + + return torch.from_numpy(batched_data) if isinstance(batched_data, np.ndarray) else batched_data + @torch.no_grad() - def _batched_transform(self, inputs, session_id, batch_size): + def _batched_transform(self, inputs, offset, session_id, batch_size): num_samples = inputs.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size output = [] @@ -300,12 +353,23 @@ def _batched_transform(self, inputs, session_id, batch_size): start_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_samples) batched_data = inputs[start_idx:end_idx] - output_batch = self.model(batched_data) + + if self.pad_before_transform: + batched_data = self._get_batched_data_with_padding(inputs, offset, start_idx, end_idx, i, num_batches) + + if isinstance(self.model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + batched_data = batched_data.transpose(1, 0).unsqueeze(0) + output_batch = self.model(batched_data).squeeze(0).transpose(1, 0) + else: + output_batch = self.model(batched_data) + + output.append(output_batch) output = torch.cat(output) + return output - # OPTION 2 #num_samples = inputs.shape[0] @@ -334,11 +398,13 @@ def transform(self, The output embedding. """ + offset = self.model.get_offset() + if batch_size is not None: #TODO: padding properly with convolutions!! - output = self._batched_transform(inputs, session_id, batch_size) + output = self._batched_transform(inputs, offset, session_id, batch_size) else: output = self._transform(inputs, session_id) From 1f1989d699253a487887c551aa67361e4ebcb79b Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 23 Jun 2023 11:53:00 +0200 Subject: [PATCH 03/45] start tests --- tests/test_solver.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_solver.py b/tests/test_solver.py index 46efd319..633c1df0 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -29,6 +29,7 @@ import cebra.datasets import cebra.models import cebra.solver +import numpy as np device = "cpu" @@ -168,3 +169,36 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): assert isinstance(log, dict) solver.fit(loader) + + +def test_batched_transform(data_name, loader_initfunc, solver_initfunc): + """ + test to know if we are getting the batches right without padding + """ + + loader = _get_loader(data_name, loader_initfunc) + model = _make_model(loader.dataset) + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer, + pad_before_transform = False) + + solver.fit(loader) + + # batched_transform + batch_size = 1024 + + # should pad_before_transform be an argument of the transform() method? + embedding_batched = solver.transform(batch_size = batch_size) + embedding = solver.transform(batch_size = None) + + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding) + + + # TODO: how can I check that the batches are correct? + # maybe it is good enough if I compare to the embedding + # without batch size. \ No newline at end of file From 866566024df667e1d9419b6cfd3dc6a168780ee1 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 27 Sep 2023 17:57:07 +0200 Subject: [PATCH 04/45] add pad_before_transform to fit function and add support for convolutional models in _transform --- cebra/solver/base.py | 180 +++++++++++++++++++++++++++---------------- 1 file changed, 112 insertions(+), 68 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 52bef6b9..21b40d14 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -58,11 +58,6 @@ class Solver(abc.ABC, cebra.io.HasDevice): criterion: The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own. optimizer: A PyTorch optimizer for updating model and criterion parameters. - pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would - only be ``n-m+1`` steps long. history: Deprecated since 0.0.2. Use :py:attr:`log`. decode_history: Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of :py:meth:`fit`. @@ -75,7 +70,6 @@ class Solver(abc.ABC, cebra.io.HasDevice): model: torch.nn.Module criterion: torch.nn.Module optimizer: torch.optim.Optimizer - pad_before_transform: bool = True history: List = dataclasses.field(default_factory=list) decode_history: List = dataclasses.field(default_factory=list) log: Dict = dataclasses.field(default_factory=lambda: ({ @@ -102,7 +96,6 @@ def state_dict(self) -> dict: return { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "pad_before_transform": self.pad_before_transform, "loss": torch.tensor(self.history), "decode": self.decode_history, "criterion": self.criterion.state_dict(), @@ -138,8 +131,6 @@ def _get(key): self.criterion.load_state_dict(_get("criterion")) if _contains("optimizer"): self.optimizer.load_state_dict(_get("optimizer")) - if _contains("pad_before_transform"): - self.pad_before_transform = _get("pad_before_transform") # TODO(stes): This will be deprecated at some point; the "log" attribute # holds the same information. if _contains("loss"): @@ -294,95 +285,137 @@ def decoding(self, train_loader, valid_loader): ) return decode_metric - @torch.no_grad() - def _transform(self, inputs, session_id): + def _select_model(self, inputs: torch.Tensor, session_id: int): + is_multisession = False #TODO: take care of this + self.num_sessions = self.loader.dataset.num_sessions if is_multisession else None + if self.num_sessions is not None: # multisession implementation + if session_id is None: + raise RuntimeError( + "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." + ) + if self.n_features_[session_id] != X.shape[1]: + raise ValueError( + f"Invalid input shape: model for session {session_id} requires an input of shape" + f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})." + ) - #model = self.select_model(n_inputs_features=inputs.shape[1], - # session_id=session_id) - #model.to(inputs.device) - #offset = model.get_offset() -# - #model.eval() -# - #if self.pad_before_transform: - # device = inputs.device - # inputs = np.pad(inputs.cpu().numpy(), - # ((offset.left, offset.right - 1), (0, 0)), - # mode="edge") - # inputs = torch.from_numpy(inputs).float().to(device) -# - #if isinstance(model, cebra.models.ConvolutionalModelMixin): - # # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - # inputs = inputs.transpose(1, 0).unsqueeze(0) - # output = model(inputs).squeeze(0).transpose(1, 0) - #else: - # # Standard evaluation, (T, C, dt) - # output = model(inputs) + model = self.model[session_id] + #model.to(self.device_) #TODO: do I need to do this? - output = self.model(inputs) - return output + else: # single session + if session_id is not None and session_id > 0: + raise RuntimeError( + f"Invalid session_id {session_id}: single session models only takes an optional null session_id." + ) + model = self.model + + offset = model.get_offset() + return model, offset + + def _get_batched_data_with_padding(self, + inputs: torch.Tensor, + offset: cebra.data.Offset, + start_batch_idx: int, + end_batch_idx: int, + batch_id: int, + num_batches: int) -> torch.Tensor: - def _get_batched_data_with_padding(self, inputs, offset, start_idx, end_idx, batch_id, num_batches): + """ + Given the start_batch_idx, end_batch_idx, adds padding. + For the first batch it adds 0 to left, data to right + For the last batch it adds data to left, 0 to right + For the middle batches if adds data both to left and right - if batch_id == 0: - batched_data = inputs[start_idx:(end_idx+offset.right)] + Args: + inputs + offset: + start_batch_idx: + end_batch_idx: + offset: cebra.datatypes.Offset + + """ + print(start_batch_idx, end_batch_idx) + if batch_id == 0: # First batch + batched_data = inputs[start_batch_idx:(end_batch_idx+offset.right-1)] batched_data = np.pad(batched_data.cpu().numpy(), ((offset.left, 0), (0, 0)), mode="edge") - + elif batch_id == num_batches - 1: #Last batch - batched_data = inputs[(start_idx-offset.left):end_idx] + batched_data = inputs[(start_batch_idx-offset.left):end_batch_idx] batched_data = np.pad(batched_data.cpu().numpy(), ((0, offset.right-1), (0, 0)), mode="edge") - - else: # Middle batches - batched_data = inputs[(start_idx-offset.left):(end_idx+offset.right-1)] + else: # Middle batches + batched_data = inputs[(start_batch_idx-offset.left):(end_batch_idx+offset.right-1)] + + print(inputs.shape, batched_data.shape) return torch.from_numpy(batched_data) if isinstance(batched_data, np.ndarray) else batched_data @torch.no_grad() - def _batched_transform(self, inputs, offset, session_id, batch_size): + def _batched_transform(self, model, inputs, offset, batch_size, pad_before_transform) -> torch.Tensor: num_samples = inputs.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size output = [] for i in range(num_batches): - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, num_samples) - batched_data = inputs[start_idx:end_idx] - - if self.pad_before_transform: - batched_data = self._get_batched_data_with_padding(inputs, offset, start_idx, end_idx, i, num_batches) - - if isinstance(self.model, cebra.models.ConvolutionalModelMixin): + start_batch_idx = i * batch_size + end_batch_idx = min((i + 1) * batch_size, num_samples) + + if pad_before_transform: + batched_data = self._get_batched_data_with_padding( + inputs=inputs, + offset=offset, + start_batch_idx=start_batch_idx, + end_batch_idx=end_batch_idx, + batch_id=i, + num_batches=num_batches) + else: + batched_data = inputs[start_batch_idx:end_batch_idx] + + if isinstance(model, cebra.models.ConvolutionalModelMixin): # Fully convolutional evaluation, switch (T, C) -> (1, C, T) batched_data = batched_data.transpose(1, 0).unsqueeze(0) - output_batch = self.model(batched_data).squeeze(0).transpose(1, 0) + output_batch = model(batched_data).squeeze(0).transpose(1, 0) else: - output_batch = self.model(batched_data) + output_batch = model(batched_data) - output.append(output_batch) - output = torch.cat(output) return output - # OPTION 2 - #num_samples = inputs.shape[0] - #num_batches = (num_samples + batch_size - 1) // batch_size - #output = [self.model(inputs[i * batch_size : min((i + 1) * batch_size, num_samples)]) for i in range(num_batches)] - #output = torch.cat(output) - #return output + @torch.no_grad() + def _transform(self, model, inputs, offset, pad_before_transform) -> torch.Tensor: + + if pad_before_transform: + inputs = np.pad(inputs, ((offset.left, offset.right - 1), (0, 0)), mode="edge") + inputs = torch.from_numpy(inputs) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + inputs = inputs.transpose(1, 0).unsqueeze(0) + output = model(inputs).squeeze(0).transpose(1, 0) + else: + output = model(inputs) + + return output @torch.no_grad() def transform(self, inputs: torch.Tensor, + pad_before_transform: bool = True, #TODO: what should be the default? session_id: Optional[int] = None, batch_size: Optional[int] = None) -> torch.Tensor: + + """Compute the embedding. This function by default only applies the ``forward`` function @@ -390,6 +423,11 @@ def transform(self, Args: inputs: The input signal + pad_before_transform: If ``False``, no padding is applied to the input sequence. + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would + only be ``n-m+1`` steps long. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. @@ -397,21 +435,27 @@ def transform(self, Returns: The output embedding. """ + model, offset = self._select_model(inputs, session_id) + model.eval() - offset = self.model.get_offset() - + if len(offset) < 2 and pad_before_transform: + raise ValueError("Padding does not make sense when the offset of the model is < 2") - if batch_size is not None: - #TODO: padding properly with convolutions!! - output = self._batched_transform(inputs, offset, session_id, batch_size) + output = self._batched_transform(model=model, + inputs=inputs, + offset=offset, + batch_size=batch_size, + pad_before_transform=pad_before_transform,) else: - output = self._transform(inputs, session_id) + output = self._transform(model=model, + inputs=inputs, + offset=offset, + pad_before_transform=pad_before_transform) return output - @abc.abstractmethod def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, return the model outputs. From 8d5b114e085bfa0080cb57623bd4f1c058795670 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 27 Sep 2023 17:58:19 +0200 Subject: [PATCH 05/45] remove print statements --- cebra/solver/base.py | 137 ++++++++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 67 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 21b40d14..a243fe2e 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -35,6 +35,7 @@ from typing import Callable, Dict, List, Literal, Optional, Union import literate_dataclasses as dataclasses +import numpy as np import torch import tqdm @@ -44,7 +45,6 @@ import cebra.models from cebra.solver.util import Meter from cebra.solver.util import ProgressBar -import numpy as np @dataclasses.dataclass @@ -286,7 +286,7 @@ def decoding(self, train_loader, valid_loader): return decode_metric def _select_model(self, inputs: torch.Tensor, session_id: int): - is_multisession = False #TODO: take care of this + is_multisession = False #TODO: take care of this self.num_sessions = self.loader.dataset.num_sessions if is_multisession else None if self.num_sessions is not None: # multisession implementation if session_id is None: @@ -305,7 +305,7 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): model = self.model[session_id] #model.to(self.device_) #TODO: do I need to do this? - + else: # single session if session_id is not None and session_id > 0: raise RuntimeError( @@ -315,16 +315,12 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): offset = model.get_offset() return model, offset - - - def _get_batched_data_with_padding(self, - inputs: torch.Tensor, - offset: cebra.data.Offset, - start_batch_idx: int, - end_batch_idx: int, - batch_id: int, - num_batches: int) -> torch.Tensor: + def _get_batched_data_with_padding(self, inputs: torch.Tensor, + offset: cebra.data.Offset, + start_batch_idx: int, end_batch_idx: int, + batch_id: int, + num_batches: int) -> torch.Tensor: """ Given the start_batch_idx, end_batch_idx, adds padding. For the first batch it adds 0 to left, data to right @@ -332,35 +328,37 @@ def _get_batched_data_with_padding(self, For the middle batches if adds data both to left and right Args: - inputs - offset: - start_batch_idx: - end_batch_idx: + inputs + offset: + start_batch_idx: + end_batch_idx: offset: cebra.datatypes.Offset """ - print(start_batch_idx, end_batch_idx) - if batch_id == 0: # First batch - batched_data = inputs[start_batch_idx:(end_batch_idx+offset.right-1)] + if batch_id == 0: # First batch + batched_data = inputs[start_batch_idx:(end_batch_idx + + offset.right - 1)] batched_data = np.pad(batched_data.cpu().numpy(), - ((offset.left, 0), (0, 0)), - mode="edge") + ((offset.left, 0), (0, 0)), + mode="edge") - elif batch_id == num_batches - 1: #Last batch - batched_data = inputs[(start_batch_idx-offset.left):end_batch_idx] + elif batch_id == num_batches - 1: #Last batch + batched_data = inputs[(start_batch_idx - offset.left):end_batch_idx] batched_data = np.pad(batched_data.cpu().numpy(), - ((0, offset.right-1), (0, 0)), - mode="edge") + ((0, offset.right - 1), (0, 0)), + mode="edge") + + else: # Middle batches + batched_data = inputs[(start_batch_idx - + offset.left):(end_batch_idx + offset.right - + 1)] - else: # Middle batches - batched_data = inputs[(start_batch_idx-offset.left):(end_batch_idx+offset.right-1)] - - print(inputs.shape, batched_data.shape) - return torch.from_numpy(batched_data) if isinstance(batched_data, np.ndarray) else batched_data - + return torch.from_numpy(batched_data) if isinstance( + batched_data, np.ndarray) else batched_data @torch.no_grad() - def _batched_transform(self, model, inputs, offset, batch_size, pad_before_transform) -> torch.Tensor: + def _batched_transform(self, model, inputs, offset, batch_size, + pad_before_transform) -> torch.Tensor: num_samples = inputs.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size output = [] @@ -368,35 +366,37 @@ def _batched_transform(self, model, inputs, offset, batch_size, pad_before_trans for i in range(num_batches): start_batch_idx = i * batch_size end_batch_idx = min((i + 1) * batch_size, num_samples) - + if pad_before_transform: batched_data = self._get_batched_data_with_padding( - inputs=inputs, - offset=offset, - start_batch_idx=start_batch_idx, - end_batch_idx=end_batch_idx, - batch_id=i, - num_batches=num_batches) + inputs=inputs, + offset=offset, + start_batch_idx=start_batch_idx, + end_batch_idx=end_batch_idx, + batch_id=i, + num_batches=num_batches) else: batched_data = inputs[start_batch_idx:end_batch_idx] - + if isinstance(model, cebra.models.ConvolutionalModelMixin): # Fully convolutional evaluation, switch (T, C) -> (1, C, T) batched_data = batched_data.transpose(1, 0).unsqueeze(0) output_batch = model(batched_data).squeeze(0).transpose(1, 0) else: output_batch = model(batched_data) - + output.append(output_batch) output = torch.cat(output) - + return output @torch.no_grad() - def _transform(self, model, inputs, offset, pad_before_transform) -> torch.Tensor: - + def _transform(self, model, inputs, offset, + pad_before_transform) -> torch.Tensor: + if pad_before_transform: - inputs = np.pad(inputs, ((offset.left, offset.right - 1), (0, 0)), mode="edge") + inputs = np.pad(inputs, ((offset.left, offset.right - 1), (0, 0)), + mode="edge") inputs = torch.from_numpy(inputs) if isinstance(model, cebra.models.ConvolutionalModelMixin): @@ -405,17 +405,16 @@ def _transform(self, model, inputs, offset, pad_before_transform) -> torch.Tenso output = model(inputs).squeeze(0).transpose(1, 0) else: output = model(inputs) - + return output @torch.no_grad() - def transform(self, - inputs: torch.Tensor, - pad_before_transform: bool = True, #TODO: what should be the default? - session_id: Optional[int] = None, - batch_size: Optional[int] = None) -> torch.Tensor: - - + def transform( + self, + inputs: torch.Tensor, + pad_before_transform: bool = True, #TODO: what should be the default? + session_id: Optional[int] = None, + batch_size: Optional[int] = None) -> torch.Tensor: """Compute the embedding. This function by default only applies the ``forward`` function @@ -424,14 +423,14 @@ def transform(self, Args: inputs: The input signal pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would only be ``n-m+1`` steps long. - session_id: The session ID, an :py:class:`int` between 0 and - the number of sessions -1 for multisession, and set to + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to ``None`` for single session. - + Returns: The output embedding. """ @@ -439,14 +438,18 @@ def transform(self, model.eval() if len(offset) < 2 and pad_before_transform: - raise ValueError("Padding does not make sense when the offset of the model is < 2") - + raise ValueError( + "Padding does not make sense when the offset of the model is < 2" + ) + if batch_size is not None: - output = self._batched_transform(model=model, - inputs=inputs, - offset=offset, - batch_size=batch_size, - pad_before_transform=pad_before_transform,) + output = self._batched_transform( + model=model, + inputs=inputs, + offset=offset, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + ) else: output = self._transform(model=model, @@ -455,7 +458,7 @@ def transform(self, pad_before_transform=pad_before_transform) return output - + @abc.abstractmethod def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, return the model outputs. From 32c5ecd28d4ebce8b1063d18cd5a849327e85b76 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 27 Sep 2023 18:22:12 +0200 Subject: [PATCH 06/45] first passing test --- tests/test_solver.py | 61 +++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 633c1df0..06fea193 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -21,6 +21,7 @@ # import itertools +import numpy as np import pytest import torch from torch import nn @@ -29,7 +30,6 @@ import cebra.datasets import cebra.models import cebra.solver -import numpy as np device = "cpu" @@ -171,34 +171,53 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): solver.fit(loader) -def test_batched_transform(data_name, loader_initfunc, solver_initfunc): - """ - test to know if we are getting the batches right without padding - """ +def create_model(model_name, dataset): + return cebra.models.init(model_name, + num_neurons=dataset.input_dimension, + num_units=128, + num_output=5) + + +single_session_tests_transform = [] +for model_name in ["offset1-model", "offset10-model"]: + for args in [ + ("demo-discrete", model_name, cebra.data.DiscreteDataLoader), + ("demo-continuous", model_name, cebra.data.ContinuousDataLoader), + ("demo-mixed", model_name, cebra.data.MixedDataLoader), + ]: + single_session_tests_transform.append( + (*args, cebra.solver.SingleSessionSolver)) + + +@pytest.mark.parametrize( + "data_name, model_name, loader_initfunc, solver_initfunc", + single_session_tests_transform) +def test_batched_transform_no_padding(data_name, model_name, loader_initfunc, + solver_initfunc): + batch_size = 1024 + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset) + dataset.offset = model.get_offset() + loader_kwargs = dict(num_steps=10, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) - loader = _get_loader(data_name, loader_initfunc) - model = _make_model(loader.dataset) criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, criterion=criterion, - optimizer=optimizer, - pad_before_transform = False) - + optimizer=optimizer) solver.fit(loader) - # batched_transform - batch_size = 1024 - - # should pad_before_transform be an argument of the transform() method? - embedding_batched = solver.transform(batch_size = batch_size) - embedding = solver.transform(batch_size = None) + embedding_batched = solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size, + pad_before_transform=False) - assert embedding_batched.shape == embedding.shape - assert np.allclose(embedding_batched, embedding) + embedding = solver.transform(inputs=loader.dataset.neural, + pad_before_transform=False) + if not isinstance(model, cebra.models.ConvolutionalModelMixin): + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-02) - # TODO: how can I check that the batches are correct? - # maybe it is good enough if I compare to the embedding - # without batch size. \ No newline at end of file + #TODO: what tests can I do with convolutional models when there is no padding? From 9928f635a0deaa8d8f6c95b91b38816b783eba4e Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 28 Sep 2023 11:48:47 +0200 Subject: [PATCH 07/45] add support for hybrid models --- cebra/solver/base.py | 19 ++++-- tests/test_solver.py | 138 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 126 insertions(+), 31 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index a243fe2e..125c25c8 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -286,8 +286,10 @@ def decoding(self, train_loader, valid_loader): return decode_metric def _select_model(self, inputs: torch.Tensor, session_id: int): - is_multisession = False #TODO: take care of this - self.num_sessions = self.loader.dataset.num_sessions if is_multisession else None + """ Select the right model based on the type of solver we have.""" + + self.num_sessions = self.loader.dataset.num_sessions if isinstance( + inputs, list) else None if self.num_sessions is not None: # multisession implementation if session_id is None: raise RuntimeError( @@ -304,14 +306,23 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): ) model = self.model[session_id] - #model.to(self.device_) #TODO: do I need to do this? + model.to(self.device_) #TODO: why do I need to do this? else: # single session if session_id is not None and session_id > 0: raise RuntimeError( f"Invalid session_id {session_id}: single session models only takes an optional null session_id." ) - model = self.model + + if isinstance( + self, + cebra.solver.single_session.SingleSessionHybridSolver): + # NOTE: This is different from the sklearn API implementation. The issue is that here the + # model is a cebra.models.MultiObjective instance, and therefore to do inference I need + # to get the module inside this model. + model = self.model.module + else: + model = self.model offset = model.get_offset() return model, offset diff --git a/tests/test_solver.py b/tests/test_solver.py index 06fea193..5412b697 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -171,32 +171,51 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): solver.fit(loader) -def create_model(model_name, dataset): +def create_model(model_name, input_dimension): return cebra.models.init(model_name, - num_neurons=dataset.input_dimension, + num_neurons=input_dimension, num_units=128, num_output=5) single_session_tests_transform = [] -for model_name in ["offset1-model", "offset10-model"]: - for args in [ - ("demo-discrete", model_name, cebra.data.DiscreteDataLoader), - ("demo-continuous", model_name, cebra.data.ContinuousDataLoader), - ("demo-mixed", model_name, cebra.data.MixedDataLoader), - ]: - single_session_tests_transform.append( - (*args, cebra.solver.SingleSessionSolver)) +for padding in [True, False]: + for model_name in ["offset1-model", "offset10-model"]: + for args in [ + ("demo-discrete", model_name, padding, + cebra.data.DiscreteDataLoader), + ("demo-continuous", model_name, padding, + cebra.data.ContinuousDataLoader), + ("demo-mixed", model_name, padding, cebra.data.MixedDataLoader), + ]: + single_session_tests_transform.append( + (*args, cebra.solver.SingleSessionSolver)) + +single_session_hybrid_tests_transform = [] +for padding in [True, False]: + for model_name in ["offset1-model", "offset10-model"]: + for args in [("demo-continuous", model_name, padding, + cebra.data.HybridDataLoader)]: + single_session_hybrid_tests_transform.append( + (*args, cebra.solver.SingleSessionHybridSolver)) + +multi_session_tests_transform = [] +for padding in [True, False]: + for model_name in ["offset1-model", "offset10-model"]: + for args in [("demo-continuous-multisession", model_name, padding, + cebra.data.ContinuousMultiSessionDataLoader)]: + multi_session_tests_transform.append( + (*args, cebra.solver.MultiSessionSolver)) @pytest.mark.parametrize( - "data_name, model_name, loader_initfunc, solver_initfunc", - single_session_tests_transform) -def test_batched_transform_no_padding(data_name, model_name, loader_initfunc, - solver_initfunc): + "data_name, model_name, padding, loader_initfunc, solver_initfunc", + single_session_tests_transform + single_session_hybrid_tests_transform) +def test_batched_transform_singlesession(data_name, model_name, padding, + loader_initfunc, solver_initfunc): batch_size = 1024 dataset = cebra.datasets.init(data_name) - model = create_model(model_name, dataset) + model = create_model(model_name, dataset.input_dimension) dataset.offset = model.get_offset() loader_kwargs = dict(num_steps=10, batch_size=32) loader = loader_initfunc(dataset, **loader_kwargs) @@ -209,15 +228,80 @@ def test_batched_transform_no_padding(data_name, model_name, loader_initfunc, optimizer=optimizer) solver.fit(loader) - embedding_batched = solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size, - pad_before_transform=False) - - embedding = solver.transform(inputs=loader.dataset.neural, - pad_before_transform=False) - - if not isinstance(model, cebra.models.ConvolutionalModelMixin): - assert embedding_batched.shape == embedding.shape - assert np.allclose(embedding_batched, embedding, rtol=1e-02) - - #TODO: what tests can I do with convolutional models when there is no padding? + if len(model.get_offset()) < 2 and padding: + with pytest.raises(ValueError): + solver.transform(inputs=loader.dataset.neural, + pad_before_transform=padding) + + with pytest.raises(ValueError): + solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size, + pad_before_transform=padding) + else: + embedding_batched = solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size, + pad_before_transform=padding) + + embedding = solver.transform(inputs=loader.dataset.neural, + pad_before_transform=padding) + + if padding: + if isinstance(model, cebra.models.ConvolutionalModelMixin): + assert embedding_batched.shape == embedding.shape + assert embedding_batched.shape == embedding.shape + + else: + if isinstance(model, cebra.models.ConvolutionalModelMixin): + #TODO: what to check here exactly? + pass + else: + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-02) + + +# def test_batched_transform_multisession(data_name, model_name, padding, loader_initfunc, solver_initfunc): +# batch_size = 1024 +# dataset = cebra.datasets.init(data_name) +# model = nn.ModuleList( +# [create_model(model_name, dataset.input_dimension) for dataset in dataset.iter_sessions()]) +# dataset.offset = model[0].get_offset() +# loader_kwargs = dict(num_steps=10, batch_size=32) +# loader = loader_initfunc(dataset, **loader_kwargs) + +# criterion = cebra.models.InfoNCE() +# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# solver = solver_initfunc(model=model, +# criterion=criterion, +# optimizer=optimizer) +# solver.fit(loader) + +# if len(model.get_offset()) < 2 and padding: +# with pytest.raises(ValueError): +# solver.transform(inputs=loader.dataset.neural, +# pad_before_transform=padding) + +# with pytest.raises(ValueError): +# solver.transform(inputs=loader.dataset.neural, +# batch_size=batch_size, +# pad_before_transform=padding) +# else: +# embedding_batched = solver.transform(inputs=loader.dataset.neural, +# batch_size=batch_size, +# pad_before_transform=padding) + +# embedding = solver.transform(inputs=loader.dataset.neural, +# pad_before_transform=padding) + +# if padding: +# if isinstance(model, cebra.models.ConvolutionalModelMixin): +# assert embedding_batched.shape == embedding.shape +# assert embedding_batched.shape == embedding.shape + +# else: +# if isinstance(model, cebra.models.ConvolutionalModelMixin): +# #TODO: what to check here exactly? +# pass +# else: +# assert embedding_batched.shape == embedding.shape +# assert np.allclose(embedding_batched, embedding, rtol=1e-02) From be5630aed262e9036523e1727f748e977df7b5f7 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 28 Sep 2023 13:40:13 +0200 Subject: [PATCH 08/45] rewrite transform in sklearn API --- cebra/integrations/sklearn/cebra.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 077d3c47..2c9eba2b 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1200,11 +1200,17 @@ def fit( def transform(self, X: Union[npt.NDArray, torch.Tensor], + pad_before_transform: bool = True, session_id: Optional[int] = None) -> npt.NDArray: """Transform an input sequence and return the embedding. Args: X: A numpy array or torch tensor of size ``time x dimension``. + pad_before_transform: If ``False``, no padding is applied to the input sequence. + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would + only be ``n-m+1`` steps long. session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for multisession, set to ``None`` for single session. @@ -1224,27 +1230,13 @@ def transform(self, """ sklearn_utils_validation.check_is_fitted(self, "n_features_") - model, offset = self._select_model(X, session_id) - # Input validation X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) input_dtype = X.dtype with torch.no_grad(): - model.eval() - - if self.pad_before_transform: - X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), - mode="edge") - X = torch.from_numpy(X).float().to(self.device_) - - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - X = X.transpose(1, 0).unsqueeze(0) - output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) - else: - # Standard evaluation, (T, C, dt) - output = model(X).cpu().numpy() + output = self.solver_.transform( + X, pad_before_transform=pad_before_transform) if input_dtype == "float64": return output.astype(input_dtype) From 1300b2052ccc27d2eb7077de145c0f662202cd29 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 16 Oct 2023 16:41:25 +0200 Subject: [PATCH 09/45] baseline version of a torch.Datset --- cebra/solver/util.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/cebra/solver/util.py b/cebra/solver/util.py index 584eb0da..c7dc7533 100644 --- a/cebra/solver/util.py +++ b/cebra/solver/util.py @@ -25,6 +25,7 @@ from typing import Dict import literate_dataclasses as dataclasses +import torch import tqdm @@ -106,3 +107,31 @@ def set_description(self, stats: Dict[str, float]): """ if self.use_tqdm: self.iterator.set_description(_description(stats)) + + +def initalize_torch_dataloader(inputs: torch.Tensor, batch_size: int): + """ + Initializes a torch DataLoader. + Args: + inputs: NxD tensor + batch_size: what happens when is None? it should return the whole dataset. + """ + + class TorchDataset(torch.utils.data.Dataset): + + def __init__(self, inputs): + self.inputs = inputs + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return self.data[idx] + + # TODO: I need to implement the padding inside the dataset, otherwise + # I can't properly do this afterwards I think. + + # I wrote the simplest version possible of a torch.utils.data.Dataset, + # but should be extended with the padding. + + return torch.util.data.DataLoader(TorchDataset, batch_size=batch_size) From bc6af241dceb8183a142627f756cc2c6d4c2973a Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 20 Oct 2023 17:22:58 +0200 Subject: [PATCH 10/45] move batching logic outside solver --- cebra/solver/base.py | 97 +++++++++++--------------------------------- cebra/solver/util.py | 65 +++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 90 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 125c25c8..b282b27f 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -43,6 +43,7 @@ import cebra.data import cebra.io import cebra.models +import cebra.solver.util as cebra_solver_util from cebra.solver.util import Meter from cebra.solver.util import ProgressBar @@ -285,6 +286,17 @@ def decoding(self, train_loader, valid_loader): ) return decode_metric + def _inference_transform(self, model, inputs): + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + inputs = inputs.transpose(1, 0).unsqueeze(0) + output = model(inputs).squeeze(0).transpose(1, 0) + else: + output = model(inputs) + + return output + def _select_model(self, inputs: torch.Tensor, session_id: int): """ Select the right model based on the type of solver we have.""" @@ -327,78 +339,23 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): offset = model.get_offset() return model, offset - def _get_batched_data_with_padding(self, inputs: torch.Tensor, - offset: cebra.data.Offset, - start_batch_idx: int, end_batch_idx: int, - batch_id: int, - num_batches: int) -> torch.Tensor: - """ - Given the start_batch_idx, end_batch_idx, adds padding. - For the first batch it adds 0 to left, data to right - For the last batch it adds data to left, 0 to right - For the middle batches if adds data both to left and right - - Args: - inputs - offset: - start_batch_idx: - end_batch_idx: - offset: cebra.datatypes.Offset - - """ - if batch_id == 0: # First batch - batched_data = inputs[start_batch_idx:(end_batch_idx + - offset.right - 1)] - batched_data = np.pad(batched_data.cpu().numpy(), - ((offset.left, 0), (0, 0)), - mode="edge") - - elif batch_id == num_batches - 1: #Last batch - batched_data = inputs[(start_batch_idx - offset.left):end_batch_idx] - batched_data = np.pad(batched_data.cpu().numpy(), - ((0, offset.right - 1), (0, 0)), - mode="edge") - - else: # Middle batches - batched_data = inputs[(start_batch_idx - - offset.left):(end_batch_idx + offset.right - - 1)] - - return torch.from_numpy(batched_data) if isinstance( - batched_data, np.ndarray) else batched_data - @torch.no_grad() def _batched_transform(self, model, inputs, offset, batch_size, pad_before_transform) -> torch.Tensor: - num_samples = inputs.shape[0] - num_batches = (num_samples + batch_size - 1) // batch_size output = [] - - for i in range(num_batches): - start_batch_idx = i * batch_size - end_batch_idx = min((i + 1) * batch_size, num_samples) - - if pad_before_transform: - batched_data = self._get_batched_data_with_padding( - inputs=inputs, - offset=offset, - start_batch_idx=start_batch_idx, - end_batch_idx=end_batch_idx, - batch_id=i, - num_batches=num_batches) - else: - batched_data = inputs[start_batch_idx:end_batch_idx] - - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - batched_data = batched_data.transpose(1, 0).unsqueeze(0) - output_batch = model(batched_data).squeeze(0).transpose(1, 0) - else: - output_batch = model(batched_data) - + batches = cebra_solver_util.get_batches_of_data( + inputs=inputs, + batch_size=batch_size, + padding=pad_before_transform, + offset=offset) + + # NOTE: If we move this inside the `cebra_solver_util.get_batches_of_data`or similar + # we avoid a second for loop. Is it good practice to do inference outside the solver? + for batch in batches: + output_batch = self._inference_transform(model, batch) output.append(output_batch) - output = torch.cat(output) + output = torch.cat(output) return output @torch.no_grad() @@ -410,13 +367,7 @@ def _transform(self, model, inputs, offset, mode="edge") inputs = torch.from_numpy(inputs) - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - inputs = inputs.transpose(1, 0).unsqueeze(0) - output = model(inputs).squeeze(0).transpose(1, 0) - else: - output = model(inputs) - + output = self._inference_transform(model, inputs) return output @torch.no_grad() diff --git a/cebra/solver/util.py b/cebra/solver/util.py index c7dc7533..4137dab7 100644 --- a/cebra/solver/util.py +++ b/cebra/solver/util.py @@ -25,8 +25,13 @@ from typing import Dict import literate_dataclasses as dataclasses +import numpy as np import torch import tqdm +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +import cebra.data def _description(stats: Dict[str, float]): @@ -109,15 +114,13 @@ def set_description(self, stats: Dict[str, float]): self.iterator.set_description(_description(stats)) -def initalize_torch_dataloader(inputs: torch.Tensor, batch_size: int): - """ - Initializes a torch DataLoader. - Args: - inputs: NxD tensor - batch_size: what happens when is None? it should return the whole dataset. - """ +def get_batches_of_data(inputs: torch.Tensor, + batch_size: int, + padding: bool, + offset: cebra.data.Offset = None): + batches = [] - class TorchDataset(torch.utils.data.Dataset): + class IndexDataset(Dataset): def __init__(self, inputs): self.inputs = inputs @@ -126,12 +129,40 @@ def __len__(self): return len(self.inputs) def __getitem__(self, idx): - return self.data[idx] - - # TODO: I need to implement the padding inside the dataset, otherwise - # I can't properly do this afterwards I think. - - # I wrote the simplest version possible of a torch.utils.data.Dataset, - # but should be extended with the padding. - - return torch.util.data.DataLoader(TorchDataset, batch_size=batch_size) + return idx + + index_dataset = IndexDataset(inputs) + index_dataloader = DataLoader(index_dataset, batch_size=batch_size) + for batch_id, index_batch in enumerate(index_dataloader): + + start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + if padding: + if offset is None: + raise ValueError("offset needs to be set if padding is True.") + + if batch_id == 0: + indices = start_batch_idx, (end_batch_idx + offset.right) + batched_data = inputs[slice(*indices)] + batched_data = np.pad(batched_data.cpu().numpy(), + ((offset.left, 0), (0, 0)), + mode="edge") + + elif batch_id == len(index_dataloader) - 1: + indices = (start_batch_idx - offset.left), end_batch_idx + batched_data = inputs[slice(*indices)] + batched_data = np.pad(batched_data.cpu().numpy(), + ((0, offset.right), (0, 0)), + mode="edge") + else: # Middle batches + indices = start_batch_idx - offset.left, end_batch_idx + offset.right + batched_data = inputs[slice(*indices)] + + else: + indices = start_batch_idx, end_batch_idx + batched_data = inputs[slice(*indices)] + + batched_data = torch.from_numpy(batched_data) if isinstance( + batched_data, np.ndarray) else batched_data + batches.append(batched_data) + + return batches From ec377b9fca5c11b8325c0de3bda11ec5a85c2e6c Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 27 Oct 2023 13:43:05 +0200 Subject: [PATCH 11/45] move functionality to base file in solver and separate in functions --- cebra/solver/base.py | 139 ++++++++++++++++++++++++++++++++----------- cebra/solver/util.py | 58 ------------------ 2 files changed, 105 insertions(+), 92 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index b282b27f..d38d8c88 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -38,6 +38,8 @@ import numpy as np import torch import tqdm +from torch.utils.data import DataLoader +from torch.utils.data import Dataset import cebra import cebra.data @@ -48,6 +50,102 @@ from cebra.solver.util import ProgressBar +def _inference_transform(model, inputs): + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + inputs = inputs.transpose(1, 0).unsqueeze(0) + output = model(inputs).squeeze(0).transpose(1, 0) + else: + output = model(inputs) + return output + + +def _process_batch(inputs: torch.Tensor, add_padding: bool, + offset: cebra.data.Offset, start_batch_idx: int, + end_batch_idx: int) -> torch.Tensor: + """ + Process a batch of input data, optionally applying padding based on specified parameters. + + Args: + inputs: The input data to be processed. + add_padding: Indicates whether padding should be applied before inference. + offset: Offset configuration for padding. If add_padding is True, + offset must be set. If add_padding is False, offset is not used and can be None. + start_batch_idx: The starting index of the current batch. + end_batch_idx: The last index of the current batch. + + Returns: + torch.Tensor: The (potentially) padded data. + + Raises: + ValueError: If pad_beforadd_paddinge_transform is True and offset is not provided. + """ + + if add_padding: + if offset is None: + raise ValueError("offset needs to be set if add_padding is True.") + + if start_batch_idx == 0: # First batch + indices = start_batch_idx, (end_batch_idx + offset.right - 1) + batched_data = inputs[slice(*indices)] + batched_data = np.pad(batched_data.cpu().numpy(), + ((offset.left, 0), (0, 0)), + mode="edge") + + elif end_batch_idx == len(inputs): # Last batch + indices = (start_batch_idx - offset.left), end_batch_idx + batched_data = inputs[slice(*indices)] + batched_data = np.pad(batched_data.cpu().numpy(), + ((0, offset.right - 1), (0, 0)), + mode="edge") + else: # Middle batches + indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 + batched_data = inputs[slice(*indices)] + + else: + indices = start_batch_idx, end_batch_idx + batched_data = inputs[slice(*indices)] + + batched_data = torch.from_numpy(batched_data) if isinstance( + batched_data, np.ndarray) else batched_data + return batched_data + + +def _batched_transform(model, + inputs: torch.Tensor, + batch_size: int, + pad_before_transform: bool, + offset=None) -> torch.Tensor: + + class IndexDataset(Dataset): + + def __init__(self, inputs): + self.inputs = inputs + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return idx + + index_dataset = IndexDataset(inputs) + index_dataloader = DataLoader(index_dataset, batch_size=batch_size) + + output = [] + for batch_id, index_batch in enumerate(index_dataloader): + start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + 1 + batched_data = _process_batch(inputs=inputs, + add_padding=pad_before_transform, + offset=offset, + start_batch_idx=start_batch_idx, + end_batch_idx=end_batch_idx) + output_batch = _inference_transform(model, batched_data) + output.append(output_batch) + + output = torch.cat(output) + return output + + @dataclasses.dataclass class Solver(abc.ABC, cebra.io.HasDevice): """Solver base class. @@ -286,22 +384,14 @@ def decoding(self, train_loader, valid_loader): ) return decode_metric - def _inference_transform(self, model, inputs): - - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - inputs = inputs.transpose(1, 0).unsqueeze(0) - output = model(inputs).squeeze(0).transpose(1, 0) - else: - output = model(inputs) - - return output - def _select_model(self, inputs: torch.Tensor, session_id: int): + #NOTE: In the torch API the inputs will be a torch tensor. Then in the + # sklearn API we will convert it to numpy array. """ Select the right model based on the type of solver we have.""" - self.num_sessions = self.loader.dataset.num_sessions if isinstance( - inputs, list) else None + # before: self.loader.dataset.num_sessions + self.num_sessions = len(inputs) if isinstance(inputs, list) else None + if self.num_sessions is not None: # multisession implementation if session_id is None: raise RuntimeError( @@ -339,25 +429,6 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): offset = model.get_offset() return model, offset - @torch.no_grad() - def _batched_transform(self, model, inputs, offset, batch_size, - pad_before_transform) -> torch.Tensor: - output = [] - batches = cebra_solver_util.get_batches_of_data( - inputs=inputs, - batch_size=batch_size, - padding=pad_before_transform, - offset=offset) - - # NOTE: If we move this inside the `cebra_solver_util.get_batches_of_data`or similar - # we avoid a second for loop. Is it good practice to do inference outside the solver? - for batch in batches: - output_batch = self._inference_transform(model, batch) - output.append(output_batch) - - output = torch.cat(output) - return output - @torch.no_grad() def _transform(self, model, inputs, offset, pad_before_transform) -> torch.Tensor: @@ -367,7 +438,7 @@ def _transform(self, model, inputs, offset, mode="edge") inputs = torch.from_numpy(inputs) - output = self._inference_transform(model, inputs) + output = _inference_transform(model, inputs) return output @torch.no_grad() @@ -405,7 +476,7 @@ def transform( ) if batch_size is not None: - output = self._batched_transform( + output = _batched_transform( model=model, inputs=inputs, offset=offset, diff --git a/cebra/solver/util.py b/cebra/solver/util.py index 4137dab7..af9529f7 100644 --- a/cebra/solver/util.py +++ b/cebra/solver/util.py @@ -28,10 +28,6 @@ import numpy as np import torch import tqdm -from torch.utils.data import DataLoader -from torch.utils.data import Dataset - -import cebra.data def _description(stats: Dict[str, float]): @@ -112,57 +108,3 @@ def set_description(self, stats: Dict[str, float]): """ if self.use_tqdm: self.iterator.set_description(_description(stats)) - - -def get_batches_of_data(inputs: torch.Tensor, - batch_size: int, - padding: bool, - offset: cebra.data.Offset = None): - batches = [] - - class IndexDataset(Dataset): - - def __init__(self, inputs): - self.inputs = inputs - - def __len__(self): - return len(self.inputs) - - def __getitem__(self, idx): - return idx - - index_dataset = IndexDataset(inputs) - index_dataloader = DataLoader(index_dataset, batch_size=batch_size) - for batch_id, index_batch in enumerate(index_dataloader): - - start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] - if padding: - if offset is None: - raise ValueError("offset needs to be set if padding is True.") - - if batch_id == 0: - indices = start_batch_idx, (end_batch_idx + offset.right) - batched_data = inputs[slice(*indices)] - batched_data = np.pad(batched_data.cpu().numpy(), - ((offset.left, 0), (0, 0)), - mode="edge") - - elif batch_id == len(index_dataloader) - 1: - indices = (start_batch_idx - offset.left), end_batch_idx - batched_data = inputs[slice(*indices)] - batched_data = np.pad(batched_data.cpu().numpy(), - ((0, offset.right), (0, 0)), - mode="edge") - else: # Middle batches - indices = start_batch_idx - offset.left, end_batch_idx + offset.right - batched_data = inputs[slice(*indices)] - - else: - indices = start_batch_idx, end_batch_idx - batched_data = inputs[slice(*indices)] - - batched_data = torch.from_numpy(batched_data) if isinstance( - batched_data, np.ndarray) else batched_data - batches.append(batched_data) - - return batches From 6f9ca989dacbc878bdc3a26410761ff06809830e Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 27 Oct 2023 13:43:32 +0200 Subject: [PATCH 12/45] add test_select_model for single session --- tests/test_solver.py | 67 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 5412b697..0318e04b 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -207,6 +207,68 @@ def create_model(model_name, input_dimension): multi_session_tests_transform.append( (*args, cebra.solver.MultiSessionSolver)) +single_session_tests_select_model = [] +single_session_hybrid_tests_select_model = [] +for model_name in ["offset1-model", "offset10-model"]: + for session_id in [None, 0, 5]: + for args in [ + ("demo-discrete", model_name, session_id), + ("demo-continuous", model_name, session_id), + ("demo-mixed", model_name, session_id), + ]: + single_session_tests_select_model.append( + (*args, cebra.solver.SingleSessionSolver)) + single_session_hybrid_tests_select_model.append( + (*args, cebra.solver.SingleSessionHybridSolver)) + +multi_session_tests_select_model = [] +for model_name in ["offset1-model", "offset10-model"]: + for session_id in [None, 0, 1, 4]: + for args in [("demo-continuous-multisession", model_name, session_id)]: + multi_session_tests_select_model.append( + (*args, cebra.solver.MultiSessionSolver)) + + +@pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", + single_session_tests_select_model + + single_session_hybrid_tests_select_model) +def test_select_model_single_session(data_name, model_name, session_id, + solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + offset = model.get_offset() + solver = solver_initfunc(model=model, criterion=None, optimizer=None) + + if session_id is not None and session_id > 0: + with pytest.raises(RuntimeError): + solver._select_model(dataset.neural, session_id=session_id) + else: + model_, offset_ = solver._select_model(dataset.neural, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +#@pytest.mark.parametrize( +# "data_name, model_name,session_id,solver_initfunc", +# single_session_tests_select_model + single_session_hybrid_tests_select_model) +#def test_select_model_multi_session(data_name, model_name, session_id, solver_initfunc): +# dataset = cebra.datasets.init(data_name) +# model = nn.ModuleList( +# [create_model(model_name, dataset.input_dimension) for dataset in dataset.iter_sessions()]) +# offset = model[0].get_offset() +# solver = solver_initfunc(model=model, +# criterion=None, +# optimizer=None) +# +# if session_id is not None and session_id > 0: +# with pytest.raises(RuntimeError): +# solver._select_model(dataset.neural, session_id=session_id) +# else: +# model_, offset_ = solver._select_model(dataset.neural, session_id=session_id) +# assert offset.left == offset_.left and offset.right == offset_.right +# assert model == model_ + @pytest.mark.parametrize( "data_name, model_name, padding, loader_initfunc, solver_initfunc", @@ -229,6 +291,7 @@ def test_batched_transform_singlesession(data_name, model_name, padding, solver.fit(loader) if len(model.get_offset()) < 2 and padding: + pytest.skip("not relevant for now.") with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, pad_before_transform=padding) @@ -255,7 +318,9 @@ def test_batched_transform_singlesession(data_name, model_name, padding, #TODO: what to check here exactly? pass else: - assert embedding_batched.shape == embedding.shape + #print(model) + assert embedding_batched.shape == embedding.shape, (padding, + model) assert np.allclose(embedding_batched, embedding, rtol=1e-02) From fbe7eb420d7e89b143ef5ec68abb49f845d1ab9e Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 27 Oct 2023 16:18:56 +0200 Subject: [PATCH 13/45] add checks and test for _process_batch --- cebra/solver/base.py | 36 +++++++++++++-- tests/test_solver.py | 106 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 4 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index d38d8c88..43403911 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -81,25 +81,53 @@ def _process_batch(inputs: torch.Tensor, add_padding: bool, ValueError: If pad_beforadd_paddinge_transform is True and offset is not provided. """ + def _check_indices(indices, inputs): + if (indices[0] < 0) or (indices[1] > inputs.shape[0]): + raise ValueError( + f"offset {offset} is too big for the length of the inputs ({len(inputs)}) " + f"The indices {indices} do not match the inputs length {len(inputs)}." + ) + + if start_batch_idx < 0 or end_batch_idx < 0: + raise ValueError( + f"start_batch_idx ({start_batch_idx}) and end_batch_idx ({end_batch_idx}) must be non-negative." + ) + + if start_batch_idx > end_batch_idx: + raise ValueError( + f"start_batch_idx ({start_batch_idx}) cannot be greater than end_batch_idx ({end_batch_idx})." + ) + + if end_batch_idx > len(inputs): + raise ValueError( + f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({len(inputs)})." + ) + if add_padding: if offset is None: raise ValueError("offset needs to be set if add_padding is True.") + if not isinstance(offset, cebra.data.Offset): + raise ValueError("offset must be an instance of cebra.data.Offset") + if start_batch_idx == 0: # First batch indices = start_batch_idx, (end_batch_idx + offset.right - 1) + _check_indices(indices, inputs) batched_data = inputs[slice(*indices)] - batched_data = np.pad(batched_data.cpu().numpy(), - ((offset.left, 0), (0, 0)), + batched_data = np.pad(array=batched_data.cpu().numpy(), + pad_width=((offset.left, 0), (0, 0)), mode="edge") elif end_batch_idx == len(inputs): # Last batch indices = (start_batch_idx - offset.left), end_batch_idx + _check_indices(indices, inputs) batched_data = inputs[slice(*indices)] - batched_data = np.pad(batched_data.cpu().numpy(), - ((0, offset.right - 1), (0, 0)), + batched_data = np.pad(array=batched_data.cpu().numpy(), + pad_width=((0, offset.right - 1), (0, 0)), mode="edge") else: # Middle batches indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 + _check_indices(indices, inputs) batched_data = inputs[slice(*indices)] else: diff --git a/tests/test_solver.py b/tests/test_solver.py index 0318e04b..6911d102 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -229,6 +229,112 @@ def create_model(model_name, input_dimension): (*args, cebra.solver.MultiSessionSolver)) +@pytest.mark.parametrize( + "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", + [ + # Test case 1: No padding + (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 1, + torch.tensor([[1, 2]])), # first batch + (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 2, + torch.tensor([[1, 2], [3, 4]])), # first batch + (torch.tensor([[1, 2], [3, 4]]), False, None, 1, 2, + torch.tensor([[3, 4]])), # last batch + + # Test case 2: First batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 0, + 2, + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Test case 3: Last batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 3), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [7, 8, 9], [7, 8, 9] + ]), + ), + + # Test case 4: Middle batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 1, + 2, + torch.tensor([[4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 2), + 1, + 2, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 1, + 2, + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 2), + 1, + 2, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Examples that throw an error: + + # Padding without offset (should raise an error) + (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), + # Negative start_batch_idx or end_batch_idx (should raise an error) + (torch.tensor([[1, 2]]), False, None, -1, 2, ValueError), + # out of bound indices because offset is too large + (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( + 5, 5), 1, 2, ValueError), + ], +) +def test_process_batch(inputs, add_padding, offset, start_batch_idx, + end_batch_idx, expected_output): + if expected_output == ValueError: + with pytest.raises(ValueError): + cebra.solver.base._process_batch(inputs, add_padding, offset, + start_batch_idx, end_batch_idx) + else: + result = cebra.solver.base._process_batch(inputs, add_padding, offset, + start_batch_idx, + end_batch_idx) + assert torch.equal(result, expected_output) + + @pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", single_session_tests_select_model + single_session_hybrid_tests_select_model) From 463b0f8a8890770b1d7bf23abe52a97d4ca22d72 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 30 Oct 2023 12:54:13 +0100 Subject: [PATCH 14/45] add test_select_model for multisession --- cebra/solver/base.py | 20 +++++++++------ tests/test_solver.py | 58 ++++++++++++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 43403911..b9682f47 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,7 +32,7 @@ import abc import os -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Union import literate_dataclasses as dataclasses import numpy as np @@ -78,7 +78,7 @@ def _process_batch(inputs: torch.Tensor, add_padding: bool, torch.Tensor: The (potentially) padded data. Raises: - ValueError: If pad_beforadd_paddinge_transform is True and offset is not provided. + ValueError: If add_padding is True and offset is not provided. """ def _check_indices(indices, inputs): @@ -314,6 +314,12 @@ def fit( * Refine the API here. Drop the validation entirely, and implement this via a hook? """ + self.num_sessions = loader.dataset.num_sessions if loader.dataset.num_sessions is not None else None + self.n_features = ([ + loader.dataset.get_input_dimension(session_id) + for session_id in range(loader.dataset.num_sessions) + ] if self.num_sessions is not None else loader.dataset.input_dimension) + self.to(loader.device) iterator = self._get_loader(loader) @@ -417,9 +423,6 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): # sklearn API we will convert it to numpy array. """ Select the right model based on the type of solver we have.""" - # before: self.loader.dataset.num_sessions - self.num_sessions = len(inputs) if isinstance(inputs, list) else None - if self.num_sessions is not None: # multisession implementation if session_id is None: raise RuntimeError( @@ -429,14 +432,13 @@ def _select_model(self, inputs: torch.Tensor, session_id: int): raise RuntimeError( f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." ) - if self.n_features_[session_id] != X.shape[1]: + if self.n_features[session_id] != inputs.shape[1]: raise ValueError( f"Invalid input shape: model for session {session_id} requires an input of shape" - f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})." + f"(n_samples, {self.n_features[session_id]}), got (n_samples, {inputs.shape[1]})." ) model = self.model[session_id] - model.to(self.device_) #TODO: why do I need to do this? else: # single session if session_id is not None and session_id > 0: @@ -495,6 +497,8 @@ def transform( Returns: The output embedding. """ + #TODO: add check like sklearn? + # #sklearn_utils_validation.check_is_fitted(self, "n_features_") model, offset = self._select_model(inputs, session_id) model.eval() diff --git a/tests/test_solver.py b/tests/test_solver.py index 6911d102..72376bfa 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -222,8 +222,8 @@ def create_model(model_name, input_dimension): (*args, cebra.solver.SingleSessionHybridSolver)) multi_session_tests_select_model = [] -for model_name in ["offset1-model", "offset10-model"]: - for session_id in [None, 0, 1, 4]: +for model_name in ["offset10-model"]: + for session_id in [None, 0, 1, 5, 2, 6, 4]: for args in [("demo-continuous-multisession", model_name, session_id)]: multi_session_tests_select_model.append( (*args, cebra.solver.MultiSessionSolver)) @@ -355,25 +355,41 @@ def test_select_model_single_session(data_name, model_name, session_id, assert model == model_ -#@pytest.mark.parametrize( -# "data_name, model_name,session_id,solver_initfunc", -# single_session_tests_select_model + single_session_hybrid_tests_select_model) -#def test_select_model_multi_session(data_name, model_name, session_id, solver_initfunc): -# dataset = cebra.datasets.init(data_name) -# model = nn.ModuleList( -# [create_model(model_name, dataset.input_dimension) for dataset in dataset.iter_sessions()]) -# offset = model[0].get_offset() -# solver = solver_initfunc(model=model, -# criterion=None, -# optimizer=None) -# -# if session_id is not None and session_id > 0: -# with pytest.raises(RuntimeError): -# solver._select_model(dataset.neural, session_id=session_id) -# else: -# model_, offset_ = solver._select_model(dataset.neural, session_id=session_id) -# assert offset.left == offset_.left and offset.right == offset_.right -# assert model == model_ +@pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", + multi_session_tests_select_model) +def test_select_model_multi_session(data_name, model_name, session_id, + solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + + offset = model[0].get_offset() + solver = solver_initfunc(model=model, + criterion=cebra.models.InfoNCE(), + optimizer=torch.optim.Adam(model.parameters(), + lr=1e-3)) + + loader_kwargs = dict(num_steps=10, batch_size=32) + loader = cebra.data.ContinuousMultiSessionDataLoader( + dataset, **loader_kwargs) + solver.fit(loader) + + for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): + inputs = dataset_.neural + + if session_id is None or session_id >= dataset.num_sessions: + with pytest.raises(RuntimeError): + solver._select_model(inputs, session_id=session_id) + elif i != session_id: + with pytest.raises(ValueError): + solver._select_model(inputs, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ @pytest.mark.parametrize( From 52191714431a97da3af79860dc87729eafa75e46 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 31 Oct 2023 16:07:49 +0100 Subject: [PATCH 15/45] make self.num_sessions compatible with single session training --- cebra/solver/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index b9682f47..acc98333 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -314,7 +314,8 @@ def fit( * Refine the API here. Drop the validation entirely, and implement this via a hook? """ - self.num_sessions = loader.dataset.num_sessions if loader.dataset.num_sessions is not None else None + self.num_sessions = loader.dataset.num_sessions if hasattr( + loader.dataset, "num_sessions") else None self.n_features = ([ loader.dataset.get_input_dimension(session_id) for session_id in range(loader.dataset.num_sessions) From f9bd1a6660b494f1c14a93f391235c72ddcabaa6 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 1 Nov 2023 12:11:22 +0100 Subject: [PATCH 16/45] improve test_batched_transform_singlesession --- tests/test_solver.py | 86 ++++++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 72376bfa..0bdf2cbf 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -178,35 +178,6 @@ def create_model(model_name, input_dimension): num_output=5) -single_session_tests_transform = [] -for padding in [True, False]: - for model_name in ["offset1-model", "offset10-model"]: - for args in [ - ("demo-discrete", model_name, padding, - cebra.data.DiscreteDataLoader), - ("demo-continuous", model_name, padding, - cebra.data.ContinuousDataLoader), - ("demo-mixed", model_name, padding, cebra.data.MixedDataLoader), - ]: - single_session_tests_transform.append( - (*args, cebra.solver.SingleSessionSolver)) - -single_session_hybrid_tests_transform = [] -for padding in [True, False]: - for model_name in ["offset1-model", "offset10-model"]: - for args in [("demo-continuous", model_name, padding, - cebra.data.HybridDataLoader)]: - single_session_hybrid_tests_transform.append( - (*args, cebra.solver.SingleSessionHybridSolver)) - -multi_session_tests_transform = [] -for padding in [True, False]: - for model_name in ["offset1-model", "offset10-model"]: - for args in [("demo-continuous-multisession", model_name, padding, - cebra.data.ContinuousMultiSessionDataLoader)]: - multi_session_tests_transform.append( - (*args, cebra.solver.MultiSessionSolver)) - single_session_tests_select_model = [] single_session_hybrid_tests_select_model = [] for model_name in ["offset1-model", "offset10-model"]: @@ -392,12 +363,59 @@ def test_select_model_multi_session(data_name, model_name, session_id, assert model == model_ +#this is a very crucial test. should be checked for different choices of offsets, +# dataset sizes (also edge cases like dataset size 1001 and batch size 1000 -> is the padding properly handled?) +#try to isolate this from the remaining tests, and make it really rigorous with a lot of test cases. + +models = [ + "offset1-model", "offset10-model" +] # there is an issue with subsampe models e.g. "offset4-model-2x-subsample" +batch_size_inference = [99_999] #1, 1000 + +single_session_tests_transform = [] +for padding in [True, False]: + for model_name in models: + for batch_size in batch_size_inference: + for args in [ + ("demo-discrete", model_name, padding, batch_size, + cebra.data.DiscreteDataLoader), + ("demo-continuous", model_name, padding, batch_size, + cebra.data.ContinuousDataLoader), + ("demo-mixed", model_name, padding, batch_size, + cebra.data.MixedDataLoader), + ]: + single_session_tests_transform.append( + (*args, cebra.solver.SingleSessionSolver)) + +single_session_hybrid_tests_transform = [] +for padding in [True, False]: + for model_name in models: + for batch_size in batch_size_inference: + for args in [("demo-continuous", model_name, padding, batch_size, + cebra.data.HybridDataLoader)]: + single_session_hybrid_tests_transform.append( + (*args, cebra.solver.SingleSessionHybridSolver)) + +#multi_session_tests_transform = [] +#for padding in [True, False]: +# for model_name in ["offset1-model", "offset5-model", "offset10-model"]: +# for args in [("demo-continuous-multisession", model_name, padding, +# cebra.data.ContinuousMultiSessionDataLoader)]: +# multi_session_tests_transform.append( +# (*args, cebra.solver.MultiSessionSolver)) + + @pytest.mark.parametrize( - "data_name, model_name, padding, loader_initfunc, solver_initfunc", + "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", single_session_tests_transform + single_session_hybrid_tests_transform) -def test_batched_transform_singlesession(data_name, model_name, padding, - loader_initfunc, solver_initfunc): - batch_size = 1024 +def test_batched_transform_singlesession( + data_name, + model_name, + padding, + batch_size_inference, + loader_initfunc, + solver_initfunc, +): dataset = cebra.datasets.init(data_name) model = create_model(model_name, dataset.input_dimension) dataset.offset = model.get_offset() @@ -420,7 +438,7 @@ def test_batched_transform_singlesession(data_name, model_name, padding, with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size, + batch_size=batch_size_inference, pad_before_transform=padding) else: embedding_batched = solver.transform(inputs=loader.dataset.neural, From e23a7ef3d936b4c7e3530b46bbc3679d2b710e00 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 7 Nov 2023 18:14:55 +0100 Subject: [PATCH 17/45] make it work with small batches --- cebra/solver/base.py | 27 ++++++-- tests/test_solver.py | 151 ++++++++++++++++++++++++++----------------- 2 files changed, 112 insertions(+), 66 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index acc98333..1026dfe2 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -103,6 +103,17 @@ def _check_indices(indices, inputs): f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({len(inputs)})." ) + def _check_batch_size_length(indices_batch, offset): + batch_size_lenght = indices_batch[1] - indices_batch[0] + print("batch_size ll", add_padding, indices, batch_size_lenght, + len(offset)) + if batch_size_lenght <= len(offset): + raise ValueError( + f"The batch has length {batch_size_lenght} which " + f"is smaller or equal than the required offset length {len(offset)}." + f"Either choose a model with smaller offset or the batch shoud contain more samples." + ) + if add_padding: if offset is None: raise ValueError("offset needs to be set if add_padding is True.") @@ -112,7 +123,8 @@ def _check_indices(indices, inputs): if start_batch_idx == 0: # First batch indices = start_batch_idx, (end_batch_idx + offset.right - 1) - _check_indices(indices, inputs) + #_check_indices(indices, inputs) + _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] batched_data = np.pad(array=batched_data.cpu().numpy(), pad_width=((offset.left, 0), (0, 0)), @@ -120,18 +132,21 @@ def _check_indices(indices, inputs): elif end_batch_idx == len(inputs): # Last batch indices = (start_batch_idx - offset.left), end_batch_idx - _check_indices(indices, inputs) + #_check_indices(indices, inputs) + _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] batched_data = np.pad(array=batched_data.cpu().numpy(), pad_width=((0, offset.right - 1), (0, 0)), mode="edge") else: # Middle batches indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 - _check_indices(indices, inputs) + #_check_indices(indices, inputs) + _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] else: indices = start_batch_idx, end_batch_idx + _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] batched_data = torch.from_numpy(batched_data) if isinstance( @@ -139,11 +154,9 @@ def _check_indices(indices, inputs): return batched_data -def _batched_transform(model, - inputs: torch.Tensor, - batch_size: int, +def _batched_transform(model, inputs: torch.Tensor, batch_size: int, pad_before_transform: bool, - offset=None) -> torch.Tensor: + offset: cebra.data.Offset) -> torch.Tensor: class IndexDataset(Dataset): diff --git a/tests/test_solver.py b/tests/test_solver.py index 0bdf2cbf..12794477 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -368,9 +368,11 @@ def test_select_model_multi_session(data_name, model_name, session_id, #try to isolate this from the remaining tests, and make it really rigorous with a lot of test cases. models = [ - "offset1-model", "offset10-model" + "offset1-model", + "offset10-model", + #"offset1-model", "offset10-model", ] # there is an issue with subsampe models e.g. "offset4-model-2x-subsample" -batch_size_inference = [99_999] #1, 1000 +batch_size_inference = [23432, 99_999] #1, 1000 single_session_tests_transform = [] for padding in [True, False]: @@ -396,17 +398,9 @@ def test_select_model_multi_session(data_name, model_name, session_id, single_session_hybrid_tests_transform.append( (*args, cebra.solver.SingleSessionHybridSolver)) -#multi_session_tests_transform = [] -#for padding in [True, False]: -# for model_name in ["offset1-model", "offset5-model", "offset10-model"]: -# for args in [("demo-continuous-multisession", model_name, padding, -# cebra.data.ContinuousMultiSessionDataLoader)]: -# multi_session_tests_transform.append( -# (*args, cebra.solver.MultiSessionSolver)) - @pytest.mark.parametrize( - "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", + "data_name,model_name,padding,batch_size_inference,loader_initfunc,solver_initfunc", single_session_tests_transform + single_session_hybrid_tests_transform) def test_batched_transform_singlesession( data_name, @@ -430,7 +424,12 @@ def test_batched_transform_singlesession( optimizer=optimizer) solver.fit(loader) - if len(model.get_offset()) < 2 and padding: + smallest_batch_length = loader.dataset.neural.shape[0] - batch_size + offset_ = model.get_offset() + #print("here!", smallest_batch_length, len(offset_)) + padding_left = offset_.left if padding else 0 + + if len(offset_) < 2 and padding: pytest.skip("not relevant for now.") with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, @@ -438,8 +437,21 @@ def test_batched_transform_singlesession( with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size_inference, + batch_size=batch_size, + pad_before_transform=padding) + + # NOTE: We need to add padding_left because if padding is True, + # the batch size is not "smallest_batch_length". and the smallest + # batch will always be at the end so the last batch we need to add + # offset.left. + #TODO: this wont work in the case where the data is less than + #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 + elif smallest_batch_length + padding_left <= len(offset_): + with pytest.raises(ValueError): + solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size, pad_before_transform=padding) + else: embedding_batched = solver.transform(inputs=loader.dataset.neural, batch_size=batch_size, @@ -464,49 +476,70 @@ def test_batched_transform_singlesession( assert np.allclose(embedding_batched, embedding, rtol=1e-02) -# def test_batched_transform_multisession(data_name, model_name, padding, loader_initfunc, solver_initfunc): -# batch_size = 1024 -# dataset = cebra.datasets.init(data_name) -# model = nn.ModuleList( -# [create_model(model_name, dataset.input_dimension) for dataset in dataset.iter_sessions()]) -# dataset.offset = model[0].get_offset() -# loader_kwargs = dict(num_steps=10, batch_size=32) -# loader = loader_initfunc(dataset, **loader_kwargs) - -# criterion = cebra.models.InfoNCE() -# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# solver = solver_initfunc(model=model, -# criterion=criterion, -# optimizer=optimizer) -# solver.fit(loader) - -# if len(model.get_offset()) < 2 and padding: -# with pytest.raises(ValueError): -# solver.transform(inputs=loader.dataset.neural, -# pad_before_transform=padding) - -# with pytest.raises(ValueError): -# solver.transform(inputs=loader.dataset.neural, -# batch_size=batch_size, -# pad_before_transform=padding) -# else: -# embedding_batched = solver.transform(inputs=loader.dataset.neural, -# batch_size=batch_size, -# pad_before_transform=padding) - -# embedding = solver.transform(inputs=loader.dataset.neural, -# pad_before_transform=padding) - -# if padding: -# if isinstance(model, cebra.models.ConvolutionalModelMixin): -# assert embedding_batched.shape == embedding.shape -# assert embedding_batched.shape == embedding.shape - -# else: -# if isinstance(model, cebra.models.ConvolutionalModelMixin): -# #TODO: what to check here exactly? -# pass -# else: -# assert embedding_batched.shape == embedding.shape -# assert np.allclose(embedding_batched, embedding, rtol=1e-02) +multi_session_tests_transform = [] +for padding in [True, False]: + for model_name in models: + for batch_size in batch_size_inference: + for args in [ + ("demo-continuous-multisession", model_name, padding, + batch_size, cebra.data.ContinuousMultiSessionDataLoader) + ]: + multi_session_tests_transform.append( + (*args, cebra.solver.MultiSessionSolver)) + + +@pytest.mark.parametrize( + "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", + multi_session_tests_transform) +def test_batched_transform_multisession(data_name, model_name, padding, + batch_size_inference, loader_initfunc, + solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.offset = model[0].get_offset() + loader_kwargs = dict(num_steps=10, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + solver.fit(loader) + + #if len(model[0].get_offset()) < 2 and padding: + # with pytest.raises(ValueError): + # solver.transform(inputs=loader.dataset.neural, + # pad_before_transform=padding) + + +# +# with pytest.raises(ValueError): +# solver.transform(inputs=loader.dataset.neural, +# batch_size=batch_size, +# pad_before_transform=padding) +#else: +# embedding_batched = solver.transform(inputs=loader.dataset.neural, +# batch_size=batch_size, +# pad_before_transform=padding) +# +# embedding = solver.transform(inputs=loader.dataset.neural, +# pad_before_transform=padding) +# +# if padding: +# if isinstance(model, cebra.models.ConvolutionalModelMixin): +# assert embedding_batched.shape == embedding.shape +# assert embedding_batched.shape == embedding.shape +# +# else: +# if isinstance(model, cebra.models.ConvolutionalModelMixin): +# #TODO: what to check here exactly? +# pass +# else: +# assert embedding_batched.shape == embedding.shape +# assert np.allclose(embedding_batched, embedding, rtol=1e-02) +# From 19c3f8709edb738f50ebcefd1026df75d7dbed29 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 8 Nov 2023 13:33:20 +0100 Subject: [PATCH 18/45] make test with multisession work --- tests/test_solver.py | 91 ++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 12794477..7c433bdc 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -370,9 +370,10 @@ def test_select_model_multi_session(data_name, model_name, session_id, models = [ "offset1-model", "offset10-model", + "offset40-model-4x-subsample", #"offset1-model", "offset10-model", -] # there is an issue with subsampe models e.g. "offset4-model-2x-subsample" -batch_size_inference = [23432, 99_999] #1, 1000 +] # there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model. +batch_size_inference = [23432, 99_999] # 99_999 single_session_tests_transform = [] for padding in [True, False]: @@ -500,6 +501,19 @@ def test_batched_transform_multisession(data_name, model_name, padding, for dataset in dataset.iter_sessions() ]) dataset.offset = model[0].get_offset() + + n_samples = dataset._datasets[0].neural.shape[0] + assert all( + d.neural.shape[0] == n_samples for d in dataset._datasets + ), "for this set all of the sessions need ot have same number of samples." + + smallest_batch_length = n_samples - batch_size + offset_ = model[0].get_offset() + #print("here!", smallest_batch_length, len(offset_)) + padding_left = offset_.left if padding else 0 + for d in dataset._datasets: + d.offset = offset_ + #dataset._datasets[0].offset = cebra.data.Offset(0, 1) loader_kwargs = dict(num_steps=10, batch_size=32) loader = loader_initfunc(dataset, **loader_kwargs) @@ -511,35 +525,48 @@ def test_batched_transform_multisession(data_name, model_name, padding, optimizer=optimizer) solver.fit(loader) - #if len(model[0].get_offset()) < 2 and padding: - # with pytest.raises(ValueError): - # solver.transform(inputs=loader.dataset.neural, - # pad_before_transform=padding) + # Transform each session with the right model, by providing the corresponding session ID + for i, inputs in enumerate(dataset.iter_sessions()): + if len(offset_) < 2 and padding: + with pytest.raises(ValueError): + embedding = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding) -# -# with pytest.raises(ValueError): -# solver.transform(inputs=loader.dataset.neural, -# batch_size=batch_size, -# pad_before_transform=padding) -#else: -# embedding_batched = solver.transform(inputs=loader.dataset.neural, -# batch_size=batch_size, -# pad_before_transform=padding) -# -# embedding = solver.transform(inputs=loader.dataset.neural, -# pad_before_transform=padding) -# -# if padding: -# if isinstance(model, cebra.models.ConvolutionalModelMixin): -# assert embedding_batched.shape == embedding.shape -# assert embedding_batched.shape == embedding.shape -# -# else: -# if isinstance(model, cebra.models.ConvolutionalModelMixin): -# #TODO: what to check here exactly? -# pass -# else: -# assert embedding_batched.shape == embedding.shape -# assert np.allclose(embedding_batched, embedding, rtol=1e-02) -# + with pytest.raises(ValueError): + embedding_batched = solver.transform( + inputs=inputs.neural, + session_id=i, + pad_before_transform=padding, + batch_size=batch_size) + + elif smallest_batch_length + padding_left <= len(offset_): + with pytest.raises(ValueError): + solver.transform(inputs=inputs.neural, + batch_size=batch_size, + session_id=i, + pad_before_transform=padding) + + else: + model_ = model[i] + embedding = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding) + embedding_batched = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding, + batch_size=batch_size) + + if padding: + if isinstance(model_, cebra.models.ConvolutionalModelMixin): + assert embedding_batched.shape == embedding.shape + assert embedding_batched.shape == embedding.shape + + else: + if isinstance(model_, cebra.models.ConvolutionalModelMixin): + #TODO: what to check here exactly? + pass + else: + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-02) From 87bebac38dca71387e819f749611954430480943 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Nov 2023 12:21:31 +0100 Subject: [PATCH 19/45] change to torch padding --- cebra/solver/base.py | 47 +++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 1026dfe2..25b4ecb6 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -37,6 +37,7 @@ import literate_dataclasses as dataclasses import numpy as np import torch +import torch.nn.functional as F import tqdm from torch.utils.data import DataLoader from torch.utils.data import Dataset @@ -51,6 +52,10 @@ def _inference_transform(model, inputs): + + #TODO: I am not sure what is the best way with dealing with the types and + # device when using batched inference. This works for now. + inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device) if isinstance(model, cebra.models.ConvolutionalModelMixin): # Fully convolutional evaluation, switch (T, C) -> (1, C, T) inputs = inputs.transpose(1, 0).unsqueeze(0) @@ -126,18 +131,24 @@ def _check_batch_size_length(indices_batch, offset): #_check_indices(indices, inputs) _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] - batched_data = np.pad(array=batched_data.cpu().numpy(), - pad_width=((offset.left, 0), (0, 0)), - mode="edge") + batched_data = F.pad(batched_data.T, (offset.left, 0), + 'replicate').T + + #batched_data = np.pad(array=batched_data.cpu().numpy(), + # pad_width=((offset.left, 0), (0, 0)), + # mode="edge") elif end_batch_idx == len(inputs): # Last batch indices = (start_batch_idx - offset.left), end_batch_idx #_check_indices(indices, inputs) _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] - batched_data = np.pad(array=batched_data.cpu().numpy(), - pad_width=((0, offset.right - 1), (0, 0)), - mode="edge") + batched_data = F.pad(batched_data.T, (0, offset.right - 1), + 'replicate').T + + #batched_data = np.pad(array=batched_data.cpu().numpy(), + # pad_width=((0, offset.right - 1), (0, 0)), + # mode="edge") else: # Middle batches indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 #_check_indices(indices, inputs) @@ -149,8 +160,8 @@ def _check_batch_size_length(indices_batch, offset): _check_batch_size_length(indices, offset) batched_data = inputs[slice(*indices)] - batched_data = torch.from_numpy(batched_data) if isinstance( - batched_data, np.ndarray) else batched_data + #batched_data = torch.from_numpy(batched_data) if isinstance( + # batched_data, np.ndarray) else batched_data return batched_data @@ -486,12 +497,11 @@ def _transform(self, model, inputs, offset, return output @torch.no_grad() - def transform( - self, - inputs: torch.Tensor, - pad_before_transform: bool = True, #TODO: what should be the default? - session_id: Optional[int] = None, - batch_size: Optional[int] = None) -> torch.Tensor: + def transform(self, + inputs: torch.Tensor, + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = None) -> torch.Tensor: """Compute the embedding. This function by default only applies the ``forward`` function @@ -500,13 +510,14 @@ def transform( Args: inputs: The input signal pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would - only be ``n-m+1`` steps long. + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would + only be ``n-m+1`` steps long. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. + batch_size: If not None, batched inference will be applied. Returns: The output embedding. From f0303e01881c78195c709052f6359bf2575e2109 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Nov 2023 12:21:39 +0100 Subject: [PATCH 20/45] add argument to sklearn api --- cebra/integrations/sklearn/cebra.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 2c9eba2b..d9294706 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1201,16 +1201,18 @@ def fit( def transform(self, X: Union[npt.NDArray, torch.Tensor], pad_before_transform: bool = True, + batch_size: Optional[int] = None, session_id: Optional[int] = None) -> npt.NDArray: """Transform an input sequence and return the embedding. Args: X: A numpy array or torch tensor of size ``time x dimension``. pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would - only be ``n-m+1`` steps long. + and the output sequence will be smaller than the input sequence due to the + receptive field of the model. If the input sequence is ``n`` steps long, + and a model with receptive field ``m`` is used, the output sequence would + only be ``n-m+1`` steps long. + batch_size: session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for multisession, set to ``None`` for single session. @@ -1233,10 +1235,15 @@ def transform(self, # Input validation X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) input_dtype = X.dtype + #print(type(X)) + #print(X.dtype) with torch.no_grad(): output = self.solver_.transform( - X, pad_before_transform=pad_before_transform) + inputs=X, + pad_before_transform=pad_before_transform, + session_id=session_id, + batch_size=batch_size) if input_dtype == "float64": return output.astype(input_dtype) From 8c8be85d00073b98b9a674161c16e7a6c4b8ca75 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Nov 2023 12:43:08 +0100 Subject: [PATCH 21/45] add torch padding to _transform --- cebra/solver/base.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 25b4ecb6..28dd7832 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -56,6 +56,7 @@ def _inference_transform(model, inputs): #TODO: I am not sure what is the best way with dealing with the types and # device when using batched inference. This works for now. inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device) + if isinstance(model, cebra.models.ConvolutionalModelMixin): # Fully convolutional evaluation, switch (T, C) -> (1, C, T) inputs = inputs.transpose(1, 0).unsqueeze(0) @@ -110,8 +111,6 @@ def _check_indices(indices, inputs): def _check_batch_size_length(indices_batch, offset): batch_size_lenght = indices_batch[1] - indices_batch[0] - print("batch_size ll", add_padding, indices, batch_size_lenght, - len(offset)) if batch_size_lenght <= len(offset): raise ValueError( f"The batch has length {batch_size_lenght} which " @@ -489,10 +488,8 @@ def _transform(self, model, inputs, offset, pad_before_transform) -> torch.Tensor: if pad_before_transform: - inputs = np.pad(inputs, ((offset.left, offset.right - 1), (0, 0)), - mode="edge") - inputs = torch.from_numpy(inputs) - + inputs = F.pad(inputs.T, (offset.left, offset.right - 1), + 'replicate').T output = _inference_transform(model, inputs) return output From 59df4026b1b8598f7e5978881f8a9d2f115869fe Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Nov 2023 12:52:17 +0100 Subject: [PATCH 22/45] convert to torch if numpy array as inputs --- cebra/integrations/sklearn/cebra.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index d9294706..1121ee98 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1233,10 +1233,13 @@ def transform(self, sklearn_utils_validation.check_is_fitted(self, "n_features_") # Input validation + #TODO: if inputs are in cuda, then it throws an error, deal with this. X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) input_dtype = X.dtype - #print(type(X)) - #print(X.dtype) + + if isinstance(X, np.ndarray): + X = torch.from_numpy(X) + # TODO: which type and device should be put there? with torch.no_grad(): output = self.solver_.transform( From 1aadc8b39d2f309cead0f04582ce47adb902e2b5 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 15 Nov 2023 18:04:04 +0100 Subject: [PATCH 23/45] add distinction between pad with data and pad with zeros and modify test accordingly --- cebra/solver/base.py | 73 ++++++++++++++++---------------------------- tests/test_solver.py | 45 ++++++++------------------- 2 files changed, 38 insertions(+), 80 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 28dd7832..5282e00c 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -66,11 +66,10 @@ def _inference_transform(model, inputs): return output -def _process_batch(inputs: torch.Tensor, add_padding: bool, - offset: cebra.data.Offset, start_batch_idx: int, - end_batch_idx: int) -> torch.Tensor: +def _pad_with_data(inputs: torch.Tensor, offset: cebra.data.Offset, + start_batch_idx: int, end_batch_idx: int) -> torch.Tensor: """ - Process a batch of input data, optionally applying padding based on specified parameters. + Pads a batch of input data with its own data (maybe this is not called padding) Args: inputs: The input data to be processed. @@ -118,49 +117,18 @@ def _check_batch_size_length(indices_batch, offset): f"Either choose a model with smaller offset or the batch shoud contain more samples." ) - if add_padding: - if offset is None: - raise ValueError("offset needs to be set if add_padding is True.") - - if not isinstance(offset, cebra.data.Offset): - raise ValueError("offset must be an instance of cebra.data.Offset") - - if start_batch_idx == 0: # First batch - indices = start_batch_idx, (end_batch_idx + offset.right - 1) - #_check_indices(indices, inputs) - _check_batch_size_length(indices, offset) - batched_data = inputs[slice(*indices)] - batched_data = F.pad(batched_data.T, (offset.left, 0), - 'replicate').T - - #batched_data = np.pad(array=batched_data.cpu().numpy(), - # pad_width=((offset.left, 0), (0, 0)), - # mode="edge") - - elif end_batch_idx == len(inputs): # Last batch - indices = (start_batch_idx - offset.left), end_batch_idx - #_check_indices(indices, inputs) - _check_batch_size_length(indices, offset) - batched_data = inputs[slice(*indices)] - batched_data = F.pad(batched_data.T, (0, offset.right - 1), - 'replicate').T - - #batched_data = np.pad(array=batched_data.cpu().numpy(), - # pad_width=((0, offset.right - 1), (0, 0)), - # mode="edge") - else: # Middle batches - indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 - #_check_indices(indices, inputs) - _check_batch_size_length(indices, offset) - batched_data = inputs[slice(*indices)] + if start_batch_idx == 0: # First batch + indices = start_batch_idx, (end_batch_idx + offset.right - 1) - else: - indices = start_batch_idx, end_batch_idx - _check_batch_size_length(indices, offset) - batched_data = inputs[slice(*indices)] + elif end_batch_idx == len(inputs): # Last batch + indices = (start_batch_idx - offset.left), end_batch_idx + + else: # Middle batches + indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 - #batched_data = torch.from_numpy(batched_data) if isinstance( - # batched_data, np.ndarray) else batched_data + #_check_batch_size_length(indices, offset) + #TODO: modify this check_batch_size to pass test. + batched_data = inputs[slice(*indices)] return batched_data @@ -185,11 +153,22 @@ def __getitem__(self, idx): output = [] for batch_id, index_batch in enumerate(index_dataloader): start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + 1 - batched_data = _process_batch(inputs=inputs, - add_padding=pad_before_transform, + + # This applies to all batches. + batched_data = _pad_with_data(inputs=inputs, offset=offset, start_batch_idx=start_batch_idx, end_batch_idx=end_batch_idx) + + if pad_before_transform: + if start_batch_idx == 0: # First batch + batched_data = F.pad(batched_data.T, (offset.left, 0), + 'replicate').T + + elif end_batch_idx == len(inputs): # Last batch + batched_data = F.pad(batched_data.T, (0, offset.right - 1), + 'replicate').T + output_batch = _inference_transform(model, batched_data) output.append(output_batch) diff --git a/tests/test_solver.py b/tests/test_solver.py index 7c433bdc..335166d0 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -373,7 +373,7 @@ def test_select_model_multi_session(data_name, model_name, session_id, "offset40-model-4x-subsample", #"offset1-model", "offset10-model", ] # there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model. -batch_size_inference = [23432, 99_999] # 99_999 +batch_size_inference = [23432] # 99_999 single_session_tests_transform = [] for padding in [True, False]: @@ -427,7 +427,6 @@ def test_batched_transform_singlesession( smallest_batch_length = loader.dataset.neural.shape[0] - batch_size offset_ = model.get_offset() - #print("here!", smallest_batch_length, len(offset_)) padding_left = offset_.left if padding else 0 if len(offset_) < 2 and padding: @@ -447,11 +446,13 @@ def test_batched_transform_singlesession( # offset.left. #TODO: this wont work in the case where the data is less than #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 - elif smallest_batch_length + padding_left <= len(offset_): - with pytest.raises(ValueError): - solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size, - pad_before_transform=padding) + + #elif smallest_batch_length + padding_left <= len(offset_): + # print('here') + # with pytest.raises(ValueError): + # solver.transform(inputs=loader.dataset.neural, + # batch_size=batch_size, + # pad_before_transform=padding) else: embedding_batched = solver.transform(inputs=loader.dataset.neural, @@ -461,20 +462,8 @@ def test_batched_transform_singlesession( embedding = solver.transform(inputs=loader.dataset.neural, pad_before_transform=padding) - if padding: - if isinstance(model, cebra.models.ConvolutionalModelMixin): - assert embedding_batched.shape == embedding.shape - assert embedding_batched.shape == embedding.shape - - else: - if isinstance(model, cebra.models.ConvolutionalModelMixin): - #TODO: what to check here exactly? - pass - else: - #print(model) - assert embedding_batched.shape == embedding.shape, (padding, - model) - assert np.allclose(embedding_batched, embedding, rtol=1e-02) + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-02) multi_session_tests_transform = [] @@ -558,15 +547,5 @@ def test_batched_transform_multisession(data_name, model_name, padding, pad_before_transform=padding, batch_size=batch_size) - if padding: - if isinstance(model_, cebra.models.ConvolutionalModelMixin): - assert embedding_batched.shape == embedding.shape - assert embedding_batched.shape == embedding.shape - - else: - if isinstance(model_, cebra.models.ConvolutionalModelMixin): - #TODO: what to check here exactly? - pass - else: - assert embedding_batched.shape == embedding.shape - assert np.allclose(embedding_batched, embedding, rtol=1e-02) + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-02) From bc8ee250b2643f9c44d98fd434872c121515a080 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 17 Nov 2023 15:59:52 +0100 Subject: [PATCH 24/45] differentiate between data padding and zero padding --- cebra/solver/base.py | 98 +++++------ tests/test_solver.py | 384 +++++++++++++++++++++---------------------- 2 files changed, 229 insertions(+), 253 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 5282e00c..2cecab08 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -66,56 +66,32 @@ def _inference_transform(model, inputs): return output -def _pad_with_data(inputs: torch.Tensor, offset: cebra.data.Offset, - start_batch_idx: int, end_batch_idx: int) -> torch.Tensor: - """ - Pads a batch of input data with its own data (maybe this is not called padding) - - Args: - inputs: The input data to be processed. - add_padding: Indicates whether padding should be applied before inference. - offset: Offset configuration for padding. If add_padding is True, - offset must be set. If add_padding is False, offset is not used and can be None. - start_batch_idx: The starting index of the current batch. - end_batch_idx: The last index of the current batch. - - Returns: - torch.Tensor: The (potentially) padded data. - - Raises: - ValueError: If add_padding is True and offset is not provided. - """ - - def _check_indices(indices, inputs): - if (indices[0] < 0) or (indices[1] > inputs.shape[0]): - raise ValueError( - f"offset {offset} is too big for the length of the inputs ({len(inputs)}) " - f"The indices {indices} do not match the inputs length {len(inputs)}." - ) +def _check_indices(start_batch_idx, end_batch_idx, offset, num_samples): if start_batch_idx < 0 or end_batch_idx < 0: raise ValueError( f"start_batch_idx ({start_batch_idx}) and end_batch_idx ({end_batch_idx}) must be non-negative." ) - if start_batch_idx > end_batch_idx: raise ValueError( f"start_batch_idx ({start_batch_idx}) cannot be greater than end_batch_idx ({end_batch_idx})." ) + if end_batch_idx > num_samples: + raise ValueError( + f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({num_samples})." + ) - if end_batch_idx > len(inputs): + batch_size_lenght = end_batch_idx - start_batch_idx + if batch_size_lenght <= len(offset): raise ValueError( - f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({len(inputs)})." + f"The batch has length {batch_size_lenght} which " + f"is smaller or equal than the required offset length {len(offset)}." + f"Either choose a model with smaller offset or the batch shoud contain more samples." ) - def _check_batch_size_length(indices_batch, offset): - batch_size_lenght = indices_batch[1] - indices_batch[0] - if batch_size_lenght <= len(offset): - raise ValueError( - f"The batch has length {batch_size_lenght} which " - f"is smaller or equal than the required offset length {len(offset)}." - f"Either choose a model with smaller offset or the batch shoud contain more samples." - ) + +def _get_batch(inputs: torch.Tensor, offset: cebra.data.Offset, + start_batch_idx: int, end_batch_idx: int) -> torch.Tensor: if start_batch_idx == 0: # First batch indices = start_batch_idx, (end_batch_idx + offset.right - 1) @@ -126,12 +102,25 @@ def _check_batch_size_length(indices_batch, offset): else: # Middle batches indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 - #_check_batch_size_length(indices, offset) - #TODO: modify this check_batch_size to pass test. + _check_indices(indices[0], indices[1], offset, len(inputs)) batched_data = inputs[slice(*indices)] return batched_data +def _add_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset, + start_batch_idx: int, end_batch_idx: int, + number_of_samples: int): + + if start_batch_idx == 0: # First batch + batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T + + elif end_batch_idx == number_of_samples: # Last batch + batched_data = F.pad(batched_data.T, (0, offset.right - 1), + 'replicate').T + + return batched_data + + def _batched_transform(model, inputs: torch.Tensor, batch_size: int, pad_before_transform: bool, offset: cebra.data.Offset) -> torch.Tensor: @@ -153,21 +142,17 @@ def __getitem__(self, idx): output = [] for batch_id, index_batch in enumerate(index_dataloader): start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + 1 - - # This applies to all batches. - batched_data = _pad_with_data(inputs=inputs, - offset=offset, - start_batch_idx=start_batch_idx, - end_batch_idx=end_batch_idx) + batched_data = _get_batch(inputs=inputs, + offset=offset, + start_batch_idx=start_batch_idx, + end_batch_idx=end_batch_idx) if pad_before_transform: - if start_batch_idx == 0: # First batch - batched_data = F.pad(batched_data.T, (offset.left, 0), - 'replicate').T - - elif end_batch_idx == len(inputs): # Last batch - batched_data = F.pad(batched_data.T, (0, offset.right - 1), - 'replicate').T + batched_data = _add_zero_padding(batched_data=batched_data, + offset=offset, + start_batch_idx=start_batch_idx, + end_batch_idx=end_batch_idx, + number_of_samples=len(inputs)) output_batch = _inference_transform(model, batched_data) output.append(output_batch) @@ -503,10 +488,11 @@ def transform(self, model, offset = self._select_model(inputs, session_id) model.eval() - if len(offset) < 2 and pad_before_transform: - raise ValueError( - "Padding does not make sense when the offset of the model is < 2" - ) + #TODO: should we add this error? + #if len(offset) < 2 and pad_before_transform: + # raise ValueError( + # "Padding does not make sense when the offset of the model is < 2" + # ) if batch_size is not None: output = _batched_transform( diff --git a/tests/test_solver.py b/tests/test_solver.py index 335166d0..1661003a 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -199,169 +199,165 @@ def create_model(model_name, input_dimension): multi_session_tests_select_model.append( (*args, cebra.solver.MultiSessionSolver)) - -@pytest.mark.parametrize( - "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", - [ - # Test case 1: No padding - (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 1, - torch.tensor([[1, 2]])), # first batch - (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 2, - torch.tensor([[1, 2], [3, 4]])), # first batch - (torch.tensor([[1, 2], [3, 4]]), False, None, 1, 2, - torch.tensor([[3, 4]])), # last batch - - # Test case 2: First batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 1), - 0, - 2, - torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 1), - 0, - 3, - torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), - ), - - # Test case 3: Last batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(0, 1), - 1, - 3, - torch.tensor([[4, 5, 6], [7, 8, 9]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 3), - 1, - 3, - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [7, 8, 9], [7, 8, 9] - ]), - ), - - # Test case 4: Middle batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(0, 1), - 1, - 2, - torch.tensor([[4, 5, 6]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(0, 2), - 1, - 2, - torch.tensor([[4, 5, 6], [7, 8, 9]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 1), - 1, - 2, - torch.tensor([[1, 2, 3], [4, 5, 6]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 2), - 1, - 2, - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - ), - - # Examples that throw an error: - - # Padding without offset (should raise an error) - (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), - # Negative start_batch_idx or end_batch_idx (should raise an error) - (torch.tensor([[1, 2]]), False, None, -1, 2, ValueError), - # out of bound indices because offset is too large - (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( - 5, 5), 1, 2, ValueError), - ], -) -def test_process_batch(inputs, add_padding, offset, start_batch_idx, - end_batch_idx, expected_output): - if expected_output == ValueError: - with pytest.raises(ValueError): - cebra.solver.base._process_batch(inputs, add_padding, offset, - start_batch_idx, end_batch_idx) - else: - result = cebra.solver.base._process_batch(inputs, add_padding, offset, - start_batch_idx, - end_batch_idx) - assert torch.equal(result, expected_output) - - -@pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", - single_session_tests_select_model + - single_session_hybrid_tests_select_model) -def test_select_model_single_session(data_name, model_name, session_id, - solver_initfunc): - dataset = cebra.datasets.init(data_name) - model = create_model(model_name, dataset.input_dimension) - offset = model.get_offset() - solver = solver_initfunc(model=model, criterion=None, optimizer=None) - - if session_id is not None and session_id > 0: - with pytest.raises(RuntimeError): - solver._select_model(dataset.neural, session_id=session_id) - else: - model_, offset_ = solver._select_model(dataset.neural, - session_id=session_id) - assert offset.left == offset_.left and offset.right == offset_.right - assert model == model_ - - -@pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", - multi_session_tests_select_model) -def test_select_model_multi_session(data_name, model_name, session_id, - solver_initfunc): - dataset = cebra.datasets.init(data_name) - model = nn.ModuleList([ - create_model(model_name, dataset.input_dimension) - for dataset in dataset.iter_sessions() - ]) - - offset = model[0].get_offset() - solver = solver_initfunc(model=model, - criterion=cebra.models.InfoNCE(), - optimizer=torch.optim.Adam(model.parameters(), - lr=1e-3)) - - loader_kwargs = dict(num_steps=10, batch_size=32) - loader = cebra.data.ContinuousMultiSessionDataLoader( - dataset, **loader_kwargs) - solver.fit(loader) - - for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): - inputs = dataset_.neural - - if session_id is None or session_id >= dataset.num_sessions: - with pytest.raises(RuntimeError): - solver._select_model(inputs, session_id=session_id) - elif i != session_id: - with pytest.raises(ValueError): - solver._select_model(inputs, session_id=session_id) - else: - model_, offset_ = solver._select_model(inputs, - session_id=session_id) - assert offset.left == offset_.left and offset.right == offset_.right - assert model == model_ - +# @pytest.mark.parametrize( +# "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", +# [ +# # Test case 1: No padding +# (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 1, +# torch.tensor([[1, 2]])), # first batch +# (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 2, +# torch.tensor([[1, 2], [3, 4]])), # first batch +# (torch.tensor([[1, 2], [3, 4]]), False, None, 1, 2, +# torch.tensor([[3, 4]])), # last batch + +# # Test case 2: First batch with padding +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(1, 1), +# 0, +# 2, +# torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6]]), +# ), +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(1, 1), +# 0, +# 3, +# torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# ), + +# # Test case 3: Last batch with padding +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(0, 1), +# 1, +# 3, +# torch.tensor([[4, 5, 6], [7, 8, 9]]), +# ), +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(1, 3), +# 1, +# 3, +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [7, 8, 9], [7, 8, 9] +# ]), +# ), + +# # Test case 4: Middle batch with padding +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(0, 1), +# 1, +# 2, +# torch.tensor([[4, 5, 6]]), +# ), +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(0, 2), +# 1, +# 2, +# torch.tensor([[4, 5, 6], [7, 8, 9]]), +# ), +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(1, 1), +# 1, +# 2, +# torch.tensor([[1, 2, 3], [4, 5, 6]]), +# ), +# ( +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# True, +# cebra.data.Offset(1, 2), +# 1, +# 2, +# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), +# ), + +# # Examples that throw an error: + +# # Padding without offset (should raise an error) +# (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), +# # Negative start_batch_idx or end_batch_idx (should raise an error) +# (torch.tensor([[1, 2]]), False, None, -1, 2, ValueError), +# # out of bound indices because offset is too large +# (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( +# 5, 5), 1, 2, ValueError), +# ], +# ) +# def test__get_batch(inputs, add_padding, offset, start_batch_idx, +# end_batch_idx, expected_output): +# if expected_output == ValueError: +# with pytest.raises(ValueError): +# cebra.solver.base._get_batch(inputs, add_padding, offset, +# start_batch_idx, end_batch_idx) +# else: +# result = cebra.solver.base._get_batch(inputs, add_padding, offset, +# start_batch_idx, +# end_batch_idx) +# assert torch.equal(result, expected_output) + +# @pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", +# single_session_tests_select_model + +# single_session_hybrid_tests_select_model) +# def test_select_model_single_session(data_name, model_name, session_id, +# solver_initfunc): +# dataset = cebra.datasets.init(data_name) +# model = create_model(model_name, dataset.input_dimension) +# offset = model.get_offset() +# solver = solver_initfunc(model=model, criterion=None, optimizer=None) + +# if session_id is not None and session_id > 0: +# with pytest.raises(RuntimeError): +# solver._select_model(dataset.neural, session_id=session_id) +# else: +# model_, offset_ = solver._select_model(dataset.neural, +# session_id=session_id) +# assert offset.left == offset_.left and offset.right == offset_.right +# assert model == model_ + +# @pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", +# multi_session_tests_select_model) +# def test_select_model_multi_session(data_name, model_name, session_id, +# solver_initfunc): +# dataset = cebra.datasets.init(data_name) +# model = nn.ModuleList([ +# create_model(model_name, dataset.input_dimension) +# for dataset in dataset.iter_sessions() +# ]) + +# offset = model[0].get_offset() +# solver = solver_initfunc(model=model, +# criterion=cebra.models.InfoNCE(), +# optimizer=torch.optim.Adam(model.parameters(), +# lr=1e-3)) + +# loader_kwargs = dict(num_steps=10, batch_size=32) +# loader = cebra.data.ContinuousMultiSessionDataLoader( +# dataset, **loader_kwargs) +# solver.fit(loader) + +# for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): +# inputs = dataset_.neural + +# if session_id is None or session_id >= dataset.num_sessions: +# with pytest.raises(RuntimeError): +# solver._select_model(inputs, session_id=session_id) +# elif i != session_id: +# with pytest.raises(ValueError): +# solver._select_model(inputs, session_id=session_id) +# else: +# model_, offset_ = solver._select_model(inputs, +# session_id=session_id) +# assert offset.left == offset_.left and offset.right == offset_.right +# assert model == model_ #this is a very crucial test. should be checked for different choices of offsets, # dataset sizes (also edge cases like dataset size 1001 and batch size 1000 -> is the padding properly handled?) @@ -373,7 +369,7 @@ def test_select_model_multi_session(data_name, model_name, session_id, "offset40-model-4x-subsample", #"offset1-model", "offset10-model", ] # there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model. -batch_size_inference = [23432] # 99_999 +batch_size_inference = [40_000, 99_990, 99_999] # 99_999 single_session_tests_transform = [] for padding in [True, False]: @@ -429,31 +425,25 @@ def test_batched_transform_singlesession( offset_ = model.get_offset() padding_left = offset_.left if padding else 0 - if len(offset_) < 2 and padding: - pytest.skip("not relevant for now.") - with pytest.raises(ValueError): - solver.transform(inputs=loader.dataset.neural, - pad_before_transform=padding) + #if len(offset_) < 2 and padding: + # pytest.skip("not relevant for now.") + # with pytest.raises(ValueError): + # solver.transform(inputs=loader.dataset.neural, + # pad_before_transform=padding) + # + # with pytest.raises(ValueError): + # solver.transform(inputs=loader.dataset.neural, + # batch_size=batch_size, + # pad_before_transform=padding) + #TODO: this wont work in the case where the data is less than + #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 + if smallest_batch_length <= len(offset_): with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, batch_size=batch_size, pad_before_transform=padding) - # NOTE: We need to add padding_left because if padding is True, - # the batch size is not "smallest_batch_length". and the smallest - # batch will always be at the end so the last batch we need to add - # offset.left. - #TODO: this wont work in the case where the data is less than - #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 - - #elif smallest_batch_length + padding_left <= len(offset_): - # print('here') - # with pytest.raises(ValueError): - # solver.transform(inputs=loader.dataset.neural, - # batch_size=batch_size, - # pad_before_transform=padding) - else: embedding_batched = solver.transform(inputs=loader.dataset.neural, batch_size=batch_size, @@ -517,20 +507,20 @@ def test_batched_transform_multisession(data_name, model_name, padding, # Transform each session with the right model, by providing the corresponding session ID for i, inputs in enumerate(dataset.iter_sessions()): - if len(offset_) < 2 and padding: - with pytest.raises(ValueError): - embedding = solver.transform(inputs=inputs.neural, - session_id=i, - pad_before_transform=padding) - - with pytest.raises(ValueError): - embedding_batched = solver.transform( - inputs=inputs.neural, - session_id=i, - pad_before_transform=padding, - batch_size=batch_size) - - elif smallest_batch_length + padding_left <= len(offset_): + # if len(offset_) < 2 and padding: + # with pytest.raises(ValueError): + # embedding = solver.transform(inputs=inputs.neural, + # session_id=i, + # pad_before_transform=padding) + # + # with pytest.raises(ValueError): + # embedding_batched = solver.transform( + # inputs=inputs.neural, + # session_id=i, + # pad_before_transform=padding, + # batch_size=batch_size) + + if smallest_batch_length <= len(offset_): with pytest.raises(ValueError): solver.transform(inputs=inputs.neural, batch_size=batch_size, From 5e7a14c3cc80f3d35887a38cccb6a33b580bef3a Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 24 Nov 2023 13:22:45 +0100 Subject: [PATCH 25/45] remove float16 --- cebra/integrations/sklearn/cebra.py | 9 +++++---- cebra/integrations/sklearn/utils.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 1121ee98..555966fb 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1235,7 +1235,7 @@ def transform(self, # Input validation #TODO: if inputs are in cuda, then it throws an error, deal with this. X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) - input_dtype = X.dtype + #input_dtype = X.dtype if isinstance(X, np.ndarray): X = torch.from_numpy(X) @@ -1248,10 +1248,11 @@ def transform(self, session_id=session_id, batch_size=batch_size) - if input_dtype == "float64": - return output.astype(input_dtype) + #TODO: check if this is safe. + return output.numpy(force=True) - return output + #if input_dtype == "float64": + # return output.astype(input_dtype) def fit_transform( self, diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 455213a3..0ec01aa1 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -78,7 +78,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: X, accept_sparse=False, accept_large_sparse=False, - dtype=("float16", "float32", "float64"), + # NOTE: remove float16 because F.pad does not allow float16. + dtype=("float32", "float64"), order=None, copy=False, force_all_finite=True, From 928d88247c94a0d42fc159ef1c233999262ebbe0 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 27 Nov 2023 12:09:18 +0100 Subject: [PATCH 26/45] change argument position --- cebra/integrations/sklearn/cebra.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 555966fb..39f73aa2 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1200,18 +1200,12 @@ def fit( def transform(self, X: Union[npt.NDArray, torch.Tensor], - pad_before_transform: bool = True, batch_size: Optional[int] = None, session_id: Optional[int] = None) -> npt.NDArray: """Transform an input sequence and return the embedding. Args: X: A numpy array or torch tensor of size ``time x dimension``. - pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would - only be ``n-m+1`` steps long. batch_size: session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for multisession, set to ``None`` for single session. @@ -1244,7 +1238,7 @@ def transform(self, with torch.no_grad(): output = self.solver_.transform( inputs=X, - pad_before_transform=pad_before_transform, + pad_before_transform=self.pad_before_transform, session_id=session_id, batch_size=batch_size) From 07bac1cbe39c162f7ab1709c769f71d68167fe94 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 27 Nov 2023 12:12:00 +0100 Subject: [PATCH 27/45] clean test --- tests/test_solver.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 1661003a..0b0eb823 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -425,17 +425,6 @@ def test_batched_transform_singlesession( offset_ = model.get_offset() padding_left = offset_.left if padding else 0 - #if len(offset_) < 2 and padding: - # pytest.skip("not relevant for now.") - # with pytest.raises(ValueError): - # solver.transform(inputs=loader.dataset.neural, - # pad_before_transform=padding) - # - # with pytest.raises(ValueError): - # solver.transform(inputs=loader.dataset.neural, - # batch_size=batch_size, - # pad_before_transform=padding) - #TODO: this wont work in the case where the data is less than #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 if smallest_batch_length <= len(offset_): @@ -507,19 +496,6 @@ def test_batched_transform_multisession(data_name, model_name, padding, # Transform each session with the right model, by providing the corresponding session ID for i, inputs in enumerate(dataset.iter_sessions()): - # if len(offset_) < 2 and padding: - # with pytest.raises(ValueError): - # embedding = solver.transform(inputs=inputs.neural, - # session_id=i, - # pad_before_transform=padding) - # - # with pytest.raises(ValueError): - # embedding_batched = solver.transform( - # inputs=inputs.neural, - # session_id=i, - # pad_before_transform=padding, - # batch_size=batch_size) - if smallest_batch_length <= len(offset_): with pytest.raises(ValueError): solver.transform(inputs=inputs.neural, From 0823b54efa549ceed51b1cc2fd25d82d8eb5afa0 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 27 Nov 2023 12:18:15 +0100 Subject: [PATCH 28/45] clean test --- tests/test_solver.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_solver.py b/tests/test_solver.py index 0b0eb823..f84edeb5 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -425,8 +425,6 @@ def test_batched_transform_singlesession( offset_ = model.get_offset() padding_left = offset_.left if padding else 0 - #TODO: this wont work in the case where the data is less than - #the offset from the beginning, i.e len(data) = 10, len(offset) = 10 if smallest_batch_length <= len(offset_): with pytest.raises(ValueError): solver.transform(inputs=loader.dataset.neural, @@ -477,11 +475,9 @@ def test_batched_transform_multisession(data_name, model_name, padding, smallest_batch_length = n_samples - batch_size offset_ = model[0].get_offset() - #print("here!", smallest_batch_length, len(offset_)) padding_left = offset_.left if padding else 0 for d in dataset._datasets: d.offset = offset_ - #dataset._datasets[0].offset = cebra.data.Offset(0, 1) loader_kwargs = dict(num_steps=10, batch_size=32) loader = loader_initfunc(dataset, **loader_kwargs) From 9fe3af351cddabdc37886bcea1f251997be03bce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Tue, 26 Mar 2024 20:46:16 +0100 Subject: [PATCH 29/45] Fix warning --- cebra/solver/base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 2cecab08..643ae8b8 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -111,12 +111,18 @@ def _add_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset, start_batch_idx: int, end_batch_idx: int, number_of_samples: int): + reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1) + if start_batch_idx == 0: # First batch - batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T + batched_data = F.pad(batched_data.permute(*reversed_dims), + (offset.left, 0), 'replicate').permute(*reversed_dims) + #batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T elif end_batch_idx == number_of_samples: # Last batch - batched_data = F.pad(batched_data.T, (0, offset.right - 1), - 'replicate').T + batched_data = F.pad(batched_data.permute(*reversed_dims), + (0, offset.right - 1), 'replicate').permute(*reversed_dims) + #batched_data = F.pad(batched_data.T, (0, offset.right - 1), 'replicate').T + return batched_data From b417a239ed01e32f16d85ef9a7005987f8e60b7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:42:53 +0200 Subject: [PATCH 30/45] Improve modularity remove duplicate code and todos --- cebra/integrations/sklearn/cebra.py | 44 +--- cebra/integrations/sklearn/metrics.py | 3 +- cebra/solver/base.py | 329 +++++++++++++++----------- cebra/solver/multi_session.py | 66 +++++- cebra/solver/single_session.py | 95 +++++++- 5 files changed, 359 insertions(+), 178 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 39f73aa2..adabd874 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -791,33 +791,7 @@ def _configure_for_all( def _select_model(self, X: Union[npt.NDArray, torch.Tensor], session_id: int): - # Choose the model and get its corresponding offset - if self.num_sessions is not None: # multisession implementation - if session_id is None: - raise RuntimeError( - "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." - ) - if session_id >= self.num_sessions or session_id < 0: - raise RuntimeError( - f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." - ) - if self.n_features_[session_id] != X.shape[1]: - raise ValueError( - f"Invalid input shape: model for session {session_id} requires an input of shape" - f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})." - ) - - model = self.model_[session_id] - model.to(self.device_) - else: # single session - if session_id is not None and session_id > 0: - raise RuntimeError( - f"Invalid session_id {session_id}: single session models only takes an optional null session_id." - ) - model = self.model_ - - offset = model.get_offset() - return model, offset + return self.solver_._select_model(X, session_id=session_id) def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): """Check that the input labels are compatible with the labels used to fit the model. @@ -1224,16 +1198,16 @@ def transform(self, >>> embedding = cebra_model.transform(dataset) """ - + self.solver_._check_is_session_id_valid(session_id=session_id) sklearn_utils_validation.check_is_fitted(self, "n_features_") - # Input validation - #TODO: if inputs are in cuda, then it throws an error, deal with this. + + if torch.is_tensor(X) and X.device.type == "cuda": + X = X.detach().cpu() + X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) - #input_dtype = X.dtype if isinstance(X, np.ndarray): X = torch.from_numpy(X) - # TODO: which type and device should be put there? with torch.no_grad(): output = self.solver_.transform( @@ -1242,11 +1216,7 @@ def transform(self, session_id=session_id, batch_size=batch_size) - #TODO: check if this is safe. - return output.numpy(force=True) - - #if input_dtype == "float64": - # return output.astype(input_dtype) + return output.detach().cpu().numpy() def fit_transform( self, diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 9712d021..59a961b3 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -83,7 +83,8 @@ def infonce_loss( f"got {len(y[0])} sessions.") model, _ = cebra_model._select_model( - X, session_id) # check session_id validity and corresponding model + X, session_id=session_id + ) # check session_id validity and corresponding model cebra_model._check_labels_types(y, session_id=session_id) dataset, is_multisession = cebra_model._prepare_data(X, y) # single session diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 643ae8b8..5f3acb35 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,7 +32,8 @@ import abc import os -from typing import Callable, Dict, Iterable, List, Literal, Optional, Union +from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, + Union) import literate_dataclasses as dataclasses import numpy as np @@ -51,37 +52,35 @@ from cebra.solver.util import ProgressBar -def _inference_transform(model, inputs): - - #TODO: I am not sure what is the best way with dealing with the types and - # device when using batched inference. This works for now. - inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device) - - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - inputs = inputs.transpose(1, 0).unsqueeze(0) - output = model(inputs).squeeze(0).transpose(1, 0) - else: - output = model(inputs) - return output - - -def _check_indices(start_batch_idx, end_batch_idx, offset, num_samples): +def _check_indices(batch_start_idx: int, batch_end_idx: int, + offset: cebra.data.Offset, num_samples: int): + """Check that indexes in a batch are in a correct range. + + First and last index must be positive integers, smaller than the total length of inputs + in the dataset, the first index must be smaller than the last and the batch size cannot + be smaller than the offset of the model. + + Args: + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. + offset: Model offset. + num_samples: Total number of samples in the input. + """ - if start_batch_idx < 0 or end_batch_idx < 0: + if batch_start_idx < 0 or batch_end_idx < 0: raise ValueError( - f"start_batch_idx ({start_batch_idx}) and end_batch_idx ({end_batch_idx}) must be non-negative." + f"batch_start_idx ({batch_start_idx}) and batch_end_idx ({batch_end_idx}) must be positive integers." ) - if start_batch_idx > end_batch_idx: + if batch_start_idx > batch_end_idx: raise ValueError( - f"start_batch_idx ({start_batch_idx}) cannot be greater than end_batch_idx ({end_batch_idx})." + f"batch_start_idx ({batch_start_idx}) cannot be greater than batch_end_idx ({batch_end_idx})." ) - if end_batch_idx > num_samples: + if batch_end_idx > num_samples: raise ValueError( - f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({num_samples})." + f"batch_end_idx ({batch_end_idx}) cannot exceed the length of inputs ({num_samples})." ) - batch_size_lenght = end_batch_idx - start_batch_idx + batch_size_lenght = batch_end_idx - batch_start_idx if batch_size_lenght <= len(offset): raise ValueError( f"The batch has length {batch_size_lenght} which " @@ -91,45 +90,123 @@ def _check_indices(start_batch_idx, end_batch_idx, offset, num_samples): def _get_batch(inputs: torch.Tensor, offset: cebra.data.Offset, - start_batch_idx: int, end_batch_idx: int) -> torch.Tensor: + batch_start_idx: int, batch_end_idx: int) -> torch.Tensor: + """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`. - if start_batch_idx == 0: # First batch - indices = start_batch_idx, (end_batch_idx + offset.right - 1) + Args: + inputs: Input data. + offset: Model offset. + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. - elif end_batch_idx == len(inputs): # Last batch - indices = (start_batch_idx - offset.left), end_batch_idx + Returns: + The batch. + """ - else: # Middle batches - indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1 + if batch_start_idx == 0: # First batch + indices = batch_start_idx, (batch_end_idx + offset.right - 1) + elif batch_end_idx == len(inputs): # Last batch + indices = (batch_start_idx - offset.left), batch_end_idx + else: + indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1 _check_indices(indices[0], indices[1], offset, len(inputs)) batched_data = inputs[slice(*indices)] return batched_data -def _add_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset, - start_batch_idx: int, end_batch_idx: int, - number_of_samples: int): +def _add_batched_zero_padding(batched_data: torch.Tensor, + offset: cebra.data.Offset, batch_start_idx: int, + batch_end_idx: int, + num_samples: int) -> torch.Tensor: + """Add zero padding to the input data before inference. - reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1) - - if start_batch_idx == 0: # First batch - batched_data = F.pad(batched_data.permute(*reversed_dims), - (offset.left, 0), 'replicate').permute(*reversed_dims) - #batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T + Args: + batched_data: Data to apply the inference on. + offset (cebra.data.Offset): _description_ + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. + num_samples (int): Total number of samples in the data. - elif end_batch_idx == number_of_samples: # Last batch - batched_data = F.pad(batched_data.permute(*reversed_dims), - (0, offset.right - 1), 'replicate').permute(*reversed_dims) - #batched_data = F.pad(batched_data.T, (0, offset.right - 1), 'replicate').T + Returns: + The padded batch. + """ + reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1) + if batch_start_idx == 0: # First batch + batched_data = F.pad(batched_data.permute(*reversed_dims), + (offset.left, 0), + 'replicate').permute(*reversed_dims) + elif batch_end_idx == num_samples: # Last batch + batched_data = F.pad(batched_data.permute(*reversed_dims), + (0, offset.right - 1), + 'replicate').permute(*reversed_dims) return batched_data -def _batched_transform(model, inputs: torch.Tensor, batch_size: int, - pad_before_transform: bool, +def _inference_transform(model: cebra.models.Model, + inputs: torch.Tensor) -> torch.Tensor: + """Compute the embedding on the inputs using the model provided. + + Args: + model: Model to use for inference. + inputs: Data. + + Returns: + The embedding. + """ + #TODO(rodrigo): I am not sure what is the best way with dealing with the types and + # device when using batched inference. This works for now. + inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + inputs = inputs.transpose(1, 0).unsqueeze(0) + output = model(inputs).squeeze(0).transpose(1, 0) + else: + output = model(inputs) + return output + + +def _transform( + model: cebra.models.Model, + inputs: torch.Tensor, + pad_before_transform: bool, + offset: cebra.data.Offset, +) -> torch.Tensor: + """Compute the embedding. + + Args: + model: The model to use for inference. + inputs: Input data. + pad_before_transform: If True, the input data is zero padded before inference. + offset: Model offset. + + Returns: + The embedding. + """ + if pad_before_transform: + inputs = F.pad(inputs.T, (offset.left, offset.right - 1), 'replicate').T + output = _inference_transform(model, inputs) + return output + + +def _batched_transform(model: cebra.models.Model, inputs: torch.Tensor, + batch_size: int, pad_before_transform: bool, offset: cebra.data.Offset) -> torch.Tensor: + """Compute the embedding on batched inputs. + + Args: + model: The model to use for inference. + inputs: Input data. + batch_size: Integer corresponding to the batch size. + pad_before_transform: If True, the input data is zero padded before inference. + offset: Model offset. + + Returns: + The embedding. + """ class IndexDataset(Dataset): @@ -146,19 +223,20 @@ def __getitem__(self, idx): index_dataloader = DataLoader(index_dataset, batch_size=batch_size) output = [] - for batch_id, index_batch in enumerate(index_dataloader): - start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + 1 + for index_batch in index_dataloader: + batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1 batched_data = _get_batch(inputs=inputs, offset=offset, - start_batch_idx=start_batch_idx, - end_batch_idx=end_batch_idx) + batch_start_idx=batch_start_idx, + batch_end_idx=batch_end_idx) if pad_before_transform: - batched_data = _add_zero_padding(batched_data=batched_data, - offset=offset, - start_batch_idx=start_batch_idx, - end_batch_idx=end_batch_idx, - number_of_samples=len(inputs)) + batched_data = _add_batched_zero_padding( + batched_data=batched_data, + offset=offset, + batch_start_idx=batch_start_idx, + batch_end_idx=batch_end_idx, + num_samples=len(inputs)) output_batch = _inference_transform(model, batched_data) output.append(output_batch) @@ -265,13 +343,9 @@ def num_parameters(self) -> int: """Total number of parameters in the encoder and criterion.""" return sum(p.numel() for p in self.parameters()) - def parameters(self): - """Iterate over all parameters.""" - for parameter in self.model.parameters(): - yield parameter - - for parameter in self.criterion.parameters(): - yield parameter + @abc.abstractmethod + def parameters(self, session_id: Optional[int] = None): + raise NotImplementedError def _get_loader(self, loader): return ProgressBar( @@ -279,6 +353,10 @@ def _get_loader(self, loader): "tqdm" if self.tqdm_on else "off", ) + @abc.abstractmethod + def _set_fitted_params(self, loader: cebra.data.Loader): + raise NotImplementedError + def fit( self, loader: cebra.data.Loader, @@ -306,14 +384,6 @@ def fit( TODO: * Refine the API here. Drop the validation entirely, and implement this via a hook? """ - - self.num_sessions = loader.dataset.num_sessions if hasattr( - loader.dataset, "num_sessions") else None - self.n_features = ([ - loader.dataset.get_input_dimension(session_id) - for session_id in range(loader.dataset.num_sessions) - ] if self.num_sessions is not None else loader.dataset.input_dimension) - self.to(loader.device) iterator = self._get_loader(loader) @@ -341,6 +411,8 @@ def fit( save_hook(num_steps, self) self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") + self._set_fitted_params(loader) + def step(self, batch: cebra.data.Batch) -> dict: """Perform a single gradient update. @@ -377,8 +449,9 @@ def validation(self, Args: loader: Data loader, which is an iterator over `cebra.data.Batch` instances. Each batch contains reference, positive and negative input samples. - session_id: The session ID, an integer between 0 and the number of sessions in the - multisession model, set to None for single session. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. Returns: Loss averaged over iterations on data batch. @@ -412,56 +485,43 @@ def decoding(self, train_loader, valid_loader): ) return decode_metric - def _select_model(self, inputs: torch.Tensor, session_id: int): - #NOTE: In the torch API the inputs will be a torch tensor. Then in the - # sklearn API we will convert it to numpy array. - """ Select the right model based on the type of solver we have.""" - - if self.num_sessions is not None: # multisession implementation - if session_id is None: - raise RuntimeError( - "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." - ) - if session_id >= self.num_sessions or session_id < 0: - raise RuntimeError( - f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." - ) - if self.n_features[session_id] != inputs.shape[1]: - raise ValueError( - f"Invalid input shape: model for session {session_id} requires an input of shape" - f"(n_samples, {self.n_features[session_id]}), got (n_samples, {inputs.shape[1]})." - ) - - model = self.model[session_id] - - else: # single session - if session_id is not None and session_id > 0: - raise RuntimeError( - f"Invalid session_id {session_id}: single session models only takes an optional null session_id." - ) - - if isinstance( - self, - cebra.solver.single_session.SingleSessionHybridSolver): - # NOTE: This is different from the sklearn API implementation. The issue is that here the - # model is a cebra.models.MultiObjective instance, and therefore to do inference I need - # to get the module inside this model. - model = self.model.module - else: - model = self.model + @abc.abstractmethod + def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): + """Check that the inputs can be infered using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + raise NotImplementedError - offset = model.get_offset() - return model, offset + @abc.abstractmethod + def _check_is_session_id_valid(self, session_id: Optional[int] = None): + raise NotImplementedError - @torch.no_grad() - def _transform(self, model, inputs, offset, - pad_before_transform) -> torch.Tensor: + @abc.abstractmethod + def _select_model( + self, inputs: Union[torch.Tensor, + List[torch.Tensor]], session_id: Optional[int] + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. - if pad_before_transform: - inputs = F.pad(inputs.T, (offset.left, offset.right - 1), - 'replicate').T - output = _inference_transform(model, inputs) - return output + Returns: + The model (first returns) and the offset of the model (second returns). + """ + raise NotImplementedError @torch.no_grad() def transform(self, @@ -489,17 +549,16 @@ def transform(self, Returns: The output embedding. """ - #TODO: add check like sklearn? - # #sklearn_utils_validation.check_is_fitted(self, "n_features_") + if not hasattr(self, "n_features"): + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") model, offset = self._select_model(inputs, session_id) - model.eval() - #TODO: should we add this error? - #if len(offset) < 2 and pad_before_transform: - # raise ValueError( - # "Padding does not make sense when the offset of the model is < 2" - # ) + if len(offset) < 2 and pad_before_transform: + pad_before_transform = False + model.eval() if batch_size is not None: output = _batched_transform( model=model, @@ -508,12 +567,11 @@ def transform(self, batch_size=batch_size, pad_before_transform=pad_before_transform, ) - else: - output = self._transform(model=model, - inputs=inputs, - offset=offset, - pad_before_transform=pad_before_transform) + output = _transform(model=model, + inputs=inputs, + offset=offset, + pad_before_transform=pad_before_transform) return output @@ -539,6 +597,7 @@ def load(self, logdir, filename="checkpoint.pth"): """Load the experiment from its checkpoint file. Args: + logdir: Log directory. filename (str): Checkpoint name for loading the experiment. """ @@ -549,6 +608,12 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) + if hasattr(self.model, "n_features"): + n_features = self.model.n_features + self.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) + def save(self, logdir, filename="checkpoint_last.pth"): """Save the model and optimizer params. diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 7f103708..666dafb8 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -43,6 +43,15 @@ class MultiSessionSolver(abc_.Solver): _variant_name = "multi-session" + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters.""" + self._check_is_session_id_valid(session_id=session_id) + for parameter in self.model[session_id].parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + def _mix(self, array: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: shape = array.shape n, m = shape[:2] @@ -116,6 +125,61 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: negative=neg.view(-1, num_features), ) + def _set_fitted_params(self, loader: cebra.data.Loader): + self.num_sessions = loader.dataset.num_sessions + self.n_features = [ + loader.dataset.get_input_dimension(session_id) + for session_id in range(loader.dataset.num_sessions) + ] + + def _check_is_inputs_valid(self, inputs: torch.Tensor, + session_id: Optional[int]): + """Check that the inputs can be infered using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + if self.n_features[session_id] != inputs.shape[1]: + raise ValueError( + f"Invalid input shape: model for session {session_id} requires an input of shape" + f"(n_samples, {self.n_features[session_id]}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid(self, session_id: Optional[int]): + if session_id is None: + raise RuntimeError( + "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." + ) + + def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]): + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_inputs_valid(inputs, session_id=session_id) + + model = self.model[session_id] + offset = model.get_offset() + return model, offset + def validation(self, loader, session_id: Optional[int] = None): """Compute score of the model on data. @@ -147,7 +211,7 @@ def validation(self, loader, session_id: Optional[int] = None): @register("multi-session-aux") -class MultiSessionAuxVariableSolver(abc_.Solver): +class MultiSessionAuxVariableSolver(MultiSessionSolver): """Multi session training, contrasting neural data against behavior.""" _variant_name = "multi-session-aux" diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index ded526e9..0ac603e2 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -21,11 +21,8 @@ # """Single session solvers embed a single pair of time series.""" -import abc import copy -import os -from collections.abc import Iterable -from typing import List +from typing import List, Optional, Tuple, Union import literate_dataclasses as dataclasses import torch @@ -42,11 +39,72 @@ class SingleSessionSolver(abc_.Solver): """Single session training with a symmetric encoder. This solver assumes that reference, positive and negative samples - are processed by the same features encoder. + are processed by the same features encoder and that a single session + is provided to that encoder. """ _variant_name = "single-session" + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters.""" + self._check_is_session_id_valid(session_id=session_id) + for parameter in self.model.parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + + def _set_fitted_params(self, loader: cebra.data.Loader): + self.num_sessions = None + self.n_features = loader.dataset.input_dimension + + def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): + """Check that the inputs can be infered using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + if self.n_features != inputs.shape[1]: + raise ValueError( + f"Invalid input shape: model for session {session_id} requires an input of shape" + f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid(self, session_id: Optional[int] = None): + if session_id is not None and session_id > 0: + raise RuntimeError( + f"Invalid session_id {session_id}: single session models only takes an optional null session_id." + ) + + def _select_model( + self, inputs: Union[torch.Tensor, + List[torch.Tensor]], session_id: Optional[int] + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + self._check_is_inputs_valid(inputs, session_id=session_id) + self._check_is_session_id_valid(session_id=session_id) + + model = self.model + offset = model.get_offset() + return model, offset + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, computes the feature representation/embedding. @@ -94,7 +152,7 @@ def get_embedding(self, data: torch.Tensor) -> torch.Tensor: @register("single-session-aux") @dataclasses.dataclass -class SingleSessionAuxVariableSolver(abc_.Solver): +class SingleSessionAuxVariableSolver(SingleSessionSolver): """Single session training for reference and positive/negative samples. This solver processes reference samples with a model different from @@ -131,7 +189,7 @@ def _inference(self, batch): @register("single-session-hybrid") @dataclasses.dataclass -class SingleSessionHybridSolver(abc_.MultiobjectiveSolver): +class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver): """Single session training, contrasting neural data against behavior.""" _variant_name = "single-session-hybrid" @@ -149,6 +207,29 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: behavior_neg), cebra.data.Batch( time_ref, time_pos, time_neg) + def _select_model( + self, inputs: Union[torch.Tensor, + List[torch.Tensor]], session_id: Optional[int] + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + self._check_is_inputs_valid(inputs, session_id=session_id) + self._check_is_session_id_valid(session_id=session_id) + + model = self.model.module + offset = model.get_offset() + return model, offset + @register("single-session-full") @dataclasses.dataclass From 83c16691d081c90e51b0e90d6d4d306f74457d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:41:44 +0200 Subject: [PATCH 31/45] Add tests to solver --- cebra/data/base.py | 4 + cebra/data/multi_session.py | 15 +- cebra/data/single_session.py | 14 +- cebra/integrations/sklearn/cebra.py | 4 +- cebra/solver/base.py | 90 +++-- cebra/solver/single_session.py | 5 +- tests/test_solver.py | 592 ++++++++++++++++++---------- 7 files changed, 458 insertions(+), 266 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index d2ee47b5..874ed58b 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -196,6 +196,7 @@ def load_batch(self, index: BatchIndex) -> Batch: """ raise NotImplementedError() + @abc.abstractmethod def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. @@ -205,6 +206,7 @@ def configure_for(self, model: "cebra.models.Model"): Args: model: The model to configure the dataset for. """ + raise NotImplementedError self.offset = model.get_offset() @@ -230,6 +232,8 @@ class Loader(abc.ABC, cebra.io.HasDevice): doc="""A dataset instance specifying a ``__getitem__`` function.""", ) + time_offset: int = dataclasses.field(default=10) + num_steps: int = dataclasses.field( default=None, doc= diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 8cd74286..a8d56d10 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -111,6 +111,18 @@ def configure_for(self, model): for session in self.iter_sessions(): session.configure_for(model) + def configure_for(self, model: "cebra.models.Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + for i, session in enumerate(self.iter_sessions()): + session.configure_for(model[i]) + @dataclasses.dataclass class MultiSessionLoader(cebra_data.Loader): @@ -121,8 +133,6 @@ class MultiSessionLoader(cebra_data.Loader): dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`. """ - time_offset: int = dataclasses.field(default=10) - def __post_init__(self): super().__post_init__() self.sampler = cebra_distr.MultisessionSampler(self.dataset, @@ -151,7 +161,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader): """Contrastive learning conditioned on a continuous behavior variable.""" conditional: str = "time_delta" - time_offset: int = dataclasses.field(default=10) @property def index(self): diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index c27b10f5..71cd0c3e 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -72,6 +72,17 @@ def load_batch(self, index: BatchIndex) -> Batch: reference=self[index.reference], ) + def configure_for(self, model: "cebra.models.Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + self.offset = model.get_offset() + @dataclasses.dataclass class DiscreteDataLoader(cebra_data.Loader): @@ -192,7 +203,6 @@ class ContinuousDataLoader(cebra_data.Loader): and become equivalent to time contrastive learning. """, ) - time_offset: int = dataclasses.field(default=10) delta: float = dataclasses.field(default=0.1) def __post_init__(self): @@ -274,7 +284,6 @@ class MixedDataLoader(cebra_data.Loader): """ conditional: str = dataclasses.field(default="time_delta") - time_offset: int = dataclasses.field(default=10) @property def dindex(self): @@ -337,7 +346,6 @@ class HybridDataLoader(cebra_data.Loader): """ conditional: str = dataclasses.field(default="time_delta") - time_offset: int = dataclasses.field(default=10) delta: float = dataclasses.field(default=0.1) @property diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index adabd874..4240074f 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -776,8 +776,6 @@ def _configure_for_all( f"receptive fields/offsets larger than 1 via the sklearn API. " f"Please use a different model, or revert to the pytorch " f"API for training.") - - d.configure_for(model[n]) else: if not isinstance(model, cebra.models.ConvolutionalModelMixin): if len(model.get_offset()) > 1: @@ -787,7 +785,7 @@ def _configure_for_all( f"Please use a different model, or revert to the pytorch " f"API for training.") - dataset.configure_for(model) + dataset.configure_for(model) def _select_model(self, X: Union[npt.NDArray, torch.Tensor], session_id: int): diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 5f3acb35..ec33f23e 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -37,6 +37,7 @@ import literate_dataclasses as dataclasses import numpy as np +import numpy.typing as npt import torch import torch.nn.functional as F import tqdm @@ -89,32 +90,6 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, ) -def _get_batch(inputs: torch.Tensor, offset: cebra.data.Offset, - batch_start_idx: int, batch_end_idx: int) -> torch.Tensor: - """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`. - - Args: - inputs: Input data. - offset: Model offset. - batch_start_idx: Index of the first sample in the batch. - batch_end_idx: Index of the first sample in the batch. - - Returns: - The batch. - """ - - if batch_start_idx == 0: # First batch - indices = batch_start_idx, (batch_end_idx + offset.right - 1) - elif batch_end_idx == len(inputs): # Last batch - indices = (batch_start_idx - offset.left), batch_end_idx - else: - indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1 - - _check_indices(indices[0], indices[1], offset, len(inputs)) - batched_data = inputs[slice(*indices)] - return batched_data - - def _add_batched_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset, batch_start_idx: int, batch_end_idx: int, @@ -145,6 +120,45 @@ def _add_batched_zero_padding(batched_data: torch.Tensor, return batched_data +def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset], + batch_start_idx: int, batch_end_idx: int, + pad_before_transform: bool) -> torch.Tensor: + """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`. + + Args: + inputs: Input data. + offset: Model offset. + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. + pad_before_transform: If True zero-pad the batched data. + + Returns: + The batch. + """ + if offset is None: + raise ValueError(f"offset cannot be null.") + + if batch_start_idx == 0: # First batch + indices = batch_start_idx, (batch_end_idx + offset.right - 1) + elif batch_end_idx == len(inputs): # Last batch + indices = (batch_start_idx - offset.left), batch_end_idx + else: + indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1 + + _check_indices(indices[0], indices[1], offset, len(inputs)) + batched_data = inputs[slice(*indices)] + + if pad_before_transform: + batched_data = _add_batched_zero_padding( + batched_data=batched_data, + offset=offset, + batch_start_idx=batch_start_idx, + batch_end_idx=batch_end_idx, + num_samples=len(inputs)) + + return batched_data + + def _inference_transform(model: cebra.models.Model, inputs: torch.Tensor) -> torch.Tensor: """Compute the embedding on the inputs using the model provided. @@ -156,9 +170,7 @@ def _inference_transform(model: cebra.models.Model, Returns: The embedding. """ - #TODO(rodrigo): I am not sure what is the best way with dealing with the types and - # device when using batched inference. This works for now. - inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device) + inputs = inputs.float().to(next(model.parameters()).device) if isinstance(model, cebra.models.ConvolutionalModelMixin): # Fully convolutional evaluation, switch (T, C) -> (1, C, T) @@ -228,15 +240,8 @@ def __getitem__(self, idx): batched_data = _get_batch(inputs=inputs, offset=offset, batch_start_idx=batch_start_idx, - batch_end_idx=batch_end_idx) - - if pad_before_transform: - batched_data = _add_batched_zero_padding( - batched_data=batched_data, - offset=offset, - batch_start_idx=batch_start_idx, - batch_end_idx=batch_end_idx, - num_samples=len(inputs)) + batch_end_idx=batch_end_idx, + pad_before_transform=pad_before_transform) output_batch = _inference_transform(model, batched_data) output.append(output_batch) @@ -549,6 +554,15 @@ def transform(self, Returns: The output embedding. """ + if isinstance(inputs, list): + raise NotImplementedError( + "Inputs to transform() should be the data for a single session." + ) + + elif not isinstance(inputs, torch.Tensor): + raise ValueError( + f"Inputs should be a torch.Tensor, not {type(inputs)}.") + if not hasattr(self, "n_features"): raise ValueError( f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index 0ac603e2..b941a8ba 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -227,7 +227,10 @@ def _select_model( self._check_is_session_id_valid(session_id=session_id) model = self.model.module - offset = model.get_offset() + if hasattr(model, 'get_offset'): + offset = model.get_offset() + else: + offset = None return model, offset diff --git a/tests/test_solver.py b/tests/test_solver.py index f84edeb5..4bb17232 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -35,72 +35,121 @@ single_session_tests = [] for args in [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), + ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"), + ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"), + ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"), + ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"), + ("demo-continuous", cebra.data.ContinuousDataLoader, "offset10-model"), + ("demo-continuous", cebra.data.ContinuousDataLoader, "offset1-model"), + ("demo-mixed", cebra.data.MixedDataLoader, "offset10-model"), + ("demo-mixed", cebra.data.MixedDataLoader, "offset1-model"), ]: single_session_tests.append((*args, cebra.solver.SingleSessionSolver)) single_session_hybrid_tests = [] -for args in [("demo-continuous", cebra.data.HybridDataLoader)]: +for args in [("demo-continuous", cebra.data.HybridDataLoader, "offset10-model"), + ("demo-continuous", cebra.data.HybridDataLoader, "offset1-model")]: single_session_hybrid_tests.append( (*args, cebra.solver.SingleSessionHybridSolver)) multi_session_tests = [] -for args in [("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader)]: +for args in [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"), + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"), +]: multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) - # multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) -print(single_session_tests) +# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) -def _get_loader(data_name, loader_initfunc): - data = cebra.datasets.init(data_name) - kwargs = dict(num_steps=10, batch_size=32) +def _get_loader(data, loader_initfunc): + kwargs = dict(num_steps=5, batch_size=32) loader = loader_initfunc(data, **kwargs) return loader -def _make_model(dataset): - # TODO flexible input dimension - return nn.Sequential( - nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), - nn.Flatten(start_dim=1, end_dim=-1), - ) +OUTPUT_DIMENSION = 3 -def _make_behavior_model(dataset): +def _make_model(dataset, model_architecture="offset10-model"): # TODO flexible input dimension - return nn.Sequential( - nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), - nn.Flatten(start_dim=1, end_dim=-1), - ) + # return nn.Sequential( + # nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), + # nn.Flatten(start_dim=1, end_dim=-1), + # ) + return cebra.models.init(model_architecture, dataset.input_dimension, 32, + OUTPUT_DIMENSION) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_tests) -def test_single_session(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) - model = _make_model(loader.dataset) +# def _make_behavior_model(dataset): +# # TODO flexible input dimension +# return nn.Sequential( +# nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), +# nn.Flatten(start_dim=1, end_dim=-1), +# ) + + +@pytest.mark.parametrize( + "data_name, loader_initfunc, model_architecture, solver_initfunc", + single_session_tests) +def test_single_session(data_name, loader_initfunc, model_architecture, + solver_initfunc): + data = cebra.datasets.init(data_name) + loader = _get_loader(data, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, criterion=criterion, - optimizer=optimizer) + optimizer=optimizer, + tqdm_on=False) batch = next(iter(loader)) - assert batch.reference.shape == (32, loader.dataset.input_dimension, 10) + assert batch.reference.shape[:2] == (32, loader.dataset.input_dimension) log = solver.step(batch) assert isinstance(log, dict) + X = loader.dataset.neural + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X) + solver.fit(loader) + assert solver.num_sessions == None + assert solver.n_features == X.shape[1] + + embedding = solver.transform(X) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(torch.Tensor(X)) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, session_id=0) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + solver.transform(X.numpy()) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X, session_id=2) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_tests) -def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc): + for param in solver.parameters(): + assert isinstance(param, torch.Tensor) + + +@pytest.mark.parametrize( + "data_name, loader_initfunc, model_architecture, solver_initfunc", + single_session_tests) +def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, + solver_initfunc): return # TODO loader = _get_loader(data_name, loader_initfunc) @@ -124,12 +173,16 @@ def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc): solver.fit(loader) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_hybrid_tests) -def test_single_session_hybrid(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) - model = cebra.models.init("offset10-model", loader.dataset.input_dimension, - 32, 3) +@pytest.mark.parametrize( + "data_name, loader_initfunc, model_architecture, solver_initfunc", + single_session_hybrid_tests) +def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, + solver_initfunc): + data = cebra.datasets.init(data_name) + loader = _get_loader(data, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, @@ -142,16 +195,50 @@ def test_single_session_hybrid(data_name, loader_initfunc, solver_initfunc): log = solver.step(batch) assert isinstance(log, dict) + X = loader.dataset.neural + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X) + solver.fit(loader) + assert solver.num_sessions == None + assert solver.n_features == X.shape[1] -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - multi_session_tests) -def test_multi_session(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) + embedding = solver.transform(X) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(torch.Tensor(X)) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, session_id=0) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + solver.transform(X.numpy()) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X, session_id=2) + + for param in solver.parameters(): + assert isinstance(param, torch.Tensor) + + +@pytest.mark.parametrize( + "data_name, loader_initfunc, model_architecture, solver_initfunc", + multi_session_tests) +def test_multi_session(data_name, loader_initfunc, model_architecture, + solver_initfunc): + data = cebra.datasets.init(data_name) + loader = _get_loader(data, loader_initfunc) + model = nn.ModuleList([ + _make_model(dataset, model_architecture) + for dataset in data.iter_sessions() + ]) + data.configure_for(model) criterion = cebra.models.InfoNCE() - model = nn.ModuleList( - [_make_model(dataset) for dataset in loader.dataset.iter_sessions()]) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, @@ -160,22 +247,178 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): batch = next(iter(loader)) for session_id, dataset in enumerate(loader.dataset.iter_sessions()): - assert batch[session_id].reference.shape == (32, - dataset.input_dimension, - 10) + assert batch[session_id].reference.shape[:2] == ( + 32, dataset.input_dimension) assert batch[session_id].index is not None log = solver.step(batch) assert isinstance(log, dict) + X = [ + loader.dataset.get_session(i).neural + for i in range(loader.dataset.num_sessions) + ] + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X[0], session_id=0) + solver.fit(loader) + assert solver.num_sessions == 3 + assert solver.n_features == [X[i].shape[1] for i in range(len(X))] + + embedding = solver.transform(X[0], session_id=0) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X[1], session_id=1) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X[0], session_id=0, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[0].shape[0] - + len(solver.model[0].get_offset()) + 1, + OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + embedding = solver.transform(X[0].numpy(), session_id=0) + + with pytest.raises(ValueError, match="shape"): + embedding = solver.transform(X[1], session_id=0) + with pytest.raises(ValueError, match="shape"): + embedding = solver.transform(X[0], session_id=1) + + with pytest.raises(RuntimeError, match="No.*session_id"): + embedding = solver.transform(X[0]) + with pytest.raises(RuntimeError, match="single.*session"): + embedding = solver.transform(X) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X[0], session_id=5) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X[0], session_id=-1) + + for param in solver.parameters(session_id=0): + assert isinstance(param, torch.Tensor) + + with pytest.raises(RuntimeError, match="No.*session_id"): + for param in solver.parameters(): + assert isinstance(param, torch.Tensor) + + +@pytest.mark.parametrize( + "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", + [ + # Test case 1: No padding + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch + (torch.tensor( + [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch + + # Test case 2: First batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 0, + 2, + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Test case 3: Last batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(1, 2), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + ), + + # Test case 4: Middle batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(1, 1), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(0, 1), + 2, + 4, + torch.tensor([[7, 8, 9], [10, 11, 12]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Examples that throw an error: + + # Padding without offset (should raise an error) + (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), + # Negative start_batch_idx or end_batch_idx (should raise an error) + (torch.tensor([[1, 2]]), False, cebra.data.Offset( + 0, 1), -1, 2, ValueError), + # out of bound indices because offset is too large + (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( + 5, 5), 1, 2, ValueError), + # Batch length is smaller than offset. + (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( + 0, 1), 0, 1, ValueError), # first batch + ], +) +def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, + expected_output): + if expected_output == ValueError: + with pytest.raises(ValueError): + cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + else: + result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + assert torch.equal(result, expected_output) + def create_model(model_name, input_dimension): return cebra.models.init(model_name, num_neurons=input_dimension, num_units=128, - num_output=5) + num_output=OUTPUT_DIMENSION) single_session_tests_select_model = [] @@ -183,9 +426,11 @@ def create_model(model_name, input_dimension): for model_name in ["offset1-model", "offset10-model"]: for session_id in [None, 0, 5]: for args in [ - ("demo-discrete", model_name, session_id), - ("demo-continuous", model_name, session_id), - ("demo-mixed", model_name, session_id), + ("demo-discrete", model_name, session_id, + cebra.data.DiscreteDataLoader), + ("demo-continuous", model_name, session_id, + cebra.data.ContinuousDataLoader), + ("demo-mixed", model_name, session_id, cebra.data.MixedDataLoader), ]: single_session_tests_select_model.append( (*args, cebra.solver.SingleSessionSolver)) @@ -195,169 +440,79 @@ def create_model(model_name, input_dimension): multi_session_tests_select_model = [] for model_name in ["offset10-model"]: for session_id in [None, 0, 1, 5, 2, 6, 4]: - for args in [("demo-continuous-multisession", model_name, session_id)]: + for args in [("demo-continuous-multisession", model_name, session_id, + cebra.data.ContinuousMultiSessionDataLoader)]: multi_session_tests_select_model.append( (*args, cebra.solver.MultiSessionSolver)) -# @pytest.mark.parametrize( -# "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", -# [ -# # Test case 1: No padding -# (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 1, -# torch.tensor([[1, 2]])), # first batch -# (torch.tensor([[1, 2], [3, 4]]), False, None, 0, 2, -# torch.tensor([[1, 2], [3, 4]])), # first batch -# (torch.tensor([[1, 2], [3, 4]]), False, None, 1, 2, -# torch.tensor([[3, 4]])), # last batch - -# # Test case 2: First batch with padding -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(1, 1), -# 0, -# 2, -# torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6]]), -# ), -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(1, 1), -# 0, -# 3, -# torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# ), - -# # Test case 3: Last batch with padding -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(0, 1), -# 1, -# 3, -# torch.tensor([[4, 5, 6], [7, 8, 9]]), -# ), -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(1, 3), -# 1, -# 3, -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [7, 8, 9], [7, 8, 9] -# ]), -# ), - -# # Test case 4: Middle batch with padding -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(0, 1), -# 1, -# 2, -# torch.tensor([[4, 5, 6]]), -# ), -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(0, 2), -# 1, -# 2, -# torch.tensor([[4, 5, 6], [7, 8, 9]]), -# ), -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(1, 1), -# 1, -# 2, -# torch.tensor([[1, 2, 3], [4, 5, 6]]), -# ), -# ( -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# True, -# cebra.data.Offset(1, 2), -# 1, -# 2, -# torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), -# ), - -# # Examples that throw an error: - -# # Padding without offset (should raise an error) -# (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), -# # Negative start_batch_idx or end_batch_idx (should raise an error) -# (torch.tensor([[1, 2]]), False, None, -1, 2, ValueError), -# # out of bound indices because offset is too large -# (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( -# 5, 5), 1, 2, ValueError), -# ], -# ) -# def test__get_batch(inputs, add_padding, offset, start_batch_idx, -# end_batch_idx, expected_output): -# if expected_output == ValueError: -# with pytest.raises(ValueError): -# cebra.solver.base._get_batch(inputs, add_padding, offset, -# start_batch_idx, end_batch_idx) -# else: -# result = cebra.solver.base._get_batch(inputs, add_padding, offset, -# start_batch_idx, -# end_batch_idx) -# assert torch.equal(result, expected_output) - -# @pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", -# single_session_tests_select_model + -# single_session_hybrid_tests_select_model) -# def test_select_model_single_session(data_name, model_name, session_id, -# solver_initfunc): -# dataset = cebra.datasets.init(data_name) -# model = create_model(model_name, dataset.input_dimension) -# offset = model.get_offset() -# solver = solver_initfunc(model=model, criterion=None, optimizer=None) - -# if session_id is not None and session_id > 0: -# with pytest.raises(RuntimeError): -# solver._select_model(dataset.neural, session_id=session_id) -# else: -# model_, offset_ = solver._select_model(dataset.neural, -# session_id=session_id) -# assert offset.left == offset_.left and offset.right == offset_.right -# assert model == model_ - -# @pytest.mark.parametrize("data_name, model_name,session_id,solver_initfunc", -# multi_session_tests_select_model) -# def test_select_model_multi_session(data_name, model_name, session_id, -# solver_initfunc): -# dataset = cebra.datasets.init(data_name) -# model = nn.ModuleList([ -# create_model(model_name, dataset.input_dimension) -# for dataset in dataset.iter_sessions() -# ]) - -# offset = model[0].get_offset() -# solver = solver_initfunc(model=model, -# criterion=cebra.models.InfoNCE(), -# optimizer=torch.optim.Adam(model.parameters(), -# lr=1e-3)) - -# loader_kwargs = dict(num_steps=10, batch_size=32) -# loader = cebra.data.ContinuousMultiSessionDataLoader( -# dataset, **loader_kwargs) -# solver.fit(loader) - -# for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): -# inputs = dataset_.neural - -# if session_id is None or session_id >= dataset.num_sessions: -# with pytest.raises(RuntimeError): -# solver._select_model(inputs, session_id=session_id) -# elif i != session_id: -# with pytest.raises(ValueError): -# solver._select_model(inputs, session_id=session_id) -# else: -# model_, offset_ = solver._select_model(inputs, -# session_id=session_id) -# assert offset.left == offset_.left and offset.right == offset_.right -# assert model == model_ + +@pytest.mark.parametrize( + "data_name, model_name ,session_id, loader_initfunc, solver_initfunc", + single_session_tests_select_model + + single_session_hybrid_tests_select_model) +def test_select_model_single_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + dataset.configure_for(model) + loader = _get_loader(dataset, loader_initfunc=loader_initfunc) + offset = model.get_offset() + solver = solver_initfunc(model=model, criterion=None, optimizer=None) + + with pytest.raises(ValueError): + solver.n_features = 1000 + solver._select_model(inputs=dataset.neural, session_id=0) + + solver.n_features = dataset.neural.shape[1] + if session_id is not None and session_id > 0: + with pytest.raises(RuntimeError): + solver._select_model(inputs=dataset.neural, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs=dataset.neural, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +@pytest.mark.parametrize( + "data_name, model_name, session_id, loader_initfunc, solver_initfunc", + multi_session_tests_select_model) +def test_select_model_multi_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.configure_for(model) + loader = _get_loader(dataset, loader_initfunc=loader_initfunc) + + offset = model[0].get_offset() + solver = solver_initfunc(model=model, + criterion=cebra.models.InfoNCE(), + optimizer=torch.optim.Adam(model.parameters(), + lr=1e-3)) + + loader_kwargs = dict(num_steps=10, batch_size=32) + loader = cebra.data.ContinuousMultiSessionDataLoader( + dataset, **loader_kwargs) + solver.fit(loader) + + for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): + inputs = dataset_.neural + + if session_id is None or session_id >= dataset.num_sessions: + with pytest.raises(RuntimeError): + solver._select_model(inputs, session_id=session_id) + elif i != session_id: + with pytest.raises(ValueError): + solver._select_model(inputs, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + #this is a very crucial test. should be checked for different choices of offsets, # dataset sizes (also edge cases like dataset size 1001 and batch size 1000 -> is the padding properly handled?) @@ -367,9 +522,10 @@ def create_model(model_name, input_dimension): "offset1-model", "offset10-model", "offset40-model-4x-subsample", - #"offset1-model", "offset10-model", + "offset1-model", + "offset10-model", ] # there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model. -batch_size_inference = [40_000, 99_990, 99_999] # 99_999 +batch_size_inference = [40_000, 99_990, 99_999] single_session_tests_transform = [] for padding in [True, False]: @@ -397,9 +553,9 @@ def create_model(model_name, input_dimension): @pytest.mark.parametrize( - "data_name,model_name,padding,batch_size_inference,loader_initfunc,solver_initfunc", + "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc", single_session_tests_transform + single_session_hybrid_tests_transform) -def test_batched_transform_singlesession( +def test_batched_transform_single_session( data_name, model_name, padding, @@ -458,9 +614,9 @@ def test_batched_transform_singlesession( @pytest.mark.parametrize( "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", multi_session_tests_transform) -def test_batched_transform_multisession(data_name, model_name, padding, - batch_size_inference, loader_initfunc, - solver_initfunc): +def test_batched_transform_multi_session(data_name, model_name, padding, + batch_size_inference, loader_initfunc, + solver_initfunc): dataset = cebra.datasets.init(data_name) model = nn.ModuleList([ create_model(model_name, dataset.input_dimension) From 9c46eb97d830402917bbb3b8a8365fb6a9d26c30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:44:35 +0200 Subject: [PATCH 32/45] Remove unused import in solver/utils --- cebra/solver/util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cebra/solver/util.py b/cebra/solver/util.py index af9529f7..584eb0da 100644 --- a/cebra/solver/util.py +++ b/cebra/solver/util.py @@ -25,8 +25,6 @@ from typing import Dict import literate_dataclasses as dataclasses -import numpy as np -import torch import tqdm From c845ec3ef611f7e2330079a6a2a3fd4e16155712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:52:53 +0200 Subject: [PATCH 33/45] Fix test plot --- cebra/integrations/sklearn/cebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 4240074f..39a64073 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1196,8 +1196,8 @@ def transform(self, >>> embedding = cebra_model.transform(dataset) """ - self.solver_._check_is_session_id_valid(session_id=session_id) sklearn_utils_validation.check_is_fitted(self, "n_features_") + self.solver_._check_is_session_id_valid(session_id=session_id) if torch.is_tensor(X) and X.device.type == "cuda": X = X.detach().cpu() From 9db3e3701ec89b93020918473f55b8f193216998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 12:00:20 +0200 Subject: [PATCH 34/45] Add some coverage --- cebra/solver/base.py | 13 ++++++++++++- cebra/solver/multi_session.py | 19 +++++++++++++++++++ cebra/solver/single_session.py | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index ec33f23e..6fb786b4 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -360,6 +360,12 @@ def _get_loader(self, loader): @abc.abstractmethod def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + Args: + loader: Loader used to fit the solver. + """ + raise NotImplementedError def fit( @@ -507,6 +513,11 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): @abc.abstractmethod def _check_is_session_id_valid(self, session_id: Optional[int] = None): + """Check that the session ID provided is valid for the solver instance. + + Args: + session_id: The session ID to check. + """ raise NotImplementedError @abc.abstractmethod @@ -530,7 +541,7 @@ def _select_model( @torch.no_grad() def transform(self, - inputs: torch.Tensor, + inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], pad_before_transform: bool = True, session_id: Optional[int] = None, batch_size: Optional[int] = None) -> torch.Tensor: diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 666dafb8..f10f36a6 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -126,6 +126,17 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: ) def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + In multi session solver, the number of session is set to the number of + sessions in the dataset of the loader and the number of + features is set as a list corresponding to the number of neurons in + each dataset. + + Args: + loader: Loader used to fit the solver. + """ + self.num_sessions = loader.dataset.num_sessions self.n_features = [ loader.dataset.get_input_dimension(session_id) @@ -152,6 +163,14 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, ) def _check_is_session_id_valid(self, session_id: Optional[int]): + """Check that the session ID provided is valid for the solver instance. + + The session ID must be non-null and between 0 and the number session in the dataset. + + Args: + session_id: The session ID to check. + """ + if session_id is None: raise RuntimeError( "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index b941a8ba..eb75db0e 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -55,6 +55,14 @@ def parameters(self, session_id: Optional[int] = None): yield parameter def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + In single session solver, the number of session is set to None and the number of + features is set to the number of neurons in the dataset. + + Args: + loader: Loader used to fit the solver. + """ self.num_sessions = None self.n_features = loader.dataset.input_dimension @@ -77,6 +85,14 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): ) def _check_is_session_id_valid(self, session_id: Optional[int] = None): + """Check that the session ID provided is valid for the solver instance. + + The session ID must be null or equal to 0. + + Args: + session_id: The session ID to check. + """ + if session_id is not None and session_id > 0: raise RuntimeError( f"Invalid session_id {session_id}: single session models only takes an optional null session_id." From 8e5f9332768ed328b23623eba4cd20225f5bd83c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 13:27:34 +0200 Subject: [PATCH 35/45] Fix save/load --- cebra/integrations/sklearn/cebra.py | 5 +++ cebra/solver/base.py | 11 +++-- tests/test_solver.py | 62 +++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 39a64073..c3fd9c9e 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1417,6 +1417,11 @@ def load(cls, else: cebra_ = _check_type_checkpoint(checkpoint) + n_features = cebra_.n_features_ + cebra_.solver_.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) + return cebra_ def to(self, device: Union[str, torch.device]): diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 6fb786b4..d60c4515 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -633,13 +633,12 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) - if hasattr(self.model, "n_features"): - n_features = self.model.n_features - self.n_features = ([ - session_n_features for session_n_features in n_features - ] if isinstance(n_features, list) else n_features) + n_features = self.n_features + self.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) - def save(self, logdir, filename="checkpoint_last.pth"): + def save(self, logdir, filename="checkpoint.pth"): """Save the model and optimizer params. Args: diff --git a/tests/test_solver.py b/tests/test_solver.py index 4bb17232..8ebef4a0 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy import itertools +import tempfile import numpy as np import pytest @@ -91,6 +93,48 @@ def _make_model(dataset, model_architecture="offset10-model"): # ) +def _assert_same_state_dict(first, second): + assert first.keys() == second.keys() + for key in first: + if isinstance(first[key], torch.Tensor): + assert torch.allclose(first[key], second[key]), key + elif isinstance(first[key], dict): + _assert_same_state_dict(first[key], second[key]), key + else: + assert first[key] == second[key] + + +def check_if_fit(model): + """Check if a model was already fit. + + Args: + model: The model to check. + + Returns: + True if the model was already fit. + """ + return hasattr(model, "n_features_") + + +def _assert_equal(original_solver, loaded_solver): + for k in original_solver.model.state_dict(): + assert original_solver.model.state_dict()[k].all( + ) == loaded_solver.model.state_dict()[k].all() + assert check_if_fit(loaded_solver) == check_if_fit(original_solver) + + if check_if_fit(loaded_solver): + _assert_same_state_dict(original_solver.state_dict_, + loaded_solver.state_dict_) + X = np.random.normal(0, 1, (100, 1)) + + if loaded_solver.num_sessions is not None: + assert np.allclose(loaded_solver.transform(X, session_id=0), + original_solver.transform(X, session_id=0)) + else: + assert np.allclose(loaded_solver.transform(X), + original_solver.transform(X)) + + @pytest.mark.parametrize( "data_name, loader_initfunc, model_architecture, solver_initfunc", single_session_tests) @@ -144,6 +188,12 @@ def test_single_session(data_name, loader_initfunc, model_architecture, for param in solver.parameters(): assert isinstance(param, torch.Tensor) + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + @pytest.mark.parametrize( "data_name, loader_initfunc, model_architecture, solver_initfunc", @@ -225,6 +275,12 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, for param in solver.parameters(): assert isinstance(param, torch.Tensor) + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + @pytest.mark.parametrize( "data_name, loader_initfunc, model_architecture, solver_initfunc", @@ -302,6 +358,12 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, for param in solver.parameters(): assert isinstance(param, torch.Tensor) + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + @pytest.mark.parametrize( "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", From d08e400f2846b546dc43ef2ec68ea76bbce0d8dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:28:36 +0200 Subject: [PATCH 36/45] Remove duplicate configure_for in multi dataset --- cebra/data/multi_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index a8d56d10..1758deb3 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -106,11 +106,6 @@ def load_batch(self, index: BatchIndex) -> List[Batch]: ) for session_id, session in enumerate(self.iter_sessions()) ] - def configure_for(self, model): - self.offset = model.get_offset() - for session in self.iter_sessions(): - session.configure_for(model) - def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. From 0c693dd1b005a437faf5388eab061a256b82ae81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:24:44 +0200 Subject: [PATCH 37/45] Make save/load cleaner --- cebra/solver/base.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index d60c4515..f9ae3d82 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -296,7 +296,7 @@ def state_dict(self) -> dict: the model was trained with. """ - return { + state_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "loss": torch.tensor(self.history), @@ -306,6 +306,13 @@ def state_dict(self) -> dict: "log": self.log, } + if hasattr(self, "n_features"): + state_dict["n_features"] = self.n_features + if hasattr(self, "num_sessions"): + state_dict["num_sessions"] = self.num_sessions + + return state_dict + def load_state_dict(self, state_dict: dict, strict: bool = True): """Update the solver state with the given state_dict. @@ -343,6 +350,12 @@ def _get(key): if _contains("log"): self.log = _get("log") + # Not defined if the model was saved before being fitted. + if "n_features" in state_dict: + self.n_features = _get("n_features") + if "num_sessions" in state_dict: + self.num_sessions = _get("num_sessions") + @property def num_parameters(self) -> int: """Total number of parameters in the encoder and criterion.""" @@ -633,11 +646,6 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) - n_features = self.n_features - self.n_features = ([ - session_n_features for session_n_features in n_features - ] if isinstance(n_features, list) else n_features) - def save(self, logdir, filename="checkpoint.pth"): """Save the model and optimizer params. From 794867bf58fc078de09623f33d944dce815aa704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:58:33 +0200 Subject: [PATCH 38/45] Fix codespell errors --- cebra/solver/base.py | 4 ++-- cebra/solver/multi_session.py | 2 +- cebra/solver/single_session.py | 2 +- tests/test_solver.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index f9ae3d82..1d8bb9ce 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -86,7 +86,7 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, raise ValueError( f"The batch has length {batch_size_lenght} which " f"is smaller or equal than the required offset length {len(offset)}." - f"Either choose a model with smaller offset or the batch shoud contain more samples." + f"Either choose a model with smaller offset or the batch should contain more samples." ) @@ -511,7 +511,7 @@ def decoding(self, train_loader, valid_loader): @abc.abstractmethod def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): - """Check that the inputs can be infered using the selected model. + """Check that the inputs can be inferred using the selected model. Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 350266af..87d906d4 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -144,7 +144,7 @@ def _set_fitted_params(self, loader: cebra.data.Loader): def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: Optional[int]): - """Check that the inputs can be infered using the selected model. + """Check that the inputs can be inferred using the selected model. Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index eb75db0e..e0927a21 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -67,7 +67,7 @@ def _set_fitted_params(self, loader: cebra.data.Loader): self.n_features = loader.dataset.input_dimension def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): - """Check that the inputs can be infered using the selected model. + """Check that the inputs can be inferred using the selected model. Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. diff --git a/tests/test_solver.py b/tests/test_solver.py index ffe01d4a..63caed67 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -683,7 +683,7 @@ def test_batched_transform_multi_session(data_name, model_name, padding, n_samples = dataset._datasets[0].neural.shape[0] assert all( d.neural.shape[0] == n_samples for d in dataset._datasets - ), "for this set all of the sessions need ot have same number of samples." + ), # all sessions need to have same number of samples smallest_batch_length = n_samples - batch_size offset_ = model[0].get_offset() From 0bb654940b81a30107a8b93acf6400c14c7bd125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:10:25 +0200 Subject: [PATCH 39/45] Fix docs compilation errors --- cebra/data/multi_session.py | 6 +++--- docs/source/conf.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index f9c4ca47..0af2793c 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -30,7 +30,7 @@ import torch import cebra.data as cebra_data -import cebra.distributions as cebra_distr +import cebra.distributions from cebra.data.datatypes import Batch from cebra.data.datatypes import BatchIndex @@ -130,7 +130,7 @@ class MultiSessionLoader(cebra_data.Loader): def __post_init__(self): super().__post_init__() - self.sampler = cebra_distr.MultisessionSampler(self.dataset, + self.sampler = cebra.distributions.MultisessionSampler(self.dataset, self.time_offset) def get_indices(self, num_samples: int) -> List[BatchIndex]: @@ -169,7 +169,7 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader): # Overwrite sampler with the discrete implementation # Generalize MultisessionSampler to avoid doing this? def __post_init__(self): - self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset) + self.sampler = cebra.distributions.DiscreteMultisessionSampler(self.dataset) @property def index(self): diff --git a/docs/source/conf.py b/docs/source/conf.py index be839ddf..025a988b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -104,7 +104,7 @@ def get_years(start_year=2021): intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "torch": ("https://pytorch.org/docs/master/", None), + "torch": ("https://pytorch.org/docs/stable/", None), "sklearn": ("https://scikit-learn.org/stable", None), "numpy": ("https://numpy.org/doc/stable/", None), "matplotlib": ("https://matplotlib.org/stable/", None), From 04a102ffb733ba0a962fe0d4cb8ba89721fc4d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:11:30 +0200 Subject: [PATCH 40/45] Fix formatting --- cebra/data/multi_session.py | 7 ++++--- tests/test_datasets.py | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 0af2793c..be2e556b 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -130,8 +130,8 @@ class MultiSessionLoader(cebra_data.Loader): def __post_init__(self): super().__post_init__() - self.sampler = cebra.distributions.MultisessionSampler(self.dataset, - self.time_offset) + self.sampler = cebra.distributions.MultisessionSampler( + self.dataset, self.time_offset) def get_indices(self, num_samples: int) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) @@ -169,7 +169,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader): # Overwrite sampler with the discrete implementation # Generalize MultisessionSampler to avoid doing this? def __post_init__(self): - self.sampler = cebra.distributions.DiscreteMultisessionSampler(self.dataset) + self.sampler = cebra.distributions.DiscreteMultisessionSampler( + self.dataset) @property def index(self): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index adbfab64..98885d07 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -153,9 +153,8 @@ def test_allen(): @pytest.mark.requires_dataset -@pytest.mark.parametrize("options", - cebra.datasets.get_options("*", - expand_parametrized=False)) +@pytest.mark.parametrize( + "options", cebra.datasets.get_options("*", expand_parametrized=False)) def test_options(options): assert len(options) > 0 assert len(multisubject_options) > 0 From 7aab28251b38f5b5069b7839ce4790fce0211bbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:22:54 +0200 Subject: [PATCH 41/45] Fix extra docs errors --- cebra/data/multi_session.py | 2 +- cebra/data/single_session.py | 2 +- cebra/solver/base.py | 4 ++-- tests/test_solver.py | 4 +++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index be2e556b..9d10fbfc 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -110,7 +110,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`offset` attribute of the dataset. + :py:attr:`cebra_data.Dataset.offset` attribute of the dataset. Args: model: The model to configure the dataset for. diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 71cd0c3e..169ebcb6 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -76,7 +76,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`offset` attribute of the dataset. + :py:attr:`cebra_data.Dataset.offset` attribute of the dataset. Args: model: The model to configure the dataset for. diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 1d8bb9ce..0b5549cf 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -185,7 +185,7 @@ def _transform( model: cebra.models.Model, inputs: torch.Tensor, pad_before_transform: bool, - offset: cebra.data.Offset, + offset: cebra.data.datatypes.Offset, ) -> torch.Tensor: """Compute the embedding. @@ -206,7 +206,7 @@ def _transform( def _batched_transform(model: cebra.models.Model, inputs: torch.Tensor, batch_size: int, pad_before_transform: bool, - offset: cebra.data.Offset) -> torch.Tensor: + offset: cebra.data.datatypes.Offset) -> torch.Tensor: """Compute the embedding on batched inputs. Args: diff --git a/tests/test_solver.py b/tests/test_solver.py index 63caed67..d93c90e9 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -65,6 +65,7 @@ # multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) + def _get_loader(data, loader_initfunc): kwargs = dict(num_steps=5, batch_size=32) loader = loader_initfunc(data, **kwargs) @@ -574,6 +575,7 @@ def test_select_model_multi_session(data_name, model_name, session_id, assert offset.left == offset_.left and offset.right == offset_.right assert model == model_ + models = [ "offset1-model", "offset10-model", @@ -683,7 +685,7 @@ def test_batched_transform_multi_session(data_name, model_name, padding, n_samples = dataset._datasets[0].neural.shape[0] assert all( d.neural.shape[0] == n_samples for d in dataset._datasets - ), # all sessions need to have same number of samples + ), "for this set all of the sessions need to have same number of samples." smallest_batch_length = n_samples - batch_size offset_ = model[0].get_offset() From ffa66eb79891aac77134ff787cacff0bddf26a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:18:58 +0200 Subject: [PATCH 42/45] Fix offset in docs --- cebra/data/multi_session.py | 2 +- cebra/data/single_session.py | 2 +- cebra/solver/base.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index 9d10fbfc..f9686769 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -110,7 +110,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`cebra_data.Dataset.offset` attribute of the dataset. + :py:attr:`cebra.data.Dataset.offset` attribute of the dataset. Args: model: The model to configure the dataset for. diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 169ebcb6..9270c98b 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -76,7 +76,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`cebra_data.Dataset.offset` attribute of the dataset. + :py:attr:`cebra.data.Dataset.offset` attribute of the dataset. Args: model: The model to configure the dataset for. diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 0b5549cf..af617838 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -91,14 +91,15 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, def _add_batched_zero_padding(batched_data: torch.Tensor, - offset: cebra.data.Offset, batch_start_idx: int, + offset: cebra.data.Offset, + batch_start_idx: int, batch_end_idx: int, num_samples: int) -> torch.Tensor: """Add zero padding to the input data before inference. Args: batched_data: Data to apply the inference on. - offset (cebra.data.Offset): _description_ + offset: Offset of the model to consider when padding. batch_start_idx: Index of the first sample in the batch. batch_end_idx: Index of the first sample in the batch. num_samples (int): Total number of samples in the data. From 7f58607d969ffe5085b63abd69d5259744cc79db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:50:20 +0200 Subject: [PATCH 43/45] Remove attribute ref --- cebra/data/multi_session.py | 2 +- cebra/data/single_session.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index f9686769..cff61038 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -110,7 +110,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`cebra.data.Dataset.offset` attribute of the dataset. + `offset` attribute of the dataset. Args: model: The model to configure the dataset for. diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 9270c98b..a821db97 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -76,7 +76,7 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - :py:attr:`cebra.data.Dataset.offset` attribute of the dataset. + `offset` attribute of the dataset. Args: model: The model to configure the dataset for. From c2544c759478ee962e0a37992a35155df08d2b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:55:19 +0200 Subject: [PATCH 44/45] Add review updates --- cebra/data/base.py | 1 - cebra/integrations/sklearn/cebra.py | 60 +++++++- cebra/solver/base.py | 35 +++-- cebra/solver/multi_session.py | 6 +- tests/test_sklearn.py | 220 +++++++++++++++++++++++++++- tests/test_solver.py | 6 +- 6 files changed, 300 insertions(+), 28 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index 874ed58b..54ae4579 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -207,7 +207,6 @@ def configure_for(self, model: "cebra.models.Model"): model: The model to configure the dataset for. """ raise NotImplementedError - self.offset = model.get_offset() @dataclasses.dataclass diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index ce50b7ea..bdae8ca7 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1202,7 +1202,7 @@ def transform(self, sklearn_utils_validation.check_is_fitted(self, "n_features_") self.solver_._check_is_session_id_valid(session_id=session_id) - if torch.is_tensor(X) and X.device.type == "cuda": + if torch.is_tensor(X): X = X.detach().cpu() X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) @@ -1210,6 +1210,10 @@ def transform(self, if isinstance(X, np.ndarray): X = torch.from_numpy(X) + if batch_size is not None and batch_size < 1: + raise ValueError( + f"Batch size should be at least 1, got {batch_size}") + with torch.no_grad(): output = self.solver_.transform( inputs=X, @@ -1219,6 +1223,60 @@ def transform(self, return output.detach().cpu().numpy() + # Deprecated, kept for testing. + def transform_deprecated(self, + X: Union[npt.NDArray, torch.Tensor], + session_id: Optional[int] = None) -> npt.NDArray: + """Transform an input sequence and return the embedding. + + Args: + X: A numpy array or torch tensor of size ``time x dimension``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for + multisession, set to ``None`` for single session. + + Returns: + A :py:func:`numpy.array` of size ``time x output_dimension``. + + Example: + + >>> import cebra + >>> import numpy as np + >>> dataset = np.random.uniform(0, 1, (1000, 30)) + >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model.fit(dataset) + CEBRA(max_iterations=10) + >>> embedding = cebra_model.transform(dataset) + + """ + + sklearn_utils_validation.check_is_fitted(self, "n_features_") + model, offset = self._select_model(X, session_id) + + # Input validation + X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) + input_dtype = X.dtype + + with torch.no_grad(): + model.eval() + + if self.pad_before_transform: + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), + mode="edge") + X = torch.from_numpy(X).float().to(self.device_) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) + else: + # Standard evaluation, (T, C, dt) + output = model(X).cpu().numpy() + + if input_dtype == "float64": + return output.astype(input_dtype) + + return output + def fit_transform( self, X: Union[npt.NDArray, torch.Tensor], diff --git a/cebra/solver/base.py b/cebra/solver/base.py index af617838..7f0cbef1 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -81,18 +81,17 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, f"batch_end_idx ({batch_end_idx}) cannot exceed the length of inputs ({num_samples})." ) - batch_size_lenght = batch_end_idx - batch_start_idx - if batch_size_lenght <= len(offset): + batch_size_length = batch_end_idx - batch_start_idx + if batch_size_length <= len(offset): raise ValueError( - f"The batch has length {batch_size_lenght} which " + f"The batch has length {batch_size_length} which " f"is smaller or equal than the required offset length {len(offset)}." f"Either choose a model with smaller offset or the batch should contain more samples." ) def _add_batched_zero_padding(batched_data: torch.Tensor, - offset: cebra.data.Offset, - batch_start_idx: int, + offset: cebra.data.Offset, batch_start_idx: int, batch_end_idx: int, num_samples: int) -> torch.Tensor: """Add zero padding to the input data before inference. @@ -409,6 +408,7 @@ def fit( TODO: * Refine the API here. Drop the validation entirely, and implement this via a hook? """ + self._set_fitted_params(loader) self.to(loader.device) iterator = self._get_loader(loader) @@ -436,8 +436,6 @@ def fit( save_hook(num_steps, self) self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") - self._set_fitted_params(loader) - def step(self, batch: cebra.data.Batch) -> dict: """Perform a single gradient update. @@ -553,6 +551,10 @@ def _select_model( """ raise NotImplementedError + @property + def is_fitted(self): + return hasattr(self, "n_features") + @torch.no_grad() def transform(self, inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], @@ -579,19 +581,24 @@ def transform(self, Returns: The output embedding. """ + if not self.is_fitted: + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") + + if batch_size is not None and batch_size < 1: + raise ValueError( + f"Batch size should be at least 1, got {batch_size}") + if isinstance(inputs, list): - raise NotImplementedError( - "Inputs to transform() should be the data for a single session." + raise ValueError( + "Inputs to transform() should be the data for a single session, but received a list." ) elif not isinstance(inputs, torch.Tensor): raise ValueError( f"Inputs should be a torch.Tensor, not {type(inputs)}.") - if not hasattr(self, "n_features"): - raise ValueError( - f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this estimator.") model, offset = self._select_model(inputs, session_id) if len(offset) < 2 and pad_before_transform: @@ -647,7 +654,7 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) - def save(self, logdir, filename="checkpoint.pth"): + def save(self, logdir, filename="checkpoint_last.pth"): """Save the model and optimizer params. Args: diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 87d906d4..b4be2125 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -44,9 +44,9 @@ class MultiSessionSolver(abc_.Solver): def parameters(self, session_id: Optional[int] = None): """Iterate over all parameters.""" - self._check_is_session_id_valid(session_id=session_id) - for parameter in self.model[session_id].parameters(): - yield parameter + if session_id is not None: + for parameter in self.model[session_id].parameters(): + yield parameter for parameter in self.criterion.parameters(): yield parameter diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index e409c0e3..0644aef7 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -231,7 +231,7 @@ def iterate_models(): ) in itertools.product( [ "offset10-model", "offset10-model-mse", "offset1-model", - "resample-model" + "offset40-model-4x-subsample" ], _DEVICES, ["euclidean", "cosine"], @@ -343,6 +343,20 @@ def test_sklearn(model_architecture, device): assert cebra_model.num_sessions is None embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) + + if model_architecture in [ + "offset36-model-cpu", "offset36-model-dropout-cpu", + "offset36-model-more-dropout-cpu", + "offset40-model-4x-subsample-cpu", + "offset20-model-4x-subsample-cpu", "offset36-model-cuda", + "offset36-model-dropout-cuda", "offset36-model-more-dropout-cuda", + "offset40-model-4x-subsample-cuda", + "offset20-model-4x-subsample-cuda" + ]: + with pytest.raises(ValueError, match="required.*offset.*length"): + embedding = cebra_model.transform(X, batch_size=10) # continuous behavior contrastive cebra_model.fit(X, y_c1, y_c2) @@ -354,9 +368,17 @@ def test_sklearn(model_architecture, device): assert isinstance(embedding, np.ndarray) embedding = cebra_model.transform(X, session_id=0) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, session_id=0, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) + with pytest.raises(ValueError, match="Batch.*size"): + embedding = cebra_model.transform(X, batch_size=0) + with pytest.raises(ValueError, match="Batch.*size"): + embedding = cebra_model.transform(X, batch_size=-10) with pytest.raises(ValueError, match="Invalid.*labels"): cebra_model.fit(X, [y_c1, y_c1_s2]) with pytest.raises(ValueError, match="Invalid.*samples"): @@ -369,11 +391,15 @@ def test_sklearn(model_architecture, device): cebra_model.fit(X, y_d) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) # mixed cebra_model.fit(X, y_c1, y_c2, y_d) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) # multi-session discrete behavior contrastive cebra_model.fit([X, X_s2], [y_d, y_d_s2]) @@ -387,6 +413,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=1) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X_s2.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -411,6 +440,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=1) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X_s2.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -442,6 +474,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X, session_id=2) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X, session_id=2, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -467,6 +502,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X, session_id=2) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X, session_id=2, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -711,6 +749,8 @@ def check_first_layer_dim(model, X): check_first_layer_dim(cebra_model, X_s2) embedding = cebra_model.transform(X_s2) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X_s2, batch_size=50) + assert isinstance(embedding, np.ndarray) cebra_model.fit(X, y_c1, y_c2, adapt=True) check_first_layer_dim(cebra_model, X) @@ -718,6 +758,8 @@ def check_first_layer_dim(model, X): assert isinstance(embedding, np.ndarray) embedding = cebra_model.transform(X, session_id=0) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, session_id=0, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) @@ -730,11 +772,15 @@ def check_first_layer_dim(model, X): check_first_layer_dim(cebra_model, X_s2) embedding = cebra_model.transform(X_s2) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X_s2, batch_size=50) + assert isinstance(embedding, np.ndarray) cebra_model.fit(X, y_c1, y_c2, y_d, adapt=True) check_first_layer_dim(cebra_model, X) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(NotImplementedError, match=".*multisession.*"): cebra_model.fit([X, X_s2], [y_c1, y_c1_s2], adapt=True) @@ -848,8 +894,8 @@ def test_sklearn_full(model_architecture, device, pad_before_transform): @pytest.mark.parametrize("model_architecture,device", - [("resample-model", "cpu"), - ("resample5-model", "cpu")]) + [("offset40-model-4x-subsample", "cpu"), + ("offset20-model-4x-subsample", "cpu")]) def test_sklearn_resampling_model(model_architecture, device): cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture=model_architecture, @@ -869,10 +915,12 @@ def test_sklearn_resampling_model(model_architecture, device): cebra_model.fit(X, y_c1) output = cebra_model.transform(X) assert output.shape == (250, 4) + output = cebra_model.transform(X, batch_size=100) + assert output.shape == (250, 4) @pytest.mark.parametrize("model_architecture,device", - [("resample1-model", "cpu")]) + [("offset4-model-2x-subsample", "cpu")]) def test_sklearn_resampling_model_not_yet_supported(model_architecture, device): cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture=model_architecture, max_iterations=5) @@ -1294,3 +1342,167 @@ def test_check_device(): torch.backends.mps.is_built = lambda: False with pytest.raises(ValueError): cebra_sklearn_utils.check_device(device) + + +@_util.parametrize_slow( + arg_names="model_architecture,device", + fast_arguments=list( + itertools.islice( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES), + 2, + )), + slow_arguments=list( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES)), +) +def test_new_transform(model_architecture, device): + """ + This is a test that the original sklearn transform returns the same output as + the new sklearn transform that uses the pytorch solver transform. + """ + output_dimension = 4 + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture=model_architecture, + time_offsets=10, + learning_rate=3e-4, + max_iterations=5, + device=device, + output_dimension=output_dimension, + batch_size=42, + verbose=True, + ) + + # example dataset + X = np.random.uniform(0, 1, (1000, 50)) + X_s2 = np.random.uniform(0, 1, (800, 30)) + X_s3 = np.random.uniform(0, 1, (1000, 30)) + y_c1 = np.random.uniform(0, 1, (1000, 5)) + y_c1_s2 = np.random.uniform(0, 1, (800, 5)) + y_c2 = np.random.uniform(0, 1, (1000, 2)) + y_c2_s2 = np.random.uniform(0, 1, (800, 2)) + y_d = np.random.randint(0, 10, (1000,)) + y_d_s2 = np.random.randint(0, 10, (800,)) + + # time contrastive + cebra_model.fit(X) + embedding1 = cebra_model.transform(X) + embedding2 = cebra_model.transform_deprecated(X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # continuous behavior contrastive + cebra_model.fit(X, y_c1, y_c2) + assert cebra_model.num_sessions is None + + embedding1 = cebra_model.transform(X) + embedding2 = cebra_model.transform_deprecated(X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X)) + embedding2 = cebra_model.transform_deprecated(torch.Tensor(X)) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # tensor input + cebra_model.fit(torch.Tensor(X), torch.Tensor(y_c1), torch.Tensor(y_c2)) + + # discrete behavior contrastive + cebra_model.fit(X, y_d) + embedding1 = cebra_model.transform(X) + embedding2 = cebra_model.transform_deprecated(X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # mixed + cebra_model.fit(X, y_c1, y_c2, y_d) + embedding1 = cebra_model.transform(X) + embedding2 = cebra_model.transform_deprecated(X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session discrete behavior contrastive + cebra_model.fit([X, X_s2], [y_d, y_d_s2]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = cebra_model.transform_deprecated(X, session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session continuous behavior contrastive + cebra_model.fit([X, X_s2], [y_c1, y_c1_s2]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = cebra_model.transform_deprecated(X, session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = cebra_model.transform(X_s2, session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session tensor inputs + cebra_model.fit( + [torch.Tensor(X), torch.Tensor(X_s2)], + [torch.Tensor(y_c1), torch.Tensor(y_c1_s2)], + ) + + # multi-session discrete behavior contrastive, more than two sessions + cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = cebra_model.transform_deprecated(X, session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X, session_id=2) + embedding2 = cebra_model.transform_deprecated(X, session_id=2) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session continuous behavior contrastive, more than two sessions + cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = cebra_model.transform_deprecated(X, session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X, session_id=2) + embedding2 = cebra_model.transform_deprecated(X, session_id=2) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" diff --git a/tests/test_solver.py b/tests/test_solver.py index d93c90e9..c27a9e41 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -344,7 +344,7 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, with pytest.raises(RuntimeError, match="No.*session_id"): embedding = solver.transform(X[0]) - with pytest.raises(RuntimeError, match="single.*session"): + with pytest.raises(ValueError, match="single.*session"): embedding = solver.transform(X) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = solver.transform(X[0], session_id=5) @@ -354,10 +354,6 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, for param in solver.parameters(session_id=0): assert isinstance(param, torch.Tensor) - with pytest.raises(RuntimeError, match="No.*session_id"): - for param in solver.parameters(): - assert isinstance(param, torch.Tensor) - fitted_solver = copy.deepcopy(solver) with tempfile.TemporaryDirectory() as temp_dir: solver.save(temp_dir) From e1b7cc76bdeb87fcdcac2978cfc8fba8058d78cd Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 27 Oct 2024 19:08:10 +0100 Subject: [PATCH 45/45] apply ruff auto-fixes --- cebra/__init__.py | 6 ++-- cebra/__main__.py | 4 --- cebra/config.py | 1 - cebra/data/base.py | 3 -- cebra/data/datasets.py | 7 ---- cebra/data/datatypes.py | 3 -- cebra/data/helper.py | 10 +++--- cebra/data/multi_session.py | 2 -- cebra/data/single_session.py | 7 ++-- cebra/datasets/allen/ca_movie.py | 4 --- cebra/datasets/allen/ca_movie_decoding.py | 5 --- cebra/datasets/allen/combined.py | 20 ++--------- cebra/datasets/allen/make_neuropixel.py | 2 -- cebra/datasets/allen/neuropixel_movie.py | 14 +------- .../allen/neuropixel_movie_decoding.py | 8 ----- cebra/datasets/allen/single_session_ca.py | 8 ----- cebra/datasets/gaussian_mixture.py | 4 --- cebra/datasets/generate_synthetic_data.py | 1 - cebra/datasets/hippocampus.py | 2 -- cebra/datasets/make_neuropixel.py | 1 - cebra/datasets/monkey_reaching.py | 5 +-- cebra/distributions/base.py | 3 +- cebra/distributions/continuous.py | 5 ++- cebra/distributions/index.py | 7 ++-- cebra/distributions/mixed.py | 1 - cebra/integrations/deeplabcut.py | 2 +- cebra/integrations/sklearn/cebra.py | 31 ++++++++--------- cebra/integrations/sklearn/helpers.py | 2 +- cebra/integrations/sklearn/metrics.py | 12 +++---- cebra/models/criterions.py | 2 +- cebra/models/model.py | 2 -- cebra/models/projector.py | 2 +- cebra/solver/base.py | 34 ++++++++----------- cebra/solver/multi_session.py | 25 ++++++-------- cebra/solver/single_session.py | 26 +++++++------- cebra/solver/supervised.py | 8 ----- tests/test_datasets.py | 8 ++--- tests/test_sklearn.py | 4 +-- tests/test_solver.py | 5 ++- 39 files changed, 91 insertions(+), 205 deletions(-) diff --git a/cebra/__init__.py b/cebra/__init__.py index fd4cf58c..b361a441 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -33,7 +33,7 @@ from cebra.integrations.sklearn.decoder import L1LinearRegressor is_sklearn_available = True -except ImportError as e: +except ImportError: # silently fail for now pass @@ -42,7 +42,7 @@ from cebra.integrations.matplotlib import * is_matplotlib_available = True -except ImportError as e: +except ImportError: # silently fail for now pass @@ -51,7 +51,7 @@ from cebra.integrations.plotly import * is_plotly_available = True -except ImportError as e: +except ImportError: # silently fail for now pass diff --git a/cebra/__main__.py b/cebra/__main__.py index 6c7c18bf..4ba66993 100644 --- a/cebra/__main__.py +++ b/cebra/__main__.py @@ -27,11 +27,7 @@ import argparse import sys -import numpy as np -import torch - import cebra -import cebra.distributions as cebra_distr def train(parser, kwargs): diff --git a/cebra/config.py b/cebra/config.py index ba6e3922..a960721f 100644 --- a/cebra/config.py +++ b/cebra/config.py @@ -21,7 +21,6 @@ # import argparse import json -from dataclasses import MISSING from typing import Literal, Optional import literate_dataclasses as dataclasses diff --git a/cebra/data/base.py b/cebra/data/base.py index 54ae4579..e35e20c5 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -22,11 +22,8 @@ """Base classes for datasets and loaders.""" import abc -import collections -from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data.assets as cebra_data_assets diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 0b7f191d..9fa815c2 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -21,21 +21,14 @@ # """Pre-defined datasets.""" -import abc -import collections import types from typing import List, Tuple, Union -import literate_dataclasses as dataclasses import numpy as np import numpy.typing as npt import torch -from numpy.typing import NDArray import cebra.data as cebra_data -import cebra.distributions -from cebra.data.datatypes import Batch -from cebra.data.datatypes import BatchIndex class TensorDataset(cebra_data.SingleSessionDataset): diff --git a/cebra/data/datatypes.py b/cebra/data/datatypes.py index 11583909..4b2ac8a2 100644 --- a/cebra/data/datatypes.py +++ b/cebra/data/datatypes.py @@ -20,9 +20,6 @@ # limitations under the License. # import collections -from typing import Tuple - -import torch __all__ = ["Batch", "BatchIndex", "Offset"] diff --git a/cebra/data/helper.py b/cebra/data/helper.py index c324a80f..d2a1cfe3 100644 --- a/cebra/data/helper.py +++ b/cebra/data/helper.py @@ -181,14 +181,14 @@ def fit( elif ref_data.shape[0] == data.shape[0] and (ref_label is None or label is None): raise ValueError( - f"Missing labels: the data to align are the same shape but you provided only " - f"one of the sets of labels. Either provide both the reference and alignment " - f"labels or none.") + "Missing labels: the data to align are the same shape but you provided only " + "one of the sets of labels. Either provide both the reference and alignment " + "labels or none.") else: if ref_label is None or label is None: raise ValueError( - f"Missing labels: the data to align are not the same shape, " - f"provide labels to align the data and reference data.") + "Missing labels: the data to align are not the same shape, " + "provide labels to align the data and reference data.") if len(ref_label.shape) == 1: ref_label = np.expand_dims(ref_label, axis=1) diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index cff61038..ebae8b6f 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -22,11 +22,9 @@ """Datasets and loaders for multi-session training.""" import abc -import collections from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data as cebra_data diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index a821db97..0c575ed7 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -26,12 +26,9 @@ """ import abc -import collections import warnings -from typing import List import literate_dataclasses as dataclasses -import numpy as np import torch import cebra.data as cebra_data @@ -365,8 +362,8 @@ def __post_init__(self): if self.conditional != "time_delta": raise NotImplementedError( - f"Hybrid training is currently only implemented using the ``time_delta`` " - f"continual distribution.") + "Hybrid training is currently only implemented using the ``time_delta`` " + "continual distribution.") self.time_distribution = cebra.distributions.TimeContrastive( time_offset=self.time_offset, diff --git a/cebra/datasets/allen/ca_movie.py b/cebra/datasets/allen/ca_movie.py index f11e5e93..fa25f72a 100644 --- a/cebra/datasets/allen/ca_movie.py +++ b/cebra/datasets/allen/ca_movie.py @@ -29,11 +29,8 @@ """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np import pandas as pd @@ -46,7 +43,6 @@ import cebra.data from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/ca_movie_decoding.py b/cebra/datasets/allen/ca_movie_decoding.py index 12d6cc64..8bb164cc 100644 --- a/cebra/datasets/allen/ca_movie_decoding.py +++ b/cebra/datasets/allen/ca_movie_decoding.py @@ -29,11 +29,8 @@ """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np import pandas as pd @@ -41,12 +38,10 @@ import torch from numpy.random import Generator from numpy.random import PCG64 -from sklearn.decomposition import PCA import cebra.data from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS from cebra.datasets.allen import SEEDS_DISJOINT diff --git a/cebra/datasets/allen/combined.py b/cebra/datasets/allen/combined.py index bfaca9b3..a05eb17c 100644 --- a/cebra/datasets/allen/combined.py +++ b/cebra/datasets/allen/combined.py @@ -31,22 +31,8 @@ """ -import glob -import hashlib - -import h5py -import joblib -import numpy as np -import pandas as pd -import scipy.io -import torch -from numpy.random import Generator -from numpy.random import PCG64 -from sklearn.decomposition import PCA - import cebra.data from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie from cebra.datasets.allen import ca_movie_decoding from cebra.datasets.allen import neuropixel_movie @@ -80,7 +66,7 @@ def __init__(self, num_neurons=1000, seed=111, area="VISp"): ) def __repr__(self): - return f"CaNeuropixelDataset" + return "CaNeuropixelDataset" @parametrize( @@ -117,7 +103,7 @@ def __init__(self, ) def __repr__(self): - return f"CaNeuropixelMovieOneCorticesDataset" + return "CaNeuropixelMovieOneCorticesDataset" @parametrize( @@ -152,4 +138,4 @@ def __init__(self, group, num_neurons, seed, cortex, split_flag="train"): ) def __repr__(self): - return f"CaNeuropixelMovieOneCorticesDisjointDataset" + return "CaNeuropixelMovieOneCorticesDisjointDataset" diff --git a/cebra/datasets/allen/make_neuropixel.py b/cebra/datasets/allen/make_neuropixel.py index 5c0568b7..1eabfe9f 100644 --- a/cebra/datasets/allen/make_neuropixel.py +++ b/cebra/datasets/allen/make_neuropixel.py @@ -31,14 +31,12 @@ """ import argparse -import glob import pathlib import h5py import joblib as jl import numpy as np import numpy.typing as npt -import pandas as pd from cebra.datasets import get_datapath diff --git a/cebra/datasets/allen/neuropixel_movie.py b/cebra/datasets/allen/neuropixel_movie.py index 51011407..f9b9c3ea 100644 --- a/cebra/datasets/allen/neuropixel_movie.py +++ b/cebra/datasets/allen/neuropixel_movie.py @@ -26,24 +26,12 @@ *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. """ -import glob -import hashlib import pathlib -import h5py import joblib -import numpy as np -import pandas as pd -import scipy.io -import torch -from numpy.random import Generator -from numpy.random import PCG64 -from sklearn.decomposition import PCA - -import cebra.data + from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/neuropixel_movie_decoding.py b/cebra/datasets/allen/neuropixel_movie_decoding.py index a99f367d..4ff1ebc2 100644 --- a/cebra/datasets/allen/neuropixel_movie_decoding.py +++ b/cebra/datasets/allen/neuropixel_movie_decoding.py @@ -26,25 +26,17 @@ *Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92. """ -import glob -import hashlib import pathlib -import h5py import joblib import numpy as np -import pandas as pd -import scipy.io import torch from numpy.random import Generator from numpy.random import PCG64 -from sklearn.decomposition import PCA import cebra.data -from cebra.datasets import allen from cebra.datasets import get_datapath from cebra.datasets import parametrize -from cebra.datasets import register from cebra.datasets.allen import ca_movie_decoding from cebra.datasets.allen import NUM_NEURONS from cebra.datasets.allen import SEEDS diff --git a/cebra/datasets/allen/single_session_ca.py b/cebra/datasets/allen/single_session_ca.py index f207a1bc..5a3eea4d 100644 --- a/cebra/datasets/allen/single_session_ca.py +++ b/cebra/datasets/allen/single_session_ca.py @@ -28,25 +28,17 @@ *http://observatory.brain-map.org/visualcoding """ -import glob -import hashlib import pathlib -import h5py -import joblib import numpy as np -import pandas as pd import scipy.io import torch -from numpy.random import Generator -from numpy.random import PCG64 from sklearn.decomposition import PCA import cebra.data from cebra.datasets import get_datapath from cebra.datasets import init from cebra.datasets import parametrize -from cebra.datasets import register _DEFAULT_DATADIR = get_datapath() diff --git a/cebra/datasets/gaussian_mixture.py b/cebra/datasets/gaussian_mixture.py index f5508838..05fd971d 100644 --- a/cebra/datasets/gaussian_mixture.py +++ b/cebra/datasets/gaussian_mixture.py @@ -20,17 +20,13 @@ # limitations under the License. # import pathlib -from typing import Tuple import joblib as jl -import literate_dataclasses as dataclasses import numpy as np -import sklearn import torch import cebra.data import cebra.io -from cebra.datasets import get_datapath from cebra.datasets import parametrize from cebra.datasets import register diff --git a/cebra/datasets/generate_synthetic_data.py b/cebra/datasets/generate_synthetic_data.py index 8a243d6d..0fc33963 100644 --- a/cebra/datasets/generate_synthetic_data.py +++ b/cebra/datasets/generate_synthetic_data.py @@ -26,7 +26,6 @@ """ import argparse import pathlib -import sys import joblib as jl import keras diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index a32209a3..92537b8e 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -31,12 +31,10 @@ """ -import hashlib import pathlib import joblib import numpy as np -import scipy.io import sklearn.model_selection import sklearn.neighbors import torch diff --git a/cebra/datasets/make_neuropixel.py b/cebra/datasets/make_neuropixel.py index 7c097f38..65029f94 100644 --- a/cebra/datasets/make_neuropixel.py +++ b/cebra/datasets/make_neuropixel.py @@ -36,7 +36,6 @@ import joblib as jl import numpy as np import numpy.typing as npt -import pandas as pd def _filter_units( diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 23fc5a6c..a07e24fd 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -28,14 +28,11 @@ """ -import hashlib import pathlib -import pickle as pk from typing import Union import joblib as jl import numpy as np -import scipy.io import torch import cebra.data @@ -72,7 +69,7 @@ def _load_data( try: from nlb_tools.nwb_interface import NWBDataset - except ImportError as e: + except ImportError: raise ImportError( "Could not import the nlb_tools package required for data loading " "the raw reaching datasets in NWB format. " diff --git a/cebra/distributions/base.py b/cebra/distributions/base.py index 990d7e79..07ad9ae4 100644 --- a/cebra/distributions/base.py +++ b/cebra/distributions/base.py @@ -31,7 +31,6 @@ """ import abc -import functools import torch @@ -82,7 +81,7 @@ def to(self, device: str): self._generator = torch.Generator(device=device) try: self._generator.set_state(state.to(device)) - except (TypeError, RuntimeError) as e: + except (TypeError, RuntimeError): # TODO(https://discuss.pytorch.org/t/cuda-rng-state-does-not-change-when-re-seeding-why-is-that/47917/3) self._generator.manual_seed(self.seed) diff --git a/cebra/distributions/continuous.py b/cebra/distributions/continuous.py index c4235d48..ad95fdf6 100644 --- a/cebra/distributions/continuous.py +++ b/cebra/distributions/continuous.py @@ -23,7 +23,6 @@ from typing import Literal, Optional -import numpy as np import torch import cebra.data @@ -112,8 +111,8 @@ def __init__( abc_.HasGenerator.__init__(self, device=device, seed=seed) if continuous is None and num_samples is None: raise ValueError( - f"Supply either a continuous index (which will be used to infer the dataset size) " - f"or alternatively the number of datapoints using the num_samples argument." + "Supply either a continuous index (which will be used to infer the dataset size) " + "or alternatively the number of datapoints using the num_samples argument." ) if continuous is not None and num_samples is not None: if len(continuous) != num_samples: diff --git a/cebra/distributions/index.py b/cebra/distributions/index.py index 0ee0959a..724e86e4 100644 --- a/cebra/distributions/index.py +++ b/cebra/distributions/index.py @@ -30,7 +30,6 @@ discrete labels should be converted accordingly. """ -import numpy as np import torch import cebra.data @@ -188,9 +187,9 @@ def __init__(self, discrete, continuous): "of samples.") if len(discrete.shape) > 1: raise ValueError( - f"Discrete indexing information needs to be limited to a 1d " - f"array/tensor. Multi-dimensional discrete indices should be " - f"reformatted first.") + "Discrete indexing information needs to be limited to a 1d " + "array/tensor. Multi-dimensional discrete indices should be " + "reformatted first.") # TODO(stes): Once a helper function exists, the error message should # mention it. diff --git a/cebra/distributions/mixed.py b/cebra/distributions/mixed.py index 14fb8a61..7221fd99 100644 --- a/cebra/distributions/mixed.py +++ b/cebra/distributions/mixed.py @@ -27,7 +27,6 @@ """ from typing import Literal -import numpy as np import torch import cebra.io diff --git a/cebra/integrations/deeplabcut.py b/cebra/integrations/deeplabcut.py index c265b09a..4c5b292d 100644 --- a/cebra/integrations/deeplabcut.py +++ b/cebra/integrations/deeplabcut.py @@ -160,7 +160,7 @@ def load_data(self, pcutoff: float = 0.6) -> npt.NDArray: ) elif self.dlc_df.columns.nlevels == 4: raise NotImplementedError( - f"Multi-animals DLC files are not handled. Please provide a single-animal file." + "Multi-animals DLC files are not handled. Please provide a single-animal file." ) dlc_df_coords = ( diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index bdae8ca7..97beaaaa 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -21,9 +21,7 @@ # """Define the CEBRA model.""" -import copy import itertools -import warnings from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union) @@ -33,7 +31,6 @@ import sklearn.utils.validation as sklearn_utils_validation import torch from sklearn.base import BaseEstimator -from sklearn.base import ClassifierMixin from sklearn.base import TransformerMixin from torch import nn @@ -274,8 +271,8 @@ def _require_arg(key): "Until then, please train using the PyTorch API.")) else: raise RuntimeError( - f"Index combination not covered. Please report this issue and add the following " - f"information to your bug report: \n" + error_message) + "Index combination not covered. Please report this issue and add the following " + "information to your bug report: \n" + error_message) def _check_type_checkpoint(checkpoint): @@ -776,18 +773,18 @@ def _configure_for_all( cebra.models.ConvolutionalModelMixin): if len(model[n].get_offset()) > 1: raise ValueError( - f"It is not yet supported to run non-convolutional models with " - f"receptive fields/offsets larger than 1 via the sklearn API. " - f"Please use a different model, or revert to the pytorch " - f"API for training.") + "It is not yet supported to run non-convolutional models with " + "receptive fields/offsets larger than 1 via the sklearn API. " + "Please use a different model, or revert to the pytorch " + "API for training.") else: if not isinstance(model, cebra.models.ConvolutionalModelMixin): if len(model.get_offset()) > 1: raise ValueError( - f"It is not yet supported to run non-convolutional models with " - f"receptive fields/offsets larger than 1 via the sklearn API. " - f"Please use a different model, or revert to the pytorch " - f"API for training.") + "It is not yet supported to run non-convolutional models with " + "receptive fields/offsets larger than 1 via the sklearn API. " + "Please use a different model, or revert to the pytorch " + "API for training.") dataset.configure_for(model) @@ -1466,12 +1463,12 @@ def load(cls, if isinstance(checkpoint, dict) and backend == "torch": raise RuntimeError( - f"Cannot use 'torch' backend with a dictionary-based checkpoint. " - f"Please try a different backend.") + "Cannot use 'torch' backend with a dictionary-based checkpoint. " + "Please try a different backend.") if not isinstance(checkpoint, dict) and backend == "sklearn": raise RuntimeError( - f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. " - f"Please try a different backend.") + "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " + "Please try a different backend.") if backend == "sklearn": cebra_ = _load_cebra_with_sklearn_backend(checkpoint) diff --git a/cebra/integrations/sklearn/helpers.py b/cebra/integrations/sklearn/helpers.py index 06095c1e..9127aaa2 100644 --- a/cebra/integrations/sklearn/helpers.py +++ b/cebra/integrations/sklearn/helpers.py @@ -42,7 +42,7 @@ def _get_min_max( for label in labels: if any(isinstance(l, str) for l in label): raise ValueError( - f"Invalid labels dtype, expect floats or integers, got string") + "Invalid labels dtype, expect floats or integers, got string") min = np.min(label) if min > np.min(label) else min max = np.max(label) if max < np.max(label) else max return min, max diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 59a961b3..d07f9359 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -188,7 +188,7 @@ def _consistency_datasets( if labels is None: raise ValueError( "Missing labels, computing consistency between datasets requires labels, expect " - f"a set of labels for each embedding.") + "a set of labels for each embedding.") if len(embeddings) != len(labels): raise ValueError( "Invalid set of labels, computing consistency between datasets requires labels, " @@ -274,8 +274,8 @@ def _consistency_runs( if not all(embeddings[0].shape[0] == embeddings[i].shape[0] for i in range(1, len(embeddings))): raise ValueError( - f"Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way." - f"If your embeddings are coming from different models, you can use between-datasets" + "Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way." + "If your embeddings are coming from different models, you can use between-datasets" ) run_ids = np.arange(len(embeddings)) @@ -354,11 +354,11 @@ def consistency_score( if between == "runs": if labels is not None: raise ValueError( - f"No labels should be provided for between-runs consistency.") + "No labels should be provided for between-runs consistency.") if dataset_ids is not None: raise ValueError( - f"No dataset ID should be provided for between-runs consistency." - f"All embeddings should be computed on the same dataset.") + "No dataset ID should be provided for between-runs consistency." + "All embeddings should be computed on the same dataset.") scores, pairs, ids = _consistency_runs(embeddings=embeddings,) elif between == "datasets": scores, pairs, ids = _consistency_datasets( diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index 8dbdc2b4..d2a5a04f 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -33,7 +33,7 @@ """ import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import nn diff --git a/cebra/models/model.py b/cebra/models/model.py index f4a5d862..7631ba86 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -22,10 +22,8 @@ """Neural network models and criterions for training CEBRA models.""" import abc -import literate_dataclasses as dataclasses import torch import torch.nn.functional as F -import tqdm from torch import nn import cebra.data diff --git a/cebra/models/projector.py b/cebra/models/projector.py index 0c924296..dd7388bc 100644 --- a/cebra/models/projector.py +++ b/cebra/models/projector.py @@ -134,7 +134,7 @@ def features(self, inp, index): return self._features[index](inp) def forward(self, inp): - raise NotImplemented() + raise NotImplementedError() def get_offset(self) -> cebra.data.Offset: return cebra.data.Offset(5, 5) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 7f0cbef1..b28f4848 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,15 +32,12 @@ import abc import os -from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, - Union) +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import literate_dataclasses as dataclasses -import numpy as np import numpy.typing as npt import torch import torch.nn.functional as F -import tqdm from torch.utils.data import DataLoader from torch.utils.data import Dataset @@ -48,7 +45,6 @@ import cebra.data import cebra.io import cebra.models -import cebra.solver.util as cebra_solver_util from cebra.solver.util import Meter from cebra.solver.util import ProgressBar @@ -56,9 +52,9 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, offset: cebra.data.Offset, num_samples: int): """Check that indexes in a batch are in a correct range. - - First and last index must be positive integers, smaller than the total length of inputs - in the dataset, the first index must be smaller than the last and the batch size cannot + + First and last index must be positive integers, smaller than the total length of inputs + in the dataset, the first index must be smaller than the last and the batch size cannot be smaller than the offset of the model. Args: @@ -101,7 +97,7 @@ def _add_batched_zero_padding(batched_data: torch.Tensor, offset: Offset of the model to consider when padding. batch_start_idx: Index of the first sample in the batch. batch_end_idx: Index of the first sample in the batch. - num_samples (int): Total number of samples in the data. + num_samples (int): Total number of samples in the data. Returns: The padded batch. @@ -136,7 +132,7 @@ def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset], The batch. """ if offset is None: - raise ValueError(f"offset cannot be null.") + raise ValueError("offset cannot be null.") if batch_start_idx == 0: # First batch indices = batch_start_idx, (batch_end_idx + offset.right - 1) @@ -427,7 +423,7 @@ def fit( validation_loss = self.validation(valid_loader) if self.best_loss is None or validation_loss < self.best_loss: self.best_loss = validation_loss - self.save(logdir, f"checkpoint_best.pth") + self.save(logdir, "checkpoint_best.pth") if save_model: if decode: self.decode_history.append( @@ -511,11 +507,11 @@ def decoding(self, train_loader, valid_loader): @abc.abstractmethod def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): """Check that the inputs can be inferred using the selected model. - + Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to @@ -526,8 +522,8 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): @abc.abstractmethod def _check_is_session_id_valid(self, session_id: Optional[int] = None): """Check that the session ID provided is valid for the solver instance. - - Args: + + Args: session_id: The session ID to check. """ raise NotImplementedError @@ -539,14 +535,14 @@ def _select_model( ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], cebra.data.datatypes.Offset]: """ Select the model based on the input dimension and session ID. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. - Returns: + Returns: The model (first returns) and the offset of the model (second returns). """ raise NotImplementedError diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index b4be2125..2c2153c2 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -21,11 +21,8 @@ # """Solver implementations for multi-session datasetes.""" -import abc -from collections.abc import Iterable from typing import List, Optional -import literate_dataclasses as dataclasses import torch import cebra @@ -126,10 +123,10 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: def _set_fitted_params(self, loader: cebra.data.Loader): """Set parameters once the solver is fitted. - + In multi session solver, the number of session is set to the number of sessions in the dataset of the loader and the number of - features is set as a list corresponding to the number of neurons in + features is set as a list corresponding to the number of neurons in each dataset. Args: @@ -145,11 +142,11 @@ def _set_fitted_params(self, loader: cebra.data.Loader): def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: Optional[int]): """Check that the inputs can be inferred using the selected model. - + Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to @@ -163,10 +160,10 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, def _check_is_session_id_valid(self, session_id: Optional[int]): """Check that the session ID provided is valid for the solver instance. - + The session ID must be non-null and between 0 and the number session in the dataset. - - Args: + + Args: session_id: The session ID to check. """ @@ -181,14 +178,14 @@ def _check_is_session_id_valid(self, session_id: Optional[int]): def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]): """ Select the model based on the input dimension and session ID. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. - Returns: + Returns: The model (first returns) and the offset of the model (second returns). """ self._check_is_session_id_valid(session_id=session_id) diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index e0927a21..62570a57 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -56,7 +56,7 @@ def parameters(self, session_id: Optional[int] = None): def _set_fitted_params(self, loader: cebra.data.Loader): """Set parameters once the solver is fitted. - + In single session solver, the number of session is set to None and the number of features is set to the number of neurons in the dataset. @@ -68,11 +68,11 @@ def _set_fitted_params(self, loader: cebra.data.Loader): def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): """Check that the inputs can be inferred using the selected model. - + Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to @@ -86,10 +86,10 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): def _check_is_session_id_valid(self, session_id: Optional[int] = None): """Check that the session ID provided is valid for the solver instance. - + The session ID must be null or equal to 0. - - Args: + + Args: session_id: The session ID to check. """ @@ -104,14 +104,14 @@ def _select_model( ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], cebra.data.datatypes.Offset]: """ Select the model based on the input dimension and session ID. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. - Returns: + Returns: The model (first returns) and the offset of the model (second returns). """ self._check_is_inputs_valid(inputs, session_id=session_id) @@ -229,14 +229,14 @@ def _select_model( ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], cebra.data.datatypes.Offset]: """ Select the model based on the input dimension and session ID. - - Args: + + Args: inputs: Data to infer using the selected model. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. - Returns: + Returns: The model (first returns) and the offset of the model (second returns). """ self._check_is_inputs_valid(inputs, session_id=session_id) diff --git a/cebra/solver/supervised.py b/cebra/solver/supervised.py index f69308e6..f4e4f95c 100644 --- a/cebra/solver/supervised.py +++ b/cebra/solver/supervised.py @@ -25,17 +25,9 @@ It is inclear whether these will be kept. Consider the implementation as experimental/outdated, and the API for this particular package unstable. """ -import abc -from collections.abc import Iterable -from typing import List -import literate_dataclasses as dataclasses import torch -import tqdm -import cebra -import cebra.data -import cebra.models import cebra.solver.base as abc_ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 98885d07..c9f9fb2f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -68,7 +68,6 @@ def test_demo(): @pytest.mark.requires_dataset def test_hippocampus(): - from cebra.datasets import hippocampus pytest.skip("Outdated") dataset = cebra.datasets.init("rat-hippocampus-single") @@ -99,7 +98,6 @@ def test_hippocampus(): @pytest.mark.requires_dataset def test_monkey(): - from cebra.datasets import monkey_reaching dataset = cebra.datasets.init( "area2-bump-pos-active-passive", @@ -111,7 +109,6 @@ def test_monkey(): @pytest.mark.requires_dataset def test_allen(): - from cebra.datasets import allen pytest.skip("Test takes too long") @@ -153,8 +150,9 @@ def test_allen(): @pytest.mark.requires_dataset -@pytest.mark.parametrize( - "options", cebra.datasets.get_options("*", expand_parametrized=False)) +@pytest.mark.parametrize("options", + cebra.datasets.get_options("*", + expand_parametrized=False)) def test_options(options): assert len(options) > 0 assert len(multisubject_options) > 0 diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 0644aef7..e1e09e5d 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1145,7 +1145,7 @@ def test_move_cpu_to_cuda_device(device): def test_move_cpu_to_mps_device(device): if not cebra.helper._is_mps_availabe(torch): - pytest.skip(f"MPS device is not available") + pytest.skip("MPS device is not available") X = np.random.uniform(0, 1, (10, 5)) cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model", @@ -1360,7 +1360,7 @@ def test_check_device(): ) def test_new_transform(model_architecture, device): """ - This is a test that the original sklearn transform returns the same output as + This is a test that the original sklearn transform returns the same output as the new sklearn transform that uses the pytorch solver transform. """ output_dimension = 4 diff --git a/tests/test_solver.py b/tests/test_solver.py index c27a9e41..68e2a43e 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -20,7 +20,6 @@ # limitations under the License. # import copy -import itertools import tempfile import numpy as np @@ -506,8 +505,8 @@ def create_model(model_name, input_dimension): @pytest.mark.parametrize( "data_name, model_name ,session_id, loader_initfunc, solver_initfunc", - single_session_tests_select_model + - single_session_hybrid_tests_select_model) + single_session_tests_select_model + single_session_hybrid_tests_select_model +) def test_select_model_single_session(data_name, model_name, session_id, loader_initfunc, solver_initfunc): dataset = cebra.datasets.init(data_name)