Skip to content

Commit

Permalink
backwards-compatible sparse check
Browse files Browse the repository at this point in the history
  • Loading branch information
timmysilv committed Dec 2, 2024
1 parent 12dc38d commit 5928766
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
7 changes: 6 additions & 1 deletion flamingpy/cv/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def invert_permutation(p):
return p_inverted


def issparse(array):
"""Check if an array is sparse. Backwards-compatible with old SciPy versions."""
return isinstance(array, getattr(sp, "sparray", sp.coo_matrix))


def SCZ_mat(adj, sparse=True):
"""Return a symplectic matrix corresponding to CZ gate application.
Expand Down Expand Up @@ -59,7 +64,7 @@ def SCZ_mat(adj, sparse=True):
# Construct symplectic
symplectic = block_func([[identity, zeros], [adj, identity]])

if not sparse and isinstance(symplectic, sp.coo_array):
if not sparse and issparse(symplectic):
return symplectic.toarray()

return symplectic
Expand Down
12 changes: 5 additions & 7 deletions tests/cv/test_cv_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import scipy.sparse as sp

from flamingpy.codes.graphs import EGraph
from flamingpy.cv.ops import invert_permutation, SCZ_mat, SCZ_apply
from flamingpy.cv.ops import invert_permutation, SCZ_mat, SCZ_apply, issparse

now = datetime.now()
int_time = int(str(now.year) + str(now.month) + str(now.day) + str(now.hour) + str(now.minute))
Expand All @@ -50,13 +50,11 @@ def random_graph(request):
class TestSCZ:
"""Tests for symplectic CZ matrices."""

@pytest.mark.parametrize(
"sparse, expected_out_type", sorted([(True, sp.coo_array), (False, np.ndarray)])
)
def test_SCZ_mat_sparse_param(self, random_graph, sparse, expected_out_type):
@pytest.mark.parametrize("sparse", [True, False])
def test_SCZ_mat_sparse_param(self, random_graph, sparse):
"""Tests the SCZ_mat function outputs sparse or dense arrays."""
SCZ = SCZ_mat(random_graph[2], sparse=sparse)
assert isinstance(SCZ, expected_out_type)
assert issparse(SCZ) if sparse else isinstance(SCZ, np.ndarray)

def test_SCZ_mat(self, random_graph):
"""Tests the SCZ_mat function."""
Expand All @@ -65,7 +63,7 @@ def test_SCZ_mat(self, random_graph):
# Check if SCZ_mat adjusts type of output matrix based on
# type of input.
assert isinstance(SCZ, np.ndarray)
assert isinstance(SCZ_sparse, sp.coo_array)
assert isinstance(SCZ_sparse, sp.sparray)
# Check that structure of SCZ matrix is correct.
for mat in (SCZ, SCZ_sparse.toarray()):
assert np.array_equal(mat[:N, :N], np.identity(N))
Expand Down

0 comments on commit 5928766

Please sign in to comment.