Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameters for MixedDataLoader #101

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 64 additions & 14 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,30 +265,69 @@ class MixedDataLoader(cebra_data.Loader):

Sampling can be configured in different modes:

1. Positive pairs always share their discrete variable.
1. Positive pairs always share their discrete variable (positive_sampling = "discrete_variable").
2. Positive pairs are drawn only based on their conditional,
not discrete variable.
not discrete variable (positive_sampling = "conditional").

When using the discrete variable, the prior distribution can either be uniform
(discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical").

Based on the selection of those parameters, the :py:class:`cebra.distributions.mixed.MixedTimeDeltaDistribution`,
:py:class:`cebra.distributions.discrete.DiscreteEmpirical`, or :py:class:`cebra.distributions.discrete.DiscreteUniform`
distributions are used for sampling.

Args:
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
positive_sampling (str): either "discrete_variable" (default) or "conditional"
discrete_sampling_prior (str): either "empirical" (default) or "uniform"
"""

conditional: str = dataclasses.field(default="time_delta")
time_offset: int = dataclasses.field(default=10)
positive_sampling: str = dataclasses.field(default="discrete_variable")
discrete_sampling_prior: str = dataclasses.field(default="uniform")
stes marked this conversation as resolved.
Show resolved Hide resolved

@property
def dindex(self):
# TODO(stes) rename to discrete_index
warnings.warn("dindex is deprecated. Use discrete_index instead.",
DeprecationWarning)
return self.dataset.discrete_index

@property
def discrete_index(self):
return self.dataset.discrete_index

@property
def cindex(self):
# TODO(stes) rename to continuous_index
warnings.warn("cindex is deprecated. Use continuous_index instead.",
DeprecationWarning)
return self.dataset.continuous_index

@property
def continuous_index(self):
return self.dataset.continuous_index

def __post_init__(self):
super().__post_init__()
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
discrete=self.dindex,
continuous=self.cindex,
time_delta=self.time_offset)
if self.positive_sampling == "conditional":
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
discrete=self.discrete_index,
continuous=self.continuous_index,
time_delta=self.time_offset)
Comment on lines +313 to +317
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be the default behavior, that was how the class used to behave.

elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical":
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index)
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform":
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index)
Comment on lines +318 to +321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these modes different from going for the empirical discrete / uniform discrete distribution in the first place? I think what we rather want is specify an option to the MixedTimeDeltaDistribution to support empirical vs. uniform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but I understood that the current docstring of MixedDataLoader suggests that this is indeed the intended functionality:

1. Positive pairs always share their discrete variable.

Even though I agree that it wouldn't make sense in this case to call the MixedDataLoader

elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]:
raise ValueError(
f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")
else:
raise ValueError(
f"Invalid positive sampling mode: "
f"{self.positive_sampling} valid options are "
f"'conditional' or 'discrete_variable'.")

def get_indices(self, num_samples: int) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.
Expand All @@ -313,12 +352,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
class.
- Sample the negatives with matching discrete variable
"""
reference_idx = self.distribution.sample_prior(num_samples)
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
positive=self.distribution.sample_conditional(reference_idx),
)
if self.positive_sampling == "conditional":
reference_idx = self.distribution.sample_prior(num_samples)
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
positive=self.distribution.sample_conditional(reference_idx),
)
else:
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)
Comment on lines +363 to +371
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)
return self.distribution.get_indices(num_samples)

should be equivalent. But I think this is actually not the desired functionality, as this then completely ignores the continuous index...

Comment on lines +355 to +371
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this is setup currently means we either take both variables into account, or we ignore the continuous variables. I think that behavior is not necessarily intended (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree. As you suggested above, it would then makes sense to pass the MixedDataLoader.discrete_sampling_prior argument to MixedTimeDeltaDistribution directly, and adapt MixedTimeDeltaDistribution to not only sample DiscreteUniform distribution



@dataclasses.dataclass
Expand Down
31 changes: 31 additions & 0 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark):
benchmark(load_speed)


@parametrize_device
@pytest.mark.parametrize(
"conditional, positive_sampling, discrete_sampling_prior",
[
("time", "discrete_variable", "empirical"),
("time", "conditional", "empirical"),
("time", "discrete_variable", "uniform"),
("time", "conditional", "uniform"),
("time_delta", "discrete_variable", "empirical"),
("time_delta", "conditional", "empirical"),
("time_delta", "discrete_variable", "uniform"),
("time_delta", "conditional", "uniform"),
],
)
def test_mixed(
conditional, positive_sampling, discrete_sampling_prior, device, benchmark
):
dataset = RandomDataset(N=100, d=5, device=device)
loader = cebra.data.MixedDataLoader(
dataset=dataset,
num_steps=10,
batch_size=8,
conditional=conditional,
positive_sampling=positive_sampling,
discrete_sampling_prior=discrete_sampling_prior,
)
_assert_dataset_on_correct_device(loader, device)
load_speed = LoadSpeed(loader)
benchmark(load_speed)
Comment on lines +206 to +217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should extend the test to check the properties of the positive and negative samples (e.g., check if the discrete labels match and so forth, as expected for each setting of parameters)



def _check_attributes(obj, is_list=False):
if is_list:
for obj_ in obj:
Expand Down
Loading