Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joss paper of jaxDecomp #20

Merged
merged 23 commits into from
Jul 18, 2024
Merged

Joss paper of jaxDecomp #20

merged 23 commits into from
Jul 18, 2024

Conversation

ASKabalan
Copy link
Collaborator

Adding a draft of JOSS paper

@ASKabalan ASKabalan force-pushed the joss-paper branch 2 times, most recently from aa5403d to d053753 Compare July 4, 2024 17:41
@EiffL
Copy link
Member

EiffL commented Jul 7, 2024

Lol, @ASKabalan ^^ so many force pushes, can we forbid the force pushes from now on?

@EiffL
Copy link
Member

EiffL commented Jul 9, 2024

Thanks @ASKabalan for the draft, it's a very good start, I have some high level comments that I will add here, and maybe make some particular comments on the markdown file.

My main overaching comment is that this is not a jaxpm paper, it's a jaxDecomp paper. PM simulations are just one potential example of real world application, but not the only raison d'etre of the library.

  • Motivation: Currently you open the abstract with cosmological simulations, but that is not the right level for this paper. This is a software paper for a distribution library. So I think the story should be different. You can for instance take a look at how mpi4jax structured their abstract: https://joss.theoj.org/papers/10.21105/joss.03419

I think our story here in the abstract could be the following:

  1. JAX has been a powerful tool for scientific computations, not just machine learning (cite e.g. jax-cosmo ;-) or jax-md)
  2. Until very recently general distributed computing (multinode) was not easy in JAX, which hinders the applicability of the framework for HPC tasks.
  3. Some solutions have been proposed in the past for SPMD, in particular mpi4jax. However mpi4jax has limitations, it is not compatible with JAX array distribution logic, limited to "small" messages of less than 2GB.
  4. Over the last year, a huge amount of progress has happened in JAX regarding its native support for SPMD through the unified jax.Array API and the merge of jit and pjit.
  5. However, not all native JAX operations have a specialized distribution strategy, and so pjitting a program can currently lead to more communication than necessary for some operations. And in particular the key operation we are concerned with is 3D FFT.
  6. To alleviate these limitations, we introduce jaxDecomp, a jax wrapper for the cuDecomp domain decompositiopn library, which provides jax primitives with highly efficient CUDA implementations for key operations needed for HPC simulation tasks, namely 3D FFTs and halo exchanges.
  7. Being implemented as jax primitives, jaxDecomp directly builds on top of the distributed Array strategy in JAX, and is compatible with jax transformation such as jax.grad and jax.jit
  8. Through cuDecomp, jaxDecomp provides lowlevel NCCL, CUDA-Aware MPI, and NVSHMEM backends for distributed array transpose operations.
  • Statement of Need Here the main point is that we should motivate why native jax distribution might not be enough. We can say that for numerical simulations on HPC systems, we would like to allow for peak performance, and performance is bottlenecked by inter-gpu communications. While it's technically possible to write for instance a distributed FFT in native JAX , here our aim is to go for unbeatable performance using a highly optimized and dedicated CUDA library as the backend.
    • Pour faire les choses bien, here we might want to compare in the benchmark the perfomance of a simple distributed FFT op in JAX. By that I mean something like this:
# Create an array
x = jax.random.normal(jax.random.key(0), (32, 32, 32))

# Distributes the array
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,2,1)))
x = jax.device_put(x, sharding)

# Perform 1D FFT along the last dimension and transpose the array
x = jnp.fft.fft(x).transpose(2,0,1)  # [z', x, y]
x = jnp.fft.fft(x).transpose(2,0,1)  # [y', z', x]
x = jnp.fft.fft(x)                              # [y', z', x']

If we have such a comparison in the benchmark, we can refer to it here as a statement of need.

Then at the end of the statement, we can mention a real world application, and that's where we can talk about PM simulations for cosmology. We can in particular mention FlowPM (distributed but in TF, so useless) and pmwd (not distributed and so limited to 512 volumes).

  • Implementation:
    1. I think here we want to start by explaining how we build the wrapper around the cuDecomp operations. So, mention that we use the custom_op tool, maybe mention something about the strategy you have built to preserve the state of cudecomp between kernel calls.
    2. We want to explain the concept of domain decomposition, explain the pencils and slabs decompositions supported by the library, and explain how one would build a distributed domain in JAX.
    3. Once we have explained the above, we can go more into a description of the key operations, 3D FFTs and halo exchange.

You can add a couple of lines of code to illustrate the API for points 2 and 3 above.

  • Example of Application: Here you can talk about LPT simulations, and link to the example script. You can give a little bit of context for what these simulations are, then you can for instance illustrate how one would compute the gravitational potential from a density field with the jaxdecomp library with some actual code. Something like:
def potential(delta):
  delta_k = pfft3d(delta)  
  kvec = ... 
  laplace_kernel = 1/kk
  potential_k = delta_k * laplace_kernel
  return ipfft3d(potential_k)
  • Benchmark: Here I think we could change the legend of the plots. For "JAX" what I think you mean is "single GPU jnp.fft.fftn" ? For the other lines, it's not clear how many GPUs you have used. And as mentioned earlier, we should include a comparison to a native jax distributed FFT built from transpose and fft 1d operations... And cross fingers that this doesn't beat us ^^
    We should also include a point where we reach several nodes, to see how the performance scales between nodes. It will be a function of the interconnect.


## Distributed Halo Exchange

In a particle mesh simulation, we use the 3DFFT to estimate the force field acting on the particles. The force field is then interpolated to the particles, and the particles are moved accordingly. The particles that are close to the boundary of the local domain need to be updated using the data from the neighboring domains. This is done using a halo exchange operation. Where we pad each slice of the simulation then we perform a halo exchange operation to update the particles that are close to the boundary of the local domain.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't motivate the halo exchange from the pm simulation, halo exchanges are very common operations for distributed computing: https://wgropp.cs.illinois.edu/courses/cs598-s15/lectures/lecture25.pdf
(first result on google)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here I think we just want to explain that cuDecomp allows for the exchange of border regions, which is a pattern necessary to handle border crossing in many types of simulations.

@EiffL
Copy link
Member

EiffL commented Jul 9, 2024

Taking a careful look a this figure, I think their might be a problem on the right:
image
not the right colors....

But maybe also we don't want to include the reduction, because it's not clear what that means... On this figure, it looks like you are replacing the border region with the one from the neighbording slice. In general, what one does with the halo region depends on the simulation one runs. So maybe no need to include it here in this figure

@ASKabalan
Copy link
Collaborator Author

What's wrong with the colors?

@EiffL
Copy link
Member

EiffL commented Jul 9, 2024

lol, tried to better higlight one of the issues:
image

@ASKabalan
Copy link
Collaborator Author

Oh ..
I am going to remove the reduction step anyway.. it is confusing

@EiffL
Copy link
Member

EiffL commented Jul 9, 2024

note, I found this previous implementation I had made using xmap of 3d distributed fft
https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/dev/test_pfft.py

@partial(xmap,
         in_axes={  0: 'x', 1: 'y' },
         out_axes=['x', 'y', ...],
         axis_resources={  'x': 'nx',  'y': 'ny' })
@jax.jit
def pfft3d(mesh):
    # [x, y, z]
    mesh = jnp.fft.fft(mesh)  # Transform on z
    mesh = lax.all_to_all(mesh, 'x', 0, 0)  # Now x is exposed, [z,y,x]
    mesh = jnp.fft.fft(mesh)  # Transform on x
    mesh = lax.all_to_all(mesh, 'y', 0, 0)  # Now y is exposed, [z,x,y]
    mesh = jnp.fft.fft(mesh)  # Transform on y
    # [z, x, y]
    return mesh

@partial(xmap,
         in_axes={  0: 'x',  1: 'y' },
         out_axes=['x', 'y', ...],
         axis_resources={  'x': 'nx',  'y': 'ny' })
@jax.jit
def pifft3d(mesh):
    # [z, x, y]
    mesh = jnp.fft.ifft(mesh)  # Transform on y
    mesh = lax.all_to_all(mesh, 'y', 0, 0)  # Now x is exposed, [z,y,x]
    mesh = jnp.fft.ifft(mesh)  # Transform on x
    mesh = lax.all_to_all(mesh, 'x', 0, 0)  # Now z is exposed, [x,y,z]
    mesh = jnp.fft.ifft(mesh)  # Transform on z
    # [x, y, z]
    return mesh

something like this, but using shard_map is probably what we want to benchmark jaxDecomp against

@EiffL
Copy link
Member

EiffL commented Jul 14, 2024

Could you push the benchmark scripts @ASKabalan when you are back? Curious to see if we can gain a bit performance

@ASKabalan
Copy link
Collaborator Author

ASKabalan commented Jul 14, 2024

@EiffL
The benchmarks are now on github
JAX : https://github.com/ASKabalan/jaxdecomp-benchmarks/blob/main/scripts/jaxfft.py
JAXDECOMP : https://github.com/ASKabalan/jaxdecomp-benchmarks/blob/main/scripts/pfft3d.py
MPI4JAX : https://github.com/ASKabalan/jaxdecomp-benchmarks/blob/main/scripts/mpi4jaxfft.py

I am trying to make MPI4JAX work
Do you want me to put them in the main repo?

@ASKabalan ASKabalan merged commit 067ee89 into main Jul 18, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants