Skip to content

Commit

Permalink
feat: added proposal for UNSAFE_BROADCAST partition
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Nov 14, 2024
1 parent f6dd7a9 commit 05ff24e
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ class Partition(Enum):
Distributing data among different processes.
- ``BROADCAST``: Distributes data to all processes.
- ``BROADCAST``: Distributes data to all processes
(ensuring that data is kept consistent across processes)
- ``UNSAFE_BROADCAST``: Distributes data to all processes
(without ensuring that data is kept consistent across processes)
- ``SCATTER``: Distributes unique portions to each process.
"""
BROADCAST = "Broadcast"
UNSAFE_BROADCAST = "UnsafeBroadcast"
SCATTER = "Scatter"


Expand All @@ -41,7 +45,7 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
local_shape : :obj:`tuple`
Shape of the local array.
"""
if partition == Partition.BROADCAST:
if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
local_shape = global_shape
# Split the array
else:
Expand Down Expand Up @@ -75,7 +79,7 @@ class DistributedArray:
MPI Communicator over which array is distributed.
Defaults to ``mpi4py.MPI.COMM_WORLD``.
partition : :obj:`Partition`, optional
Broadcast or Scatter the array. Defaults to ``Partition.SCATTER``.
Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
axis : :obj:`int`, optional
Axis along which distribution occurs. Defaults to ``0``.
local_shapes : :obj:`list`, optional
Expand All @@ -102,8 +106,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
raise IndexError(f"Axis {axis} out of range for DistributedArray "
f"of shape {global_shape}")
if partition not in Partition:
raise ValueError(f"Should be either {Partition.BROADCAST} "
f"or {Partition.SCATTER}")
raise ValueError(f"Should be either {Partition.BROADCAST}, "
f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}")
self.dtype = dtype
self._global_shape = _value_or_sized_to_tuple(global_shape)
self._base_comm = base_comm
Expand All @@ -128,6 +132,9 @@ def __setitem__(self, index, value):
`Partition.SCATTER` - Local Arrays are assigned their
unique values.
`Partition.UNSAFE_SCATTER` - Local Arrays are assigned their
unique values.
`Partition.BROADCAST` - The value at rank-0 is broadcasted
and is assigned to all the ranks.
Expand All @@ -139,12 +146,10 @@ def __setitem__(self, index, value):
Represents the value that will be assigned to the local array at
the specified index positions.
"""
# if self.partition is Partition.BROADCAST:
# self.local_array[index] = self.base_comm.bcast(value)
# else:
# self.local_array[index] = value
# testing this... avoid broadcasting and just let the user store the same value in each rank
self.local_array[index] = value
if self.partition is Partition.BROADCAST:
self.local_array[index] = self.base_comm.bcast(value)
else:
self.local_array[index] = value

@property
def global_shape(self):
Expand Down Expand Up @@ -288,7 +293,7 @@ def asarray(self):
Global Array gathered at all ranks
"""
# Since the global array was replicated at all ranks
if self.partition == Partition.BROADCAST:
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
# Get only self.local_array.
return self.local_array
# Gather all the local arrays and apply concatenation.
Expand Down Expand Up @@ -333,7 +338,7 @@ def to_dist(cls, x: NDArray,
mask=mask,
engine=get_module_name(get_array_module(x)),
dtype=x.dtype)
if partition == Partition.BROADCAST:
if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
dist_array[:] = x
else:
slices = [slice(None)] * x.ndim
Expand All @@ -352,7 +357,7 @@ def _check_local_shapes(self, local_shapes):
raise ValueError(f"Length of local shapes is not equal to number of processes; "
f"{len(local_shapes)} != {self.size}")
# Check if local shape == global shape
if self.partition is Partition.BROADCAST and local_shapes[self.rank] != self.global_shape:
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] and local_shapes[self.rank] != self.global_shape:
raise ValueError(f"Local shape is not equal to global shape at rank = {self.rank};"
f"{local_shapes[self.rank]} != {self.global_shape}")
elif self.partition is Partition.SCATTER:
Expand Down Expand Up @@ -481,9 +486,9 @@ def dot(self, dist_array):

# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
if self.partition is Partition.BROADCAST else self
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
y = DistributedArray.to_dist(x=dist_array.local_array) \
if self.partition is Partition.BROADCAST else dist_array
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
# Flatten the local arrays and calculate dot product
return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten()))

Expand Down Expand Up @@ -555,7 +560,7 @@ def norm(self, ord: Optional[int] = None,
"""
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
if self.partition is Partition.BROADCAST else self
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
if axis == -1:
# Flatten the local arrays and calculate norm
return x._compute_vector_norm(x.local_array.flatten(), axis=0, ord=ord)
Expand Down

0 comments on commit 05ff24e

Please sign in to comment.