diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 153669d50..cdd0a9e5a 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -400,41 +400,29 @@ def initializer(a: mx.array) -> mx.array: def orthogonal( gain: float = 1.0, dtype: mx.Dtype = mx.float32 ) -> Callable[[mx.array], mx.array]: - r"""An orthogonal initializer. - - Generates a 2D orthogonal matrix by: - 1. Determine a square size based on max(rows, cols). - 2. Sampling a random square matrix from a normal distribution. - 3. Performing a CPU-based QR decomposition on this square matrix. - 4. Adjusting the sign of Q to ensure a unique orthogonal matrix. - 5. Slicing Q to the desired shape (rows, cols) if non-square. - 6. Scaling by `gain`. + r"""An initializer that returns an orthogonal matrix. Args: gain (float, optional): Scaling factor for the orthogonal matrix. - Defaults to 1.0. - dtype (Dtype, optional): Data type of the array. Defaults to float32. + Default: ``1.0``. + dtype (Dtype, optional): Data type of the array. Default: ``float32``. Returns: - Callable[[mx.array], mx.array]: An initializer function that produces - an orthogonal matrix. + Callable[[array], array]: An initializer that returns + an orthogonal matrix with the same shape as the input. """ def initializer(a: mx.array) -> mx.array: - if a.ndim < 2: - raise ValueError( - f"Orthogonal initialization requires at least 2D tensor but got {a.ndim}D." - ) - if a.ndim > 2: + if a.ndim != 2: raise ValueError( - "Orthogonal initialization currently only supports 2D arrays." + f"Orthogonal initialization requires a 2D array but got" + " a {a.ndim}D array." ) rows, cols = a.shape n = max(rows, cols) - # Generate a square random matrix - rmat = mx.random.normal(shape=(n, n), dtype=dtype) + rmat = mx.random.normal(shape=(n, n)) # Perform QR decomposition on CPU q, r = mx.linalg.qr(rmat, stream=mx.cpu) @@ -448,6 +436,6 @@ def initializer(a: mx.array) -> mx.array: # Scale Q by gain q = q * gain - return q + return q.astype(dtype) return initializer