From 2235dee906e6f8dc0c25a235287452148d8bd0bd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 27 Jan 2025 14:05:43 -0800 Subject: [PATCH] catch stream errors earlier to avoid aborts (#1801) --- mlx/linalg.cpp | 31 +++++++++++++++++++++++++------ python/mlx/nn/init.py | 2 +- tests/random_tests.cpp | 3 +++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 989dae52fd..c4a21a881a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -10,6 +10,15 @@ namespace mlx::core::linalg { +void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) { + if (to_stream(s).device == Device::gpu) { + throw std::invalid_argument( + prefix + + " This op is not yet supported on the GPU. " + "Explicitly pass a CPU stream to run it."); + } +} + Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } @@ -174,6 +183,7 @@ array norm( } std::pair qr(const array& a, StreamOrDevice s /* = {} */) { + check_cpu_stream(s, "[linalg::qr]"); if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::qr] Arrays must type float32. Received array " @@ -201,6 +211,7 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { } std::vector svd(const array& a, StreamOrDevice s /* = {} */) { + check_cpu_stream(s, "[linalg::svd]"); if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::svd] Input array must have type float32. Received array " @@ -239,6 +250,7 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { } array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { + check_cpu_stream(s, "[linalg::inv]"); if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::inv] Arrays must type float32. Received array " @@ -279,6 +291,7 @@ array cholesky( const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { + check_cpu_stream(s, "[linalg::cholesky]"); if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::cholesky] Arrays must type float32. Received array " @@ -307,6 +320,7 @@ array cholesky( } array pinv(const array& a, StreamOrDevice s /* = {} */) { + check_cpu_stream(s, "[linalg::pinv]"); if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::pinv] Arrays must type float32. Received array " @@ -353,16 +367,17 @@ array cholesky_inv( const array& L, bool upper /* = false */, StreamOrDevice s /* = {} */) { + check_cpu_stream(s, "[linalg::cholesky_inv]"); if (L.dtype() != float32) { std::ostringstream msg; - msg << "[linalg::cholesky] Arrays must type float32. Received array " + msg << "[linalg::cholesky_inv] Arrays must type float32. Received array " << "with type " << L.dtype() << "."; throw std::invalid_argument(msg.str()); } if (L.ndim() < 2) { std::ostringstream msg; - msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " + msg << "[linalg::cholesky_inv] Arrays must have >= 2 dimensions. Received array " "with " << L.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); @@ -370,7 +385,7 @@ array cholesky_inv( if (L.shape(-1) != L.shape(-2)) { throw std::invalid_argument( - "[linalg::cholesky] Cholesky inverse is only defined for square " + "[linalg::cholesky_inv] Cholesky inverse is only defined for square " "matrices."); } @@ -454,7 +469,11 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh(const array& a, const std::string fname) { +void validate_eigh( + const array& a, + const StreamOrDevice& stream, + const std::string fname) { + check_cpu_stream(stream, fname); if (a.dtype() != float32) { std::ostringstream msg; msg << fname << " Arrays must have type float32. Received array " @@ -478,7 +497,7 @@ array eigvalsh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, "[linalg::eigvalsh]"); + validate_eigh(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); return array( std::move(out_shape), @@ -491,7 +510,7 @@ std::pair eigh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, "[linalg::eigh]"); + validate_eigh(a, s, "[linalg::eigh]"); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index cdd0a9e5a9..68a9221c45 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -385,7 +385,7 @@ def initializer(a: mx.array) -> mx.array: raise ValueError("Only tensors with 2 dimensions are supported") rows, cols = a.shape - num_zeros = int(mx.ceil(sparsity * cols)) + num_zeros = int(math.ceil(sparsity * cols)) order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1) a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 9d51c82b1c..d6a3815247 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -421,6 +421,9 @@ TEST_CASE("test random normal") { } TEST_CASE("test random multivariate_normal") { + // Scope switch to the cpu for SVDs + StreamContext sc(Device::cpu); + { auto mean = zeros({3}); auto cov = eye(3);