From aadf66e3fda3a014c267e48ef43e68882e8b15e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 Jan 2025 08:16:02 -0800 Subject: [PATCH] add transforms --- mlx/ops.cpp | 79 ++++++++++++-------- mlx/ops.h | 8 ++ mlx/primitives.cpp | 133 +++++++++++++++++++++++++++++++++- mlx/primitives.h | 4 +- python/tests/test_autograd.py | 31 ++++++++ python/tests/test_vmap.py | 47 ++++++++++++ 6 files changed, 267 insertions(+), 35 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7691ffa256..51cfd87834 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3080,73 +3080,92 @@ array take_along_axis( // Allow negative axis axis = axis < 0 ? a.ndim() + axis : axis; - // Broadcast indices to input shape ignoring the take axis + // Broadcast indices and input ignoring the take axis auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s); - if (inputs[0].shape() != a.shape()) { - std::ostringstream msg; - msg << "[take_along_axis] Indices of shape " << indices.shape() - << " do not broadcast to array of shape " << a.shape() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } + + auto out_shape = inputs[1].shape(); return array( - inputs[1].shape(), + std::move(out_shape), a.dtype(), std::make_shared(to_stream(s), axis), std::move(inputs)); } -array put_along_axis( +array scatter_axis( const array& a, const array& indices, const array& values, int axis, - StreamOrDevice s /* = {} */) { + ScatterAxis::ReduceType mode, + StreamOrDevice s) { + std::string prefix = + (mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]"; if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { std::ostringstream msg; - msg << "[put_along_axis] Received invalid axis " << " for array with " - << a.ndim() << " dimensions."; + msg << prefix << " Received invalid axis " << " for array with " << a.ndim() + << " dimensions."; throw std::invalid_argument(msg.str()); } if (indices.ndim() != a.ndim()) { std::ostringstream msg; - msg << "[put_along_axis] Indices of dimension " << indices.ndim() + msg << prefix << " Indices of dimension " << indices.ndim() << " does not match array of dimension " << a.ndim() << "."; throw std::invalid_argument(msg.str()); } - // Allow negative axis - axis = axis < 0 ? a.ndim() + axis : axis; + auto upd = astype(values, a.dtype(), s); + + // Squeeze leading singletons out of update + if (upd.ndim() > indices.ndim()) { + std::vector sq_ax(upd.ndim() - indices.ndim()); + std::iota(sq_ax.begin(), sq_ax.end(), 0); + upd = squeeze(upd, sq_ax, s); + } - auto inputs = broadcast_arrays({indices, values}, s); + auto inputs = broadcast_arrays({indices, upd}, s); inputs.insert(inputs.begin(), a); - // Broadcast indices, values to src shape ignoring the take axis + // Allow negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + // Broadcast src, indices, values while ignoring the take axis inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s); - if (inputs[0].shape() != a.shape()) { - std::ostringstream msg; - msg << "[take_along_axis] Indices of shape " << indices.shape() - << " do not broadcast to array of shape " << a.shape() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - inputs[2] = astype(inputs[2], a.dtype(), s); + + auto out_shape = inputs[0].shape(); return array( - inputs[0].shape(), + std::move(out_shape), a.dtype(), - std::make_shared(to_stream(s), ScatterAxis::None, axis), + std::make_shared(to_stream(s), mode, axis), std::move(inputs)); } +array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s /* = {} */) { + return scatter_axis(a, indices, values, axis, ScatterAxis::None, s); +} + +array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s /* = {} */) { + return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s); +} + /** Scatter updates to given indices */ array scatter( const array& a, const std::vector& indices, const array& updates, const std::vector& axes, - Scatter::ReduceType mode /*= Scatter::ReduceType::None*/, - StreamOrDevice s /*= {}*/) { + Scatter::ReduceType mode, + StreamOrDevice s) { // Checks that indices, dimensions, and slice_sizes are all valid if (indices.size() > a.ndim()) { std::ostringstream msg; diff --git a/mlx/ops.h b/mlx/ops.h index 141cfde709..c0cbc2780b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -968,6 +968,14 @@ array put_along_axis( int axis, StreamOrDevice s = {}); +/** Add the values into the array at the given indices along the axis */ +array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + /** Scatter updates to the given indices. * * The parameters ``indices`` and ``axes`` determine the locations of ``a`` diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 29a88efe7b..4cde848318 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2101,7 +2101,29 @@ bool Gather::is_equivalent(const Primitive& other) const { std::pair, std::vector> GatherAxis::vmap( const std::vector& inputs, const std::vector& axes) { - return {{inputs[0]}, axes}; + bool vmap_in = axes[0] >= 0; + bool vmap_idx = axes[1] >= 0; + + auto in = inputs[0]; + auto idx = inputs[1]; + int out_ax; + if (vmap_in && vmap_idx) { + // reorder the vmap axes to the same location + idx = moveaxis(idx, axes[1], axes[0], stream()); + out_ax = axes[0]; + } else if (vmap_in) { + // expand just the indices dimension + idx = expand_dims(idx, axes[0], stream()); + out_ax = axes[0]; + } else if (vmap_idx) { + // expand just the input dimension + in = expand_dims(in, axes[1], stream()); + out_ax = axes[1]; + } else { + out_ax = -1; + } + int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_; + return {{take_along_axis(in, idx, axis, stream())}, {out_ax}}; } std::vector GatherAxis::vjp( @@ -2117,8 +2139,11 @@ std::vector GatherAxis::vjp( zeros(primals[argnum].shape(), primals[argnum].dtype(), stream())); } else { auto src = zeros_like(primals[0], stream()); - vjps.push_back( - put_along_axis(src, primals[1], cotangents[0], axis_, stream())); + vjps.push_back(array( + src.shape(), + src.dtype(), + std::make_shared(stream(), ScatterAxis::Sum, axis_), + {src, primals[1], cotangents[0]})); } } return vjps; @@ -2152,6 +2177,7 @@ std::vector Gather::output_shapes(const std::vector& inputs) { out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); return {std::move(out_shape)}; } + std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { @@ -3666,6 +3692,107 @@ std::pair, std::vector> Scatter::vmap( return {{out}, {src_ax}}; } +std::vector ScatterAxis::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + const auto& indices = primals[1]; + const auto& updates = primals[2]; + + std::vector vjps; + for (auto num : argnums) { + // Gradient wrt to the input array + if (num == 0) { + if (reduce_type_ == ScatterAxis::None) { + // Scatter 0s to the locations that were updated with the updates + vjps.push_back(put_along_axis( + cotangents[0], + indices, + zeros_like(updates, stream()), + axis_, + stream())); + } else { + // The input array values are kept so they all get gradients + vjps.push_back(cotangents[0]); + } + } else if (num == 2) { + vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream())); + } else { + throw std::invalid_argument( + "[scatter_axis] Cannot calculate VJP with respect to indices."); + } + } + return vjps; +} + +std::vector ScatterAxis::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + for (auto arg : argnums) { + if (arg == 1) { + throw std::invalid_argument( + "[scatter_axis] Cannot calculate JVP with respect to indices."); + } + } + if (argnums.size() == 2) { + return {array( + primals[0].shape(), + primals[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + {tangents[0], primals[1], tangents[1]})}; + } else { + auto tan_a = + argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream()); + auto tan_b = + argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream()); + return {array( + primals[0].shape(), + primals[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + {tan_a, primals[1], tan_b})}; + } +} + +std::pair, std::vector> ScatterAxis::vmap( + const std::vector& inputs, + const std::vector& axes) { + // Find the first vmap axis + int out_ax = -1; + for (auto ax : axes) { + if (ax >= 0) { + out_ax = ax; + break; + } + } + + if (out_ax < 0) { + return { + {array( + inputs[0].shape(), + inputs[0].dtype(), + std::make_shared(stream(), reduce_type_, axis_), + inputs)}, + {-1}}; + } + + auto v_in = inputs; + for (int i = 0; i < axes.size(); ++i) { + if (axes[i] >= 0) { + // if out_ax >= 0 move axis o/w set out_ax + if (out_ax != axes[i]) { + v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream()); + } + } else { + v_in[i] = expand_dims(v_in[i], out_ax, stream()); + } + } + int axis = axis_ >= out_ax ? axis_ + 1 : axis_; + auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis; + return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}}; +} + std::vector ScatterAxis::output_shapes( const std::vector& inputs) { return {inputs[0].shape()}; diff --git a/mlx/primitives.h b/mlx/primitives.h index db5219e0a9..782ed7e275 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1819,8 +1819,8 @@ class ScatterAxis : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - // DEFINE_VMAP() - // DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_GRADS() void print(std::ostream& os) override { os << "ScatterAxis"; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 4ab7fb922a..b3281bd3e4 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -669,6 +669,37 @@ def test_matmul_jvps(self): _, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,)) self.assertTrue(mx.allclose(tangent, expected)) + def test_put_along_axis_grads(self): + a = mx.zeros((5, 1)) + b = mx.ones((2, 1)) + + def fun(a, b): + idx = mx.array([[0], [3]]) + return mx.put_along_axis(a, idx, b, axis=0) + + # Test VJP + cotan = mx.full((5, 1), 2.0) + _, (da, db) = mx.vjp(fun, (a, b), (cotan,)) + expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None] + expected_db = mx.array([2.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected_da, da)) + self.assertTrue(mx.allclose(expected_db, db)) + + # Test JVP + tan_a = mx.full((5, 1), 2.0) + tan_b = mx.full((2, 1), 3.0) + _, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b)) + expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected, jout)) + + def fun(a): + idx = mx.array([[0], [3]]) + return mx.put_along_axis(a, idx, b, axis=0) + + _, (jout,) = mx.jvp(fun, (a,), (tan_a,)) + expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None] + self.assertTrue(mx.allclose(expected, jout)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 0789593c5d..b98bdb0fcb 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -549,6 +549,53 @@ def cat_constant(x): target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2) self.assertTrue(mx.array_equal(out, target)) + def test_vmap_take_along_axis(self): + a = mx.zeros((4, 5, 1)) + idx = mx.zeros((2, 4, 1), mx.int32) + + def fun(a, idx): + return mx.take_along_axis(a, idx, axis=0) + + out = mx.vmap(fun, in_axes=(0, 1))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + idx = mx.zeros((2, 1), mx.int32) + + out = mx.vmap(fun, in_axes=(0, None))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + a = mx.zeros((5, 1)) + idx = mx.zeros((4, 2, 1), mx.int32) + + out = mx.vmap(fun, in_axes=(None, 0))(a, idx) + self.assertEqual(out.shape, (4, 2, 1)) + + def test_vmap_put_along_axis(self): + a = mx.zeros((4, 5, 1)) + idx = mx.ones((2, 4, 1), mx.int32) + upd = mx.ones((2, 4, 1)) + + def fun(a, idx, upd): + return mx.put_along_axis(a, idx, upd, axis=0) + + out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + upd = mx.ones((2, 1)) + out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + idx = mx.ones((2, 1), mx.int32) + upd = mx.ones((2, 1)) + out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + + a = mx.zeros((5, 1)) + idx = mx.ones((2, 4, 1), mx.int32) + upd = mx.ones((2, 4, 1)) + out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd) + self.assertEqual(out.shape, (4, 5, 1)) + if __name__ == "__main__": unittest.main()