Skip to content

Commit e26dd40

Browse files
authored
CUDA graph compatible neighbor backward (#194)
* Test backwards CUDA graph compatibility * Preliminary work on manual double backward * Manually implement MaskedIndexAdd for the backward * Use a dummy atom to mimic masked_index_add. Graph compatible but might be inefficient in some cases * Remove unnecessary test
1 parent 20c5de6 commit e26dd40

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed

tests/test_neighbors.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def test_jit_script_compatible(
545545
@pytest.mark.parametrize("include_transpose", [True, False])
546546
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
547547
@pytest.mark.parametrize("dtype", [torch.float32])
548-
def test_cuda_graph_compatible(
548+
def test_cuda_graph_compatible_forward(
549549
device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype
550550
):
551551
if device == "cuda" and not torch.cuda.is_available():
@@ -563,7 +563,7 @@ def test_cuda_graph_compatible(
563563
# Ensure there is at least one pair
564564
pos[0, :] = torch.zeros(3)
565565
pos[1, :] = torch.zeros(3)
566-
pos.requires_grad = True
566+
pos.requires_grad_(True)
567567
if box_type is None:
568568
box = None
569569
else:
@@ -618,3 +618,76 @@ def test_cuda_graph_compatible(
618618
assert np.allclose(neighbors, ref_neighbors)
619619
assert np.allclose(distances, ref_distances)
620620
assert np.allclose(distance_vecs, ref_distance_vecs)
621+
622+
@pytest.mark.parametrize("device", ["cuda"])
623+
@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
624+
@pytest.mark.parametrize("n_batches", [1, 128])
625+
@pytest.mark.parametrize("cutoff", [1.0])
626+
@pytest.mark.parametrize("loop", [True, False])
627+
@pytest.mark.parametrize("include_transpose", [True, False])
628+
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
629+
@pytest.mark.parametrize("dtype", [torch.float32])
630+
def test_cuda_graph_compatible_backward(
631+
device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype
632+
):
633+
if device == "cuda" and not torch.cuda.is_available():
634+
pytest.skip("CUDA not available")
635+
if box_type == "triclinic" and strategy == "cell":
636+
pytest.skip("Triclinic only supported for brute force")
637+
torch.manual_seed(4321)
638+
n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,))
639+
batch = torch.repeat_interleave(
640+
torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch
641+
).to(device)
642+
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
643+
lbox = 10.0
644+
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox
645+
# Ensure there is at least one pair
646+
pos[0, :] = torch.zeros(3)
647+
pos[1, :] = torch.zeros(3)
648+
pos.requires_grad_(True)
649+
if box_type is None:
650+
box = None
651+
else:
652+
box = (
653+
torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]])
654+
.to(pos.dtype)
655+
.to(device)
656+
)
657+
ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(
658+
pos, batch, loop, include_transpose, cutoff, box
659+
)
660+
max_num_pairs = ref_neighbors.shape[1]
661+
s = torch.cuda.Stream()
662+
s.wait_stream(torch.cuda.current_stream())
663+
with torch.cuda.stream(s):
664+
nl = OptimizedDistance(
665+
cutoff_lower=0.0,
666+
loop=loop,
667+
cutoff_upper=cutoff,
668+
max_num_pairs=max_num_pairs,
669+
strategy=strategy,
670+
box=box,
671+
return_vecs=True,
672+
include_transpose=include_transpose,
673+
check_errors=False,
674+
resize_to_fit=False,
675+
)
676+
batch.to(device)
677+
678+
graph = torch.cuda.CUDAGraph()
679+
# Warm up
680+
neighbors, distappnces, distance_vecs = nl(pos, batch)
681+
for _ in range(10):
682+
neighbors, distances, distance_vecs = nl(pos, batch)
683+
distances.sum().backward()
684+
pos.grad.data.zero_()
685+
torch.cuda.synchronize()
686+
687+
# Capture
688+
with torch.cuda.graph(graph):
689+
neighbors, distances, distance_vecs = nl(pos, batch)
690+
distances.sum().backward()
691+
pos.grad.data.zero_()
692+
graph.replay()
693+
torch.cuda.synchronize()

torchmdnet/neighbors/neighbors_cuda.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ public:
3838
return {neighbors, deltas, distances, i_curr_pair};
3939
}
4040

41+
using Slice = torch::indexing::Slice;
42+
4143
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
4244
auto saved = ctx->get_saved_variables();
4345
auto edge_index = saved[0];
@@ -46,17 +48,23 @@ public:
4648
auto num_atoms = ctx->saved_data["num_atoms"].toInt();
4749
auto grad_edge_vec = grad_outputs[1];
4850
auto grad_edge_weight = grad_outputs[2];
49-
auto r0 = edge_weight.nonzero().squeeze(-1);
50-
auto grad_positions = torch::zeros({num_atoms, 3}, edge_vec.options());
51+
auto zero_mask = edge_weight == 0;
52+
auto zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_edge_vec);
5153
// We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the
5254
// case of a double backwards. This is why we index_select like this.
53-
auto grad_distances_ =
54-
(edge_vec.index_select(0, r0) / edge_weight.index_select(0, r0).unsqueeze(-1)) *
55-
grad_edge_weight.index_select(0, r0).unsqueeze(-1);
56-
auto edge_index_no_r0 = edge_index.index_select(1, r0);
57-
auto result = grad_edge_vec.index_select(0, r0) + grad_distances_;
58-
grad_positions.index_add_(0, edge_index_no_r0[0], result);
59-
grad_positions.index_add_(0, edge_index_no_r0[1], -result);
55+
auto grad_distances_ = edge_vec / edge_weight.masked_fill(zero_mask, 1).unsqueeze(-1) *
56+
grad_edge_weight.masked_fill(zero_mask, 0).unsqueeze(-1);
57+
auto result = grad_edge_vec.masked_fill(zero_mask3, 0) + grad_distances_;
58+
// Given that there is no masked_index_add function, in order to make the operation
59+
// CUDA-graph compatible I need to transform masked indices into a dummy value (num_atoms)
60+
// and then exclude that value from the output.
61+
// TODO: replace this once masked_index_add or masked_scatter_add are available
62+
auto grad_positions_ = torch::zeros({num_atoms + 1, 3}, edge_vec.options());
63+
auto edge_index_ =
64+
edge_index.masked_fill(zero_mask.unsqueeze(0).expand_as(edge_index), num_atoms);
65+
grad_positions_.index_add_(0, edge_index_[0], result);
66+
grad_positions_.index_add_(0, edge_index_[1], -result);
67+
auto grad_positions = grad_positions_.index({Slice(0, num_atoms), Slice()});
6068
Tensor ignore;
6169
return {ignore, grad_positions, ignore, ignore, ignore, ignore,
6270
ignore, ignore, ignore, ignore, ignore};

0 commit comments

Comments
 (0)