Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of the graph network #48

Open
5 of 9 tasks
raimis opened this issue Nov 23, 2021 · 14 comments
Open
5 of 9 tasks

Optimization of the graph network #48

raimis opened this issue Nov 23, 2021 · 14 comments
Assignees

Comments

@raimis
Copy link
Collaborator

raimis commented Nov 23, 2021

Optimization of the graph network (TorchMD_GN) with NNPOps (https://github.com/openmm/NNPOps).

In a special case, TorchMD_GN is equivalent to SchNet (#45 (comment)), which is already supported by NNPOps:

TorchMD_GN(rbf_type="gauss", trainable_rbf=False, activation="ssp", neighbor_embedding=False)

In general, TorchMD_GN needs these:

TorchMD_GN(rbf_type="expnorm", trainable_rbf=True, activation="silu", neighbor_embedding=True)
  • Implement the exponentially-modified Gaussian in CFConv (rbf_type="expnorm")
  • Allow to pass arbitrary RBF positions to CFConv (trainable_rbf=True)
  • Implement the SILU activation in CFConv (activation="silu")
  • Reuse CFConv to accelerate the neighbor embedding (neighbor_embedding=True)
@raimis raimis self-assigned this Nov 23, 2021
@raimis
Copy link
Collaborator Author

raimis commented Nov 24, 2021

Regarding the interface, it should look and work like this:

 # Create or load a model in any way
model = TorchMD_GN()

# Optional: train or do what ever you want with the model

# Optimize the model
from torchmdnet.optimize import optimize
optimized_model = optimize(model, some_optimization_options)

# Do the inference with the model
results = optimized_model.forward(z, pos, batch)

# Optional: convert the model into TorchScript and save for external use (e.g. OpenMM-Torch)
torch.jit.script(optimized_model).save('model.pt')

It is similar, what is being implemented for the TorchANI optimization (https://github.com/raimis/NNPOps/blob/opt_ani/README.md#example).

@PhilippThoelke @stefdoerr @giadefa any comments?

@raimis
Copy link
Collaborator Author

raimis commented Dec 9, 2021

For a moment, it seems all the PyTorch-Geometric packages are broken (pyg-team/pytorch_geometric#3660).

@raimis
Copy link
Collaborator Author

raimis commented Feb 22, 2022

@peastman I have just finished integrating NNPOps (#50). The performance (https://github.com/torchmd/torchmd-net/blob/main/benchmarks/graph_network.ipynb) is just 2-3 time better for the small molecules (10-100 atoms) and no significant improvement for the larger ones.

I'll try to profile to get a better insight. At some, we should discuss, if we can make any further improvements.

cc: @giadefa

@peastman
Copy link
Collaborator

It would be useful to separate out all the different optimizations in NNPOps. Can you identify the effect of each one separately?

Back when we first started designing it, we discussed requirements and decided it would be optimized for molecules of about 100 atoms. The code is all designed around that assumption. If we want good performance on much larger molecules, it would need to be written differently. For example, it uses a O(n^2) algorithm to build the neighbor list, which is very fast for small molecules and very slow for large ones.

@giadefa
Copy link
Contributor

giadefa commented Feb 22, 2022 via email

@peastman
Copy link
Collaborator

That would definitely need code changes to be efficient. You want it to know it only needs to check each atom against the other atoms in its own copy, not all the other copies. Spreading the copies out through space is also inaccurate. The further an atom is from the origin, the less precisely its position can be specified.

@giadefa
Copy link
Contributor

giadefa commented Feb 22, 2022 via email

@peastman
Copy link
Collaborator

It's possible. Can you open an issue on the NNPOps repository describing exactly how you would want it to work?

@giadefa
Copy link
Contributor

giadefa commented Feb 23, 2022

@raimis can yuo make an issue there as you probbably know the details of what you need in NNPOps.

@raimis
Copy link
Collaborator Author

raimis commented Feb 23, 2022

Just before going into NNPOps, I checked how much CUDA Graphs (https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) can help.

CUDA Graphs don't work with TorchMD_GN due to rusty1s/pytorch_cluster#123. To circumvent this, I have implemented a fake neighbor search (#60), which assumes that all the atoms are neighbors, a.k.a. brute force.

The results (https://github.com/raimis/torchmd-net/blob/poc_cuda_graph/benchmarks/graph_network.ipynb) are promising:
image

  • For alanine dipeptide (ALA2, 22 atoms) and testosterone (TST, 49 atoms), the brute force approach with CUDA Graphs beat everything else.
  • For chignolin (CLN, 166 atoms), the brute force is not longer the best and, for larger systems, it runs out of memory.

Ping: @giadefa @peastman @claudi

@giadefa
Copy link
Contributor

giadefa commented Feb 23, 2022 via email

@raimis
Copy link
Collaborator Author

raimis commented Feb 23, 2022

The current implementation doesn't support batching, but it could be implemented.

@peastman
Copy link
Collaborator

That's interesting. It tells us that for the smaller molecules, the computation time is just dominated by kernel launch overhead.

@raimis
Copy link
Collaborator Author

raimis commented Apr 25, 2022

Optimization: round 2

I have wrote optimized kernels for the neighbor search (#61) and message passing (#69). The kernels are drop-in replacement for the generic kernels from PyTorch Geometric and have such optimizations:

Speed:

  • kernels use just the new kernels

  • kernels+graphs use the new kernels and CUDA Graphs

  • Other benchmarks as in the previous plot (Optimization of the graph network #48 (comment))
    image

  • There is a significant speed up for the small molecules, as it even more removes overhead.

  • For large molecule, the speed is comparable to the @peastman kernels, as time is dominated by computation by itself.

  • The new kernels fail with STMV, but not due to the lack of memory. Still I need to debug the cause.

Full details in the notebook: https://github.com/raimis/torchmd-net/blob/poc_cuda_graph_2/benchmarks/graph_network.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants