Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nvalchemi/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ def get_data(self, idx: int) -> AtomicData:
"""
if idx < 0:
idx = self.num_graphs + idx
if not (0 <= idx < self.num_graphs):
raise IndexError(
f"graph index {idx} is out of range for batch with "
f"{self.num_graphs} graph(s)"
)
Comment on lines 584 to +590
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Error message reports adjusted index, not the original

When a negative index is adjusted before the bounds check, the error message reports the adjusted (intermediate) index rather than the value the caller actually passed. For example, get_data(-5) on a 2-graph batch adjusts to idx = -3 and then raises:

IndexError: graph index -3 is out of range for batch with 2 graph(s)

The caller passed -5, so seeing -3 in the message is surprising and hard to debug. Preserving the original value before adjustment and referencing it in the message would make the error much clearer:

Suggested change
if idx < 0:
idx = self.num_graphs + idx
if not (0 <= idx < self.num_graphs):
raise IndexError(
f"graph index {idx} is out of range for batch with "
f"{self.num_graphs} graph(s)"
)
original_idx = idx
if idx < 0:
idx = self.num_graphs + idx
if not (0 <= idx < self.num_graphs):
raise IndexError(
f"graph index {original_idx} is out of range for batch with "
f"{self.num_graphs} graph(s)"
)


data: dict[str, Any] = {}

Expand Down
8 changes: 8 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ def test_get_data_negative_index(self):
last = batch.get_data(-1)
assert last.num_nodes == 3

@pytest.mark.parametrize("idx", [-3, -4, -5, -100, 2, 3, 100])
def test_get_data_out_of_range_raises(self, idx):
d1 = _minimal_atomic_data(2)
d2 = _minimal_atomic_data(3)
batch = Batch.from_data_list([d1, d2])
with pytest.raises(IndexError):
batch.get_data(idx)

def test_to_data_list(self):
d1 = _minimal_atomic_data(2)
d2 = _minimal_atomic_data(3)
Expand Down
Loading