Skip to content

Commit

Permalink
CUDA graph compatible neighbor backward (#194)
Browse files Browse the repository at this point in the history
* Test backwards CUDA graph compatibility

* Preliminary work on manual double backward

* Manually implement MaskedIndexAdd for the backward

* Use a dummy atom to mimic masked_index_add. Graph compatible but might
be inefficient in some cases

* Remove unnecessary test
  • Loading branch information
RaulPPelaez authored Jun 26, 2023
1 parent 20c5de6 commit e26dd40
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
77 changes: 75 additions & 2 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def test_jit_script_compatible(
@pytest.mark.parametrize("include_transpose", [True, False])
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_cuda_graph_compatible(
def test_cuda_graph_compatible_forward(
device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype
):
if device == "cuda" and not torch.cuda.is_available():
Expand All @@ -563,7 +563,7 @@ def test_cuda_graph_compatible(
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)
pos.requires_grad = True
pos.requires_grad_(True)
if box_type is None:
box = None
else:
Expand Down Expand Up @@ -618,3 +618,76 @@ def test_cuda_graph_compatible(
assert np.allclose(neighbors, ref_neighbors)
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)

@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("cutoff", [1.0])
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_cuda_graph_compatible_backward(
device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype
):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
if box_type == "triclinic" and strategy == "cell":
pytest.skip("Triclinic only supported for brute force")
torch.manual_seed(4321)
n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,))
batch = torch.repeat_interleave(
torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch
).to(device)
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
lbox = 10.0
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)
pos.requires_grad_(True)
if box_type is None:
box = None
else:
box = (
torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]])
.to(pos.dtype)
.to(device)
)
ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(
pos, batch, loop, include_transpose, cutoff, box
)
max_num_pairs = ref_neighbors.shape[1]
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
nl = OptimizedDistance(
cutoff_lower=0.0,
loop=loop,
cutoff_upper=cutoff,
max_num_pairs=max_num_pairs,
strategy=strategy,
box=box,
return_vecs=True,
include_transpose=include_transpose,
check_errors=False,
resize_to_fit=False,
)
batch.to(device)

graph = torch.cuda.CUDAGraph()
# Warm up
neighbors, distappnces, distance_vecs = nl(pos, batch)
for _ in range(10):
neighbors, distances, distance_vecs = nl(pos, batch)
distances.sum().backward()
pos.grad.data.zero_()
torch.cuda.synchronize()

# Capture
with torch.cuda.graph(graph):
neighbors, distances, distance_vecs = nl(pos, batch)
distances.sum().backward()
pos.grad.data.zero_()
graph.replay()
torch.cuda.synchronize()
26 changes: 17 additions & 9 deletions torchmdnet/neighbors/neighbors_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public:
return {neighbors, deltas, distances, i_curr_pair};
}

using Slice = torch::indexing::Slice;

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto edge_index = saved[0];
Expand All @@ -46,17 +48,23 @@ public:
auto num_atoms = ctx->saved_data["num_atoms"].toInt();
auto grad_edge_vec = grad_outputs[1];
auto grad_edge_weight = grad_outputs[2];
auto r0 = edge_weight.nonzero().squeeze(-1);
auto grad_positions = torch::zeros({num_atoms, 3}, edge_vec.options());
auto zero_mask = edge_weight == 0;
auto zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_edge_vec);
// We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the
// case of a double backwards. This is why we index_select like this.
auto grad_distances_ =
(edge_vec.index_select(0, r0) / edge_weight.index_select(0, r0).unsqueeze(-1)) *
grad_edge_weight.index_select(0, r0).unsqueeze(-1);
auto edge_index_no_r0 = edge_index.index_select(1, r0);
auto result = grad_edge_vec.index_select(0, r0) + grad_distances_;
grad_positions.index_add_(0, edge_index_no_r0[0], result);
grad_positions.index_add_(0, edge_index_no_r0[1], -result);
auto grad_distances_ = edge_vec / edge_weight.masked_fill(zero_mask, 1).unsqueeze(-1) *
grad_edge_weight.masked_fill(zero_mask, 0).unsqueeze(-1);
auto result = grad_edge_vec.masked_fill(zero_mask3, 0) + grad_distances_;
// Given that there is no masked_index_add function, in order to make the operation
// CUDA-graph compatible I need to transform masked indices into a dummy value (num_atoms)
// and then exclude that value from the output.
// TODO: replace this once masked_index_add or masked_scatter_add are available
auto grad_positions_ = torch::zeros({num_atoms + 1, 3}, edge_vec.options());
auto edge_index_ =
edge_index.masked_fill(zero_mask.unsqueeze(0).expand_as(edge_index), num_atoms);
grad_positions_.index_add_(0, edge_index_[0], result);
grad_positions_.index_add_(0, edge_index_[1], -result);
auto grad_positions = grad_positions_.index({Slice(0, num_atoms), Slice()});
Tensor ignore;
return {ignore, grad_positions, ignore, ignore, ignore, ignore,
ignore, ignore, ignore, ignore, ignore};
Expand Down

0 comments on commit e26dd40

Please sign in to comment.