diff --git a/curvlinops/utils.py b/curvlinops/utils.py index e06ca77..6c60455 100644 --- a/curvlinops/utils.py +++ b/curvlinops/utils.py @@ -5,8 +5,6 @@ from numpy import cumsum from torch import Tensor -from curvlinops._torch_base import PyTorchLinearOperator - def split_list(x: Union[List, Tuple], sizes: List[int]) -> List[List]: """Split a list into multiple lists of specified size. @@ -70,11 +68,11 @@ def allclose_report( return close -def assert_is_square(A: Union[Tensor, PyTorchLinearOperator]) -> int: +def assert_is_square(A) -> int: """Assert that a matrix or linear operator is square. Args: - A: Matrix or linear operator to be checked. + A: Matrix or linear operator to be checked. Must have a ``.shape`` attribute. Returns: The dimension of the square matrix. @@ -88,13 +86,11 @@ def assert_is_square(A: Union[Tensor, PyTorchLinearOperator]) -> int: return dim -def assert_matvecs_subseed_dim( - A: Union[Tensor, PyTorchLinearOperator], num_matvecs: int -): +def assert_matvecs_subseed_dim(A, num_matvecs: int): """Assert that the number of matrix-vector products is smaller than the dimension. Args: - A: Matrix or linear operator to be checked. + A: Matrix or linear operator to be checked. Must have a ``.shape`` attribute. num_matvecs: Number of matrix-vector products. Raises: