Skip to content

Commit

Permalink
add mulmat for q8_0
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed May 8, 2024
1 parent 9461051 commit 9f46a12
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 5 deletions.
13 changes: 8 additions & 5 deletions ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ ggml_backend_cann_buffer_type(int32_t device) {

static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
struct ggml_tensor* dst) {
std::cout<<"CANN OP = "<<ggml_op_name(dst->op)<<std::endl;
// std::cout<<"CANN OP = "<<ggml_op_name(dst->op)<<std::endl;
switch (dst->op) {
case GGML_OP_REPEAT:
ggml_cann_repeat(ctx, dst);
Expand Down Expand Up @@ -609,6 +609,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_rms_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT:
ggml_cann_mul_mat(ctx, dst);
break;
case GGML_OP_MUL_MAT_ID:
return false;
case GGML_OP_SCALE:
Expand Down Expand Up @@ -829,9 +831,9 @@ GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute(

ggml_cann_set_device(cann_ctx->device);

for (int i = 0; i < cgraph->n_nodes; i++) {
std::cout<<"OP: "<<ggml_op_name(cgraph->nodes[i]->op)<<std::endl;
}
// for (int i = 0; i < cgraph->n_nodes; i++) {
// std::cout<<"OP: "<<ggml_op_name(cgraph->nodes[i]->op)<<std::endl;
// }

for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor* node = cgraph->nodes[i];
Expand Down Expand Up @@ -873,6 +875,7 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
return false;
}
case GGML_OP_MUL_MAT:
return true;
case GGML_OP_MUL_MAT_ID:
// embedding
case GGML_OP_GET_ROWS:
Expand Down Expand Up @@ -1065,7 +1068,7 @@ GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params,
extern "C" GGML_CALL int ggml_backend_cann_reg_devices();

GGML_CALL int ggml_backend_cann_reg_devices() {

aclInit(nullptr);
uint32_t device_count = ggml_backend_cann_get_device_count();
// initialization
for (uint32_t i = 0; i < device_count; i++) {
Expand Down
Loading

0 comments on commit 9f46a12

Please sign in to comment.