From 3788aa73eba0239246b6244639cd5b2a05516524 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov <rgerganov@gmail.com>
Date: Thu, 19 Jun 2025 11:04:23 +0300
Subject: [PATCH 01/18] ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366
---
 ggml/include/ggml.h          |  7 +++++
 ggml/src/ggml-cpu/ggml-cpu.c |  5 +++
 ggml/src/ggml-cpu/ops.cpp    | 59 ++++++++++++++++++++++++++++++++++++
 ggml/src/ggml-cpu/ops.h      |  1 +
 ggml/src/ggml.c              | 28 +++++++++++++++--
 5 files changed, 98 insertions(+), 2 deletions(-)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 9c4e24023b5ad..5376696d4cb01 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -470,6 +470,7 @@ extern "C" {
         GGML_OP_TRANSPOSE,
         GGML_OP_GET_ROWS,
         GGML_OP_GET_ROWS_BACK,
+        GGML_OP_SET_ROWS,
         GGML_OP_DIAG,
         GGML_OP_DIAG_MASK_INF,
         GGML_OP_DIAG_MASK_ZERO,
@@ -1375,6 +1376,12 @@ extern "C" {
             struct ggml_tensor  * b,  // row indices
             struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
 
+    GGML_API struct ggml_tensor * ggml_set_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // destination
+            struct ggml_tensor  * b,  // source
+            struct ggml_tensor  * c); // row indices
+
     GGML_API struct ggml_tensor * ggml_diag(
         struct ggml_context     * ctx,
         struct ggml_tensor      * a);
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 1d3cd009affc6..1afafe10d50f1 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -1814,6 +1814,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_get_rows_back(params, tensor);
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                ggml_compute_forward_set_rows(params, tensor);
+            } break;
         case GGML_OP_DIAG:
             {
                 ggml_compute_forward_diag(params, tensor);
@@ -2167,6 +2171,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_GET_ROWS:
+        case GGML_OP_SET_ROWS:
             {
                 // FIXME: get_rows can use additional threads, but the cost of launching additional threads
                 // decreases performance with GPU offloading
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index eff4a53e3442b..2bbaa5a5dd416 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
     //}
 }
 
+static void ggml_compute_forward_set_rows_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(src0) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne1);
+
+        ggml_cpu_fp32_to_fp16(
+            (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
+            (ggml_fp16_t *) ((char *)  dst->data + i01*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+void ggml_compute_forward_set_rows(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_get_rows_back
 
 static void ggml_compute_forward_get_rows_back_f32_f16(
diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
index 2d8544d7d3d43..3a395fdcd9f04 100644
--- a/ggml/src/ggml-cpu/ops.h
+++ b/ggml/src/ggml-cpu/ops.h
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
 void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index f8e7c595bce15..8e99c55237afc 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "TRANSPOSE",
     "GET_ROWS",
     "GET_ROWS_BACK",
+    "SET_ROWS",
     "DIAG",
     "DIAG_MASK_INF",
     "DIAG_MASK_ZERO",
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "transpose(x)",
     "get_rows(x)",
     "get_rows_back(x)",
+    "set_rows(x)",
     "diag(x)",
     "diag_mask_inf(x)",
     "diag_mask_zero(x)",
@@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -3395,6 +3397,28 @@ struct ggml_tensor * ggml_get_rows_back(
     return result;
 }
 
+// ggml_set_rows
+
+struct ggml_tensor * ggml_set_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(b->ne[2] == c->ne[1]);
+    GGML_ASSERT(c->ne[3] == 1);
+    GGML_ASSERT(a->type == GGML_TYPE_F16);
+    GGML_ASSERT(b->type == GGML_TYPE_F32);
+    GGML_ASSERT(c->type == GGML_TYPE_I32);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op     = GGML_OP_SET_ROWS;
+    result->src[0] = b;
+    result->src[1] = c;
+
+    return result;
+}
+
 // ggml_diag
 
 struct ggml_tensor * ggml_diag(

From b2bd0a7e445292228e29883d7c774970104c89b4 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov <rgerganov@gmail.com>
Date: Fri, 20 Jun 2025 11:37:43 +0300
Subject: [PATCH 02/18] use I64 for indices

---
 ggml/src/ggml-cpu/ops.cpp | 2 +-
 ggml/src/ggml.c           | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 2bbaa5a5dd416..c618fd5b9d712 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4501,7 +4501,7 @@ static void ggml_compute_forward_set_rows_f32(
         const int64_t i12 = i/(ne11*ne10);
         const int64_t i11 = (i - i12*ne11*ne10)/ne10;
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+        const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
         GGML_ASSERT(i01 >= 0 && i01 < ne1);
 
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 8e99c55237afc..13405a7d3b535 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3408,7 +3408,7 @@ struct ggml_tensor * ggml_set_rows(
     GGML_ASSERT(c->ne[3] == 1);
     GGML_ASSERT(a->type == GGML_TYPE_F16);
     GGML_ASSERT(b->type == GGML_TYPE_F32);
-    GGML_ASSERT(c->type == GGML_TYPE_I32);
+    GGML_ASSERT(c->type == GGML_TYPE_I64);
 
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 

From 788fb0ff31cc61bbabe0cff7868437a862f90dfe Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sat, 21 Jun 2025 09:07:25 +0300
Subject: [PATCH 03/18] ggml : add repeat impl for i64

---
 examples/eval-callback/eval-callback.cpp |  2 +
 ggml/src/ggml-cpu/ops.cpp                | 50 ++++++++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp
index fb188f5a9e132..bbbec6a01a175 100644
--- a/examples/eval-callback/eval-callback.cpp
+++ b/examples/eval-callback/eval-callback.cpp
@@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
                         v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
                     } else if (type == GGML_TYPE_F32) {
                         v = *(float *) &data[i];
+                    } else if (type == GGML_TYPE_I64) {
+                        v = (float) *(int64_t *) &data[i];
                     } else if (type == GGML_TYPE_I32) {
                         v = (float) *(int32_t *) &data[i];
                     } else if (type == GGML_TYPE_I16) {
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index c618fd5b9d712..3dd57ccefcf6f 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
     }
 }
 
+static void ggml_compute_forward_repeat_i64(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(int64_t));
+    GGML_ASSERT(nb00 == sizeof(int64_t));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                int64_t * y = (int64_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
+                                int64_t * x = (int64_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
+                                for (int i = 0; i < ne00; ++i) {
+                                    y[i]  = x[i];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
 void ggml_compute_forward_repeat(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
@@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, dst);
             } break;
+        case GGML_TYPE_I64:
+            {
+                ggml_compute_forward_repeat_i64(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");

From 2f3a43d065ebfeea91a51db6b43903a4f00b8104 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 10:27:31 +0300
Subject: [PATCH 04/18] ggml : add ggml_is_contiguous_rows

---
 ggml/include/ggml.h | 3 +++
 ggml/src/ggml.c     | 6 ++++++
 2 files changed, 9 insertions(+)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 5376696d4cb01..2c6de09ab9f93 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -688,6 +688,9 @@ extern "C" {
     // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
     GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
 
+    // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
+    GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
+
     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 13405a7d3b535..4779565110e1f 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1353,6 +1353,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
         tensor->nb[2] == ggml_type_size(tensor->type);
 }
 
+bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
+    return
+        tensor->ne[0] == ggml_blck_size(tensor->type) ||
+        tensor->nb[0] == ggml_type_size(tensor->type);
+}
+
 static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 

From 93a568cdd2ca9414a42a09f640b149bc5093ab66 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 10:28:07 +0300
Subject: [PATCH 05/18] ggml : ggml_set_rows support broadcast

---
 ggml/include/ggml.h       |  9 +++++++++
 ggml/src/ggml-cpu/ops.cpp | 33 ++++++++++++++++++++-------------
 ggml/src/ggml.c           | 12 ++++++++++--
 3 files changed, 39 insertions(+), 15 deletions(-)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 2c6de09ab9f93..b9a8d37715624 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -1379,6 +1379,15 @@ extern "C" {
             struct ggml_tensor  * b,  // row indices
             struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
 
+    // a TD  [n_embd, ne1,    ne2,    ne3]
+    // b TS  [n_embd, n_rows, ne02,   ne03] | ne02 == ne2, ne03 == ne3
+    // c I64 [n_rows, ne11,   ne12,   1]    | c[i] in [0, ne1)
+    //
+    // broadcast:
+    //   ne2 % ne11 == 0
+    //   ne3 % ne12 == 0
+    //
+    // return view(a)
     GGML_API struct ggml_tensor * ggml_set_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,  // destination
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 3dd57ccefcf6f..1c8bdfbcef91a 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
+    const int64_t nr = ne01;
 
     assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(float));
-    assert(ggml_nrows(src0) == nr);
+    assert(ne2  == ne02);
+    assert(ne3  == ne03);
+    assert(src0->type == GGML_TYPE_F32);
+    assert(ne02 % ne11 == 0);
+    assert(ne03 % ne12 == 0);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+    for (int64_t i03 = 0; i03 < ne03; ++i03) {
+        for (int64_t i02 = 0; i02 < ne02; ++i02) {
+            for (int64_t i = ir0; i < ir1; ++i) {
+                const int64_t i12 = i03%ne12;
+                const int64_t i11 = i02%ne11;
+                const int64_t i10 = i;
+
+                const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        GGML_ASSERT(i01 >= 0 && i01 < ne1);
+                GGML_ASSERT(i01 >= 0 && i01 < ne1);
 
-        ggml_cpu_fp32_to_fp16(
-            (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
-            (ggml_fp16_t *) ((char *)  dst->data + i01*nb1  + i11*nb2  + i12*nb3), nc);
+                ggml_cpu_fp32_to_fp16(
+                        (const float *) ((char *) src0->data +   i*nb01 + i02*nb02 + i03*nb03),
+                        (ggml_fp16_t *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3), nc);
+            }
+        }
     }
 }
 
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 4779565110e1f..12d9ad70ac972 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         struct ggml_tensor  * c) {
-    GGML_ASSERT(b->ne[2] == c->ne[1]);
+    GGML_ASSERT(a->ne[0] == b->ne[0]);
+    GGML_ASSERT(a->ne[2] == b->ne[2]);
+    GGML_ASSERT(a->ne[3] == b->ne[3]);
+    GGML_ASSERT(b->ne[1] == c->ne[0]);
+    GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
+    GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
     GGML_ASSERT(c->ne[3] == 1);
-    GGML_ASSERT(a->type == GGML_TYPE_F16);
+    GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
     GGML_ASSERT(b->type == GGML_TYPE_F32);
     GGML_ASSERT(c->type == GGML_TYPE_I64);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(a));
+    GGML_ASSERT(ggml_is_contiguous_rows(b));
+
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
     result->op     = GGML_OP_SET_ROWS;

From 6b9b86d34b1fb0cb1953bb1c5b5c6bdc118ee759 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 11:10:42 +0300
Subject: [PATCH 06/18] ggml : ggml_set_rows support quantized dst

ggml-ci
---
 ggml/src/ggml-cpu/ops.cpp | 6 ++++--
 ggml/src/ggml.c           | 1 -
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 1c8bdfbcef91a..e9832f8dee467 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4549,6 +4549,8 @@ static void ggml_compute_forward_set_rows_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
+
     for (int64_t i03 = 0; i03 < ne03; ++i03) {
         for (int64_t i02 = 0; i02 < ne02; ++i02) {
             for (int64_t i = ir0; i < ir1; ++i) {
@@ -4560,9 +4562,9 @@ static void ggml_compute_forward_set_rows_f32(
 
                 GGML_ASSERT(i01 >= 0 && i01 < ne1);
 
-                ggml_cpu_fp32_to_fp16(
+                from_float(
                         (const float *) ((char *) src0->data +   i*nb01 + i02*nb02 + i03*nb03),
-                        (ggml_fp16_t *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3), nc);
+                                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3), nc);
             }
         }
     }
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 12d9ad70ac972..2b69f70cff1cf 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3417,7 +3417,6 @@ struct ggml_tensor * ggml_set_rows(
     GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
     GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
     GGML_ASSERT(c->ne[3] == 1);
-    GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
     GGML_ASSERT(b->type == GGML_TYPE_F32);
     GGML_ASSERT(c->type == GGML_TYPE_I64);
 

From b13524cb912820895690520256d58e2c746ec85f Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 18:44:42 +0300
Subject: [PATCH 07/18] ggml : support GGML_TYPE_F32 ".from_float" trait

---
 ggml/include/ggml-cpu.h      | 1 +
 ggml/src/ggml-cpu/ggml-cpu.c | 5 +++++
 2 files changed, 6 insertions(+)

diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h
index de77a875ec533..1a78935aa05cf 100644
--- a/ggml/include/ggml-cpu.h
+++ b/ggml/include/ggml-cpu.h
@@ -133,6 +133,7 @@ extern "C" {
 
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
+    GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 1afafe10d50f1..768065b99f2ef 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -192,6 +192,7 @@ typedef pthread_t ggml_thread_t;
 
 static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32] = {
+        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
         .vec_dot_type             = GGML_TYPE_F32,
         .nrows                    = 1,
@@ -3126,6 +3127,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
     return ggml_graph_compute(cgraph, &cplan);
 }
 
+void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
+    memcpy(y, x, n * sizeof(float));
+}
+
 void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
     int64_t i = 0;
 #if defined(__F16C__)

From b597019fb69375905a84df23d456e0f00f279b25 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 18:45:07 +0300
Subject: [PATCH 08/18] ggml : ggml_set_rows update comment + better index name

---
 ggml/include/ggml.h       | 2 ++
 ggml/src/ggml-cpu/ops.cpp | 8 ++++----
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index b9a8d37715624..2b1bd6e0f48b9 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -1383,6 +1383,8 @@ extern "C" {
     // b TS  [n_embd, n_rows, ne02,   ne03] | ne02 == ne2, ne03 == ne3
     // c I64 [n_rows, ne11,   ne12,   1]    | c[i] in [0, ne1)
     //
+    // undefined behavior if destination rows overlap
+    //
     // broadcast:
     //   ne2 % ne11 == 0
     //   ne3 % ne12 == 0
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index e9832f8dee467..8a51e4cf3afd7 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4558,13 +4558,13 @@ static void ggml_compute_forward_set_rows_f32(
                 const int64_t i11 = i02%ne11;
                 const int64_t i10 = i;
 
-                const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+                const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-                GGML_ASSERT(i01 >= 0 && i01 < ne1);
+                GGML_ASSERT(i1 >= 0 && i1 < ne1);
 
                 from_float(
-                        (const float *) ((char *) src0->data +   i*nb01 + i02*nb02 + i03*nb03),
-                                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3), nc);
+                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
+                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
             }
         }
     }

From f7d0aab7a9e6c4cc33c34e3feb5e829bd39b5fca Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 18:45:30 +0300
Subject: [PATCH 09/18] tests : add ggml_set_rows

---
 tests/test-backend-ops.cpp | 81 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 81 insertions(+)

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 7be7f2205fa04..f5cd5b695bb16 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1213,6 +1213,78 @@ struct test_get_rows_back : public test_case {
     }
 };
 
+// GGML_OP_SET_ROWS
+struct test_set_rows : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to set
+    const int b0; // batch size
+    const int b1; // batch size
+    const int bs; // batch size src (for testing broadcast)
+    const bool v; // view (non-contiguous src1)
+
+    std::string vars() override {
+        return VARS_TO_STR7(type, n, m, r, b0, bs, v);
+    }
+
+    test_set_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, int bs = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b0(b), b1(3), bs(bs), v(v) {
+            GGML_ASSERT(b0 % bs == 0 && "b0 must be a multiple of bs");
+            GGML_ASSERT(r <= m && "r must be less than or equal to m");
+        }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, n, m, b0, b1);
+        ggml_set_name(dst, "dst");
+
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n, r, b0, b1);
+        ggml_set_name(src, "src");
+
+        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, bs, b1);
+        ggml_set_name(row_idxs, "row_idxs");
+
+        if (v) {
+            src      = ggml_view_4d(ctx, src, n, r/2, b0, b1, src->nb[1], src->nb[2], src->nb[3], 0);
+            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, bs, b1, row_idxs->nb[1], row_idxs->nb[2], 0);
+            ggml_set_name(row_idxs, "view_of_rows");
+        }
+
+        ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I64) {
+                if (ggml_is_view_op(t->op)) {
+                    continue;
+                }
+
+                for (int i2 = 0; i2 < t->ne[2]; i2++) {
+                    for (int i1 = 0; i1 < t->ne[1]; i1++) {
+                        std::vector<int64_t> data(m);
+                        for (int i = 0; i < m; i++) {
+                            data[i] = i;
+                        }
+                        std::shuffle(data.begin(), data.end(), rng);
+                        data.resize(t->ne[0]);
+
+                        const size_t offs = i1*t->nb[1] + i2*t->nb[2];
+                        ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+                    }
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_ARGMAX
 struct test_argmax : public test_case {
     const ggml_type type;
@@ -3984,6 +4056,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
     }
 
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
+    for (ggml_type type : all_types) {
+        for (int b : {1, 7}) {
+            for (bool v : {false, true}) {
+                test_cases.emplace_back(new test_set_rows(type, 256, 5, 4, b, 1, v));
+            }
+        }
+    }
+
     for (ggml_type type_input : {GGML_TYPE_F32}) {
         for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
             for (int k0 : {1, 3}) {

From a881dc273d05adfcb91d1d82b507d0033477c6bc Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sun, 22 Jun 2025 18:45:52 +0300
Subject: [PATCH 10/18] metal : add ggml_set_rows implementation

ggml-ci
---
 ggml/src/ggml-metal/ggml-metal-impl.h |  16 +
 ggml/src/ggml-metal/ggml-metal.m      | 110 +++++-
 ggml/src/ggml-metal/ggml-metal.metal  | 469 ++++++++++++++++----------
 3 files changed, 411 insertions(+), 184 deletions(-)

diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 17eab976f3ad1..260440aedde00 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -521,6 +521,22 @@ typedef struct {
     uint64_t nb2;
 } ggml_metal_kargs_get_rows;
 
+typedef struct {
+    int32_t  nk0;
+    int32_t  ne01;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_set_rows;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 19f4d59e59747..fd138ef0e7c0c 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1175,15 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,                 get_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,                 get_rows_iq4_xs,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                    get_rows_i32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,                    set_rows_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,                    set_rows_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,                   set_rows_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,                   set_rows_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,                   set_rows_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,                   set_rows_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,                   set_rows_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
@@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 
     if (!use_bfloat) {
         for (size_t i = 0, n = 3; i < n; ++i) {
-            if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
+            if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
                 return false;
             }
         }
@@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             {
                 return op->ne[3] == 1;
             }
+        case GGML_OP_SET_ROWS:
+            {
+                if (op->src[0]->type != GGML_TYPE_F32) {
+                    return false;
+                }
+
+                switch (op->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_BF16:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_IQ4_NL:
+                        return true;
+                    default:
+                        return false;
+                };
+            }
         default:
             return false;
     }
@@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node(
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&args length:sizeof(args) atIndex:3];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (dst->type) {
+                    case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32   ].pipeline; break;
+                    case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16   ].pipeline; break;
+                    case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16  ].pipeline; break;
+                    case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0  ].pipeline; break;
+                    case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0  ].pipeline; break;
+                    case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1  ].pipeline; break;
+                    case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0  ].pipeline; break;
+                    case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1  ].pipeline; break;
+                    case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
+                    default: GGML_ABORT("not implemented");
+                }
+
+                const int32_t nk0 = ne0/ggml_blck_size(dst->type);
+
+                int nth = 32; // SIMD width
+
+                while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                int nrptg = 1;
+                if (nth > nk0) {
+                    nrptg = (nth + nk0 - 1)/nk0;
+                    nth   = nk0;
+
+                    if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                        nrptg--;
+                    }
+                }
+
+                nth = MIN(nth, nk0);
+
+                ggml_metal_kargs_set_rows args = {
+                    /*.nk0  =*/ nk0,
+                    /*.ne01 =*/ ne01,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
+                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
+            } break;
         case GGML_OP_RMS_NORM:
             {
                 GGML_ASSERT(ne00 % 4 == 0);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 3da19879b4b36..d2dec8d296708 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
     -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 };
 
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template <typename type4x4>
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
     }
 }
 
+void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK4_0; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = src[0       + j]*id;
+        const float x1 = src[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+        dst.qs[j]  = xi0;
+        dst.qs[j] |= xi1 << 4;
+    }
+}
+
+void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
+    float min = FLT_MAX;
+    float max = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; j++) {
+        const float v = src[j];
+        if (min > v) min = v;
+        if (max < v) max = v;
+    }
+
+    const float d = (max - min) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+    dst.m = min;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (src[0       + j] - min)*id;
+        const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+        dst.qs[j]  = xi0;
+        dst.qs[j] |= xi1 << 4;
+    }
+}
+
+void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK5_0; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = src[0       + j]*id;
+        const float x1 = src[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+
+    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+    for (int j = 0; j < 4; ++j) {
+        dst.qh[j] = qh8[j];
+    }
+}
+
+void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
+    float max = src[0];
+    float min = src[0];
+
+    for (int j = 1; j < QK5_1; j++) {
+        const float v = src[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+    dst.m = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (src[0       + j] - min)*id;
+        const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+
+    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+    for (int j = 0; j < 4; ++j) {
+        dst.qh[j] = qh8[j];
+    }
+}
+
+void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK4_NL; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / kvalues_iq4nl_f[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = src[0        + j]*id;
+        const float x1 = src[QK4_NL/2 + j]*id;
+
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+        dst.qs[j] = xi0 | (xi1 << 4);
+
+        const float v0 = kvalues_iq4nl_f[xi0];
+        const float v1 = kvalues_iq4nl_f[xi1];
+        const float w0 = src[0        + j]*src[0        + j];
+        const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+        sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+
+    }
+
+    dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
 template <typename type4x4>
 void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
     }
 }
 
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = src[j];
+        amax = MAX(amax, fabs(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = src[j]*id;
+
+        dst.qs[j] = round(x0);
+    }
+}
+
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
     const float d = xb->d;
@@ -4341,6 +4539,7 @@ template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bf
 template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
 #endif
 
+// TODO: templetify these kernels
 kernel void kernel_cpy_f32_q8_0(
         constant ggml_metal_kargs_cpy & args,
         device const char * src0,
@@ -4364,23 +4563,7 @@ kernel void kernel_cpy_f32_q8_0(
     for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-
-        for (int j = 0; j < QK8_0; j++) {
-            const float v = src[j];
-            amax = MAX(amax, fabs(v));
-        }
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK8_0].d = d;
-
-        for (int j = 0; j < QK8_0; ++j) {
-            const float x0 = src[j]*id;
-
-            dst_data[i00/QK8_0].qs[j] = round(x0);
-        }
+        quantize_q8_0(src, dst_data[i00/QK8_0]);
     }
 }
 
@@ -4407,32 +4590,7 @@ kernel void kernel_cpy_f32_q4_0(
     for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -8;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_0].d = d;
-
-        for (int j = 0; j < QK4_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK4_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
-            dst_data[i00/QK4_0].qs[j]  = xi0;
-            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
-        }
+        quantize_q4_0(src, dst_data[i00/QK4_0]);
     }
 }
 
@@ -4459,31 +4617,7 @@ kernel void kernel_cpy_f32_q4_1(
     for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int j = 0; j < QK4_1; j++) {
-            const float v = src[j];
-            if (min > v) min = v;
-            if (max < v) max = v;
-        }
-
-        const float d = (max - min) / ((1 << 4) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_1].d = d;
-        dst_data[i00/QK4_1].m = min;
-
-        for (int j = 0; j < QK4_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK4_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
-            dst_data[i00/QK4_1].qs[j]  = xi0;
-            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
-        }
+        quantize_q4_1(src, dst_data[i00/QK4_1]);
     }
 }
 
@@ -4510,38 +4644,7 @@ kernel void kernel_cpy_f32_q5_0(
     for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK5_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -16;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_0].d = d;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK5_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
-            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
-            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_0].qh[j] = qh8[j];
-        }
+        quantize_q5_0(src, dst_data[i00/QK5_0]);
     }
 }
 
@@ -4568,49 +4671,8 @@ kernel void kernel_cpy_f32_q5_1(
     for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float max = src[0];
-        float min = src[0];
-
-        for (int j = 1; j < QK5_1; j++) {
-            const float v = src[j];
-            min = v < min ? v : min;
-            max = v > max ? v : max;
-        }
-
-        const float d = (max - min) / 31;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_1].d = d;
-        dst_data[i00/QK5_1].m = min;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK5_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
-            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
-            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_1].qh[j] = qh8[j];
-        }
-    }
-}
-
-static inline int best_index_int8(int n, constant float * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
+        quantize_q5_1(src, dst_data[i00/QK5_1]);
     }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
 kernel void kernel_cpy_f32_iq4_nl(
@@ -4636,40 +4698,7 @@ kernel void kernel_cpy_f32_iq4_nl(
     for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_NL; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / kvalues_iq4nl_f[0];
-        const float id = d ? 1.0f/d : 0.0f;
-
-        float sumqx = 0, sumq2 = 0;
-        for (int j = 0; j < QK4_NL/2; ++j) {
-            const float x0 = src[0        + j]*id;
-            const float x1 = src[QK4_NL/2 + j]*id;
-
-            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
-            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
-
-            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
-
-            const float v0 = kvalues_iq4nl_f[xi0];
-            const float v1 = kvalues_iq4nl_f[xi1];
-            const float w0 = src[0        + j]*src[0        + j];
-            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
-            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
-            sumq2 += w0*v0*v0 + w1*v1*v1;
-
-        }
-
-        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+        quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
     }
 }
 
@@ -6350,10 +6379,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows_q(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6373,10 +6402,10 @@ kernel void kernel_get_rows_q(
 
 template<typename T>
 kernel void kernel_get_rows_f(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6394,10 +6423,10 @@ kernel void kernel_get_rows_f(
 }
 
 kernel void kernel_get_rows_i32(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device     int32_t * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6414,6 +6443,67 @@ kernel void kernel_get_rows_i32(
     }
 }
 
+template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
+kernel void kernel_set_rows_q32(
+        constant ggml_metal_kargs_set_rows & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+
+    const int32_t i12 = i03%args.ne12;
+    const int32_t i11 = i02%args.ne11;
+
+    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+    if (i01 >= args.ne01) {
+        return;
+    }
+
+    const int32_t i10 = i01;
+    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+          device block_q * dst_row = (      device block_q *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
+    const device float   * src_row = (const device float   *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+        quantize_func(src_row + 32*ind, dst_row[ind]);
+    }
+}
+
+template<typename T>
+kernel void kernel_set_rows_f(
+        constant ggml_metal_kargs_set_rows & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+
+    const int32_t i12 = i03%args.ne12;
+    const int32_t i11 = i02%args.ne11;
+
+    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+    if (i01 >= args.ne01) {
+        return;
+    }
+
+    const int32_t i10 = i01;
+    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+    device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
+    const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+        dst_row[ind] = (T) src_row[ind];
+    }
+}
 
 #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
 #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6837,6 +6927,27 @@ template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get
 template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;
 template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
+//
+// set rows
+//
+
+typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
+
+template [[host_name("kernel_set_rows_f32")]]  kernel set_rows_f_t kernel_set_rows_f<float>;
+template [[host_name("kernel_set_rows_f16")]]  kernel set_rows_f_t kernel_set_rows_f<half>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
+#endif
+
+typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
+
+template [[host_name("kernel_set_rows_q8_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0,   quantize_q8_0>;
+template [[host_name("kernel_set_rows_q4_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0,   quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1,   quantize_q4_1>;
+template [[host_name("kernel_set_rows_q5_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0,   quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1,   quantize_q5_1>;
+template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
+
 //
 // matrix-matrix multiplication
 //

From a92973eff1e41e1cd6c3f049f15d381da520e897 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov <rgerganov@gmail.com>
Date: Mon, 23 Jun 2025 11:16:54 +0300
Subject: [PATCH 11/18] ggml : simplify forward_dup_f32

---
 ggml/src/ggml-cpu/ops.cpp | 22 +++-------------------
 1 file changed, 3 insertions(+), 19 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 8a51e4cf3afd7..c05ab91189c47 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
     if (ggml_is_contiguous(dst)) {
         // TODO: simplify
         if (nb00 == sizeof(float)) {
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+            if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
 
                 size_t id = 0;
                 size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
                         id += rs * ir0;
                         for (int i01 = ir0; i01 < ir1; i01++) {
                             const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            from_float(src0_ptr, dst_ptr + id, ne00);
                             id += rs;
                         }
                         id += rs * (ne01 - ir1);

From 1a5b2a16a6789122eb9d43364f70632b01d1b246 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov <rgerganov@gmail.com>
Date: Mon, 23 Jun 2025 11:25:16 +0300
Subject: [PATCH 12/18] ggml : fix supports_op

---
 ggml/src/ggml-cpu/ggml-cpu.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
index 735ef3f015c13..cc9b922fac640 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
 
     switch (op->op) {
         case GGML_OP_CPY:
+        case GGML_OP_SET_ROWS:
             return
                 op->type != GGML_TYPE_IQ3_XXS &&
                 op->type != GGML_TYPE_IQ3_S   &&

From 3c331241b8a0251a73733e79774f11d006b4145c Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Thu, 26 Jun 2025 10:42:37 +0300
Subject: [PATCH 13/18] tests : add comment to set_rows

---
 tests/test-backend-ops.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index f5cd5b695bb16..8de3ec315e9ac 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1267,6 +1267,7 @@ struct test_set_rows : public test_case {
 
                 for (int i2 = 0; i2 < t->ne[2]; i2++) {
                     for (int i1 = 0; i1 < t->ne[1]; i1++) {
+                        // generate a shuffled subset of row indices
                         std::vector<int64_t> data(m);
                         for (int i = 0; i < m; i++) {
                             data[i] = i;

From 929e11844c24f45ce3ebbb20d53ae9b38636485a Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Fri, 27 Jun 2025 11:08:14 +0300
Subject: [PATCH 14/18] ggml : leave the repeat_i64 for a separate PR

ggml-ci
---
 ggml/src/ggml-cpu/ops.cpp | 56 +++++----------------------------------
 1 file changed, 6 insertions(+), 50 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index c05ab91189c47..cfd0b5fbbdf71 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -2266,52 +2266,6 @@ static void ggml_compute_forward_repeat_f16(
     }
 }
 
-static void ggml_compute_forward_repeat_i64(
-        const ggml_compute_params * params,
-        ggml_tensor * dst) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(int64_t));
-    GGML_ASSERT(nb00 == sizeof(int64_t));
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                int64_t * y = (int64_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
-                                int64_t * x = (int64_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
-                                for (int i = 0; i < ne00; ++i) {
-                                    y[i]  = x[i];
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
 void ggml_compute_forward_repeat(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
@@ -2330,10 +2284,12 @@ void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, dst);
             } break;
-        case GGML_TYPE_I64:
-            {
-                ggml_compute_forward_repeat_i64(params, dst);
-            } break;
+        // TODO: templateify the implemenation and support for I64
+        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
+        //case GGML_TYPE_I64:
+        //    {
+        //        ggml_compute_forward_repeat_i64(params, dst);
+        //    } break;
         default:
             {
                 GGML_ABORT("fatal error");

From 56d691401080f603f22343499fa8053390def525 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Fri, 27 Jun 2025 11:11:10 +0300
Subject: [PATCH 15/18] ggml : set_rows use std::min instead of MIN

---
 ggml/src/ggml-cpu/ops.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index cfd0b5fbbdf71..048456eac0fb7 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4483,11 +4483,11 @@ static void ggml_compute_forward_set_rows_f32(
     const int nth = params->nth;
 
     // rows per thread
-    const int dr = (nr + nth - 1)/nth;
+    const int64_t dr = (nr + nth - 1)/nth;
 
     // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = std::min(ir0 + dr, nr);
 
     ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
 

From 8f0f615071d17e3c937e7de6d4cb8c6d5a3bd7c3 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Fri, 27 Jun 2025 11:13:18 +0300
Subject: [PATCH 16/18] ggml : better error message for set_rows unsupported
 type

---
 ggml/src/ggml-cpu/ops.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 048456eac0fb7..0f214dc3c3202 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4523,7 +4523,7 @@ void ggml_compute_forward_set_rows(
             } break;
         default:
             {
-                GGML_ABORT("fatal error");
+                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
             }
     }
 }

From 838e89d0386c0f8b9b65ed4ec95dbecd9730ab13 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Fri, 27 Jun 2025 11:14:11 +0300
Subject: [PATCH 17/18] metal : perform op->type check only once

---
 ggml/src/ggml-metal/ggml-metal.m | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index fd138ef0e7c0c..8d19bd502667d 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -1647,8 +1647,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     const bool use_bfloat              = ctx_dev->use_bfloat;
 
     if (!use_bfloat) {
+        if (op->type == GGML_TYPE_BF16) {
+            return false;
+        }
+
         for (size_t i = 0, n = 3; i < n; ++i) {
-            if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
+            if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
                 return false;
             }
         }

From f46ddbae7afc064e911bf2565a2ce95f3b09aa02 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Fri, 27 Jun 2025 11:41:15 +0300
Subject: [PATCH 18/18] tests : more consistent implementation + more tests

ggml-ci
---
 tests/test-backend-ops.cpp | 45 +++++++++++++++++++++-----------------
 1 file changed, 25 insertions(+), 20 deletions(-)

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 8de3ec315e9ac..07b67b184a805 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1216,37 +1216,34 @@ struct test_get_rows_back : public test_case {
 // GGML_OP_SET_ROWS
 struct test_set_rows : public test_case {
     const ggml_type type;
-    const int n; // cols
-    const int m; // rows
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 2> nr23; // broadcast only dims 2 and 3
     const int r; // rows to set
-    const int b0; // batch size
-    const int b1; // batch size
-    const int bs; // batch size src (for testing broadcast)
     const bool v; // view (non-contiguous src1)
 
     std::string vars() override {
-        return VARS_TO_STR7(type, n, m, r, b0, bs, v);
+        return VARS_TO_STR5(type, ne, nr23, r, v);
     }
 
-    test_set_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, int bs = 1, bool v = false)
-        : type(type), n(n), m(m), r(r), b0(b), b1(3), bs(bs), v(v) {
-            GGML_ASSERT(b0 % bs == 0 && "b0 must be a multiple of bs");
-            GGML_ASSERT(r <= m && "r must be less than or equal to m");
-        }
+    test_set_rows(ggml_type type,
+            std::array<int64_t, 4> ne,
+            std::array<int, 2> nr23,
+            int r, bool v = false)
+        : type(type), ne(ne), nr23(nr23), r(r), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, n, m, b0, b1);
+        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type,          ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
         ggml_set_name(dst, "dst");
 
-        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n, r, b0, b1);
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r,     ne[2]*nr23[0], ne[3]*nr23[1]);
         ggml_set_name(src, "src");
 
-        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, bs, b1);
+        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
         ggml_set_name(row_idxs, "row_idxs");
 
         if (v) {
-            src      = ggml_view_4d(ctx, src, n, r/2, b0, b1, src->nb[1], src->nb[2], src->nb[3], 0);
-            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, bs, b1, row_idxs->nb[1], row_idxs->nb[2], 0);
+            src      = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
+            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
             ggml_set_name(row_idxs, "view_of_rows");
         }
 
@@ -1268,8 +1265,8 @@ struct test_set_rows : public test_case {
                 for (int i2 = 0; i2 < t->ne[2]; i2++) {
                     for (int i1 = 0; i1 < t->ne[1]; i1++) {
                         // generate a shuffled subset of row indices
-                        std::vector<int64_t> data(m);
-                        for (int i = 0; i < m; i++) {
+                        std::vector<int64_t> data(ne[1]);
+                        for (int i = 0; i < ne[1]; i++) {
                             data[i] = i;
                         }
                         std::shuffle(data.begin(), data.end(), rng);
@@ -4057,11 +4054,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
     }
 
-    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
     for (ggml_type type : all_types) {
         for (int b : {1, 7}) {
             for (bool v : {false, true}) {
-                test_cases.emplace_back(new test_set_rows(type, 256, 5, 4, b, 1, v));
+                test_cases.emplace_back(new test_set_rows(type, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));
+                test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
+
+                test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
+
+                if (ggml_blck_size(type) == 1) {
+                    test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
+                    test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
+                }
             }
         }
     }