Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

macOS MPS support for torched-net? #242

Open
dondavidsb opened this issue Dec 4, 2023 · 7 comments
Open

macOS MPS support for torched-net? #242

dondavidsb opened this issue Dec 4, 2023 · 7 comments

Comments

@dondavidsb
Copy link

Hi,

I'm new using all the "torch" things and I wonder if it is possible to somehow use the Metal Performance Shaders (MPS) for Accelerated PyTorch training on Mac with torchmd-net.

Thank you in advance.
David

@RaulPPelaez
Copy link
Collaborator

Hi,
We do not have an osx build available in conda-forge yet. I have tried to build it but without success (I do not have an OSX machine to debug, nor the experience with the OS for that matter) conda-forge/torchmd-net-feedstock#1 . It actually compiles, but a test fails and I do not know how to debug it.

I believe @sef43 was trying to build torchmd-net for OSX AFAIK, maybe he can share his insights :P

@dondavidsb
Copy link
Author

Ookay, thank you for your quick reply.

@sef43
Copy link
Collaborator

sef43 commented Dec 5, 2023

Yes there are two parts to this.

  1. the package builds on macOS, I will share the instructions soon.
  2. MPS acceleration will not be currently supported, this is because not all pytorch operations are supported by MPS yet. I can do a quick investigation and find out what will need to be changed/ worked around.

@sef43
Copy link
Collaborator

sef43 commented Dec 5, 2023

You can install torchmd-net on MacOS using the Linux build from source instructions.
pip install -e . will work provided you have the dependencies from conda installed.

You will need to set num_workers: 0 in the config yaml files otherwise the error in the feedstock will occur: conda-forge/torchmd-net-feedstock#1

You will need to make it run on CPU, MPS does not work yet:

  • First problem is:
    [W MPSFallback.mm:13] Warning: The operator 'torchmdnet_extensions::get_neighbor_pairs' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
    But it can fall back to CPU so can still run.

  • Second problem is:
    /AppleInternal/Library/BuildRoots/f0468ab4-4115-11ed-8edc-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayGatherND.mm:234: failed assertion Rank of updates array (1) must be greater than or equal to inner-most dimension of indices array (2306) zsh: abort PYTORCH_ENABLE_MPS_FALLBACK=1 torchmd-train --conf ET-QM9.yaml
    which is probably this open issue: MPS internal error in torch.gather when last dimension is a singleton dimension pytorch/pytorch#94765

@dondavidsb
Copy link
Author

Okay, I will keep an eye on how this evolves. Thank you very much!

@sef43
Copy link
Collaborator

sef43 commented Dec 5, 2023

There are other unsupported pytorch operations in the get_neighbor_pairs code, e.g:
Warning: The operator 'aten::tril_indices' is not currently supported on the MPS backend

This is not something we can easily do from our end. Will have to wait on PyTorch to have implementations for the operations we need - pytorch/pytorch#77764

@RaulPPelaez
Copy link
Collaborator

RaulPPelaez commented Dec 5, 2023

Just for completeness, when all required operations are available in that backend the neighbor extension can be made MPS-compatible by adding a new TORCH_LIBRARY_IMPL registration simply copy-pasting the CPU implementation:

TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) {
m.impl("get_neighbor_pairs", [](const std::string &strategy, const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
return forward(positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, max_num_pairs, loop, include_transpose);
});
}

TORCH_LIBRARY_IMPL(torchmdnet_extensions, MPS, m) {
  m.impl("get_neighbor_pairs", [](const std::string &strategy,  const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
				    bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
              const Scalar& max_num_pairs, bool loop, bool include_transpose) {
      return forward(positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, max_num_pairs, loop, include_transpose);
    });
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants