Skip to content

Commit

Permalink
Merge pull request #119 from PyLops/bug-stackedarraynorm
Browse files Browse the repository at this point in the history
bug: fixed StackedDistributedArray.norm to work with cupy arrays
  • Loading branch information
mrava87 authored Nov 26, 2024
2 parents e9bbecc + a36e19d commit e5d7b52
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,21 +818,22 @@ def norm(self, ord: Optional[int] = None):
ord : :obj:`int`, optional
Order of the norm.
"""
norms = np.array([distarray.norm(ord) for distarray in self.distarrays])
ncp = get_module(self.distarrays[0].engine)
norms = ncp.array([distarray.norm(ord) for distarray in self.distarrays])
ord = 2 if ord is None else ord
if ord in ['fro', 'nuc']:
raise ValueError(f"norm-{ord} not possible for vectors")
elif ord == 0:
# Count non-zero then sum reduction
norm = np.sum(norms)
elif ord == np.inf:
norm = ncp.sum(norms)
elif ord == ncp.inf:
# Calculate max followed by max reduction
norm = np.max(norms)
elif ord == -np.inf:
norm = ncp.max(norms)
elif ord == -ncp.inf:
# Calculate min followed by max reduction
norm = np.min(norms)
norm = ncp.min(norms)
else:
norm = np.power(np.sum(np.power(norms, ord)), 1. / ord)
norm = ncp.power(ncp.sum(ncp.power(norms, ord)), 1. / ord)
return norm

def conj(self):
Expand Down

0 comments on commit e5d7b52

Please sign in to comment.