Skip to content

Commit

Permalink
feat: added mask to DistributedArray
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Oct 23, 2024
1 parent 6d3b1e8 commit cc699df
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 18 deletions.
82 changes: 71 additions & 11 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum

from pylops.utils import DTypeLike, NDArray
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.backend import get_module, get_array_module, get_module_name


Expand Down Expand Up @@ -78,7 +79,10 @@ class DistributedArray:
axis : :obj:`int`, optional
Axis along which distribution occurs. Defaults to ``0``.
local_shapes : :obj:`list`, optional
List of tuples representing local shapes at each rank.
List of tuples or integers representing local shapes at each rank.
mask : :obj:`list`, optional
Mask defining subsets of ranks to consider when performing 'global'
operations on the distributed array such as dot product or norm.
engine : :obj:`str`, optional
Engine used to store array (``numpy`` or ``cupy``)
dtype : :obj:`str`, optional
Expand All @@ -88,7 +92,8 @@ class DistributedArray:
def __init__(self, global_shape: Union[Tuple, Integral],
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
partition: Partition = Partition.SCATTER, axis: int = 0,
local_shapes: Optional[List[Tuple]] = None,
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
mask: Optional[List[Integral]] = None,
engine: Optional[str] = "numpy",
dtype: Optional[DTypeLike] = np.float64):
if isinstance(global_shape, Integral):
Expand All @@ -100,10 +105,14 @@ def __init__(self, global_shape: Union[Tuple, Integral],
raise ValueError(f"Should be either {Partition.BROADCAST} "
f"or {Partition.SCATTER}")
self.dtype = dtype
self._global_shape = global_shape
self._global_shape = _value_or_sized_to_tuple(global_shape)
self._base_comm = base_comm
self._partition = partition
self._axis = axis
self._mask = mask
self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank)

local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
self._check_local_shapes(local_shapes)
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
partition, axis)
Expand Down Expand Up @@ -165,6 +174,16 @@ def local_shape(self):
"""
return self._local_shape

@property
def mask(self):
"""Mask of the Distributed array
Returns
-------
engine : :obj:`list`
"""
return self._mask

@property
def engine(self):
"""Engine of the Distributed array
Expand Down Expand Up @@ -246,6 +265,16 @@ def local_shapes(self):
"""
return self.base_comm.allgather(self.local_shape)

@property
def sub_comm(self):
"""MPI Sub-Communicator
Returns
-------
sub_comm : :obj:`MPI.Comm`
"""
return self._sub_comm

def asarray(self):
"""Global view of the array
Expand All @@ -269,7 +298,8 @@ def to_dist(cls, x: NDArray,
base_comm: MPI.Comm = MPI.COMM_WORLD,
partition: Partition = Partition.SCATTER,
axis: int = 0,
local_shapes: Optional[List[Tuple]] = None):
local_shapes: Optional[List[Tuple]] = None,
mask: Optional[List[Integral]] = None):
"""Convert A Global Array to a Distributed Array
Parameters
Expand All @@ -284,6 +314,9 @@ def to_dist(cls, x: NDArray,
Axis of Distribution
local_shapes : :obj:`list`, optional
Local Shapes at each rank.
mask : :obj:`list`, optional
Mask defining subsets of ranks to consider when performing 'global'
operations on the distributed array such as dot product or norm.
Returns
----------
Expand All @@ -295,6 +328,7 @@ def to_dist(cls, x: NDArray,
partition=partition,
axis=axis,
local_shapes=local_shapes,
mask=mask,
engine=get_module_name(get_array_module(x)),
dtype=x.dtype)
if partition == Partition.BROADCAST:
Expand Down Expand Up @@ -336,6 +370,12 @@ def _check_partition_shape(self, dist_array):
raise ValueError(f"Local Array Shape Mismatch - "
f"{self.local_shape} != {dist_array.local_shape}")

def _check_mask(self, dist_array):
"""Check mask of the Array
"""
if not np.array_equal(self.mask, dist_array.mask):
raise ValueError("Mask of both the arrays must be same")

def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""MPI Allreduce operation
"""
Expand All @@ -345,12 +385,22 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
self.base_comm.Allreduce(send_buf, recv_buf, op)
return recv_buf

def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""MPI Allreduce operation with subcommunicator
"""
if recv_buf is None:
return self.sub_comm.allreduce(send_buf, op)
# For MIN and MAX which require recv_buf
self.sub_comm.Allreduce(send_buf, recv_buf, op)
return recv_buf

def __neg__(self):
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
mask=self.mask,
engine=self.engine,
dtype=self.dtype)
arr[:] = -self.local_array
Expand Down Expand Up @@ -378,11 +428,13 @@ def add(self, dist_array):
"""Distributed Addition of arrays
"""
self._check_partition_shape(dist_array)
self._check_mask(dist_array)
SumArray = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
dtype=self.dtype,
partition=self.partition,
local_shapes=self.local_shapes,
mask=self.mask,
engine=self.engine,
axis=self.axis)
SumArray[:] = self.local_array + dist_array.local_array
Expand All @@ -392,6 +444,7 @@ def iadd(self, dist_array):
"""Distributed In-place Addition of arrays
"""
self._check_partition_shape(dist_array)
self._check_mask(dist_array)
self[:] = self.local_array + dist_array.local_array
return self

Expand All @@ -400,12 +453,14 @@ def multiply(self, dist_array):
"""
if isinstance(dist_array, DistributedArray):
self._check_partition_shape(dist_array)
self._check_mask(dist_array)

ProductArray = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
dtype=self.dtype,
partition=self.partition,
local_shapes=self.local_shapes,
mask=self.mask,
engine=self.engine,
axis=self.axis)
if isinstance(dist_array, DistributedArray):
Expand All @@ -420,13 +475,15 @@ def dot(self, dist_array):
"""Distributed Dot Product
"""
self._check_partition_shape(dist_array)
self._check_mask(dist_array)

# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
if self.partition is Partition.BROADCAST else self
y = DistributedArray.to_dist(x=dist_array.local_array) \
if self.partition is Partition.BROADCAST else dist_array
# Flatten the local arrays and calculate dot product
return self._allreduce(np.dot(x.local_array.flatten(), y.local_array.flatten()))
return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten()))

def _compute_vector_norm(self, local_array: NDArray,
axis: int, ord: Optional[int] = None):
Expand All @@ -453,20 +510,20 @@ def _compute_vector_norm(self, local_array: NDArray,
raise ValueError(f"norm-{ord} not possible for vectors")
elif ord == 0:
# Count non-zero then sum reduction
recv_buf = self._allreduce(np.count_nonzero(local_array, axis=axis).astype(np.float64))
recv_buf = self._allreduce_subcomm(np.count_nonzero(local_array, axis=axis).astype(np.float64))
elif ord == np.inf:
# Calculate max followed by max reduction
recv_buf = self._allreduce(np.max(np.abs(local_array), axis=axis).astype(np.float64),
recv_buf, op=MPI.MAX)
recv_buf = self._allreduce_subcomm(np.max(np.abs(local_array), axis=axis).astype(np.float64),
recv_buf, op=MPI.MAX)
recv_buf = np.squeeze(recv_buf, axis=axis)
elif ord == -np.inf:
# Calculate min followed by min reduction
recv_buf = self._allreduce(np.min(np.abs(local_array), axis=axis).astype(np.float64),
recv_buf, op=MPI.MIN)
recv_buf = self._allreduce_subcomm(np.min(np.abs(local_array), axis=axis).astype(np.float64),
recv_buf, op=MPI.MIN)
recv_buf = np.squeeze(recv_buf, axis=axis)

else:
recv_buf = self._allreduce(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis))
recv_buf = self._allreduce_subcomm(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis))
recv_buf = np.power(recv_buf, 1. / ord)
return recv_buf

Expand Down Expand Up @@ -500,6 +557,7 @@ def conj(self):
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
mask=self.mask,
engine=self.engine,
dtype=self.dtype)
conj[:] = self.local_array.conj()
Expand All @@ -513,6 +571,7 @@ def copy(self):
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
mask=self.mask,
engine=self.engine,
dtype=self.dtype)
arr[:] = self.local_array
Expand All @@ -535,6 +594,7 @@ def ravel(self, order: Optional[str] = "C"):
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
arr = DistributedArray(global_shape=np.prod(self.global_shape),
local_shapes=local_shapes,
mask=self.mask,
partition=self.partition,
engine=self.engine,
dtype=self.dtype)
Expand Down
12 changes: 9 additions & 3 deletions pylops_mpi/basicoperators/BlockDiag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from scipy.sparse.linalg._interface import _get_dtype
from mpi4py import MPI
from typing import Optional, Sequence
from typing import Optional, Sequence, Union, List
from numbers import Integral

from pylops import LinearOperator
from pylops.utils import DTypeLike
Expand All @@ -28,6 +29,9 @@ class MPIBlockDiag(MPILinearOperator):
One or more :class:`pylops.LinearOperator` to be stacked.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
mask : :obj:`list`, optional
Mask defining subsets of ranks to consider when performing 'global' operations on
the distributed array such as dot product or norm.
dtype : :obj:`str`, optional
Type of elements in input array.
Expand Down Expand Up @@ -95,8 +99,10 @@ class MPIBlockDiag(MPILinearOperator):

def __init__(self, ops: Sequence[LinearOperator],
base_comm: MPI.Comm = MPI.COMM_WORLD,
mask: Optional[List[Integral]] = None,
dtype: Optional[DTypeLike] = None):
self.ops = ops
self.mask = mask
mops = np.zeros(len(self.ops), dtype=np.int64)
nops = np.zeros(len(self.ops), dtype=np.int64)
for iop, oper in enumerate(self.ops):
Expand All @@ -116,7 +122,7 @@ def __init__(self, ops: Sequence[LinearOperator],
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
engine=x.engine, dtype=self.dtype)
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
y1.append(oper.matvec(x.local_array[self.mmops[iop]:
Expand All @@ -128,7 +134,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
engine=x.engine, dtype=self.dtype)
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]:
Expand Down
10 changes: 6 additions & 4 deletions pylops_mpi/waveeqprocessing/MDC.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _MDC(G, nt, nv, nfmax, dt=1., dr=1., twosided=True,
Used to be able to provide operators from different libraries to
MDC. It operates in the same way as public method
(PoststackLinearModelling) but has additional input parameters allowing
(MPIMDC) but has additional input parameters allowing
passing a different operator and additional arguments to be passed to such
operator.
Expand Down Expand Up @@ -81,8 +81,10 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True,
base_comm: MPI.Comm = MPI.COMM_WORLD):
r"""Multi-dimensional convolution.
Apply multi-dimensional convolution between two datasets. Model and data
should be provided after flattening 2- or 3-dimensional arrays of size
Apply multi-dimensional convolution between two datasets in a distributed
fashion, with ``G`` distributed over ranks across the frequency axis.
Model and data are broadcasted and should be provided after flattening
2- or 3-dimensional arrays of size
:math:`[n_t \times n_r (\times n_{vs})]` and
:math:`[n_t \times n_s (\times n_{vs})]` (or :math:`2*n_t-1` for
``twosided=True``), respectively.
Expand All @@ -91,7 +93,7 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True,
----------
G : :obj:`numpy.ndarray`
Multi-dimensional convolution kernel in frequency domain of size
:math:`[n_{fmax} \times n_s \times n_r]`
:math:`[n_{f,rank} \times n_s \times n_r]`
nt : :obj:`int`
Number of samples along time axis for model and data (note that this
must be equal to ``2*n_t-1`` when working with ``twosided=True``.
Expand Down

0 comments on commit cc699df

Please sign in to comment.