diff --git a/qiskit/quantum_info/__init__.py b/qiskit/quantum_info/__init__.py index 97a515f10858..a3832ef98c23 100644 --- a/qiskit/quantum_info/__init__.py +++ b/qiskit/quantum_info/__init__.py @@ -39,6 +39,7 @@ PauliList pauli_basis get_clifford_gate_names + operator_schmidt_decomposition .. _quantum_info_states: @@ -143,6 +144,7 @@ double_commutator, pauli_basis, get_clifford_gate_names, + operator_schmidt_decomposition, ) from .operators.channel import PTM, Chi, Choi, Kraus, Stinespring, SuperOp from .operators.dihedral import CNOTDihedral diff --git a/qiskit/quantum_info/operators/__init__.py b/qiskit/quantum_info/operators/__init__.py index 958cba4261ef..93c8edc64a26 100644 --- a/qiskit/quantum_info/operators/__init__.py +++ b/qiskit/quantum_info/operators/__init__.py @@ -27,3 +27,4 @@ get_clifford_gate_names, ) from .utils import anti_commutator, commutator, double_commutator +from .operator_schmidt_decomposition import operator_schmidt_decomposition diff --git a/qiskit/quantum_info/operators/operator_schmidt_decomposition.py b/qiskit/quantum_info/operators/operator_schmidt_decomposition.py new file mode 100644 index 000000000000..e9f9fdf52eff --- /dev/null +++ b/qiskit/quantum_info/operators/operator_schmidt_decomposition.py @@ -0,0 +1,262 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Operator Schmidt decomposition utilities. +""" +from __future__ import annotations + +from typing import Any +from collections.abc import Sequence +import numpy as np + +from qiskit.exceptions import QiskitError + + +def _permutation_matrix_from_qubit_order(new_order: Sequence[int], n: int) -> np.ndarray: + """ + Return the ``(2**n) x (2**n)`` permutation matrix ``P`` that reorders **little‑endian** qubits. + + Little‑endian convention: qubit 0 is the **least significant bit** (LSB). + + Mapping (bits → indices): + * ``new_order[k]`` gives which **original** qubit becomes bit‑position ``k`` in the **new** + representation (with ``k=0`` the LSB). + * For a computational basis state with original bitstring + ``b = (b_{n-1} ... b_1 b_0)`` we form the new bitstring ``b'`` by + ``b'_k = b_{ new_order[k] }``. + * Index mapping: ``i' = sum_k b'_k 2^k``. + + Action: + * States: ``|psi_new> = P |psi_old>``. + * Operators: ``U_new = P U_old P^T`` (``P`` is real; ``P^T = P.conj().T``). + + Args: + new_order: A permutation of ``range(n)`` where entry ``k`` is the original qubit index that + becomes bit‑position ``k`` in the new ordering (LSB is ``k=0``). + n: Number of qubits. + + Returns: + P: A boolean permutation matrix of shape ``(2**n, 2**n)`` such that the above actions hold. + + Raises: + QiskitError: If ``new_order`` is not a permutation of ``range(n)`` or sizes mismatch. + + Example: + For ``n=3`` and ``new_order = [2, 0, 1]`` (LSB first): + - New LSB (k=0) is original qubit 2, + - New middle bit (k=1) is original qubit 0, + - New MSB (k=2) is original qubit 1. + If original state has bits ``(b2 b1 b0)``, the new index corresponds to bits ``(b1 b0 b2)`` + in MSB→LSB order. + """ + # Validate + if not isinstance(n, int) or n < 0: + raise QiskitError("`n` must be a non‑negative integer.") + if len(new_order) != n: + raise QiskitError(f"`new_order` must have length n={n}.") + if set(new_order) != set(range(n)): + raise QiskitError("`new_order` must be a permutation of range(n).") + + dim = 2**n + indices = np.arange(dim, dtype=np.int64) # original indices i + + # Extract original bits b_q (q=0 is LSB) for each index. + bits = (indices[:, None] >> np.arange(n, dtype=np.int64)) & 1 # shape (dim, n) + + # Reorder bits so that new bit‑position k gets original bit from new_order[k]. + reordered_bits = bits[:, new_order] # shape (dim, n) + + # Convert reordered bits to new indices i' + new_indices = np.sum(reordered_bits << np.arange(n, dtype=np.int64), axis=1) + + # Build permutation matrix with columns permuted by new_indices + return np.eye(dim, dtype=bool)[:, new_indices] + + +def _check_inputs( + op: np.ndarray, qargs: Sequence[int] +) -> tuple[int, tuple[int, ...], tuple[int, ...]]: + if not isinstance(op, np.ndarray): + raise QiskitError("`op` must be a numpy.ndarray.") + if op.ndim != 2 or op.shape[0] != op.shape[1]: + raise QiskitError("`op` must be a square matrix.") + n_float = np.log2(op.shape[0]) + if not np.isclose(n_float, int(n_float)): + raise QiskitError("`op` dimension must be a power of 2.") + n = int(round(n_float)) + if op.shape != (2**n, 2**n): + raise QiskitError(f"`op` must have shape {(2**n, 2**n)}.") + + subset_a = tuple(sorted({int(q) for q in qargs})) + if any(q < 0 or q >= n for q in subset_a): + raise QiskitError(f"All indices in `qargs` must be in [0, {n-1}].") + subset_b = tuple(sorted(set(range(n)) - set(subset_a))) + if not subset_a or not subset_b: + raise QiskitError("`qargs` must be a strict, non‑empty subset of the qubit indices.") + return n, subset_a, subset_b + + +def _realign_row_major(u_perm: np.ndarray, dim_a: int, dim_b: int) -> np.ndarray: + """Return realignment ``realigned`` of ``u_perm`` for bipartition ``A (MSB) ⊗ B (LSB)``. + + ``realigned[(iA, jA), (iB, jB)] = u_perm[(iA, iB), (jA, jB)]`` via reshape+transpose. + """ + u4 = u_perm.reshape(dim_a, dim_b, dim_a, dim_b) # (iA, iB, jA, jB) + realigned = np.transpose(u4, (0, 2, 1, 3)) # (iA, jA, iB, jB) + return realigned.reshape(dim_a * dim_a, dim_b * dim_b) + + +def operator_schmidt_decomposition( + op: np.ndarray, + qargs: Sequence[int], + *, + k: int | None = None, + return_reconstruction: bool = False, +) -> dict[str, Any]: + r""" + Compute the operator Schmidt decomposition of ``op`` across the bipartition + defined by ``qargs`` (subsystem :math:`A`) and its complement (subsystem :math:`B`). + + Given an operator :math:`U` acting on :math:`n` qubits, and a bipartition + :math:`\mathcal{H} = \mathcal{H}_A \otimes \mathcal{H}_B` with + :math:`\dim(\mathcal{H}_A) = 2^{|A|}`, :math:`\dim(\mathcal{H}_B) = 2^{|B|}`, + the operator Schmidt decomposition is + + .. math:: + + U \;=\; \sum_{r=1}^{R} s_r \, A_r \otimes B_r, + + where :math:`s_r \ge 0` are the singular values of the **realigned** matrix, + and :math:`A_r, B_r` are matrices on :math:`\mathcal{H}_A` and + :math:`\mathcal{H}_B`, respectively. + + **Basis and permutation.** + The decomposition is computed in a **permuted basis** where the qubit order is + ``[Sc, S]`` (complement first, then selected subset). In this basis we have + + .. math:: + + U_{\text{perm}} \;=\; \sum_{r} A_r \otimes B_r, + + with :math:`A_r` acting on subsystem :math:`A` (MSB block) and :math:`B_r` on + subsystem :math:`B` (LSB block). The original operator satisfies + + .. math:: + + U \;=\; P^\top\, U_{\text{perm}}\, P, + + where ``P`` is the permutation matrix mapping the original qubit order to ``[Sc, S]``. + + **Truncation (top-``k`` terms).** + If ``k`` is provided, the returned factors correspond to the best rank-``k`` approximation + (in Frobenius norm) of the realigned matrix; i.e., only the top-``k`` singular components + are used to construct the factors and (optionally) the reconstruction. The array + ``singular_values`` in the return value always contains the **full** spectrum so that you + can inspect or post-process the tail; metadata about truncation and the Frobenius error of + the discarded part are also returned. + + Args: + op: Complex matrix of shape ``(2**n, 2**n)`` (unitary or not). + qargs: Qubit indices belonging to subsystem :math:`A`. Little‑endian ordering is used + in Qiskit (qubit 0 is the least significant bit). + k: If not ``None``, keep only the top-``k`` Schmidt terms. Must be a positive integer. + If ``k`` exceeds the number of available singular values, it is clipped. + return_reconstruction: If ``True``, also return the reconstruction + (sum of kept terms) mapped back to the **original** qubit order. + + Returns: + dict: + * ``partition``: ``{"S": tuple, "Sc": tuple}`` with the chosen split. + * ``permutation``: dict with: + - ``new_order``: tuple of qubit indices in the permuted order ``[Sc, S]``. + - ``matrix``: the permutation matrix ``P`` (shape ``(2**n, 2**n)``, real). + * ``singular_values``: 1D ``np.ndarray`` of **all** singular values (descending). + * ``A_factors``: list of ``np.ndarray`` of shape ``(2**|S|, 2**|S|)`` for the **kept** + terms, in the permuted basis (A on MSB block). + * ``B_factors``: list of ``np.ndarray`` of shape ``(2**|Sc|, 2**|Sc|)`` for the **kept** + terms, in the permuted basis (B on LSB block). + * ``truncation``: dict with: + - ``kept_terms``: number of terms kept (``k`` after clipping; equals full rank if + ``k`` is ``None``). + - ``discarded_terms``: number of discarded terms. + - ``frobenius_error``: Frobenius norm of the discarded tail (equal for the + realigned matrix and the permuted operator). + - ``relative_frobenius_error``: ``frobenius_error / np.linalg.norm(singular_values)``. + * ``reconstruction``: optional ``np.ndarray`` of the (possibly truncated) reconstruction + in **original** qubit order (present only when ``return_reconstruction=True``). + + Raises: + QiskitError: If inputs are malformed (non‑power‑of‑two dimensions, invalid ``qargs``) + or ``k`` is not a positive integer when provided. + """ + n, subset_a, subset_b = _check_inputs(op, qargs) + + # Permute to [Sc, S] so B occupies LSB block, A occupies MSB block. + perm = _permutation_matrix_from_qubit_order(list(subset_b) + list(subset_a), n) + u_perm = perm @ op @ perm.T + + dim_a, dim_b = 2 ** len(subset_a), 2 ** len(subset_b) + + # Realign and SVD + realigned = _realign_row_major(u_perm, dim_a, dim_b) + u_left, sing_vals, vh = np.linalg.svd(realigned, full_matrices=False) + vcols = vh.conj().T + + # Determine number of terms to keep + total_terms = len(sing_vals) # = min(dim_a*dim_a, dim_b*dim_b) + if k is None: + num = total_terms + else: + if not (isinstance(k, int) and k > 0): + raise QiskitError("`k` must be a positive integer if provided.") + num = min(k, total_terms) + + # Build factors so that sum kron(A_r, B_r) == u_perm (permuted basis), truncated if needed. + a_factors: list[np.ndarray] = [] + b_factors: list[np.ndarray] = [] + for i in range(num): + vec_a = u_left[:, i] * np.sqrt(sing_vals[i]) + vec_b = np.conj(vcols[:, i]) * np.sqrt(sing_vals[i]) + a_factors.append(vec_a.reshape(dim_a, dim_a)) + b_factors.append(vec_b.reshape(dim_b, dim_b)) + + # Truncation metadata + tail = sing_vals[num:] + fro_err = float(np.sqrt(np.sum(tail**2))) if tail.size else 0.0 + denom = np.linalg.norm(sing_vals) + rel_err = float(fro_err / denom) if denom > 0 else 0.0 + + out: dict[str, Any] = { + "partition": {"S": subset_a, "Sc": subset_b}, + "permutation": { + "new_order": tuple(subset_b) + tuple(subset_a), + "matrix": perm, + }, + "singular_values": sing_vals.copy(), # full spectrum (not truncated) + "A_factors": a_factors, # truncated list (num terms) + "B_factors": b_factors, # truncated list (num terms) + "truncation": { + "kept_terms": num, + "discarded_terms": total_terms - num, + "frobenius_error": fro_err, + "relative_frobenius_error": rel_err, + }, + } + + if return_reconstruction: + u_rec = np.zeros_like(op, dtype=np.complex128) + for i in range(num): + u_rec += np.kron(a_factors[i], b_factors[i]) + # Map back to original qubit order + out["reconstruction"] = perm.T @ u_rec @ perm + + return out diff --git a/releasenotes/notes/add-operator-schmidt-decomposition-ea5cec089113124c.yaml b/releasenotes/notes/add-operator-schmidt-decomposition-ea5cec089113124c.yaml new file mode 100644 index 000000000000..1ee51ef2879b --- /dev/null +++ b/releasenotes/notes/add-operator-schmidt-decomposition-ea5cec089113124c.yaml @@ -0,0 +1,16 @@ +--- +features: + - | + Added :func:`operator_schmidt_decomposition`. This function + computes the **operator Schmidt decomposition** of an operator acting on `n` qubits + across a specified bipartition. It returns: + + * The full set of singular values (Schmidt coefficients). + * Lists of Schmidt factors `A_r` and `B_r` in the permuted basis. + * Partition and permutation metadata, including the permutation matrix. + * Optional reconstruction of the operator in the original qubit order. + * Truncation support: keep only the top-`k` Schmidt terms with Frobenius-optimal + error reporting. + + This utility is useful for analyzing entanglement structure of operators, validating + tensor decompositions, and benchmarking low-rank approximations. \ No newline at end of file diff --git a/test/python/quantum_info/operators/test_operator_schmidt_decomposition.py b/test/python/quantum_info/operators/test_operator_schmidt_decomposition.py new file mode 100644 index 000000000000..7d1331d5bbe2 --- /dev/null +++ b/test/python/quantum_info/operators/test_operator_schmidt_decomposition.py @@ -0,0 +1,376 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Tests for operator Schmidt decomposition utility (unittest + DDT parameterization). +""" +from __future__ import annotations + +from test import QiskitTestCase, slow_test + +import itertools +import unittest +from collections.abc import Iterable + +import numpy as np +import numpy.testing as npt +from ddt import ddt, idata, unpack + +from qiskit.exceptions import QiskitError +from qiskit.quantum_info import random_unitary +from qiskit.quantum_info.operators.operator_schmidt_decomposition import ( + operator_schmidt_decomposition, +) + +# Tolerances consistent with Qiskit’s double-precision checks. +from qiskit.quantum_info.operators.predicates import ATOL_DEFAULT, RTOL_DEFAULT + +ATOL = ATOL_DEFAULT +RTOL = RTOL_DEFAULT + +# Always-on small seed set (keeps default runs fast and deterministic). +SEEDS_FAST = [7, 11, 19, 23, 42] + +# Optional heavy stress, controlled by env var (e.g., QISKIT_SLOW_TESTS=1). +SEEDS_STRESS = list(range(100, 120)) + + +def _fro_error(a_mat: np.ndarray, b_mat: np.ndarray) -> float: + return np.linalg.norm(a_mat - b_mat, ord="fro") + + +# ---------- Helper case generators (evaluated at import time for DDT) ---------- + + +def _cases_exact_unitary() -> Iterable[tuple[int, int, tuple[int, ...]]]: + # (seed, n, subset_a) + for seed in SEEDS_FAST: + for n_qubits in (1, 2, 3): + for r_size in range(1, n_qubits): + for subset_a in itertools.combinations(range(n_qubits), r_size): + yield (seed, n_qubits, subset_a) + + +def _cases_exact_dense() -> Iterable[tuple[int, int, tuple[int, ...]]]: + for seed in SEEDS_FAST: + for n_qubits in (2, 3): # n=1 would be trivial + for r_size in range(1, n_qubits): + for subset_a in itertools.combinations(range(n_qubits), r_size): + yield (seed, n_qubits, subset_a) + + +def _cases_qargs_order_irrelevant() -> Iterable[tuple[int, tuple[int, ...], tuple[int, ...]]]: + # removed unused local 'n_qubits' + for seed in SEEDS_FAST: + # Same subsets, different order: + yield (seed, (2, 0), (0, 2)) + yield (seed, (1, 0), (0, 1)) + yield (seed, (2, 1), (1, 2)) + + +def _cases_singular_values_props() -> Iterable[tuple[int]]: + for seed in SEEDS_FAST: + yield (seed,) + + +def _cases_rank1_kron() -> Iterable[tuple[int]]: + for seed in SEEDS_FAST: + yield (seed,) + + +def _cases_truncation_meta() -> Iterable[tuple[int, int, tuple[int, ...]]]: + # (seed, n, subset_a) + for seed in SEEDS_FAST: + yield (seed, 3, (1,)) + yield (seed, 3, (0, 2)) + yield (seed, 4, (1, 2)) + + +def _cases_truncation_low_rank() -> Iterable[tuple[int]]: + # vary rank p; n=2 fixed inside + return [(2,), (3,)] + + +def _cases_permutation(seed_list=None) -> Iterable[tuple[int, tuple[int, ...]]]: + if seed_list is None: + seed_list = SEEDS_FAST + for seed in seed_list: + for subset_a in [(0,), (1,), (2,), (0, 2)]: + yield (seed, subset_a) + + +def _cases_k_validation() -> Iterable[tuple[int]]: + return [(0,), (-3,)] + + +# ------------------------- Main test class (fast set) -------------------------- + + +@ddt +class TestOperatorSchmidtDecomposition(QiskitTestCase): + """Fast test suite for OSD.""" + + @idata(list(_cases_exact_unitary())) + @unpack + def test_exact_reconstruction_random_unitary( + self, seed: int, n_qubits: int, subset_a: tuple[int, ...] + ): + """Exact reconstruction (full sum) for random unitaries.""" + unitary = np.array(random_unitary(2**n_qubits, seed=seed), dtype=complex) + out = operator_schmidt_decomposition(unitary, subset_a, return_reconstruction=True) + self.assertAlmostEqual(_fro_error(unitary, out["reconstruction"]), 0.0, delta=ATOL) + + @idata(list(_cases_exact_dense())) + @unpack + def test_exact_reconstruction_random_dense( + self, seed: int, n_qubits: int, subset_a: tuple[int, ...] + ): + """Exact reconstruction for random dense (nonunitary) operators.""" + rng = np.random.default_rng(seed) + dim = 2**n_qubits + op = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)) + out = operator_schmidt_decomposition(op, subset_a, return_reconstruction=True) + self.assertAlmostEqual(_fro_error(op, out["reconstruction"]), 0.0, delta=ATOL) + + @idata(list(_cases_qargs_order_irrelevant())) + @unpack + def test_qargs_order_irrelevant( + self, seed: int, subset_a_perm1: tuple[int, ...], subset_a_perm2: tuple[int, ...] + ): + """Singular values are invariant under reordering within the same subset.""" + n_qubits = 3 + unitary = np.array(random_unitary(2**n_qubits, seed=seed), dtype=complex) + out1 = operator_schmidt_decomposition(unitary, subset_a_perm1) + out2 = operator_schmidt_decomposition(unitary, subset_a_perm2) + npt.assert_allclose( + np.sort(out1["singular_values"]), + np.sort(out2["singular_values"]), + rtol=RTOL, + atol=ATOL, + ) + + @idata(list(_cases_singular_values_props())) + @unpack + def test_singular_values_properties(self, seed: int): + """SV sanity: nonnegative, descending, and Frobenius identity.""" + n_qubits = 3 + rng = np.random.default_rng(seed) + dim = 2**n_qubits + op = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)) + out = operator_schmidt_decomposition(op, (1,)) + sing_vals = out["singular_values"] + self.assertTrue(np.all(sing_vals >= -ATOL)) # nonnegative (within numerical noise) + self.assertTrue(np.all(sing_vals[:-1] + ATOL >= sing_vals[1:])) # descending + fro_sq = np.linalg.norm(op, ord="fro") ** 2 + self.assertAlmostEqual(np.sum(sing_vals**2), fro_sq, delta=max(ATOL, RTOL * abs(fro_sq))) + + @idata(list(_cases_singular_values_props())) + @unpack + def test_schmidt_factors_orthogonality(self, seed: int): + """Hilbert–Schmidt orthogonality and normalization of Schmidt factors.""" + n_qubits = 3 + rng = np.random.default_rng(seed) + dim = 2**n_qubits + op = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)) + subset_a = (0, 2) + out = operator_schmidt_decomposition(op, subset_a) + sing_vals = out["singular_values"] + a_factors = out["A_factors"] + b_factors = out["B_factors"] + + gram_a = np.array( + [[np.vdot(x_mat.ravel(), y_mat.ravel()) for y_mat in a_factors] for x_mat in a_factors] + ) + gram_b = np.array( + [[np.vdot(x_mat.ravel(), y_mat.ravel()) for y_mat in b_factors] for x_mat in b_factors] + ) + diag_s = np.diag(sing_vals) + npt.assert_allclose(gram_a, diag_s, rtol=RTOL, atol=ATOL) + npt.assert_allclose(gram_b, diag_s, rtol=RTOL, atol=ATOL) + + # Orthonormality after normalization + a_norm = [ + x_mat / np.sqrt(sing_vals[i]) if sing_vals[i] > 0 else x_mat + for i, x_mat in enumerate(a_factors) + ] + b_norm = [ + y_mat / np.sqrt(sing_vals[i]) if sing_vals[i] > 0 else y_mat + for i, y_mat in enumerate(b_factors) + ] + gram_a_n = np.array( + [[np.vdot(x_mat.ravel(), y_mat.ravel()) for y_mat in a_norm] for x_mat in a_norm] + ) + gram_b_n = np.array( + [[np.vdot(x_mat.ravel(), y_mat.ravel()) for y_mat in b_norm] for x_mat in b_norm] + ) + identity_mat = np.eye(len(sing_vals), dtype=complex) + npt.assert_allclose(gram_a_n, identity_mat, rtol=RTOL, atol=ATOL) + npt.assert_allclose(gram_b_n, identity_mat, rtol=RTOL, atol=ATOL) + + @idata(list(_cases_rank1_kron())) + @unpack + def test_rank1_kron_has_single_singular_value(self, seed: int): + """A ⊗ B has a single nonzero Schmidt value: ||A||_F * ||B||_F.""" + rng = np.random.default_rng(seed) + a_mat = rng.normal(size=(2, 2)) + 1j * rng.normal(size=(2, 2)) + b_mat = rng.normal(size=(2, 2)) + 1j * rng.normal(size=(2, 2)) + op = np.kron(a_mat, b_mat) + out = operator_schmidt_decomposition(op, qargs=[1]) + sing_vals = out["singular_values"] + s0 = np.linalg.norm(a_mat, ord="fro") * np.linalg.norm(b_mat, ord="fro") + self.assertAlmostEqual(sing_vals[0], s0, delta=max(ATOL, RTOL * abs(s0))) + if len(sing_vals) > 1: + npt.assert_allclose(sing_vals[1:], np.zeros_like(sing_vals[1:]), rtol=RTOL, atol=ATOL) + + @idata(list(_cases_k_validation())) + @unpack + def test_k_validation(self, k_bad: int): + """Non‑positive k should raise QiskitError (parameterized).""" + op = np.eye(4, dtype=complex) + with self.assertRaises(QiskitError): + operator_schmidt_decomposition(op, [0], k=k_bad) + + # --- Truncation & permutation --- + + @idata(list(_cases_truncation_meta())) + @unpack + def test_truncation_frobenius_optimality_and_metadata( + self, seed: int, n_qubits: int, subset_a: tuple[int, ...] + ): + """Top‑k truncation gives Frobenius‑optimal tail error and correct metadata.""" + rng = np.random.default_rng(seed) + dim = 2**n_qubits + op = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)) + + # Try multiple k values per case for robustness + for k_val in (1, 2, 3): + out = operator_schmidt_decomposition(op, subset_a, k=k_val, return_reconstruction=True) + sing_vals = out["singular_values"] + total_terms = len(sing_vals) + + kept = out["truncation"]["kept_terms"] + discarded = out["truncation"]["discarded_terms"] + self.assertEqual(kept, min(k_val, total_terms)) + self.assertEqual(discarded, max(0, total_terms - k_val)) + + expected_tail = np.sqrt(np.sum(sing_vals[k_val:] ** 2)) if k_val < total_terms else 0.0 + fro_err = out["truncation"]["frobenius_error"] + self.assertAlmostEqual( + fro_err, expected_tail, delta=max(ATOL, RTOL * max(1.0, expected_tail)) + ) + + denom = np.linalg.norm(sing_vals) + rel_expected = (expected_tail / denom) if denom > 0 else 0.0 + self.assertAlmostEqual( + out["truncation"]["relative_frobenius_error"], + rel_expected, + delta=max(ATOL, RTOL * max(1.0, rel_expected)), + ) + + # Reconstruction is in original order; its error equals tail error. + op_rec = out["reconstruction"] + self.assertIsInstance(op_rec, np.ndarray) + self.assertAlmostEqual( + _fro_error(op, op_rec), + expected_tail, + delta=max(ATOL, RTOL * max(1.0, expected_tail)), + ) + + @idata(list(_cases_truncation_low_rank())) + @unpack + def test_truncation_exact_low_rank_sum_of_krons(self, rank_terms: int): + """Operators with Schmidt rank p are reconstructed exactly when k >= p.""" + rng = np.random.default_rng(123 + rank_terms) # vary seed with p + a_list = [rng.normal(size=(2, 2)) + 1j * rng.normal(size=(2, 2)) for _ in range(rank_terms)] + b_list = [rng.normal(size=(2, 2)) + 1j * rng.normal(size=(2, 2)) for _ in range(rank_terms)] + op = sum(np.kron(a_mat, b_mat) for a_mat, b_mat in zip(a_list, b_list)) + + out_full = operator_schmidt_decomposition( + op, qargs=[1], k=rank_terms, return_reconstruction=True + ) + self.assertAlmostEqual(_fro_error(op, out_full["reconstruction"]), 0.0, delta=1e-10) + + out_k1 = operator_schmidt_decomposition(op, qargs=[1], k=1, return_reconstruction=True) + sing_vals = out_k1["singular_values"] + expected_tail = np.sqrt(np.sum(sing_vals[1:] ** 2)) if len(sing_vals) > 1 else 0.0 + if expected_tail > 0: + self.assertAlmostEqual( + _fro_error(op, out_k1["reconstruction"]), + expected_tail, + delta=max(ATOL, RTOL * max(1.0, expected_tail)), + ) + + @idata(list(_cases_permutation())) + @unpack + def test_permutation_new_order_and_matrix_contract(self, seed: int, subset_a: tuple[int, ...]): + """new_order == Sc + S and P maps U to the basis where sum kron(A,B) holds.""" + n_qubits = 3 + unitary = np.array(random_unitary(2**n_qubits, seed=seed), dtype=complex) + out = operator_schmidt_decomposition(unitary, subset_a) + part_info = out["partition"] + perm_info = out["permutation"] + + expected_order = tuple(part_info["Sc"]) + tuple(part_info["S"]) + self.assertEqual(tuple(perm_info["new_order"]), expected_order) + + perm_matrix = perm_info["matrix"] + self.assertEqual(perm_matrix.shape, (2**n_qubits, 2**n_qubits)) + npt.assert_allclose(perm_matrix @ perm_matrix.T, np.eye(2**n_qubits), rtol=RTOL, atol=ATOL) + npt.assert_allclose(perm_matrix.T @ perm_matrix, np.eye(2**n_qubits), rtol=RTOL, atol=ATOL) + self.assertTrue(np.all((np.abs(perm_matrix) < ATOL) | (np.abs(perm_matrix - 1) < ATOL))) + + # In the permuted basis, Up == sum_i kron(A_i, B_i) (full, untruncated case). + out_full = operator_schmidt_decomposition(unitary, subset_a, k=None) + a_factors = out_full["A_factors"] + b_factors = out_full["B_factors"] + up_from_factors = np.zeros_like(unitary, dtype=np.complex128) + for a_mat, b_mat in zip(a_factors, b_factors): + up_from_factors += np.kron(a_mat, b_mat) + up_direct = perm_matrix @ unitary @ perm_matrix.T + npt.assert_allclose(up_from_factors, up_direct, rtol=1e-11, atol=1e-11) + + +# ----------------------------- Stress tests (DDT) ------------------------------ + + +@ddt +class TestOperatorSchmidtDecompositionStress(QiskitTestCase): + """Stress tests (marked as @slow_test).""" + + @slow_test + @idata(list(_cases_permutation(SEEDS_STRESS))) + @unpack + def test_exact_reconstruction_unitary_stress(self, seed: int, subset_a: tuple[int, ...]): + """Stress: exact reconstruction over many seeds/partitions (unitary inputs).""" + n_qubits = 3 + unitary = np.array(random_unitary(2**n_qubits, seed=seed), dtype=complex) + out = operator_schmidt_decomposition(unitary, subset_a, return_reconstruction=True) + self.assertAlmostEqual(_fro_error(unitary, out["reconstruction"]), 0.0, delta=ATOL) + + @slow_test + @idata([(seed,) for seed in SEEDS_STRESS]) + @unpack + def test_singular_values_properties_stress(self, seed: int): + """Stress: SV nonnegativity/ordering and Frobenius identity (dense random inputs).""" + n_qubits = 3 + rng = np.random.default_rng(seed) + dim = 2**n_qubits + op = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)) + out = operator_schmidt_decomposition(op, [1]) + sing_vals = out["singular_values"] + self.assertTrue(np.all(sing_vals >= -ATOL)) + self.assertTrue(np.all(sing_vals[:-1] + ATOL >= sing_vals[1:])) + fro_sq = np.linalg.norm(op, ord="fro") ** 2 + self.assertAlmostEqual(np.sum(sing_vals**2), fro_sq, delta=max(ATOL, RTOL * abs(fro_sq))) + + +if __name__ == "__main__": + unittest.main()