fix(data): add bounds check for out-of-range indices in Batch.get_data()#49
fix(data): add bounds check for out-of-range indices in Batch.get_data()#49Ryan-Reese wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Batch.get_data() adjusts negative indices via `idx = num_graphs + idx` but does not validate the result. When the adjusted index is still negative (e.g. get_data(-5) on a 3-graph batch), Python's negative indexing silently wraps around, returning node-level data from one graph mixed with system-level data from another — silent data corruption. Add a bounds check after the adjustment to raise IndexError for any index outside [0, num_graphs).
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Greptile SummaryThis PR correctly fixes a silent data-corruption bug in One minor usability concern was identified:
Additionally, the PR's commit is missing the required Important Files Changed
Reviews (1): Last reviewed commit: "fix(data): add bounds check for out-of-r..." | Re-trigger Greptile |
| if idx < 0: | ||
| idx = self.num_graphs + idx | ||
| if not (0 <= idx < self.num_graphs): | ||
| raise IndexError( | ||
| f"graph index {idx} is out of range for batch with " | ||
| f"{self.num_graphs} graph(s)" | ||
| ) |
There was a problem hiding this comment.
Error message reports adjusted index, not the original
When a negative index is adjusted before the bounds check, the error message reports the adjusted (intermediate) index rather than the value the caller actually passed. For example, get_data(-5) on a 2-graph batch adjusts to idx = -3 and then raises:
IndexError: graph index -3 is out of range for batch with 2 graph(s)
The caller passed -5, so seeing -3 in the message is surprising and hard to debug. Preserving the original value before adjustment and referencing it in the message would make the error much clearer:
| if idx < 0: | |
| idx = self.num_graphs + idx | |
| if not (0 <= idx < self.num_graphs): | |
| raise IndexError( | |
| f"graph index {idx} is out of range for batch with " | |
| f"{self.num_graphs} graph(s)" | |
| ) | |
| original_idx = idx | |
| if idx < 0: | |
| idx = self.num_graphs + idx | |
| if not (0 <= idx < self.num_graphs): | |
| raise IndexError( | |
| f"graph index {original_idx} is out of range for batch with " | |
| f"{self.num_graphs} graph(s)" | |
| ) |
|
I posted a fix in #50: if that gets merged in before this PR then we'll see if the workflow runs as intended. Otherwise the tests pass like I said, so I would be fine with merging after you address my comment |
Description
Batch.get_data()adjusts negative indices viaidx = num_graphs + idxbut does not validate the result. When the adjusted index is still negative (e.g.get_data(-5)on a 3-graph batch), Python's negative indexing silently wraps around on the internal storage tensors. Node-level data (positions, forces) may come from one graph while system-level data (energies, cell) comes from a different graph, returning inconsistent data.Reproduction:
Type of Change
Changes Made
Batch.get_data()to raiseIndexErrorfor any index outside[0, num_graphs)test_get_data_out_of_range_raisesTesting
make pytest)make lint)Checklist