Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter axis + gather axis primitives #1813

Merged
merged 3 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 301 additions & 20 deletions mlx/backend/common/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}

template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}

template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
Expand Down Expand Up @@ -169,14 +164,11 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
std::vector<array> inds(inputs.begin() + 1, inputs.end());

if (inds.empty()) {
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
}

switch (inds[0].dtype()) {
case bool_:
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
break;
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
Expand All @@ -201,12 +193,142 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);

strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);

auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
auto dst_ptr = out.data<T>();
auto ind_ax_stride = ind.strides(axis);
auto src_ax_stride = src.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto ind_ax_size = ind.shape(axis);
auto src_ax_size = src.shape(axis);

size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= ind.shape(i);
}
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
dst_ptr += stride_pre;
}
}

template <typename IdxT>
void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}

void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[Gather::eval] Cannot gather with floating point indices.");
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
}
Expand Down Expand Up @@ -296,14 +418,11 @@ void dispatch_scatter(
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
return;
}

switch (inds[0].dtype()) {
case bool_:
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
break;
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
Expand All @@ -328,12 +447,9 @@ void dispatch_scatter(
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
case float16:
case float32:
case bfloat16:
case complex64:
default:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
"[Scatter::eval_cpu] Cannot scatter with indices type.");
}
}

Expand All @@ -345,7 +461,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& updates = inputs.back();

// Copy src into out (copy allocates memory for out)
copy(src, out, CopyType::General);
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);

switch (src.dtype()) {
case bool_:
Expand Down Expand Up @@ -390,4 +508,167 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}

template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);

strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);

auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
auto dst_ptr = out.data<T>();
auto idx_ax_stride = idx.strides(axis);
auto upd_ax_stride = upd.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);

size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
dst_ptr += stride_pre;
}
}

template <typename InT, typename IdxT>
void dispatch_scatter_axis_op(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
break;
}
}

template <typename InT>
void dispatch_scatter_axis(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
}
}

void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);

auto& src = inputs[0];
auto& idx = inputs[1];
auto& updates = inputs[2];

// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);

switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
}

} // namespace mlx::core
2 changes: 2 additions & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

make_jit_source(hadamard)

if(MLX_METAL_JIT)
Expand Down
Loading