Skip to content

Last complete batch indexes for batched inference can go above the input lengthΒ #199

Closed
@CeliaBenquet

Description

@CeliaBenquet

When using the code implemented in #168, I got an error in some cases where the input size is such that len(inputs) % batch_size < offset.right. That means that the last (incomplete) batch is smaller than offset.right. As a result, I get an error on the penultimate batch as the batch_end_idx is larger than input size.

Code to reproduce (from branch #168):

import cebra.data
import numpy as np

train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), 
                                     continuous=np.random.rand(20111, 2))

model = cebra.CEBRA(max_iterations=10, verbose=True, model_architecture="offset36-model-more-dropout", device="cuda_if_available")
model.fit(train.neural, train.continuous)

embedding = model.transform(train.neural, batch_size=300)

Error:

  File "/CEBRA-dev/cebra/solver/base.py", line 634, in transform
    output = _batched_transform(
             ^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 248, in _batched_transform
    batched_data = _get_batch(inputs=inputs,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 153, in _get_batch
    _check_indices(indices[0], indices[1], offset, len(inputs))
  File "/CEBRA-dev/cebra/solver/base.py", line 81, in _check_indices
    raise ValueError(
ValueError: batch_end_idx (20117) cannot exceed the length of inputs (20111).

I will propose a solution.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions