Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Intel Advanced Matrix Extensions (AMX) support to ggml #8998

Merged
merged 1 commit into from
Oct 18, 2024

Conversation

mingfeima
Copy link
Collaborator

replacement of #7707 to trigger ggml-ci on amx

@mingfeima mingfeima marked this pull request as draft August 12, 2024 06:58
@github-actions github-actions bot added build Compilation issues ggml changes relating to the ggml tensor library for machine learning labels Aug 12, 2024
@mingfeima mingfeima changed the title Pr add intel amx support Add Intel Advanced Matrix Extensions (AMX) support to ggml Aug 12, 2024
@ggerganov
Copy link
Owner

To trigger ggml-ci you need to include the string "ggml-ci" somewhere in the commit message. For example: 5ef07e2

@mingfeima
Copy link
Collaborator Author

@ggerganov could you please take a look at this one? I have moved the amx init code from ggml.c to ggml-amx/mmq.cpp according to previous comments.

Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, GGML_AMX is always enabled with llama.cpp, and whether AMX is actually used depends entirely on using -march=native with GGML_NATIVE, and building on a machine with AMX support. This is not ideal because it is a departure from the way the other x86 instruction sets are handled, and it doesn't allow for cross-compilation. I think it would be better to handle this in the same way as any other x86 instruction set, and add an explicit compiler option to enable this architecture (-mamx-int8?) when using GGML_AMX without GGML_NATIVE, or let it be enabled automatically by -march=native when using GGML_NATIVE.

@mingfeima
Copy link
Collaborator Author

@slaren just updated cmake compiler options: -mamx-tile, -mamx-int8 and -mamx-bf16!

@ggerganov ggerganov requested a review from slaren August 16, 2024 07:59
ggml/src/ggml-amx/mmq.cpp Outdated Show resolved Hide resolved
@mingfeima
Copy link
Collaborator Author

@ggerganov
Copy link
Owner

failure 124 means that the run timeout. There is currently a limit of 30 minutes for these runs and in this case it was exceeded.

It's not related to this PR - the CUDA build time has recently increased so this run is timeouting from time to time. I just restarted it to see if it would pass.

Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the issues with is_host pointed below, the ggml-backend interface implementation looks good. Unfortunately fixing these issues may result in a small hit in performance since it may cause some additional copies when data needs to be moved between the CPU and AMX backends, and fixing that will require changes to the ggml-backend interface.

The llama.cpp side will probably need some changes. I expect that the current implementation won't work with KV quantization. I cannot test this, but I think changing it so that the AMX buffer type is only used for the weights may work better, while also avoiding the need to use -ngl:

diff --git a/src/llama.cpp b/src/llama.cpp
index b85e8acd..13d70ec1 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3462,8 +3462,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
     }
 #elif defined(GGML_USE_CANN)
     buft = ggml_backend_cann_buffer_type(local_gpu);
-#elif defined(GGML_USE_AMX)
-    buft = ggml_backend_amx_buffer_type();
 #endif

     if (buft == nullptr) {
@@ -6865,7 +6863,14 @@ static bool llm_load_tensors(

     // assign cpu layers
     for (int i = 0; i < i_gpu_start; ++i) {
+    #ifdef GGML_USE_AMX
+        model.buft_layer[i] = {
+            llama_default_buffer_type_cpu(true),
+            ggml_backend_amx_buffer_type()
+        };
+    #else
         model.buft_layer[i] = llama_default_buffer_type_cpu(true);
+    #endif
     }

     if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
@@ -18587,11 +18592,6 @@ struct llama_model_params llama_model_default_params() {
     result.n_gpu_layers = 999;
 #endif

-#ifdef GGML_USE_AMX
-    // by default offload all layers to AMX
-    result.n_gpu_layers = 999;
-#endif
-
     return result;
 }

I also expect that this implementation will have issues when built with a GPU backend such as CUDA that allows the weights to be copied to VRAM when evaluating large batches (>32 tokens), although that could be fixed by implementing conversion back to standard ggml format in ggml_backend_amx_buffer_get_tensor.

ggml/src/ggml-amx.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-amx.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
@mingfeima
Copy link
Collaborator Author

@slaren after changing is_host to false from the AMX backend leads to an fault from ggml_backend_sched_backend_id_from_cur (log attached below). Do you have any insight how to fix it?

llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        AMX KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
ggml/src/ggml-backend.c:1204: pre-allocated tensor in a backend that cannot run the operation
[New LWP 2746117]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007f7d1a7205a2 in waitpid () from /lib64/libpthread.so.0
#0  0x00007f7d1a7205a2 in waitpid () from /lib64/libpthread.so.0
#1  0x000000000048a648 in ggml_print_backtrace () at ggml/src/ggml.c:282
282             waitpid(pid, &wstatus, 0);
#2  ggml_abort (file=file@entry=0x6788d4 "ggml/src/ggml-backend.c", line=line@entry=1204, fmt=fmt@entry=0x678c50 "pre-allocated tensor in a backend that cannot run the operation") at ggml/src/ggml.c:309
309         ggml_print_backtrace();
#3  0x00000000004cf025 in ggml_backend_sched_backend_id_from_cur (sched=0x172fd20, tensor=0x5305e10) at ggml/src/ggml-backend.c:1204
1204            GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
#4  0x00000000004d127c in ggml_backend_sched_split_graph (sched=sched@entry=0x172fd20, graph=graph@entry=0x1ada190) at ggml/src/ggml-backend.c:1337
1337                *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
#5  0x00000000004d2dde in ggml_backend_sched_split_graph (graph=0x1ada190, sched=<optimized out>) at ggml/src/ggml-backend.c:1327
1327        if (sched->ctx == NULL) {
#6  ggml_backend_sched_reserve (sched=0x172fd20, measure_graph=0x1ada190) at ggml/src/ggml-backend.c:1992
1992        ggml_backend_sched_split_graph(sched, measure_graph);
#7  0x000000000053204b in llama_new_context_with_model (model=0x1729f30, params=...) at src/llama.cpp:19176
19176               if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
#8  0x000000000060a48d in llama_init_from_gpt_params (params=...) at common/common.cpp:843
843         llama_context * lctx = llama_new_context_with_model(model, cparams);
#9  0x000000000043894d in main (argc=<optimized out>, argv=0x7fffb9a8de18) at examples/main/main.cpp:200
200         llama_init_result llama_init = llama_init_from_gpt_params(params);
[Inferior 1 (process 2746116) detached]
./run_generate_cpu.sh: line 26: 2746116 Aborted                 (core dumped) $PREFIX $main -m ./models/Meta-Llama-3-8B-Instruct-GGUF/$MODEL -t $CORES -n 5 -p "$prompt" --no-mmap

@slaren
Copy link
Collaborator

slaren commented Sep 30, 2024

I think that may be related to KV operations, which should be fixed with the change I suggested before. By making llama_default_buffer_type_offload return the AMX buffer type, it will cause the KV cache to be allocated on an AMX buffer, which is not good. If that doesn't fix it, please add some prints to show the tensor that is causing the error.

@mingfeima
Copy link
Collaborator Author

@slaren could you please help review this one again? just changed ggml_backend_buft_is_host to return false for amx backend.

ggml/src/ggml-backend.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
ggml/src/ggml-amx.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-amx.cpp Outdated Show resolved Hide resolved
@slaren
Copy link
Collaborator

slaren commented Oct 17, 2024

Looks good to me, feel free to merge this at any point.

@ggerganov
Copy link
Owner

ggerganov commented Oct 17, 2024

@slaren Thank you for the detailed review.

@mingfeima Remember to squash the commits when merging as explained in the contributing guidelines. Btw, I just restarted the ggml-ci node with AMX instruction set support, so we might want to wait for ggml-ci to run before merging. Will run it on this branch shortly.

Edit: the AMX CI has passed

@ggerganov
Copy link
Owner

Would recommend using 4 spaces for indentation for conformance with the rest of the codebase.

add intel amx isa detection

add vnni kernel for gemv cases

add vnni and amx kernel support for block_q8_0

code cleanup

fix packing B issue

enable openmp

fine tune amx kernel

switch to aten parallel pattern

add error message for nested parallelism

code cleanup

add f16 support in ggml-amx

add amx kernels for QK_K quant formats: Q4_K, Q5_K, Q6_K and IQ4_XS

update CMakeList

update README

fix some compilation warning

fix compiler warning when amx is not enabled

minor change

ggml-ci

move ggml_amx_init from ggml.c to ggml-amx/mmq.cpp

ggml-ci

update CMakeLists with -mamx-tile, -mamx-int8 and -mamx-bf16

ggml-ci

add amx as an ggml-backend

update header file, the old path for immintrin.h has changed to ggml-cpu-impl.h

minor change

update CMakeLists.txt

minor change

apply weight prepacking in set_tensor method in ggml-backend

fix compile error

ggml-ci

minor change

ggml-ci

update CMakeLists.txt

ggml-ci

add march dependency

minor change

ggml-ci

change ggml_backend_buffer_is_host to return false for amx backend

ggml-ci

fix supports_op

use device reg for AMX backend

ggml-ci

minor change

ggml-ci

minor change

fix rebase

set .buffer_from_host_ptr to be false for AMX backend
@mingfeima
Copy link
Collaborator Author

@ggerganov changed to tab with 4 spaces. also the branch is rebased to squash into one.

@ggerganov
Copy link
Owner

Nice, great job! Feel free to merge this - you should have the access to do so.

@mingfeima mingfeima merged commit 60ce97c into master Oct 18, 2024
60 checks passed
@mingfeima
Copy link
Collaborator Author

@slaren thanks a lot for your review!

drollings pushed a commit to drollings/llama.cpp that referenced this pull request Oct 18, 2024
add intel amx isa detection

add vnni kernel for gemv cases

add vnni and amx kernel support for block_q8_0

code cleanup

fix packing B issue

enable openmp

fine tune amx kernel

switch to aten parallel pattern

add error message for nested parallelism

code cleanup

add f16 support in ggml-amx

add amx kernels for QK_K quant formats: Q4_K, Q5_K, Q6_K and IQ4_XS

update CMakeList

update README

fix some compilation warning

fix compiler warning when amx is not enabled

minor change

ggml-ci

move ggml_amx_init from ggml.c to ggml-amx/mmq.cpp

ggml-ci

update CMakeLists with -mamx-tile, -mamx-int8 and -mamx-bf16

ggml-ci

add amx as an ggml-backend

update header file, the old path for immintrin.h has changed to ggml-cpu-impl.h

minor change

update CMakeLists.txt

minor change

apply weight prepacking in set_tensor method in ggml-backend

fix compile error

ggml-ci

minor change

ggml-ci

update CMakeLists.txt

ggml-ci

add march dependency

minor change

ggml-ci

change ggml_backend_buffer_is_host to return false for amx backend

ggml-ci

fix supports_op

use device reg for AMX backend

ggml-ci

minor change

ggml-ci

minor change

fix rebase

set .buffer_from_host_ptr to be false for AMX backend
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants