From 6fb1fffdc87bb8cb5f1704e4006e6825b109d330 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 30 Jan 2025 16:26:06 -0800 Subject: [PATCH 1/3] scatter axis + gather axis primitives --- mlx/backend/common/indexing.cpp | 315 +++++++++++++++++++++-- mlx/backend/metal/CMakeLists.txt | 2 + mlx/backend/metal/indexing.cpp | 214 +++++++++++++++ mlx/backend/metal/jit/includes.h | 2 + mlx/backend/metal/kernels/gather_axis.h | 44 ++++ mlx/backend/metal/kernels/scatter_axis.h | 52 ++++ mlx/backend/no_cpu/primitives.cpp | 2 + mlx/backend/no_metal/primitives.cpp | 2 + mlx/ops.cpp | 85 +++--- mlx/primitives.cpp | 57 +++- mlx/primitives.h | 56 ++++ python/tests/test_ops.py | 9 + 12 files changed, 778 insertions(+), 62 deletions(-) create mode 100644 mlx/backend/metal/kernels/gather_axis.h create mode 100644 mlx/backend/metal/kernels/scatter_axis.h diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index b0e354e325..29828447e4 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -16,11 +16,6 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) { return (idx < 0) ? idx + size : idx; } -template <> -inline size_t offset_neg_idx(bool idx, size_t) { - return idx; -} - template <> inline size_t offset_neg_idx(uint32_t idx, size_t) { return idx; @@ -169,14 +164,11 @@ void Gather::eval_cpu(const std::vector& inputs, array& out) { std::vector inds(inputs.begin() + 1, inputs.end()); if (inds.empty()) { - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_); return; } switch (inds[0].dtype()) { - case bool_: - dispatch_gather(src, inds, out, axes_, slice_sizes_); - break; case uint8: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; @@ -201,12 +193,142 @@ void Gather::eval_cpu(const std::vector& inputs, array& out) { case int64: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; + default: + throw std::runtime_error( + "[Gather::eval_cpu] Cannot gather with indices type."); + break; + } +} +template +void gather_axis( + const array& src, + const array& ind, + array& out, + const int axis) { + auto strides = ind.strides(); + strides.erase(strides.begin() + axis); + auto shape = ind.shape(); + shape.erase(shape.begin() + axis); + ContiguousIterator ind_it(shape, strides, src.ndim() - 1); + + strides = src.strides(); + strides.erase(strides.begin() + axis); + ContiguousIterator src_it(shape, strides, src.ndim() - 1); + + auto ind_ptr = ind.data(); + auto src_ptr = src.data(); + auto dst_ptr = out.data(); + auto ind_ax_stride = ind.strides(axis); + auto src_ax_stride = src.strides(axis); + auto dst_ax_stride = out.strides(axis); + auto ind_ax_size = ind.shape(axis); + auto src_ax_size = src.shape(axis); + + size_t size_pre = 1; + size_t size_post = 1; + for (int i = 0; i < axis; ++i) { + size_pre *= ind.shape(i); + } + for (int i = axis + 1; i < ind.ndim(); ++i) { + size_post *= ind.shape(i); + } + size_t stride_pre = size_post * ind_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < ind_ax_size; ++j) { + auto ind_val = offset_neg_idx( + ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); + dst_ptr[k + j * dst_ax_stride] = + src_ptr[src_it.loc + ind_val * src_ax_stride]; + } + ind_it.step(); + src_it.step(); + } + dst_ptr += stride_pre; + } +} + +template +void dispatch_gather_axis( + const array& src, + const array& inds, + array& out, + const int axis) { + switch (out.dtype()) { + case bool_: + gather_axis(src, inds, out, axis); + break; + case uint8: + gather_axis(src, inds, out, axis); + break; + case uint16: + gather_axis(src, inds, out, axis); + break; + case uint32: + gather_axis(src, inds, out, axis); + break; + case uint64: + gather_axis(src, inds, out, axis); + break; + case int8: + gather_axis(src, inds, out, axis); + break; + case int16: + gather_axis(src, inds, out, axis); + break; + case int32: + gather_axis(src, inds, out, axis); + break; + case int64: + gather_axis(src, inds, out, axis); + break; case float16: + gather_axis(src, inds, out, axis); + break; case float32: + gather_axis(src, inds, out, axis); + break; case bfloat16: + gather_axis(src, inds, out, axis); + break; case complex64: + gather_axis(src, inds, out, axis); + break; + } +} + +void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& src = inputs[0]; + auto& inds = inputs[1]; + switch (inds.dtype()) { + case uint8: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint16: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint32: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint64: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int8: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int16: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int32: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int64: + dispatch_gather_axis(src, inds, out, axis_); + break; + default: throw std::runtime_error( - "[Gather::eval] Cannot gather with floating point indices."); + "[GatherAxis::eval_cpu] Cannot gather with indices type."); break; } } @@ -296,14 +418,11 @@ void dispatch_scatter( const std::vector& axes, Scatter::ReduceType rtype) { if (inds.empty()) { - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds(out, inds, updates, axes, rtype); return; } switch (inds[0].dtype()) { - case bool_: - dispatch_scatter_inds(out, inds, updates, axes, rtype); - break; case uint8: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; @@ -328,12 +447,9 @@ void dispatch_scatter( case int64: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; - case float16: - case float32: - case bfloat16: - case complex64: + default: throw std::runtime_error( - "[Scatter::eval_cpu] Cannot scatter with floating point indices."); + "[Scatter::eval_cpu] Cannot scatter with indices type."); } } @@ -390,4 +506,165 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { } } +template +void scatter_axis( + array& out, + const array idx, + const array& upd, + int axis, + const OpT& op) { + auto strides = idx.strides(); + strides.erase(strides.begin() + axis); + auto shape = idx.shape(); + shape.erase(shape.begin() + axis); + ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); + + strides = upd.strides(); + strides.erase(strides.begin() + axis); + ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); + + auto idx_ptr = idx.data(); + auto upd_ptr = upd.data(); + auto dst_ptr = out.data(); + auto idx_ax_stride = idx.strides(axis); + auto upd_ax_stride = upd.strides(axis); + auto dst_ax_stride = out.strides(axis); + auto idx_ax_size = idx.shape(axis); + auto dst_ax_size = out.shape(axis); + + size_t size_pre = 1; + size_t size_post = 1; + for (int i = 0; i < axis; ++i) { + size_pre *= idx.shape(i); + } + for (int i = axis + 1; i < idx.ndim(); ++i) { + size_post *= idx.shape(i); + } + size_t stride_pre = size_post * dst_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < idx_ax_size; ++j) { + auto ind_val = offset_neg_idx( + idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); + op(upd_ptr[upd_it.loc + j * upd_ax_stride], + dst_ptr + k + ind_val * dst_ax_stride); + } + idx_it.step(); + upd_it.step(); + } + dst_ptr += stride_pre; + } +} + +template +void dispatch_scatter_axis_op( + array& out, + const array& idx, + const array& updates, + int axis, + ScatterAxis::ReduceType rtype) { + switch (rtype) { + case ScatterAxis::None: + scatter_axis( + out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }); + break; + case ScatterAxis::Sum: + scatter_axis( + out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }); + break; + } +} + +template +void dispatch_scatter_axis( + array& out, + const array& idx, + const array& updates, + int axis, + ScatterAxis::ReduceType rtype) { + switch (idx.dtype()) { + case uint8: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case uint16: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case uint32: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case uint64: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case int8: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case int16: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case int32: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + case int64: + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + break; + default: + throw std::runtime_error( + "[ScatterAxis::eval_cpu] Cannot scatter with indices type."); + } +} + +void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() >= 2); + + auto& src = inputs[0]; + auto& idx = inputs[1]; + auto& updates = inputs[2]; + + // Copy src into out (copy allocates memory for out) + copy(src, out, CopyType::General); + + switch (src.dtype()) { + case bool_: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint8: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int8: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case float16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case float32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case bfloat16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case complex64: + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_); + break; + } +} + } // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index b932a5654e..492b74810e 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -35,6 +35,8 @@ make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(scatter kernels/indexing.h) make_jit_source(gather kernels/indexing.h) +make_jit_source(gather_axis) +make_jit_source(scatter_axis) make_jit_source(hadamard) if(MLX_METAL_JIT) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b27765a406..13696842c0 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -388,4 +389,217 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + auto& src = inputs[0]; + auto& idx = inputs[1]; + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (out.size() == 0) { + return; + } + + auto& s = stream(); + auto& d = metal::device(s.device); + + size_t ndim = src.ndim(); + + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; + + std::string kernel_name = fmt::format( + "gather_axis{0}{1}_{2}", + type_to_name(out), + type_to_name(idx), + large ? "int64_t" : "int"); + std::string lib_name = kernel_name; + kernel_name += src.flags().row_contiguous ? "c" : "nc"; + kernel_name += idx.flags().row_contiguous ? "c" : "nc"; + + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source = metal::utils(); + kernel_source += metal::gather_axis(); + std::string out_type_str = get_type_string(out.dtype()); + std::string idx_type_str = get_type_string(idx.dtype()); + for (int i = 0; i < 4; ++i) { + bool sc = i & 1; + bool ic = i & 2; + kernel_source += get_template_definition( + lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"), + "gather_axis", + out_type_str, + idx_type_str, + large ? "int64_t" : "int", + sc ? "true" : "false", + ic ? "true" : "false"); + } + return kernel_source; + }); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kernel_name, lib); + compute_encoder.set_compute_pipeline_state(kernel); + + // Grid [size post, index size, size pre] + size_t size_pre = 1; + size_t size_post = 1; + for (int i = 0; i < axis_; ++i) { + size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + size_post *= idx.shape(i); + } + + int idx_ax_size = idx.shape(axis_); + auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre); + MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre); + + // Set all the buffers + compute_encoder.set_input_array(src, 0); + compute_encoder.set_input_array(idx, 1); + compute_encoder.set_output_array(out, 2); + + // Set source info + auto shape = idx.shape(); + shape.erase(shape.begin() + axis_); + compute_encoder.set_vector_bytes(shape, 3); + + auto strides = src.strides(); + strides.erase(strides.begin() + axis_); + compute_encoder.set_vector_bytes(strides, 4); + + strides = idx.strides(); + strides.erase(strides.begin() + axis_); + compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_bytes(ndim - 1, 6); + compute_encoder.set_bytes(axis_, 7); + compute_encoder.set_bytes(src.shape(axis_), 8); + compute_encoder.set_bytes(src.strides(axis_), 9); + compute_encoder.set_bytes(idx.strides(axis_), 10); + + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + auto& src = inputs[0]; + auto& idx = inputs[1]; + auto& upd = inputs[2]; + + // Copy src into out + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update + if (upd.size() == 0) { + return; + } + + auto& s = stream(); + auto& d = metal::device(s.device); + + size_t ndim = src.ndim(); + + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; + + std::string op_name; + switch (reduce_type_) { + case ScatterAxis::None: + op_name = "none"; + break; + case ScatterAxis::Sum: + op_name = "sum"; + break; + } + + std::string kernel_name = fmt::format( + "scatter_axis{0}{1}_{2}_{3}", + type_to_name(out), + type_to_name(idx), + op_name, + large ? "int64_t" : "int"); + std::string lib_name = kernel_name; + kernel_name += upd.flags().row_contiguous ? "c" : "nc"; + kernel_name += idx.flags().row_contiguous ? "c" : "nc"; + + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source = metal::utils(); + kernel_source += metal::reduce_utils(); + kernel_source += metal::scatter_axis(); + std::string out_type_str = get_type_string(out.dtype()); + std::string idx_type_str = get_type_string(idx.dtype()); + std::string op_type; + switch (reduce_type_) { + case ScatterAxis::None: + op_type = "None"; + break; + case ScatterAxis::Sum: + op_type = "Sum<" + out_type_str + ">"; + break; + } + + for (int i = 0; i < 4; ++i) { + bool uc = i & 1; + bool ic = i & 2; + kernel_source += get_template_definition( + lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"), + "scatter_axis", + out_type_str, + idx_type_str, + large ? "int64_t" : "int", + op_type, + uc ? "true" : "false", + ic ? "true" : "false"); + } + return kernel_source; + }); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kernel_name, lib); + compute_encoder.set_compute_pipeline_state(kernel); + + // Grid [size post, index size, size pre] + size_t size_pre = 1; + size_t size_post = 1; + for (int i = 0; i < axis_; ++i) { + size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + size_post *= idx.shape(i); + } + + int idx_ax_size = idx.shape(axis_); + auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre); + MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre); + + // Set all the buffers + compute_encoder.set_input_array(upd, 0); + compute_encoder.set_input_array(idx, 1); + compute_encoder.set_output_array(out, 2); + + // Set source info + auto shape = idx.shape(); + shape.erase(shape.begin() + axis_); + compute_encoder.set_vector_bytes(shape, 3); + + auto strides = upd.strides(); + strides.erase(strides.begin() + axis_); + compute_encoder.set_vector_bytes(strides, 4); + + strides = idx.strides(); + strides.erase(strides.begin() + axis_); + compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_bytes(ndim - 1, 6); + compute_encoder.set_bytes(axis_, 7); + compute_encoder.set_bytes(out.shape(axis_), 8); + compute_encoder.set_bytes(upd.strides(axis_), 9); + compute_encoder.set_bytes(idx.strides(axis_), 10); + + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 89b025c665..b14aa567b0 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -18,10 +18,12 @@ const char* binary(); const char* binary_two(); const char* copy(); const char* fft(); +const char* gather_axis(); const char* hadamard(); const char* quantized(); const char* ternary(); const char* scan(); +const char* scatter_axis(); const char* softmax(); const char* sort(); const char* reduce(); diff --git a/mlx/backend/metal/kernels/gather_axis.h b/mlx/backend/metal/kernels/gather_axis.h new file mode 100644 index 0000000000..bf490ade06 --- /dev/null +++ b/mlx/backend/metal/kernels/gather_axis.h @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template +[[kernel]] void gather_axis( + const device T* src [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device T* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* src_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& axis_size [[buffer(8)]], + const constant size_t& src_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + LocT elem_idx = index.z * static_cast(grid_dim.x); + LocT out_idx = elem_idx * grid_dim.y + index.x; + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += out_idx; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; + } + + LocT src_idx = idx_val * static_cast(src_ax_stride); + if (SrcC) { + src_idx += elem_idx * axis_size + index.x; + } else { + src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); + } + + out_idx += index.y * static_cast(grid_dim.x); + out[out_idx] = src[src_idx]; +} diff --git a/mlx/backend/metal/kernels/scatter_axis.h b/mlx/backend/metal/kernels/scatter_axis.h new file mode 100644 index 0000000000..73fd7ab4a3 --- /dev/null +++ b/mlx/backend/metal/kernels/scatter_axis.h @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template < + typename T, + typename IdxT, + typename LocT, + typename Op, + bool UpdC, + bool IdxC> +[[kernel]] void scatter_axis( + const device T* upd [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* upd_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& out_axis_size [[buffer(8)]], + const constant size_t& upd_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + Op op; + + LocT elem_idx = index.z * static_cast(grid_dim.x); + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += elem_idx * grid_dim.y + index.x; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; + } + + LocT upd_idx = index.y * static_cast(upd_ax_stride); + if (UpdC) { + upd_idx += elem_idx * grid_dim.y + index.x; + } else { + upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); + } + + LocT out_idx = elem_idx * static_cast(out_axis_size) + + idx_val * grid_dim.x + index.x; + op.atomic_update(out, upd[upd_idx], out_idx); +} diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index acc5d7560f..ce04cd600b 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -65,6 +65,7 @@ NO_CPU(Flatten) NO_CPU(Floor) NO_CPU(Full) NO_CPU(Gather) +NO_CPU(GatherAxis) NO_CPU(GatherMM) NO_CPU(GatherQMM) NO_CPU(Greater) @@ -98,6 +99,7 @@ NO_CPU(Reshape) NO_CPU(Round) NO_CPU(Scan) NO_CPU(Scatter) +NO_CPU(ScatterAxis) NO_CPU(Select) NO_CPU(Sigmoid) NO_CPU(Sign) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index a4999aa9f0..f6d65ebe62 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -65,6 +65,7 @@ NO_GPU(Flatten) NO_GPU(Floor) NO_GPU(Full) NO_GPU(Gather) +NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Greater) @@ -98,6 +99,7 @@ NO_GPU(Reshape) NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) +NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(Sigmoid) NO_GPU(Sign) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ce7216feac..7691ffa256 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -68,7 +68,7 @@ array indices_or_default( Shape shape(x.shape().begin(), x.shape().end() - 2); int total = std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); - return reshape(arange(total, uint32, s), shape, s); + return reshape(arange(total, uint32, s), std::move(shape), s); } std::pair extract_quantized_matmul_dims( @@ -3080,28 +3080,20 @@ array take_along_axis( // Allow negative axis axis = axis < 0 ? a.ndim() + axis : axis; - std::vector nd_indices; - Shape index_shape(a.ndim(), 1); - for (int i = 0; i < a.ndim(); ++i) { - if (i == axis) { - nd_indices.push_back(indices); - } else { - // Reshape so they can be broadcast - index_shape[i] = a.shape(i); - nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s)); - index_shape[i] = 1; - } - } - std::vector dims(a.ndim()); - std::iota(dims.begin(), dims.end(), 0); - Shape slice_sizes(a.ndim(), 1); - auto out = gather(a, nd_indices, dims, slice_sizes, s); - - // Squeeze out the slice shape - for (auto& d : dims) { - d += a.ndim(); + // Broadcast indices to input shape ignoring the take axis + auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s); + if (inputs[0].shape() != a.shape()) { + std::ostringstream msg; + msg << "[take_along_axis] Indices of shape " << indices.shape() + << " do not broadcast to array of shape " << a.shape() << "." + << std::endl; + throw std::invalid_argument(msg.str()); } - return squeeze(out, dims, s); + return array( + inputs[1].shape(), + a.dtype(), + std::make_shared(to_stream(s), axis), + std::move(inputs)); } array put_along_axis( @@ -3127,28 +3119,24 @@ array put_along_axis( // Allow negative axis axis = axis < 0 ? a.ndim() + axis : axis; - std::vector nd_indices; - Shape index_shape(a.ndim(), 1); - for (int i = 0; i < a.ndim(); ++i) { - if (i == axis) { - nd_indices.push_back(indices); - } else { - // Reshape so they can be broadcast - index_shape[i] = a.shape(i); - nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s)); - index_shape[i] = 1; - } - } + auto inputs = broadcast_arrays({indices, values}, s); + inputs.insert(inputs.begin(), a); - auto update = astype(broadcast_to(values, indices.shape(), s), a.dtype(), s); - { - auto update_shape = update.shape(); - update_shape.resize(update_shape.size() + a.ndim(), 1); - update = reshape(update, std::move(update_shape), s); + // Broadcast indices, values to src shape ignoring the take axis + inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s); + if (inputs[0].shape() != a.shape()) { + std::ostringstream msg; + msg << "[take_along_axis] Indices of shape " << indices.shape() + << " do not broadcast to array of shape " << a.shape() << "." + << std::endl; + throw std::invalid_argument(msg.str()); } - std::vector dims(a.ndim()); - std::iota(dims.begin(), dims.end(), 0); - return scatter(a, nd_indices, update, dims, s); + inputs[2] = astype(inputs[2], a.dtype(), s); + return array( + inputs[0].shape(), + a.dtype(), + std::make_shared(to_stream(s), ScatterAxis::None, axis), + std::move(inputs)); } /** Scatter updates to given indices */ @@ -3962,6 +3950,19 @@ array gather_qmm( std::tie(lhs_indices, rhs_indices) = broadcast_arrays(lhs_indices, rhs_indices, s); + if (!issubdtype(lhs_indices.dtype(), integer)) { + throw std::invalid_argument( + "[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral."); + } + + if (!issubdtype(rhs_indices.dtype(), integer)) { + throw std::invalid_argument( + "[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral."); + } + + lhs_indices = astype(lhs_indices, uint32, s); + rhs_indices = astype(rhs_indices, uint32, s); + // Compute the full output shape auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3346f01c8c..29a88efe7b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2098,6 +2098,52 @@ bool Gather::is_equivalent(const Primitive& other) const { return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_; } +std::pair, std::vector> GatherAxis::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{inputs[0]}, axes}; +} + +std::vector GatherAxis::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + for (int argnum : argnums) { + if (argnum > 0) { + // Grads w.r.t. indices are zero + vjps.push_back( + zeros(primals[argnum].shape(), primals[argnum].dtype(), stream())); + } else { + auto src = zeros_like(primals[0], stream()); + vjps.push_back( + put_along_axis(src, primals[1], cotangents[0], axis_, stream())); + } + } + return vjps; +} + +std::vector GatherAxis::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + if (argnums.size() > 1 || argnums[0] != 0) { + throw std::invalid_argument( + "[gather_axis] Cannot calculate JVP with respect to indices."); + } + return {take_along_axis(tangents[0], primals[1], axis_, stream())}; +} + +std::vector GatherAxis::output_shapes(const std::vector& inputs) { + return {inputs[1].shape()}; +} + +bool GatherAxis::is_equivalent(const Primitive& other) const { + auto& g_other = static_cast(other); + return axis_ == g_other.axis_; +} + std::vector Gather::output_shapes(const std::vector& inputs) { Shape out_shape; if (inputs.size() > 1) { @@ -2106,7 +2152,6 @@ std::vector Gather::output_shapes(const std::vector& inputs) { out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); return {std::move(out_shape)}; } - std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { @@ -3621,6 +3666,16 @@ std::pair, std::vector> Scatter::vmap( return {{out}, {src_ax}}; } +std::vector ScatterAxis::output_shapes( + const std::vector& inputs) { + return {inputs[0].shape()}; +} + +bool ScatterAxis::is_equivalent(const Primitive& other) const { + auto& s_other = static_cast(other); + return reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_; +} + std::vector Sigmoid::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 8158c88d67..db5219e0a9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1097,6 +1097,27 @@ class Gather : public UnaryPrimitive { Shape slice_sizes_; }; +class GatherAxis : public UnaryPrimitive { + public: + explicit GatherAxis(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(GatherAxis) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + class Greater : public UnaryPrimitive { public: explicit Greater(Stream stream) : UnaryPrimitive(stream) {} @@ -1788,6 +1809,41 @@ class Scatter : public UnaryPrimitive { std::vector axes_; }; +class ScatterAxis : public UnaryPrimitive { + public: + enum ReduceType { Sum, None }; + + explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + // DEFINE_VMAP() + // DEFINE_GRADS() + + void print(std::ostream& os) override { + os << "ScatterAxis"; + switch (reduce_type_) { + case Sum: + os << " Sum"; + break; + case None: + break; + } + } + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair state() const { + return {reduce_type_, axis_}; + } + + private: + ReduceType reduce_type_; + int axis_; +}; + class Sigmoid : public UnaryPrimitive { public: explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 97b8afa472..ce69c29f1a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1150,6 +1150,15 @@ def test_put_along_axis(self): out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax) self.assertTrue(np.array_equal(a_np, out_mlx)) + source = mx.zeros((1, 1, 8, 32)) + indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1)) + update = mx.array(1.0) + + out_mlx = mx.put_along_axis(source, indices, update, axis=-2) + out_np = np.array(source) + np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) + self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3) From aadf66e3fda3a014c267e48ef43e68882e8b15e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 Jan 2025 08:16:02 -0800 Subject: [PATCH 2/3] add transforms --- mlx/ops.cpp | 79 ++++++++++++-------- mlx/ops.h | 8 ++ mlx/primitives.cpp | 133 +++++++++++++++++++++++++++++++++- mlx/primitives.h | 4 +- python/tests/test_autograd.py | 31 ++++++++ python/tests/test_vmap.py | 47 ++++++++++++ 6 files changed, 267 insertions(+), 35 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7691ffa256..51cfd87834 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3080,73 +3080,92 @@ array take_along_axis( // Allow negative axis axis = axis < 0 ? a.ndim() + axis : axis; - // Broadcast indices to input shape ignoring the take axis + // Broadcast indices and input ignoring the take axis auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s); - if (inputs[0].shape() != a.shape()) { - std::ostringstream msg; - msg << "[take_along_axis] Indices of shape " << indices.shape() - << " do not broadcast to array of shape " << a.shape() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } + + auto out_shape = inputs[1].shape(); return array( - inputs[1].shape(), + std::move(out_shape), a.dtype(), std::make_shared(to_stream(s), axis), std::move(inputs)); } -array put_along_axis( +array scatter_axis( const array& a, const array& indices, const array& values, int axis, - StreamOrDevice s /* = {} */) { + ScatterAxis::ReduceType mode, + StreamOrDevice s) { + std::string prefix = + (mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]"; if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { std::ostringstream msg; - msg << "[put_along_axis] Received invalid axis " << " for array with " - << a.ndim() << " dimensions."; + msg << prefix << " Received invalid axis " << " for array with " << a.ndim() + << " dimensions."; throw std::invalid_argument(msg.str()); } if (indices.ndim() != a.ndim()) { std::ostringstream msg; - msg << "[put_along_axis] Indices of dimension " << indices.ndim() + msg << prefix << " Indices of dimension " << indices.ndim() << " does not match array of dimension " << a.ndim() << "."; throw std::invalid_argument(msg.str()); } - // Allow negative axis - axis = axis < 0 ? a.ndim() + axis : axis; + auto upd = astype(values, a.dtype(), s); + + // Squeeze leading singletons out of update + if (upd.ndim() > indices.ndim()) { + std::vector sq_ax(upd.ndim() - indices.ndim()); + std::iota(sq_ax.begin(), sq_ax.end(), 0); + upd = squeeze(upd, sq_ax, s); + } - auto inputs = broadcast_arrays({indices, values}, s); + auto inputs = broadcast_arrays({indices, upd}, s); inputs.insert(inputs.begin(), a); - // Broadcast indices, values to src shape ignoring the take axis + // Allow negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + // Broadcast src, indices, values while ignoring the take axis inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s); - if (inputs[0].shape() != a.shape()) { - std::ostringstream msg; - msg << "[take_along_axis] Indices of shape " << indices.shape() - << " do not broadcast to array of shape " << a.shape() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - inputs[2] = astype(inputs[2], a.dtype(), s); + + auto out_shape = inputs[0].shape(); return array( - inputs[0].shape(), + std::move(out_shape), a.dtype(), - std::make_shared(to_stream(s), ScatterAxis::None, axis), + std::make_shared(to_stream(s), mode, axis), std::move(inputs)); } +array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s /* = {} */) { + return scatter_axis(a, indices, values, axis, ScatterAxis::None, s); +} + +array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s /* = {} */) { + return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s); +} + /** Scatter updates to given indices */ array scatter( const array& a, const std::vector& indices, const array& updates, const std::vector& axes, - Scatter::ReduceType mode /*= Scatter::ReduceType::None*/, - StreamOrDevice s /*= {}*/) { + Scatter::ReduceType mode, + StreamOrDevice s) { // Checks that indices, dimensions, and slice_sizes are all valid if (indices.size() > a.ndim()) { std::ostringstream msg; diff --git a/mlx/ops.h b/mlx/ops.h index 141cfde709..c0cbc2780b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -968,6 +968,14 @@ array put_along_axis( int axis, StreamOrDevice s = {}); +/** Add the values into the array at the given indices along the axis */ +array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + /** Scatter updates to the given indices. * * The parameters ``indices`` and ``axes`` determine the locations of ``a`` diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 29a88efe7b..4cde848318 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2101,7 +2101,29 @@ bool Gather::is_equivalent(const Primitive& other) const { std::pair, std::vector> GatherAxis::vmap( const std::vector& inputs, const std::vector& axes) { - return {{inputs[0]}, axes}; + bool vmap_in = axes[0] >= 0; + bool vmap_idx = axes[1] >= 0; + + auto in = inputs[0]; + auto idx = inputs[1]; + int out_ax; + if (vmap_in && vmap_idx) { + // reorder the vmap axes to the same location + idx = moveaxis(idx, axes[1], axes[0], stream()); + out_ax = axes[0]; + } else if (vmap_in) { + // expand just the indices dimension + idx = expand_dims(idx, axes[0], stream()); + out_ax = axes[0]; + } else if (vmap_idx) { + // expand just the input dimension + in = expand_dims(in, axes[1], stream()); + out_ax = axes[1]; + } else { + out_ax = -1; + } + int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_; + return {{take_along_axis(in, idx, axis, stream())}, {out_ax}}; } std::vector GatherAxis::vjp( @@ -2117,8 +2139,11 @@ std::vector GatherAxis::vjp( zeros(primals[argnum].shape(), primals[argnum].dtype(), stream())); } else { auto src = zeros_like(primals[0], stream()); - vjps.push_back( - put_along_axis(src, primals[1], cotangents[0], axis_, stream())); + vjps.push_back(array( + src.shape(), + src.dtype(), + std::make_shared(stream(), ScatterAxis::Sum, axis_), + {src, primals[1], cotangents[0]})); } } return vjps; @@ -2152,6 +2177,7 @@ std::vector Gather::output_shapes(const std::vector& inputs) { out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); return {std::move(out_shape)}; } + std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { @@ -3666,6 +3692,107 @@ std::pair, std::vector> Scatter::vmap( return {{out}, {src_ax}}; } +std::vector ScatterAxis::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + const auto& indices = primals[1]; + const auto& updates = primals[2]; + + std::vector vjps; + for (auto num : argnums) { + // Gradient wrt to the input array + if (num == 0) { + if (reduce_type_ == ScatterAxis::None) { + // Scatter 0s to the locations that were updated with the updates + vjps.push_back(put_along_axis( + cotangents[0], + indices, + zeros_like(updates, stream()), + axis_, + stream())); + } else { + // The input array values are kept so they all get gradients + vjps.push_back(cotangents[0]); + } + } else if (num == 2) { + vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream())); + } else { + throw std::invalid_argument( + "[scatter_axis] Cannot calculate VJP with respect to indices."); + } + } + return vjps; +} + +std::vector ScatterAxis::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + for (auto arg : argnums) { + if (arg == 1) { + throw std::invalid_argument( + "[scatter_axis] Cannot calculate JVP with respect to indices."); + } + } + if (argnums.size() == 2) { + return {array( + primals[0].shape(), + primals[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + {tangents[0], primals[1], tangents[1]})}; + } else { + auto tan_a = + argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream()); + auto tan_b = + argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream()); + return {array( + primals[0].shape(), + primals[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + {tan_a, primals[1], tan_b})}; + } +} + +std::pair, std::vector> ScatterAxis::vmap( + const std::vector& inputs, + const std::vector& axes) { + // Find the first vmap axis + int out_ax = -1; + for (auto ax : axes) { + if (ax >= 0) { + out_ax = ax; + break; + } + } + + if (out_ax < 0) { + return { + {array( + inputs[0].shape(), + inputs[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + inputs)}, + {-1}}; + } + + auto v_in = inputs; + for (int i = 0; i < axes.size(); ++i) { + if (axes[i] >= 0) { + // if out_ax >= 0 move axis o/w set out_ax + if (out_ax != axes[i]) { + v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream()); + } + } else { + v_in[i] = expand_dims(v_in[i], out_ax, stream()); + } + } + int axis = axis_ >= out_ax ? axis_ + 1 : axis_; + auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis; + return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}}; +} + std::vector ScatterAxis::output_shapes( const std::vector& inputs) { return {inputs[0].shape()}; diff --git a/mlx/primitives.h b/mlx/primitives.h index db5219e0a9..782ed7e275 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1819,8 +1819,8 @@ class ScatterAxis : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - // DEFINE_VMAP() - // DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_GRADS() void print(std::ostream& os) override { os << "ScatterAxis"; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 4ab7fb922a..b3281bd3e4 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -669,6 +669,37 @@ def test_matmul_jvps(self): _, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,)) self.assertTrue(mx.allclose(tangent, expected)) + def test_put_along_axis_grads(self): + a = mx.zeros((5, 1)) + b = mx.ones((2, 1)) + + def fun(a, b): + idx = mx.array([[0], [3]]) + return mx.put_along_axis(a, idx, b, axis=0) + + # Test VJP + cotan = mx.full((5, 1), 2.0) + _, (da, db) = mx.vjp(fun, (a, b), (cotan,)) + expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None] + expected_db = mx.array([2.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected_da, da)) + self.assertTrue(mx.allclose(expected_db, db)) + + # Test JVP + tan_a = mx.full((5, 1), 2.0) + tan_b = mx.full((2, 1), 3.0) + _, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b)) + expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected, jout)) + + def fun(a): + idx = mx.array([[0], [3]]) + return mx.put_along_axis(a, idx, b, axis=0) + + _, (jout,) = mx.jvp(fun, (a,), (tan_a,)) + expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected, jout)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 0789593c5d..b98bdb0fcb 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -549,6 +549,53 @@ def cat_constant(x): target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2) self.assertTrue(mx.array_equal(out, target)) + def test_vmap_take_along_axis(self): + a = mx.zeros((4, 5, 1)) + idx = mx.zeros((2, 4, 1), mx.int32) + + def fun(a, idx): + return mx.take_along_axis(a, idx, axis=0) + + out = mx.vmap(fun, in_axes=(0, 1))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + idx = mx.zeros((2, 1), mx.int32) + + out = mx.vmap(fun, in_axes=(0, None))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + a = mx.zeros((5, 1)) + idx = mx.zeros((4, 2, 1), mx.int32) + + out = mx.vmap(fun, in_axes=(None, 0))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + def test_vmap_put_along_axis(self): + a = mx.zeros((4, 5, 1)) + idx = mx.ones((2, 4, 1), mx.int32) + upd = mx.ones((2, 4, 1)) + + def fun(a, idx, upd): + return mx.put_along_axis(a, idx, upd, axis=0) + + out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + upd = mx.ones((2, 1)) + out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + idx = mx.ones((2, 1), mx.int32) + upd = mx.ones((2, 1)) + out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + a = mx.zeros((5, 1)) + idx = mx.ones((2, 4, 1), mx.int32) + upd = mx.ones((2, 4, 1)) + out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + if __name__ == "__main__": unittest.main() From 199baf0f0fb5f9da062cefc3b74bc1aa5a2412ea Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 Jan 2025 14:35:13 -0800 Subject: [PATCH 3/3] comment --- mlx/backend/common/indexing.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index 29828447e4..6798f0245e 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -461,7 +461,9 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { auto& updates = inputs.back(); // Copy src into out (copy allocates memory for out) - copy(src, out, CopyType::General); + auto ctype = + src.flags().row_contiguous ? CopyType::Vector : CopyType::General; + copy(src, out, ctype); switch (src.dtype()) { case bool_: @@ -621,7 +623,9 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { auto& updates = inputs[2]; // Copy src into out (copy allocates memory for out) - copy(src, out, CopyType::General); + auto ctype = + src.flags().row_contiguous ? CopyType::Vector : CopyType::General; + copy(src, out, ctype); switch (src.dtype()) { case bool_: