Skip to content

ubatch : new splitting logic #14217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open

ubatch : new splitting logic #14217

wants to merge 17 commits into from

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 16, 2025

  • Remove llama_sbatch
  • llama_batch_allocr now handles ubatch splitting
  • llama_batch_allocr precomputes various index maps and guarantees the inputs are consistent
  • llama_ubatch can now iterate over unique sequence ids
  • Change notion of llama_ubatch.n_seqs from "number of sequences" to "number of sequence sets"
  • Enable pooling for n_tokens <= seq_id. Remove padding hack from llama-server
  • Detailed batch debug output

TODO:

  • Fix this:
    make -j && LLAMA_BATCH_DEBUG=2 ./bin/llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q8_0 --image ~/Downloads/rects.png -p "Please first output bbox coordinates and colors of every rectangle in this image in JSON format, and then answer how many rectangles are there in the image." --seed 1 -ngl 99 --temp 0.0 -c 20000 -b 1

@ggerganov ggerganov marked this pull request as ready for review June 17, 2025 12:36
@ggerganov ggerganov requested a review from ngxson as a code owner June 17, 2025 12:36
@ggerganov ggerganov requested a review from compilade June 17, 2025 12:36
@compilade
Copy link
Collaborator

compilade commented Jun 17, 2025

This breaks shuffled batches for equal splits.

When running test-model-random (from #14139) with this I get

Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
init: sequence 0 does not start from the last position stored in the memory
decode: failed to initialize batch
llama_decode: failed to decode, ret = -1
get_logits_ith: invalid logits id 0, reason: batch.logits[0] != true
/path/to/llama.cpp/tests/test-model-random.cpp:841: GGML_ASSERT(out) failed

But there's also something else which did not happen before:

Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=1: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=2: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=512: OK
Error for seq_id 3 is 0.008624 at n_past=525
Error for seq_id 4 is 0.005619 at n_past=487
Error for seq_id 4 is 0.133501 at n_past=590
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=1: (40%) FAILED
Error for seq_id 3 is 0.008624 at n_past=525
Error for seq_id 4 is 0.005619 at n_past=487
Error for seq_id 4 is 0.133501 at n_past=590
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=2: (40%) FAILED
Error for seq_id 3 is 0.008624 at n_past=525
Error for seq_id 4 is 0.005619 at n_past=487
Error for seq_id 4 is 0.133501 at n_past=590
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=512: (40%) FAILED

Seems like multiple sequences with chunked SWA have some inconsistency.

@ggerganov
Copy link
Member Author

But there's also something else which did not happen before:

Does it trigger consistently? It passes on my end:

  gg/ubatch-rework [¡1⇡8]  +1 -1  ~/development/github/llama.cpp/build 
 18:48:42  git show
commit c0df4490c4d6e04ec8e2421fdba2655cbc3d5b44 (HEAD -> gg/ubatch-rework)
Merge: cc7952b42 04b8f5143
Author: Georgi Gerganov <[email protected]>
Date:   Tue Jun 17 18:44:41 2025 +0300
    Merge remote-tracking branch 'origin/compilade/test-model-random' into gg/ubatch-rework
  gg/ubatch-rework [¡1⇡8]  +1 -1  ~/development/github/llama.cpp/build 
 18:48:48  git diff
diff --git a/tests/test-model-random.cpp b/tests/test-model-random.cpp
index 218cfcb82..b5c1d7248 100644
--- a/tests/test-model-random.cpp
+++ b/tests/test-model-random.cpp
@@ -1004,7 +1004,7 @@ int main(int argc, char ** argv) {
                     llama_free(ref_ctx);
                 }
 
-                for (bool shuffle : { false, true }) {
+                for (bool shuffle : { false, }) {
 
                     // skip shuffling the batch for non-recurrent models
                     // (simple splits don't handle shuffled batches correctly)
  gg/ubatch-rework [¡1⇡8]  +1 -1  ~/development/github/llama.cpp/build 
 18:48:57  a=$(make -j > /dev/null) && ./bin/test-model-random
..............
Comparing output for 'Llama2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=1: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=2: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=512: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=1: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=2: OK
Comparing output for 'Llama2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=512: OK
.............................
Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=1: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=2: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=512: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=1: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=2: OK
Comparing output for 'Llama4', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=512: OK
................
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=1: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=2: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=512: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=1: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=2: OK
Comparing output for 'Gemma2', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=512: OK
............
Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=1: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=2: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=1, n_ctx=643, n_ubatch=512: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=1: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=2: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=2, n_ctx=1286, n_ubatch=512: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=1: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=2: OK
Comparing output for 'Mamba', with shuffle=0, n_seq_max=5, n_ctx=3215, n_ubatch=512: OK

@compilade
Copy link
Collaborator

compilade commented Jun 17, 2025

Does it trigger consistently?

It does on a Pixel 9 Pro in Termux. But it seems like this might not be a regression from here since it also happens in #14139 (sorry, I didn't test that branch on this hardware before).

-- ARM feature DOTPROD enabled
-- ARM feature SVE enabled
-- ARM feature MATMUL_INT8 enabled
-- ARM feature FMA enabled
-- ARM feature FP16_VECTOR_ARITHMETIC enabled
-- Adding CPU backend variant ggml-cpu: -mcpu=native+dotprod+i8mm+sve+nosme

Reproducing commands:

$ git switch compilade/test-model-random
$ mkdir build
$ cd build
$ cmake .. --fresh
$ make -j6 test-model-random
$ ./bin/test-model-random

So it's not a problem caused by this PR, sorry for misreporting.

(The shuffled batch regression however, is)

@ggerganov
Copy link
Member Author

Yes, I reproduce it on my Mac also when I disable Metal, or force ngl = 0. So it's very likely a bug in one of the CPU kernel.

@ggerganov
Copy link
Member Author

ggerganov commented Jun 17, 2025

My best guess is that the summation here overflows FP16:

#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F16_VEC_REDUCE(sumf, sum);

Applying this patch to make the accumulation use F32 (via the leftovers loop) fixes the issue:

diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index f7614568e..03044d382 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -198,7 +198,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
     ggml_float sumf = 0.0;
 
 #if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
+    const int np = 0;
 
     GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
 

The best fix for now is probably to set the KV cache types that the test uses to F32 - this also works.

This breaks shuffled batches for equal splits.

I'll take a look if this can be handled cleanly, but I'm wondering if this use case is really needed. Do you have any specific applications in mind that require shuffled positions in the input batch?

@compilade
Copy link
Collaborator

The best fix for now is probably to set the KV cache types that the test uses to F32 - this also works.

That's very likely what I'll end up doing, thanks. (although it's less representative of actual use)

Do you have any specific applications in mind that require shuffled positions in the input batch?

The main benefit is that it makes it really easy to test that sequence aggregation works correctly for proper splitting. If it works with shuffled batches, than it can work with pretty much anything.

For an actual use case, I'm not really sure.

I'll see how the test can be changed to not affect the relative order within the sequences but still shuffle the relative order of tokens of different sequences. This makes the test a bit harder to implement, though it would be more representative of the expected possible batch orderings (and should probably make the test viable for simple splits too).

@ggerganov
Copy link
Member Author

I'll see how the test can be changed to not affect the relative order within the sequences but still shuffle the relative order of tokens of different sequences.

Ok, that would be useful. Regarding the fully shuffled batches, I will add checks for such inputs and raise an error.

return ubatch_add(idxs, idxs.size(), false);
}

llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
Copy link
Collaborator

@compilade compilade Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that for equal splits, some sequence sets are not compatible (i.e. they can't be put in the same ubatch). For example, a sequence set containing multiple seq_ids cannot be mixed with one having a seq_id in the multi-sequence set.

For example, tokens with seq_ids = { 0, 1, 2, 3 } are not compatible with tokens in seq_ids = { 1 }.

The reason is that the recurrent states are only copied to the target sequences on ubatch boundaries, and so dependant tokens cannot be mixed with a shared trunk.

Is this handled here?

Basically the main constraint to check would be that the sequence sets in a ubatch are independent (at least, I think that would be sufficient?).

(Before this PR, it was handled by splitting multi-sequence token groups on their own before the single-sequence tokens)

(I did not implement multi-sequence tests yet in #14139, but that should also be able to answer this question once implemented)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, a sequence set containing multiple seq_ids cannot be mixed with one having a seq_id in the multi-sequence set.

Yes, this logic here at the beginning of the function determines the unique non-overlapping sequence sets that will be contained in this ubatch:

// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
continue;
}
bool add = true;
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
// no overlap with existing sequence sets:
if (!(cur_seq_set[s] & seq_set[i]).none()) {
add = false;
break;
}
}
if (add) {
cur_seq_set.push_back(seq_set[i]);
if (cur_seq_set.size() > n_ubatch) {
break;
}
}
}
const uint32_t n_seqs = cur_seq_set.size();

@ggerganov
Copy link
Member Author

@compilade FYI tentative plan is to first merge #13979 and after that to merge this PR (unless you spot some more issues). ETA probably tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants