Skip to content

Commit

Permalink
linalg solve backend
Browse files Browse the repository at this point in the history
  • Loading branch information
abeleinin committed Oct 3, 2024
1 parent 5900e32 commit 45dbbca
Show file tree
Hide file tree
Showing 15 changed files with 340 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Linear Algebra
cross
qr
svd
solve
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Solve)

void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)

Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Solve)

namespace {

Expand Down
131 changes: 131 additions & 0 deletions mlx/backend/common/solve.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright © 2024 Apple Inc.

#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/primitives.h"

#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

#include <cassert>

namespace mlx::core {

namespace {

// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'trans' string to fortran).
int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) {
int info;

#ifdef LAPACK_FORTRAN_STRLEN_END
sgetrs_(
/* trans */ &trans,
/* n */ &N,
/* nrhs */ &NRHS,
/* a */ a,
/* lda */ &N,
/* ipiv */ ipiv,
/* b */ b,
/* ldb */ &N,
/* info */ &info,
/* trans_len = */ static_cast<size_t>(1));
#else
sgetrs_(
/* trans */ &trans,
/* n */ &N,
/* nrhs */ &NRHS,
/* a */ a,
/* lda */ &N,
/* ipiv */ ipiv,
/* b */ b,
/* ldb */ &N,
/* info */ &info);
#endif

return info;
}

} // namespace

void solve_impl(const array& a, const array& b, array& out) {
int N = a.shape(-2);
int NRHS = out.shape(-1);
std::vector<int> ipiv(N);

// copy b into out and make it col-contiguous
auto flags = out.flags();
flags.col_contiguous = true;
flags.row_contiguous = false;
std::vector<size_t> strides(a.ndim(), 0);
std::copy(out.strides().begin(), out.strides().end(), strides.begin());
strides[a.ndim() - 2] = 1;
strides[a.ndim() - 1] = N;

out.set_data(
allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags);
copy_inplace(b, out, CopyType::GeneralGeneral);

// lapack clobbers the input, so we have to make a copy. the copy doesn't need
// to be col-contiguous because sgetrs has a transpose parameter (trans='T').
array a_cpy(a.shape(), float32, nullptr, {});
copy(
a,
a_cpy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);

float* a_ptr = a_cpy.data<float>();
float* out_ptr = out.data<float>();
int* ipiv_ptr = ipiv.data();

int info;
size_t num_matrices = a.size() / (N * N);
for (size_t i = 0; i < num_matrices; i++) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &N,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &N,
/* ipiv */ ipiv_ptr,
/* info */ &info);

if (info != 0) {
std::stringstream ss;
ss << "solve_impl: sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " becuase argument had an illegal value");
throw std::runtime_error(ss.str());
}

static constexpr char trans = 'T';
// Solve the system using the LU factors from sgetrf
info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr);

if (info != 0) {
std::stringstream ss;
ss << "solve_impl: sgetrs_ failed with code " << info;
throw std::runtime_error(ss.str());
}

// Advance pointers to the next matrix
a_ptr += N * N;
out_ptr += N * NRHS;
}
}

void Solve::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) {
throw std::runtime_error("[Solve::eval] only supports float32.");
}
solve_impl(inputs[0], inputs[1], outputs[0]);
}

} // namespace mlx::core
6 changes: 6 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,4 +432,10 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}

void Solve::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI.");
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,6 @@ NO_CPU(Tanh)
NO_CPU(Transpose)
NO_CPU(Inverse)
NO_CPU(View)
NO_CPU_MULTI(Solve)

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU(View)
NO_GPU_MULTI(Solve)

namespace fast {
NO_GPU_MULTI(LayerNorm)
Expand Down
43 changes: 43 additions & 0 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,47 @@ array cross(
return concatenate(outputs, axis, s);
}

array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32 && b.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::solve] Input array must have type float32. Received array "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::solve] Arrays must have >= 2 dimensions. Received array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (b.ndim() < 1) {
std::ostringstream msg;
msg << "[linalg::solve] Array must have >= 1 dimension. Received array with "
<< b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << "[linalg::solve] First input must be a square matrix. Received array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != b.shape(b.ndim() - 2)) {
std::ostringstream msg;
msg << "[linalg::solve] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}

auto out_type = promote_types(a.dtype(), b.dtype());

return array(
b.shape(), out_type, std::make_shared<Solve>(to_stream(s)), {a, b});
}

} // namespace mlx::core::linalg
2 changes: 2 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

array solve(const array& a, const array& b, StreamOrDevice s = {});

/**
* Compute the cross product of two arrays along the given axis.
*/
Expand Down
11 changes: 11 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4110,6 +4110,17 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}

std::pair<std::vector<array>, std::vector<int>> Solve::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
return {{linalg::solve(a, b, stream())}, {0}};
}

std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
Expand Down
16 changes: 16 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2168,4 +2168,20 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};

class Solve : public Primitive {
public:
explicit Solve(Stream stream) : Primitive(stream) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

DEFINE_VMAP()
DEFINE_PRINT(Solve)

private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};

} // namespace mlx::core
21 changes: 21 additions & 0 deletions python/src/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,25 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
m.def(
"solve",
&solve,
"a"_a,
"b"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the solution to a square system of linear equations AX = B.
Args:
a (array): Input array.
b (array): Input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unique solution to the system AX = B.
)pbdoc");
}
54 changes: 54 additions & 0 deletions python/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,60 @@ def test_cross_product(self):
with self.assertRaises(ValueError):
mx.linalg.cross(a, b)

def test_solve(self):
mx.random.seed(7)

# Test 3x3 matrix with 1D rhs
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
b = mx.array([11.0, 35.0, 28.0])

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))

# Test symmetric positive-definite matrix
N = 5
a = mx.random.uniform(shape=(N, N))
a = mx.matmul(a, a.T) + N * mx.eye(N)
b = mx.random.uniform(shape=(N, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test batch dimension
a = mx.random.uniform(shape=(5, 5, 4, 4))
b = mx.random.uniform(shape=(5, 5, 4, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test large matrix
N = 1000
a = mx.random.uniform(shape=(N, N))
b = mx.random.uniform(shape=(N, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-2))

# Test multi-column rhs
a = mx.random.uniform(shape=(5, 5))
b = mx.random.uniform(shape=(5, 8))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test batched multi-column rhs
a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5))
b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 45dbbca

Please sign in to comment.