Skip to content

Commit 9461051

Browse files
committed
add q8_t transform
1 parent 0c159d8 commit 9461051

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

ggml-cann.cpp

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, cons
142142
GGML_ASSERT(tensor->extra == nullptr);
143143
GGML_ASSERT(tensor->op == GGML_OP_NONE);
144144

145-
void *buffer_host;
146145
size_t n_bytes = ggml_nbytes(tensor);
147146
int64_t n_elems = ggml_nelements(tensor);
148147
int64_t groups = n_elems / QK4_0;
@@ -176,7 +175,6 @@ GGML_CALL static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor* t
176175
GGML_ASSERT(tensor->extra == nullptr);
177176
GGML_ASSERT(tensor->op == GGML_OP_NONE);
178177

179-
void *buffer_host;
180178
size_t n_bytes = ggml_nbytes(tensor);
181179
int64_t n_elems = ggml_nelements(tensor);
182180
int64_t groups = n_elems / QK4_0;
@@ -206,12 +204,66 @@ GGML_CALL static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor* t
206204
}
207205
}
208206

207+
#define QK8_0 32
208+
typedef struct {
209+
uint16_t d; // delta
210+
int8_t qs[QK8_0]; // quants
211+
} block_q8_0;
212+
213+
GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, const void *src, void* dst) {
214+
GGML_ASSERT(tensor->extra == nullptr);
215+
GGML_ASSERT(tensor->op == GGML_OP_NONE);
216+
217+
size_t n_bytes = ggml_nbytes(tensor);
218+
int64_t n_elems = ggml_nelements(tensor);
219+
int64_t groups = n_elems / QK8_0;
220+
size_t quant_bytes = n_elems * sizeof(uint8_t);
221+
222+
uint8_t* quant_offset = (uint8_t*)dst;
223+
uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
224+
225+
for (int i = 0;i<groups; i++) {
226+
block_q8_0 *group = (block_q8_0*)((char*)src + i * sizeof(block_q8_0));
227+
*scale_offset = group->d;
228+
scale_offset++;
229+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
230+
memcpy(quant_offset, group->qs, group_quant_size);
231+
quant_offset += group_quant_size;
232+
}
233+
}
234+
235+
GGML_CALL static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor* tensor, const void *src, void* dst) {
236+
GGML_ASSERT(tensor->extra == nullptr);
237+
GGML_ASSERT(tensor->op == GGML_OP_NONE);
238+
239+
size_t n_bytes = ggml_nbytes(tensor);
240+
int64_t n_elems = ggml_nelements(tensor);
241+
int64_t groups = n_elems / QK8_0;
242+
size_t quant_bytes = n_elems * sizeof(uint8_t);
243+
244+
uint8_t* quant_offset = (uint8_t*)src;
245+
uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
246+
247+
for (int i = 0;i<groups; i++) {
248+
block_q8_0 *group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
249+
group->d = *scale_offset;
250+
scale_offset++;
251+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
252+
memcpy(group->qs, quant_offset, group_quant_size);
253+
quant_offset += group_quant_size;
254+
}
255+
}
256+
257+
209258
GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor, const void* src, void *dst) {
210259
std::cout<<"Transform tensor:"<<tensor->name<<std::endl;
211260
switch (tensor->type) {
212261
case GGML_TYPE_Q4_0:
213262
ggml_backend_cann_transform_q4_0(tensor, src, dst);
214263
break;
264+
case GGML_TYPE_Q8_0:
265+
ggml_backend_cann_transform_q8_0(tensor, src, dst);
266+
break;
215267
default:
216268
break;
217269
}
@@ -223,6 +275,9 @@ GGML_CALL static void ggml_backend_cann_transform_back(const ggml_tensor* tensor
223275
case GGML_TYPE_Q4_0:
224276
ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
225277
break;
278+
case GGML_TYPE_Q8_0:
279+
ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
280+
break;
226281
default:
227282
break;
228283
}
@@ -231,6 +286,7 @@ GGML_CALL static void ggml_backend_cann_transform_back(const ggml_tensor* tensor
231286
GGML_CALL static bool need_transform(ggml_type type) {
232287
switch (type) {
233288
case GGML_TYPE_Q4_0:
289+
case GGML_TYPE_Q8_0:
234290
return true;
235291
default:
236292
return false;
@@ -820,7 +876,16 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
820876
case GGML_OP_MUL_MAT_ID:
821877
// embedding
822878
case GGML_OP_GET_ROWS:
823-
return false;
879+
{
880+
switch (op->src[0]->type) {
881+
//case GGML_TYPE_Q4_0:
882+
case GGML_TYPE_Q8_0:
883+
return true;
884+
default:
885+
return false;
886+
}
887+
}
888+
break;
824889
case GGML_OP_CPY:
825890
case GGML_OP_DUP:
826891
case GGML_OP_REPEAT:

ggml-cann/aclnn_ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,4 +1687,11 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
16871687
ACL_CHECK(aclDestroyScalar(acl_scale));
16881688
ACL_CHECK(aclDestroyTensor(temp_tensor));
16891689
ACL_CHECK(aclDestroyTensor(temp_output_tensor));
1690+
}
1691+
1692+
void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1693+
ggml_tensor* src0 = dst->src[0];
1694+
ggml_tensor* src1 = dst->src[1];
1695+
1696+
16901697
}

0 commit comments

Comments
 (0)