Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 13, 2025
1 parent bb4ca39 commit 7fdd830
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions python/mlx/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,41 +400,29 @@ def initializer(a: mx.array) -> mx.array:
def orthogonal(
gain: float = 1.0, dtype: mx.Dtype = mx.float32
) -> Callable[[mx.array], mx.array]:
r"""An orthogonal initializer.
Generates a 2D orthogonal matrix by:
1. Determine a square size based on max(rows, cols).
2. Sampling a random square matrix from a normal distribution.
3. Performing a CPU-based QR decomposition on this square matrix.
4. Adjusting the sign of Q to ensure a unique orthogonal matrix.
5. Slicing Q to the desired shape (rows, cols) if non-square.
6. Scaling by `gain`.
r"""An initializer that returns an orthogonal matrix.
Args:
gain (float, optional): Scaling factor for the orthogonal matrix.
Defaults to 1.0.
dtype (Dtype, optional): Data type of the array. Defaults to float32.
Default: ``1.0``.
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
Returns:
Callable[[mx.array], mx.array]: An initializer function that produces
an orthogonal matrix.
Callable[[array], array]: An initializer that returns
an orthogonal matrix with the same shape as the input.
"""

def initializer(a: mx.array) -> mx.array:
if a.ndim < 2:
raise ValueError(
f"Orthogonal initialization requires at least 2D tensor but got {a.ndim}D."
)
if a.ndim > 2:
if a.ndim != 2:
raise ValueError(
"Orthogonal initialization currently only supports 2D arrays."
f"Orthogonal initialization requires a 2D array but got"
" a {a.ndim}D array."
)

rows, cols = a.shape
n = max(rows, cols)

# Generate a square random matrix
rmat = mx.random.normal(shape=(n, n), dtype=dtype)
rmat = mx.random.normal(shape=(n, n))

# Perform QR decomposition on CPU
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
Expand All @@ -448,6 +436,6 @@ def initializer(a: mx.array) -> mx.array:

# Scale Q by gain
q = q * gain
return q
return q.astype(dtype)

return initializer

0 comments on commit 7fdd830

Please sign in to comment.