diff --git a/common/common.cpp b/common/common.cpp index 8661e164ada6b..859e726afb1b8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat << ", pos " << std::to_string(batch.pos[i]) << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) << ", seq_id " << std::to_string(batch.seq_id[i][0]) - << ", logits " << std::to_string(batch.logits[i]); + << ", output " << std::to_string(batch.output[i]); } buf << " ]"; @@ -1617,7 +1617,7 @@ void common_batch_add( llama_token id, llama_pos pos, const std::vector & seq_ids, - bool logits) { + bool output) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); batch.token [batch.n_tokens] = id; @@ -1626,7 +1626,7 @@ void common_batch_add( for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } - batch.logits [batch.n_tokens] = logits; + batch.output [batch.n_tokens] = output; batch.n_tokens++; } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f119a7..1f1c956274f57 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -128,7 +128,7 @@ int main(int argc, char ** argv) { common_batch_add(batch, 0, i, { j }, false); } } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 55c31166ca278..18b6a21d8ca49 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() { if let seq_id = batch.seq_id[i] { seq_id[0] = 0 } - batch.logits[i] = 0 + batch.output[i] = 0 } // llama_decode will output logits only for the last token of the prompt -batch.logits[Int(batch.n_tokens) - 1] = 1 +batch.output[Int(batch.n_tokens) - 1] = 1 if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -171,7 +171,7 @@ while n_cur <= n_len { if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) } - batch.logits[Int(batch.n_tokens)] = 1 + batch.output[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e4e83..7d2a82b518099 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38d22c90f82bb..95445b5ef68d3 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2a73983a9832f..1718d6b4f525d 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( common_batch_add(*batch, 0, i, { 0 }, false); } - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; if (llama_decode(context, *batch) != 0) { LOGe("llama_decode() failed"); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index ee7141a663224..7b4a55f2fc9da 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) { batch.n_tokens = 0 } -func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { +func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) { batch.token [Int(batch.n_tokens)] = id batch.pos [Int(batch.n_tokens)] = pos batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) for i in 0.. n_seq_id; std::vector seq_id_0; std::vector seq_ids; - std::vector logits; + std::vector outputs; llama_batch batch; llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { pos .resize(n_tokens); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); + outputs .resize(n_tokens); seq_id_0.resize(1); seq_id_0[0] = seq_id; seq_ids [n_tokens] = nullptr; @@ -458,13 +458,13 @@ struct llava_embd_batch { /*pos =*/ pos.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), + /*output =*/ outputs.data(), }; for (int i = 0; i < n_tokens; i++) { batch.pos [i] = pos_0 + i; batch.n_seq_id[i] = 1; batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; + batch.output [i] = false; } } }; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ef43d5e12876..3f87c0a1aa53e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { // extract the logits only for the last token if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } client.n_prompt = tokens_prompt.size(); @@ -309,7 +309,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 5953928d47d33..15f99bcdd9087 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -146,7 +146,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9bf6c57433ab2..2b194b8d9bc74 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & batch.pos [idx] = j*n_batch + k; batch.n_seq_id[idx] = 1; batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + batch.output [idx] = batch.pos[idx] >= first ? 1 : 0; - n_outputs += batch.logits[idx] != 0; + n_outputs += batch.output[idx] != 0; } batch.n_tokens += batch_size; @@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< int n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; + n_outputs += batch_view.output[i] != 0; } memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); @@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (size_t i = 0; i < data[i1].common_prefix; ++i) { common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; n_logits += 1; for (int s = 0; s < 2; ++s) { @@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2439022a229b7..2c5b5e4862228 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cf7cbd8159cf8..2e5a2b5181eff 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokens.size(); i++) { common_batch_add(batch, tokens[i], i, {0}, false); } - batch.logits[batch.n_tokens - 1] = true; // generate next token + batch.output[batch.n_tokens - 1] = true; // generate next token // evaluate prompt llama_decode(ctx, batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9cdf2058fd037..f6642e5c820da 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2413,7 +2413,7 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -2451,7 +2451,7 @@ struct server_context { res->n_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -3109,7 +3109,7 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; @@ -3149,7 +3149,7 @@ struct server_context { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index f78f763033a23..f700229853834 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx_ttc, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 61907ed404dbf..516953a729fb2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -252,7 +252,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int8_t * output; } llama_batch; enum llama_model_kv_override_type { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..ba2127be66b6e 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } - } else if (batch->logits) { + } else if (batch->output) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; + int8_t is_output = batch->output[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split - ubatch.output = batch->logits + seq.offset; + ubatch.output = batch->output + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } @@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 } batch.seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); + if (!batch.output) { + outputs.resize(batch.n_tokens); + outputs[outputs.size() - 1] = true; + batch.output = outputs.data(); } } @@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } @@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) { } free(batch.seq_id); } - if (batch.logits) free(batch.logits); + if (batch.output) free(batch.output); } diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b770f..002a8a62f844a 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -81,7 +81,7 @@ struct llama_batch_allocr { std::vector pos; std::vector n_seq_id; std::vector seq_id; - std::vector logits; + std::vector outputs; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); diff --git a/src/llama.cpp b/src/llama.cpp index aae3c69b5a653..e24c39465c41c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8473,9 +8473,9 @@ static int llama_prepare_sbatch( lctx.embd_seq.clear(); // count outputs - if (batch.logits && !embd_pooled) { + if (batch.output && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch.logits[i] != 0; + n_outputs += batch.output[i] != 0; } } else if (lctx.logits_all || embd_pooled) { n_outputs = n_tokens_all; @@ -9972,7 +9972,6 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { return llama_kv_cache_can_shift(ctx->kv_self); } -/// int32_t llama_encode( struct llama_context * ctx,