Skip to content

Commit

Permalink
Merge pull request #19 from DifferentiableUniverseInitiative/u/ASKaba…
Browse files Browse the repository at this point in the history
…lan/clean-up

Clean up for JOSS paper
  • Loading branch information
EiffL authored Jul 8, 2024
2 parents b3978e4 + ac8f2ef commit d551967
Show file tree
Hide file tree
Showing 16 changed files with 1,374 additions and 573 deletions.
19 changes: 9 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Change log


<!-- Template for documenting changes
## jaxdecomp 0.0.1

* Changes
* Change 1
* Breaking Changes
* Change 1
* Deprecations
* Some things that are getting deprecated
* Bugs
* Bug 1
-->
* New version compatible with JAX 0.4.30
* jaxDecomp now works in a multi-host environment
* Added custom partitioning for FFTs
* Added custom partitioning for halo exchange
* Added custom partitioning for slice_pad and slice_unpad
* Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py`


## jaxdecomp 0.0.1
## jaxdecomp 0.0.1rc2
* Changes
* Added utility to run autotuning

Expand Down
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ cmake_minimum_required(VERSION 3.19...3.25)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
# Latest JAX v0.4.26 no longer supports cuda 11.8
set(NVHPC_CUDA_VERSION 12.2)
# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8
set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" )

# Build debug
# set(CMAKE_BUILD_TYPE Debug)
add_subdirectory(third_party/cuDecomp)
Expand All @@ -15,8 +17,8 @@ option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF)

set(CUDECOMP_CUDA_CC_LIST "70;80" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")

# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")

# Add pybind11 and cuDecomp subdirectories
add_subdirectory(pybind11)
Expand Down
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ with mesh:
# Add halo regions to our array
padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples
padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims)
# Perform a halo exchange + reduce
exchanged_reduced = jaxdecomp.halo_exchange(padded_array,
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(padded_array,
halo_extents=(32,32,32),
halo_periods=(True,True,True),
reduce_halo=True)
halo_periods=(True,True,True))
# Remove the halo regions
recarray = jaxdecomp.slice_unpad(exchanged_reduced, padding_width, pdims)
recarray = jaxdecomp.slice_unpad(exchanged_array, padding_width, pdims)

# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
Expand Down
11 changes: 7 additions & 4 deletions include/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ class haloDescriptor_t {
~haloDescriptor_t() = default;

bool operator==(const haloDescriptor_t& other) const {
return (double_precision == other.double_precision && halo_extents == other.halo_extents &&
halo_periods == other.halo_periods && axis == other.axis && config.gdims[0] == other.config.gdims[0] &&
config.gdims[1] == other.config.gdims[1] && config.gdims[2] == other.config.gdims[2] &&
config.pdims[0] == other.config.pdims[0] && config.pdims[1] == other.config.pdims[1]);
return (double_precision == other.double_precision && halo_extents[0] == other.halo_extents[0] &&
halo_extents[1] == other.halo_extents[1] && halo_extents[2] == other.halo_extents[2] &&
halo_periods[0] == other.halo_periods[0] && halo_periods[1] == other.halo_periods[1] &&
halo_periods[2] == other.halo_periods[2] && axis == other.axis &&
config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] &&
config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] &&
config.pdims[1] == other.config.pdims[1]);
}
};

Expand Down
Loading

0 comments on commit d551967

Please sign in to comment.