Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Nov 5, 2023
1 parent 0aaa475 commit 2ead2e1
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions dummy-files/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,30 @@ def diagonal(
axis2: int = -1,
out: Optional[JaxArray] = None,
) -> JaxArray:
"""
Returns the diagonal of the input array.
Parameters
----------
x : array_like
Input array from which the diagonals are taken.
offset : int, optional
Offset of the diagonal from the main diagonal. Default is 0.
axis1 : int, optional
Axis to take the first diagonals from. Default is -2.
axis2 : int, optional
Axis to take the second diagonals from. Default is -1.
Returns
-------
ret : ndarray
The extracted diagonal or diagonals.
Raises
------
IvyError
If the input array has less than two dimensions.
"""
if x.dtype != bool and not jnp.issubdtype(x.dtype, jnp.integer):
ret = jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
ret_edited = jnp.diagonal(
Expand All @@ -161,6 +185,30 @@ def tensorsolve(
axes: Optional[Union[int, Tuple[Sequence[int], Sequence[int]]]] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
"""
Solves the tensor equation ``a x = b`` for x.
Parameters
----------
a : array_like
Coefficient tensor, of shape (M, N, ..., M, N)
b : array_like
Right-hand side tensor, of shape (M, N, ..., Q)
axes : 2-tuple of lists of ints, optional
Axes in `a` to apply solve to (for each vector in `b`).
Default is ``(-2, -1)``
Returns
-------
x : ndarray
Solution tensor, shape ``(M, N, ..., Q)``.
Raises
------
LinAlgError
If `a` is singular or not `a.shape[axis] == b.shape[axis]` for all
`axis` in ``axes[0]``
"""
return jnp.linalg.tensorsolve(x1, x2, axes)


Expand All @@ -171,6 +219,28 @@ def tensorsolve(
def eigh(
x: JaxArray, /, *, UPLO: str = "L", out: Optional[JaxArray] = None
) -> Tuple[JaxArray]:
"""Computes the eigenvalues and eigenvectors of a Hermitian or real symmetric matrix.
Parameters
----------
x : array_like
Matrix whose eigenvalues and eigenvectors will be computed.
UPLO : {'L', 'U'}, optional
Specifies whether the calculation is done with the lower triangular part of matrix ('L', default) or the upper triangular part ('U').
out: tuple of arrays, optional
Tuple of output arrays. The first array contains the eigenvalues and the second contains the eigenvectors.
Returns
-------
eigenvalues : ndarray
Array containing the eigenvalues of the input matrix.
eigenvectors : ndarray
Array containing the eigenvectors of the input matrix.
"""
result_tuple = NamedTuple(
"eigh", [("eigenvalues", JaxArray), ("eigenvectors", JaxArray)]
)
Expand All @@ -185,11 +255,67 @@ def eigh(
def eigvalsh(
x: JaxArray, /, *, UPLO: str = "L", out: Optional[JaxArray] = None
) -> JaxArray:
"""
Computes the eigenvalues of a complex Hermitian or real symmetric matrix.
Parameters
----------
x : array_like
Input array. Must be a square 2-D array.
UPLO : {'L', 'U'}, optional
Specifies whether the calculation is done with the lower triangular part of
`x` ('L', default) or the upper triangular part ('U').
out : ndarray, optional
A location in which to store the results. If provided, it must have the same
shape as the eigenvalues.
Returns
-------
eigenvalues : ndarray
The eigenvalues, each repeated according to its multiplicity.
Raises
------
LinAlgError
If the eigenvalue computation does not converge.
Examples
--------
>>> x = np.array([[1, -2j], [2j, 5]])
>>> ivy.eigvalsh(x)
array([ 0.+2.23606802j, 0.-2.23606802j])
"""
return jnp.linalg.eigvalsh(x, UPLO=UPLO)


@with_unsupported_dtypes({"0.4.19 and below": ("complex",)}, backend_version)
def inner(x1: JaxArray, x2: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
"""Computes the inner product of two arrays.
Parameters
----------
x1 : array_like
First array to compute the inner product against.
x2 : array_like
Second array to compute the inner product against.
out : optional
Output array. If not provided, a new array will be created.
Returns
-------
ret : ndarray
Inner product of `x1` and `x2`.
Examples
--------
>>> x1 = [1, 2, 3]
>>> x2 = [4, 5, 6]
>>> inner(x1, x2)
32
"""
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
return jnp.inner(x1, x2)

Expand Down

0 comments on commit 2ead2e1

Please sign in to comment.