Skip to content

Implement MPI_Gatherv Wrapper for String Array as DataType and add bindings for MPI_LOGICAL and MPI_CHARACTER #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ module mpi
integer, parameter :: MPI_DOUBLE_PRECISION = -10004
integer, parameter :: MPI_REAL4 = -10013
integer, parameter :: MPI_REAL8 = -10014
integer, parameter :: MPI_CHARACTER = -10003
integer, parameter :: MPI_LOGICAL = -10005

integer, parameter :: MPI_COMM_TYPE_SHARED = 1
integer, parameter :: MPI_PROC_NULL = -1
Expand Down Expand Up @@ -75,6 +77,7 @@ module mpi
interface MPI_Gatherv
module procedure MPI_Gatherv_int
module procedure MPI_Gatherv_real
module procedure MPI_Gatherv_character
end interface MPI_Gatherv

interface MPI_Wtime
Expand Down Expand Up @@ -170,14 +173,18 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info
end function handle_mpi_info_f2c

integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) result(c_datatype)
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int, c_mpi_logical, c_mpi_character
integer, intent(in) :: datatype_f
if (datatype_f == MPI_REAL4) then
c_datatype = c_mpi_float
else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then
c_datatype = c_mpi_double
else if (datatype_f == MPI_INTEGER) then
c_datatype = c_mpi_int
else if (datatype_f == MPI_CHARACTER) then
c_datatype = c_mpi_character
else if (datatype_f == MPI_LOGICAL) then
c_datatype = c_mpi_logical
end if
end function

Expand Down Expand Up @@ -852,6 +859,42 @@ subroutine MPI_Gatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
end if
end subroutine MPI_Gatherv_real

subroutine MPI_Gatherv_character(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, root, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_gatherv
character(len=*), intent(in), target :: sendbuf(*)
integer, intent(in) :: sendcount
integer, intent(in) :: sendtype
character(len=*), intent(out), target :: recvbuf(*)
integer, dimension(:), intent(in) :: recvcounts
integer, dimension(:), intent(in) :: displs
integer, intent(in) :: recvtype
integer, intent(in) :: root
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
type(c_ptr) :: c_sendbuf, c_recvbuf
integer(c_int) :: local_ierr

c_sendbuf = c_loc(sendbuf)
c_recvbuf = c_loc(recvbuf)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
c_comm = handle_mpi_comm_f2c(comm)

! Call C MPI_Gatherv
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
c_recvbuf, recvcounts, displs, c_recvtype, &
root, c_comm)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Gatherv failed with error code: ", local_ierr
end if
end subroutine MPI_Gatherv_character

subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore
Expand Down
2 changes: 2 additions & 0 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ module mpi_c_bindings
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character

interface

Expand Down
4 changes: 4 additions & 0 deletions src/mpi_constants.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ void* c_MPI_IN_PLACE = MPI_IN_PLACE;
MPI_Op c_MPI_SUM = MPI_SUM;

MPI_Op c_MPI_MAX = MPI_MAX;

MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;

MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;
89 changes: 89 additions & 0 deletions tests/gatherv_2.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
program gatherv_pfunit_1
use mpi
implicit none

! Context type to mimic pFUnit
type :: mpi_context
integer :: root
integer :: mpiCommunicator
end type mpi_context

type(mpi_context) :: this
integer :: ierr, rank, size, i, j, total
character(len=10), allocatable :: sendBuffer(:), recvBuffer(:)
integer, allocatable :: counts(:), displacements(:)
logical :: error
integer :: numEntries

! Initialize MPI
call MPI_Init(ierr)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
call MPI_Comm_size(MPI_COMM_WORLD, size, ierr)

! Set up context
this%root = 0
this%mpiCommunicator = MPI_COMM_WORLD

! Each process sends 'rank + 1' strings of length 10
numEntries = rank + 1
allocate(sendBuffer(numEntries))
do i = 1, numEntries
write(sendBuffer(i), '(A,I0)') 'proc', rank ! Dummy words: "proc0", "proc1", etc.
end do

! Allocate receive buffers on root
if (rank == this%root) then
allocate(counts(size))
allocate(displacements(size))
total = 0
do i = 0, size - 1
counts(i+1) = (i + 1) * 10 ! Number of characters from each process
displacements(i+1) = total
total = total + counts(i+1)
end do
allocate(recvBuffer(total / 10)) ! Total number of strings
recvBuffer = ''
else
allocate(counts(1), displacements(1), recvBuffer(1)) ! Dummy for non-root
end if

! Perform MPI_Gatherv as in pFUnit
call MPI_Gatherv(sendBuffer, numEntries * 10, MPI_CHARACTER, &
recvBuffer, counts, displacements, MPI_CHARACTER, &
this%root, this%mpiCommunicator, ierr)

! Verify results on root
error = .false.
if (rank == this%root) then
total = 0
do i = 0, size - 1
do j = 1, i + 1
if (trim(recvBuffer(total + j)) /= 'proc'//trim(adjustl(int2str(i)))) then
print *, "Error at rank ", i, " index ", j, &
": expected 'proc", i, "', got '", &
trim(recvBuffer(total + j)), "'"
error = .true.
end if
end do
total = total + i + 1
end do
if (.not. error) then
print *, "MPI_Gatherv pFUnit test passed on root"
end if
end if

! Clean up
deallocate(sendBuffer, recvBuffer, counts, displacements)
call MPI_Finalize(ierr)

if (error) stop 1

contains
! Helper function to convert integer to string
function int2str(i) result(str)
integer, intent(in) :: i
character(len=10) :: str
write(str, '(I0)') i
end function int2str

end program gatherv_pfunit_1
Loading