Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU LU factorization and linear solvers #1451

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ Linear Algebra

.. currentmodule:: mlx.core.linalg

.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary

inv
tri_inv
Expand All @@ -18,3 +18,7 @@ Linear Algebra
svd
eigvalsh
eigh
lu
lu_factor
solve
solve_triangular
1 change: 1 addition & 0 deletions mlx/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
Expand Down
88 changes: 88 additions & 0 deletions mlx/backend/cpu/luf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright © 2024 Apple Inc.

#include <cassert>

#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"

namespace mlx::core {

void lu_factor_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices) {
int M = a.shape(-2);
int N = a.shape(-1);

// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
auto flags = lu.flags();
flags.col_contiguous = ndim == 2;
flags.row_contiguous = false;
flags.contiguous = true;
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);

auto a_ptr = lu.data<float>();

pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();

int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);

if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}

// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}

// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
}

void LUF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}

} // namespace mlx::core
6 changes: 6 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ void Eigh::eval_gpu(
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}

void LUF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}

void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd)
NO_CPU(LogicalOr)
NO_CPU(LogAddExp)
NO_CPU_MULTI(LUF)
NO_CPU(Matmul)
NO_CPU(Maximum)
NO_CPU(Minimum)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
Expand Down
132 changes: 131 additions & 1 deletion mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {

array tri_inv(
const array& a,
bool upper /* = true */,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
return inv_impl(a, /*tri=*/true, upper, s);
}
Expand Down Expand Up @@ -519,4 +519,134 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}

void validate_lu(
const array& a,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname
<< " Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}

std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
int m = a.shape()[a.shape().size() - 2];
int n = a.shape()[a.shape().size() - 1];

Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
pivots_shape.push_back(std::min(m, n));

return array::make_arrays(
{a.shape(), pivots_shape, pivots_shape},
{a.dtype(), uint32, uint32},
std::make_shared<LUF>(to_stream(s)),
{astype(a, a.dtype(), s)});
}

std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu]");

auto out = lu_helper(a, s);
auto& LU = out[0];
auto& row_pivots = out[2];

int N = a.shape(-1);
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
auto U = triu(LU, /* k = */ 0, s);
return {row_pivots, L, U};
}

std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu_factor]");
auto out = lu_helper(a, s);
return std::make_pair(out[0], out[1]);
}

void validate_solve(
const array& a,
const array& b,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " First input must have >= 2 dimensions. "
<< "Received array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (b.ndim() < 1) {
std::ostringstream msg;
msg << fname << " Second input must have >= 1 dimensions. "
<< "Received array with " << b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << fname << " First input must be a square matrix. "
<< "Received array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}

int lastDim = b.ndim() > 1 ? -2 : -1;
if (a.shape(-1) != b.shape(lastDim)) {
std::ostringstream msg;
msg << fname << " Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}

auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) {
std::ostringstream msg;
msg << fname << " Input arrays must promote to float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}
}

array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve]");

// P, L, U matrices
const auto luf = lu(a, s);
auto perm = argsort(luf[0], -1, s);
int take_axis = -1;
if (b.ndim() >= 2) {
perm = expand_dims(perm, -1, s);
take_axis -= 1;
}
auto pb = take_along_axis(b, perm, take_axis);
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s);
}

array solve_triangular(
const array& a,
const array& b,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve_triangular]");
auto a_inv = tri_inv(a, upper, s);
return matmul(a_inv, b, s);
}

} // namespace mlx::core::linalg
12 changes: 12 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

std::vector<array> lu(const array& a, StreamOrDevice s = {});

std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});

array solve(const array& a, const array& b, StreamOrDevice s = {});

array solve_triangular(
const array& a,
const array& b,
bool upper = false,
StreamOrDevice s = {});

/**
* Compute the cross product of two arrays along the given axis.
*/
Expand Down
13 changes: 12 additions & 1 deletion mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
Expand All @@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
bool compute_eigenvectors_;
};

/* LU Factorization primitive. */
class LUF : public Primitive {
public:
explicit LUF(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

DEFINE_PRINT(LUF)
};

} // namespace mlx::core
Loading