Skip to content

Commit

Permalink
catch stream errors earlier to avoid aborts (#1801)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Jan 27, 2025
1 parent 28091aa commit 2235dee
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
31 changes: 25 additions & 6 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@

namespace mlx::core::linalg {

void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
if (to_stream(s).device == Device::gpu) {
throw std::invalid_argument(
prefix +
" This op is not yet supported on the GPU. "
"Explicitly pass a CPU stream to run it.");
}
}

Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
Expand Down Expand Up @@ -174,6 +183,7 @@ array norm(
}

std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::qr]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::qr] Arrays must type float32. Received array "
Expand Down Expand Up @@ -201,6 +211,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
}

std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have type float32. Received array "
Expand Down Expand Up @@ -239,6 +250,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
}

array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
check_cpu_stream(s, "[linalg::inv]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::inv] Arrays must type float32. Received array "
Expand Down Expand Up @@ -279,6 +291,7 @@ array cholesky(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must type float32. Received array "
Expand Down Expand Up @@ -307,6 +320,7 @@ array cholesky(
}

array pinv(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::pinv]");
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must type float32. Received array "
Expand Down Expand Up @@ -353,24 +367,25 @@ array cholesky_inv(
const array& L,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky_inv]");
if (L.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must type float32. Received array "
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
<< "with type " << L.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (L.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
msg << "[linalg::cholesky_inv] Arrays must have >= 2 dimensions. Received array "
"with "
<< L.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (L.shape(-1) != L.shape(-2)) {
throw std::invalid_argument(
"[linalg::cholesky] Cholesky inverse is only defined for square "
"[linalg::cholesky_inv] Cholesky inverse is only defined for square "
"matrices.");
}

Expand Down Expand Up @@ -454,7 +469,11 @@ array cross(
return concatenate(outputs, axis, s);
}

void validate_eigh(const array& a, const std::string fname) {
void validate_eigh(
const array& a,
const StreamOrDevice& stream,
const std::string fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
Expand All @@ -478,7 +497,7 @@ array eigvalsh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]");
validate_eigh(a, s, "[linalg::eigvalsh]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
Expand All @@ -491,7 +510,7 @@ std::pair<array, array> eigh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]");
validate_eigh(a, s, "[linalg::eigh]");
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
Expand Down
2 changes: 1 addition & 1 deletion python/mlx/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def initializer(a: mx.array) -> mx.array:
raise ValueError("Only tensors with 2 dimensions are supported")

rows, cols = a.shape
num_zeros = int(mx.ceil(sparsity * cols))
num_zeros = int(math.ceil(sparsity * cols))

order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
Expand Down
3 changes: 3 additions & 0 deletions tests/random_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ TEST_CASE("test random normal") {
}

TEST_CASE("test random multivariate_normal") {
// Scope switch to the cpu for SVDs
StreamContext sc(Device::cpu);

{
auto mean = zeros({3});
auto cov = eye(3);
Expand Down

0 comments on commit 2235dee

Please sign in to comment.