diff --git a/src/mpi.f90 b/src/mpi.f90 index c05c4ce..15a4b57 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -8,6 +8,8 @@ module mpi integer, parameter :: MPI_DOUBLE_PRECISION = -10004 integer, parameter :: MPI_REAL4 = -10013 integer, parameter :: MPI_REAL8 = -10014 + integer, parameter :: MPI_CHARACTER = -10003 + integer, parameter :: MPI_LOGICAL = -10005 integer, parameter :: MPI_COMM_TYPE_SHARED = 1 integer, parameter :: MPI_PROC_NULL = -1 @@ -75,6 +77,7 @@ module mpi interface MPI_Gatherv module procedure MPI_Gatherv_int module procedure MPI_Gatherv_real + module procedure MPI_Gatherv_character end interface MPI_Gatherv interface MPI_Wtime @@ -170,7 +173,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info end function handle_mpi_info_f2c integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) result(c_datatype) - use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int + use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int, c_mpi_logical, c_mpi_character integer, intent(in) :: datatype_f if (datatype_f == MPI_REAL4) then c_datatype = c_mpi_float @@ -178,6 +181,10 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) resul c_datatype = c_mpi_double else if (datatype_f == MPI_INTEGER) then c_datatype = c_mpi_int + else if (datatype_f == MPI_CHARACTER) then + c_datatype = c_mpi_character + else if (datatype_f == MPI_LOGICAL) then + c_datatype = c_mpi_logical end if end function @@ -852,6 +859,42 @@ subroutine MPI_Gatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, & end if end subroutine MPI_Gatherv_real + subroutine MPI_Gatherv_character(sendbuf, sendcount, sendtype, recvbuf, recvcounts, & + displs, recvtype, root, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_gatherv + character(len=*), intent(in), target :: sendbuf(*) + integer, intent(in) :: sendcount + integer, intent(in) :: sendtype + character(len=*), intent(out), target :: recvbuf(*) + integer, dimension(:), intent(in) :: recvcounts + integer, dimension(:), intent(in) :: displs + integer, intent(in) :: recvtype + integer, intent(in) :: root + integer, intent(in) :: comm + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm + type(c_ptr) :: c_sendbuf, c_recvbuf + integer(c_int) :: local_ierr + + c_sendbuf = c_loc(sendbuf) + c_recvbuf = c_loc(recvbuf) + c_sendtype = handle_mpi_datatype_f2c(sendtype) + c_recvtype = handle_mpi_datatype_f2c(recvtype) + c_comm = handle_mpi_comm_f2c(comm) + + ! Call C MPI_Gatherv + local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, & + c_recvbuf, recvcounts, displs, c_recvtype, & + root, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Gatherv failed with error code: ", local_ierr + end if + end subroutine MPI_Gatherv_character + subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror) use iso_c_binding, only: c_int, c_ptr use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 88340be..e7c3751 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -18,6 +18,8 @@ module mpi_c_bindings integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character interface diff --git a/src/mpi_constants.c b/src/mpi_constants.c index 18532a5..61eef1f 100644 --- a/src/mpi_constants.c +++ b/src/mpi_constants.c @@ -17,3 +17,7 @@ void* c_MPI_IN_PLACE = MPI_IN_PLACE; MPI_Op c_MPI_SUM = MPI_SUM; MPI_Op c_MPI_MAX = MPI_MAX; + +MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL; + +MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER; \ No newline at end of file diff --git a/tests/gatherv_3.f90 b/tests/gatherv_3.f90 new file mode 100644 index 0000000..3f2e6a5 --- /dev/null +++ b/tests/gatherv_3.f90 @@ -0,0 +1,60 @@ +program gatherv_3 + use mpi + implicit none + + integer :: rank, nprocs, ierr + integer :: local_data(3) + integer, allocatable :: gathered_data(:) + integer, allocatable :: counts(:), displacements(:) + + call MPI_INIT(ierr) + call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr) + call MPI_COMM_SIZE(MPI_COMM_WORLD, nprocs, ierr) + + local_data = (/ rank*3+1, rank*3+2, rank*3+3 /) + + allocate(counts(nprocs)) + allocate(displacements(nprocs)) + + call gatherIntegers(local_data, gathered_data, counts, displacements, rank, nprocs, MPI_COMM_WORLD) + + deallocate(counts, displacements) + if (rank == 0) deallocate(gathered_data) + + call MPI_FINALIZE(ierr) + +contains + + subroutine gatherIntegers(local_data, gathered_data, counts, displacements, rank, nprocs, comm) + integer, intent(in) :: local_data(:) + integer, allocatable, intent(out) :: gathered_data(:) + integer, intent(out) :: counts(:) + integer, intent(out) :: displacements(:) + integer, intent(in) :: rank, nprocs, comm + + integer :: i, total_elements, ierr + integer :: local_size + + local_size = size(local_data) + counts = local_size + + displacements(1) = 0 + do i = 2, nprocs + displacements(i) = displacements(i-1) + counts(i-1) + end do + + total_elements = local_size * nprocs + + if (rank == 0) then + allocate(gathered_data(total_elements)) + else + allocate(gathered_data(1)) + end if + + call MPI_GatherV(local_data, local_size, MPI_INTEGER, & + gathered_data, counts, displacements, MPI_INTEGER, & + 0, comm, ierr) + + end subroutine gatherIntegers + +end program gatherv_3