diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index db3ddeecf1..4b0cea1235 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: -- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index d4fd755c4c..cdd0a9e5a9 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -395,3 +395,47 @@ def initializer(a: mx.array) -> mx.array: return a return initializer + + +def orthogonal( + gain: float = 1.0, dtype: mx.Dtype = mx.float32 +) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns an orthogonal matrix. + + Args: + gain (float, optional): Scaling factor for the orthogonal matrix. + Default: ``1.0``. + dtype (Dtype, optional): Data type of the array. Default: ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns + an orthogonal matrix with the same shape as the input. + """ + + def initializer(a: mx.array) -> mx.array: + if a.ndim != 2: + raise ValueError( + f"Orthogonal initialization requires a 2D array but got" + " a {a.ndim}D array." + ) + + rows, cols = a.shape + n = max(rows, cols) + + rmat = mx.random.normal(shape=(n, n)) + + # Perform QR decomposition on CPU + q, r = mx.linalg.qr(rmat, stream=mx.cpu) + + # Adjust the sign of Q using the diagonal of R + d = mx.diag(r) + q = q * mx.sign(d) + + # Slice Q to the desired shape + q = q[:rows, :cols] + + # Scale Q by gain + q = q * gain + return q.astype(dtype) + + return initializer diff --git a/python/tests/test_init.py b/python/tests/test_init.py index f2fa179fdd..4b209736fa 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -106,6 +106,34 @@ def test_sparse(self): with self.assertRaises(ValueError): result = initializer(mx.zeros((1,))) + def test_orthogonal(self): + initializer = init.orthogonal(gain=1.0, dtype=mx.float32) + + # Test with a square matrix + shape = (4, 4) + result = initializer(mx.zeros(shape, dtype=mx.float32)) + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, mx.float32) + + I = result @ result.T + eye = mx.eye(shape[0], dtype=mx.float32) + self.assertTrue( + mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix." + ) + + # Test with a rectangular matrix: more rows than cols + shape = (6, 4) + result = initializer(mx.zeros(shape, dtype=mx.float32)) + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, mx.float32) + + I = result.T @ result + eye = mx.eye(shape[1], dtype=mx.float32) + self.assertTrue( + mx.allclose(I, eye, atol=1e-5), + "Orthogonal init failed on a rectangular matrix.", + ) + if __name__ == "__main__": unittest.main()