fix(data): preserve other batch edge_index in Batch.append()#46
fix(data): preserve other batch edge_index in Batch.append()#46Ryan-Reese wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Batch.append() temporarily offsets the other batch's edge_index by the receiver's atom count for correct concatenation, but never restores the original values. This silently corrupts the input batch — any subsequent use of the appended batch (e.g. in inflight batching or multi-stage pipelines) produces wrong edge indices. Save and restore the original edge_index reference around the concatenation in a try/finally block to keep the input batch intact even if concatenation raises. Also reject shared-storage aliasing (batch.append(batch) or shared MultiLevelStorage) which cannot work correctly with in-place concatenation.
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Greptile SummaryThis PR fixes a silent mutation bug in Important Files Changed
Reviews (4): Last reviewed commit: "ci: retrigger checks" | Re-trigger Greptile |
nvalchemi/data/batch.py
Outdated
| finally: | ||
| # Restore other's edge_index to avoid mutating the input batch. | ||
| if saved_ei is not None: | ||
| other._edges_group._data["edge_index"] = saved_ei |
There was a problem hiding this comment.
other._edges_group re-evaluated in finally instead of using the cached local
The finally block calls other._edges_group (a property that does self._storage.groups.get("edges")) instead of reusing the already-resolved other_edges local variable. Under normal operation both refer to the same object, but using the local is slightly cheaper and makes the symmetry with the setup code explicit:
| finally: | |
| # Restore other's edge_index to avoid mutating the input batch. | |
| if saved_ei is not None: | |
| other._edges_group._data["edge_index"] = saved_ei | |
| finally: | |
| # Restore other's edge_index to avoid mutating the input batch. | |
| if saved_ei is not None: | |
| other_edges._data["edge_index"] = saved_ei |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if other is self or other._storage is self._storage: | ||
| raise ValueError( | ||
| "Cannot append a Batch that shares storage with the " | ||
| "receiver (would corrupt both). Use " | ||
| "batch.append(batch.clone()) instead." | ||
| ) |
There was a problem hiding this comment.
Missing test coverage for the new guard and the core fix
Neither the ValueError for self-aliasing nor the edge_index preservation behavior described in the PR have a corresponding test. The existing test_append only checks graph/node counts on b1 and would not catch a regression. Consider adding two assertions:
# test self-aliasing guard
with pytest.raises(ValueError, match="shares storage"):
b1.append(b1)
# test other's edge_index is not mutated
b2_ei_before = b2.edge_index.clone()
b1.append(b2)
assert torch.equal(b2.edge_index, b2_ei_before)Add three tests to TestBatchMutation covering the new append() safety behaviour: - test_append_preserves_other_edge_index: verifies input batch's edge_index is not mutated after append - test_append_self_raises: verifies self-append raises ValueError - test_append_shared_storage_raises: verifies shared-storage append raises ValueError Use cached other_edges variable in finally block per review feedback.
|
/ok to test efa2639 |
|
Coverage is failing - I'll check out the branch locally and confirm it's okay |
|
@laserkelvin thanks :) |
laserkelvin
left a comment
There was a problem hiding this comment.
I've confirmed locally that coverage is fine and tests pass
| group.concatenate(other_group) | ||
| else: | ||
| group.extend_for_appended_graphs(n_other) | ||
| try: |
There was a problem hiding this comment.
My concerns with this bit are twofold:
- There's no exception being handled explicitly
- There's overhead associated with
try/except/finally; it's not huge but it can be non-negligible
If possible, I would rewrite it without the try block, unless there is something you're guarding against
ALCHEMI Toolkit Pull Request
Description
Batch.append()temporarily offsets theotherbatch'sedge_indexby the receiver's atom count for correct concatenation, but never restores the original values. This silently corrupts the input batch — any subsequent use of the appended batch produces wrong edge indices.Type of Change
Related Issues
Relates to #21 (Data layer gaps and limitations)
Changes Made
other's originaledge_indexreference before applying the node offsettry/finallyto restoreother'sedge_indexeven if concatenation raisesbatch.append(batch)or sharedMultiLevelStorage) which cannot work correctly with in-place concatenationReproduction
Testing
make pytest)make lint)test/data/test_batch.pytests pass with no regressionsChecklist
Additional Notes
Known limitation (pre-existing):
append()is not transactional with respect toself— ifgroup.concatenate()raises mid-loop,selfmay be left partially updated. This is a pre-existing architectural issue unrelated to this fix and would require cloning all storage groups before mutation. Thetry/finallyadded here only guaranteesotheris restored.