From c98169c71727ed3a9dc34a76b66bfdc2a6a3c52a Mon Sep 17 00:00:00 2001 From: timonmerk Date: Sun, 29 Oct 2023 17:27:30 +0100 Subject: [PATCH 1/5] add positive sampling options for MixedDataLoader --- cebra/data/single_session.py | 59 +++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7802b78..eb09344 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -268,27 +268,47 @@ class MixedDataLoader(cebra_data.Loader): 1. Positive pairs always share their discrete variable. 2. Positive pairs are drawn only based on their conditional, not discrete variable. + + Args: + conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` + time_offset (int): :py:attr:`cebra.CEBRA.time_offsets` + positive_sampling (str): either "discrete_variable" (default) or "conditional" + discrete_sampling_prior (str): either "empirical" (default) or "uniform" """ conditional: str = dataclasses.field(default="time_delta") time_offset: int = dataclasses.field(default=10) + positive_sampling: str = dataclasses.field(default="discrete_variable") + discrete_sampling_prior: str = dataclasses.field(default="uniform") @property - def dindex(self): - # TODO(stes) rename to discrete_index + def discrete_index(self): return self.dataset.discrete_index @property - def cindex(self): - # TODO(stes) rename to continuous_index + def continuous_index(self): return self.dataset.continuous_index def __post_init__(self): super().__post_init__() - self.distribution = cebra.distributions.MixedTimeDeltaDistribution( - discrete=self.dindex, - continuous=self.cindex, - time_delta=self.time_offset) + if self.positive_sampling == "conditional": + self.distribution = cebra.distributions.MixedTimeDeltaDistribution( + discrete=self.discrete_index, + continuous=self.continuous_index, + time_delta=self.time_offset) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical": + self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform": + self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]: + raise ValueError( + f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but " + f"only accept 'uniform' or 'empirical' as potential values.") + else: + raise ValueError( + f"Invalid positive sampling mode: " + f"{self.positive_sampling} valid options are " + f"'conditional' or 'discrete_variable'.") def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. @@ -313,12 +333,23 @@ def get_indices(self, num_samples: int) -> BatchIndex: class. - Sample the negatives with matching discrete variable """ - reference_idx = self.distribution.sample_prior(num_samples) - return BatchIndex( - reference=reference_idx, - negative=self.distribution.sample_prior(num_samples), - positive=self.distribution.sample_conditional(reference_idx), - ) + if self.positive_sampling == "conditional": + reference_idx = self.distribution.sample_prior(num_samples) + return BatchIndex( + reference=reference_idx, + negative=self.distribution.sample_prior(num_samples), + positive=self.distribution.sample_conditional(reference_idx), + ) + else: + # taken from the DiscreteDataLoader get_indices function + reference_idx = self.distribution.sample_prior(num_samples * 2) + negative_idx = reference_idx[num_samples:] + reference_idx = reference_idx[:num_samples] + reference = self.discrete_index[reference_idx] + positive_idx = self.distribution.sample_conditional(reference) + return BatchIndex(reference=reference_idx, + positive=positive_idx, + negative=negative_idx) @dataclasses.dataclass From e82aad7997b46d5f3fe5aea8d5ba1d2fffe6c881 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 17:27:21 +0100 Subject: [PATCH 2/5] add deprecation warning for cindex and dindex --- cebra/data/single_session.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index eb09344..1fddc74 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -281,10 +281,22 @@ class MixedDataLoader(cebra_data.Loader): positive_sampling: str = dataclasses.field(default="discrete_variable") discrete_sampling_prior: str = dataclasses.field(default="uniform") + @property + def dindex(self): + warnings.warn("dindex is deprecated. Use discrete_index instead.", + DeprecationWarning) + return self.dataset.discrete_index + @property def discrete_index(self): return self.dataset.discrete_index + @property + def cindex(self): + warnings.warn("cindex is deprecated. Use continuous_index instead.", + DeprecationWarning) + return self.dataset.continuous_index + @property def continuous_index(self): return self.dataset.continuous_index From cba88ae6a5a3c8357d820384827dc52fdcc46559 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 17:31:48 +0100 Subject: [PATCH 3/5] add test for MixedDataLoader including additional keywords --- tests/test_loader.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a..e2d819c 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark): benchmark(load_speed) +@parametrize_device +@pytest.mark.parametrize( + "conditional, positive_sampling, discrete_sampling_prior", + [ + ("time", "discrete_variable", "empirical"), + ("time", "conditional", "empirical"), + ("time", "discrete_variable", "uniform"), + ("time", "conditional", "uniform"), + ("time_delta", "discrete_variable", "empirical"), + ("time_delta", "conditional", "empirical"), + ("time_delta", "discrete_variable", "uniform"), + ("time_delta", "conditional", "uniform"), + ], +) +def test_mixed( + conditional, positive_sampling, discrete_sampling_prior, device, benchmark +): + dataset = RandomDataset(N=100, d=5, device=device) + loader = cebra.data.MixedDataLoader( + dataset=dataset, + num_steps=10, + batch_size=8, + conditional=conditional, + positive_sampling=positive_sampling, + discrete_sampling_prior=discrete_sampling_prior, + ) + _assert_dataset_on_correct_device(loader, device) + load_speed = LoadSpeed(loader) + benchmark(load_speed) + + def _check_attributes(obj, is_list=False): if is_list: for obj_ in obj: From f217c323bd44701794ee8c7b6bc87bb42c1d9b6a Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 18:06:07 +0100 Subject: [PATCH 4/5] add improved docstring description --- cebra/data/single_session.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 1fddc74..61cebb7 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -265,9 +265,16 @@ class MixedDataLoader(cebra_data.Loader): Sampling can be configured in different modes: - 1. Positive pairs always share their discrete variable. + 1. Positive pairs always share their discrete variable (positive_sampling = "discrete_variable"). 2. Positive pairs are drawn only based on their conditional, - not discrete variable. + not discrete variable (positive_sampling = "conditional"). + + When using the discrete variable, the prior distribution can either be uniform + (discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). + + Based on the selection of those parameters, the :py:class:`cebra.distributions.MixedTimeDeltaDistribution`, + :py:class:`cebra.distributions.DiscreteEmpirical`, or :py:class:`cebra.distributions.DiscreteUniform` + distributions are used for sampling. Args: conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` From f91b64a86d24c7924b9e057b300a4baecf52e570 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 18:21:55 +0100 Subject: [PATCH 5/5] fix docstring sphinx link --- cebra/data/single_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 61cebb7..993078f 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -272,8 +272,8 @@ class MixedDataLoader(cebra_data.Loader): When using the discrete variable, the prior distribution can either be uniform (discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). - Based on the selection of those parameters, the :py:class:`cebra.distributions.MixedTimeDeltaDistribution`, - :py:class:`cebra.distributions.DiscreteEmpirical`, or :py:class:`cebra.distributions.DiscreteUniform` + Based on the selection of those parameters, the :py:class:`cebra.distributions.mixed.MixedTimeDeltaDistribution`, + :py:class:`cebra.distributions.discrete.DiscreteEmpirical`, or :py:class:`cebra.distributions.discrete.DiscreteUniform` distributions are used for sampling. Args: