Skip to content

Commit

Permalink
add transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 31, 2025
1 parent 6fb1fff commit aadf66e
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 35 deletions.
79 changes: 49 additions & 30 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3080,73 +3080,92 @@ array take_along_axis(
// Allow negative axis
axis = axis < 0 ? a.ndim() + axis : axis;

// Broadcast indices to input shape ignoring the take axis
// Broadcast indices and input ignoring the take axis
auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s);
if (inputs[0].shape() != a.shape()) {
std::ostringstream msg;
msg << "[take_along_axis] Indices of shape " << indices.shape()
<< " do not broadcast to array of shape " << a.shape() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}

auto out_shape = inputs[1].shape();
return array(
inputs[1].shape(),
std::move(out_shape),
a.dtype(),
std::make_shared<GatherAxis>(to_stream(s), axis),
std::move(inputs));
}

array put_along_axis(
array scatter_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s /* = {} */) {
ScatterAxis::ReduceType mode,
StreamOrDevice s) {
std::string prefix =
(mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]";
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[put_along_axis] Received invalid axis " << " for array with "
<< a.ndim() << " dimensions.";
msg << prefix << " Received invalid axis " << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}

if (indices.ndim() != a.ndim()) {
std::ostringstream msg;
msg << "[put_along_axis] Indices of dimension " << indices.ndim()
msg << prefix << " Indices of dimension " << indices.ndim()
<< " does not match array of dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}

// Allow negative axis
axis = axis < 0 ? a.ndim() + axis : axis;
auto upd = astype(values, a.dtype(), s);

// Squeeze leading singletons out of update
if (upd.ndim() > indices.ndim()) {
std::vector<int> sq_ax(upd.ndim() - indices.ndim());
std::iota(sq_ax.begin(), sq_ax.end(), 0);
upd = squeeze(upd, sq_ax, s);
}

auto inputs = broadcast_arrays({indices, values}, s);
auto inputs = broadcast_arrays({indices, upd}, s);
inputs.insert(inputs.begin(), a);

// Broadcast indices, values to src shape ignoring the take axis
// Allow negative axis
axis = axis < 0 ? a.ndim() + axis : axis;

// Broadcast src, indices, values while ignoring the take axis
inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s);
if (inputs[0].shape() != a.shape()) {
std::ostringstream msg;
msg << "[take_along_axis] Indices of shape " << indices.shape()
<< " do not broadcast to array of shape " << a.shape() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
inputs[2] = astype(inputs[2], a.dtype(), s);

auto out_shape = inputs[0].shape();
return array(
inputs[0].shape(),
std::move(out_shape),
a.dtype(),
std::make_shared<ScatterAxis>(to_stream(s), ScatterAxis::None, axis),
std::make_shared<ScatterAxis>(to_stream(s), mode, axis),
std::move(inputs));
}

array put_along_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s /* = {} */) {
return scatter_axis(a, indices, values, axis, ScatterAxis::None, s);
}

array scatter_add_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s /* = {} */) {
return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s);
}

/** Scatter updates to given indices */
array scatter(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType mode /*= Scatter::ReduceType::None*/,
StreamOrDevice s /*= {}*/) {
Scatter::ReduceType mode,
StreamOrDevice s) {
// Checks that indices, dimensions, and slice_sizes are all valid
if (indices.size() > a.ndim()) {
std::ostringstream msg;
Expand Down
8 changes: 8 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,14 @@ array put_along_axis(
int axis,
StreamOrDevice s = {});

/** Add the values into the array at the given indices along the axis */
array scatter_add_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s = {});

/** Scatter updates to the given indices.
*
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
Expand Down
133 changes: 130 additions & 3 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,29 @@ bool Gather::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> GatherAxis::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{inputs[0]}, axes};
bool vmap_in = axes[0] >= 0;
bool vmap_idx = axes[1] >= 0;

auto in = inputs[0];
auto idx = inputs[1];
int out_ax;
if (vmap_in && vmap_idx) {
// reorder the vmap axes to the same location
idx = moveaxis(idx, axes[1], axes[0], stream());
out_ax = axes[0];
} else if (vmap_in) {
// expand just the indices dimension
idx = expand_dims(idx, axes[0], stream());
out_ax = axes[0];
} else if (vmap_idx) {
// expand just the input dimension
in = expand_dims(in, axes[1], stream());
out_ax = axes[1];
} else {
out_ax = -1;
}
int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_;
return {{take_along_axis(in, idx, axis, stream())}, {out_ax}};
}

std::vector<array> GatherAxis::vjp(
Expand All @@ -2117,8 +2139,11 @@ std::vector<array> GatherAxis::vjp(
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
} else {
auto src = zeros_like(primals[0], stream());
vjps.push_back(
put_along_axis(src, primals[1], cotangents[0], axis_, stream()));
vjps.push_back(array(
src.shape(),
src.dtype(),
std::make_shared<ScatterAxis>(stream(), ScatterAxis::Sum, axis_),
{src, primals[1], cotangents[0]}));
}
}
return vjps;
Expand Down Expand Up @@ -2152,6 +2177,7 @@ std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
return {std::move(out_shape)};
}

std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
Expand Down Expand Up @@ -3666,6 +3692,107 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
return {{out}, {src_ax}};
}

std::vector<array> ScatterAxis::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const auto& indices = primals[1];
const auto& updates = primals[2];

std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
if (num == 0) {
if (reduce_type_ == ScatterAxis::None) {
// Scatter 0s to the locations that were updated with the updates
vjps.push_back(put_along_axis(
cotangents[0],
indices,
zeros_like(updates, stream()),
axis_,
stream()));
} else {
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
}
} else if (num == 2) {
vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream()));
} else {
throw std::invalid_argument(
"[scatter_axis] Cannot calculate VJP with respect to indices.");
}
}
return vjps;
}

std::vector<array> ScatterAxis::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
for (auto arg : argnums) {
if (arg == 1) {
throw std::invalid_argument(
"[scatter_axis] Cannot calculate JVP with respect to indices.");
}
}
if (argnums.size() == 2) {
return {array(
primals[0].shape(),
primals[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
{tangents[0], primals[1], tangents[1]})};
} else {
auto tan_a =
argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream());
auto tan_b =
argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream());
return {array(
primals[0].shape(),
primals[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
{tan_a, primals[1], tan_b})};
}
}

std::pair<std::vector<array>, std::vector<int>> ScatterAxis::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
// Find the first vmap axis
int out_ax = -1;
for (auto ax : axes) {
if (ax >= 0) {
out_ax = ax;
break;
}
}

if (out_ax < 0) {
return {
{array(
inputs[0].shape(),
inputs[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
inputs)},
{-1}};
}

auto v_in = inputs;
for (int i = 0; i < axes.size(); ++i) {
if (axes[i] >= 0) {
// if out_ax >= 0 move axis o/w set out_ax
if (out_ax != axes[i]) {
v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream());
}
} else {
v_in[i] = expand_dims(v_in[i], out_ax, stream());
}
}
int axis = axis_ >= out_ax ? axis_ + 1 : axis_;
auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis;
return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}};
}

std::vector<Shape> ScatterAxis::output_shapes(
const std::vector<array>& inputs) {
return {inputs[0].shape()};
Expand Down
4 changes: 2 additions & 2 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1819,8 +1819,8 @@ class ScatterAxis : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

// DEFINE_VMAP()
// DEFINE_GRADS()
DEFINE_VMAP()
DEFINE_GRADS()

void print(std::ostream& os) override {
os << "ScatterAxis";
Expand Down
31 changes: 31 additions & 0 deletions python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,37 @@ def test_matmul_jvps(self):
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
self.assertTrue(mx.allclose(tangent, expected))

def test_put_along_axis_grads(self):
a = mx.zeros((5, 1))
b = mx.ones((2, 1))

def fun(a, b):
idx = mx.array([[0], [3]])
return mx.put_along_axis(a, idx, b, axis=0)

# Test VJP
cotan = mx.full((5, 1), 2.0)
_, (da, db) = mx.vjp(fun, (a, b), (cotan,))
expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
expected_db = mx.array([2.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected_da, da))
self.assertTrue(mx.allclose(expected_db, db))

# Test JVP
tan_a = mx.full((5, 1), 2.0)
tan_b = mx.full((2, 1), 3.0)
_, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected, jout))

def fun(a):
idx = mx.array([[0], [3]])
return mx.put_along_axis(a, idx, b, axis=0)

_, (jout,) = mx.jvp(fun, (a,), (tan_a,))
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected, jout))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit aadf66e

Please sign in to comment.