diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index 29828447e..cd9ac5022 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -461,7 +461,8 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { auto& updates = inputs.back(); // Copy src into out (copy allocates memory for out) - copy(src, out, CopyType::General); + auto ctype = src.flags().row_contigous ? CopyType::Vector : CopyType::General; + copy(src, out, ctype); switch (src.dtype()) { case bool_: @@ -621,7 +622,8 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { auto& updates = inputs[2]; // Copy src into out (copy allocates memory for out) - copy(src, out, CopyType::General); + auto ctype = src.flags().row_contigous ? CopyType::Vector : CopyType::General; + copy(src, out, ctype); switch (src.dtype()) { case bool_: