diff --git a/nvalchemi/data/batch.py b/nvalchemi/data/batch.py index c26629e..604dac9 100644 --- a/nvalchemi/data/batch.py +++ b/nvalchemi/data/batch.py @@ -583,6 +583,11 @@ def get_data(self, idx: int) -> AtomicData: """ if idx < 0: idx = self.num_graphs + idx + if not (0 <= idx < self.num_graphs): + raise IndexError( + f"graph index {idx} is out of range for batch with " + f"{self.num_graphs} graph(s)" + ) data: dict[str, Any] = {} diff --git a/test/data/test_batch.py b/test/data/test_batch.py index f1db7cd..a0d9a38 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -281,6 +281,14 @@ def test_get_data_negative_index(self): last = batch.get_data(-1) assert last.num_nodes == 3 + @pytest.mark.parametrize("idx", [-3, -4, -5, -100, 2, 3, 100]) + def test_get_data_out_of_range_raises(self, idx): + d1 = _minimal_atomic_data(2) + d2 = _minimal_atomic_data(3) + batch = Batch.from_data_list([d1, d2]) + with pytest.raises(IndexError): + batch.get_data(idx) + def test_to_data_list(self): d1 = _minimal_atomic_data(2) d2 = _minimal_atomic_data(3)