Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Jun 21, 2024
1 parent 00b8598 commit 8cb2abc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
12 changes: 7 additions & 5 deletions ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,13 @@ GGML_CALL static void ggml_backend_cann_buffer_init_tensor(
ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
if (tensor->view_src != NULL && tensor->view_offs == 0) {
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
tensor->backend = tensor->view_src->backend;
set_tensor_extra(buffer, tensor);
return;
}

tensor->backend = GGML_BACKEND_TYPE_GPU;

// TODO: can backend doesn't support quantized yet. Just leave the code
// here.
if (ggml_is_quantized(tensor->type)) {
Expand Down Expand Up @@ -686,6 +689,7 @@ GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) {
GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
ACL_CHECK(aclrtSynchronizeDevice());
cann_ctx->free_device_buffers();
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
Expand All @@ -708,23 +712,21 @@ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
size_t size) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_backend_buffer_t buf =
tensor->view_src ? tensor->view_src->buffer : tensor->buffer;

if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpyAsync(tensor->data, size, (char*)data + offset,
ACL_CHECK(aclrtMemcpyAsync(tensor->data, size, (const char*)data + offset,
size, ACL_MEMCPY_HOST_TO_DEVICE,
cann_ctx->stream()));
} else {
void* transform_buffer = malloc(size);
ggml_backend_cann_transform(tensor, (char*)data + offset,
ggml_backend_cann_transform(tensor, (const char*)data + offset,
transform_buffer);

#ifndef NDEBUG
void* check_buffer = malloc(size);
ggml_backend_cann_transform_back(tensor, transform_buffer,
check_buffer);
GGML_ASSERT(memcmp((char*)data + offset, check_buffer, size));
GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size));
free(check_buffer);
#endif
ACL_CHECK(aclrtMemcpyAsync(tensor->data, size, transform_buffer, size,
Expand Down
1 change: 0 additions & 1 deletion ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,6 @@ static void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
ggml_tensor* src2 = dst->src[2]; // pos

aclTensor* acl_src0 = create_acl_tensor(src0);
aclTensor* acl_dst = create_acl_tensor(dst);
Expand Down
12 changes: 12 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,8 @@ static size_t llama_get_device_count(const llama_model & model) {
count = ggml_backend_sycl_get_device_count();
#elif defined(GGML_USE_VULKAN)
count = ggml_backend_vk_get_device_count();
#elif defined(GGML_USE_CANN)
count = ggml_backend_cann_get_device_count();
#endif
#if defined(GGML_USE_RPC)
count += model.rpc_servers.size();
Expand Down Expand Up @@ -2551,6 +2553,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
if (buft == nullptr) {
LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
}
#elif defined(GGML_USE_CANN)
buft = ggml_backend_cann_buffer_type(gpu);
#endif

if (buft == nullptr) {
Expand Down Expand Up @@ -16219,6 +16223,10 @@ void llama_backend_init(void) {
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}

#if defined(GGML_USE_CANN)
ggml_cann_backend_init();
#endif
}

void llama_numa_init(enum ggml_numa_strategy numa) {
Expand All @@ -16228,6 +16236,10 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
}

void llama_backend_free(void) {
#if defined(GGML_USE_CANN)
ggml_cann_backend_free();
#endif

ggml_quantize_free();
}

Expand Down

0 comments on commit 8cb2abc

Please sign in to comment.