Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last complete batch indexes for batched inference can go above the input length #199

Open
CeliaBenquet opened this issue Nov 19, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@CeliaBenquet
Copy link
Member

CeliaBenquet commented Nov 19, 2024

When using the code implemented in #168, I got an error in some cases where the input size is such that len(inputs) % batch_size < offset.right. That means that the last (incomplete) batch is smaller than offset.right. As a result, I get an error on the penultimate batch as the batch_end_idx is larger than input size.

Code to reproduce (from branch #168):

import cebra.data
import numpy as np

train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), 
                                     continuous=np.random.rand(20111, 2))

model = cebra.CEBRA(max_iterations=10, verbose=True, model_architecture="offset36-model-more-dropout", device="cuda_if_available")
model.fit(train.neural, train.continuous)

embedding = model.transform(train.neural, batch_size=300)

Error:

  File "/CEBRA-dev/cebra/solver/base.py", line 634, in transform
    output = _batched_transform(
             ^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 248, in _batched_transform
    batched_data = _get_batch(inputs=inputs,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 153, in _get_batch
    _check_indices(indices[0], indices[1], offset, len(inputs))
  File "/CEBRA-dev/cebra/solver/base.py", line 81, in _check_indices
    raise ValueError(
ValueError: batch_end_idx (20117) cannot exceed the length of inputs (20111).

I will propose a solution.

@CeliaBenquet CeliaBenquet added the bug Something isn't working label Nov 19, 2024
@CeliaBenquet CeliaBenquet self-assigned this Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant