Skip to content

Added dynamic context size. This is perfect for servers running llama models as a service. #13295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params, const char* dump_file_path);

DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
Expand Down
51 changes: 51 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,53 @@ ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}

void dump_state(llama_context *ctx, const char* dump_file_path) {
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());

FILE *fp_write = fopen(dump_file_path, "wb");
fwrite(state_mem.data(), 1, written, fp_write);
fclose(fp_write);
}

void load_state(llama_context* ctx, const char* dump_file_path){
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen(dump_file_path, "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);

// Free up resources
llama_free(ctx);
llama_free_model(const_cast<llama_model*>(&ctx->get_model()));
}
}

void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path = "dump_state.bin"){
// Allow only to increase the context size.
if (cparams.n_ctx < new_n_ctx) {
cparams.n_ctx = new_n_ctx;
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
};

// Resets the memory and sets it to new memory params with modified cparams
dump_state(this, dump_file_path); // Dump the state here.
memory.reset(model.create_memory(params_mem, cparams));
load_state(this, dump_file_path); // Load the state.
}
else{
LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__);
}
}

uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
Expand Down Expand Up @@ -1929,6 +1976,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}

void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path){
ctx->mod_n_ctx(new_n_ctx, params, dump_file_path);
}

uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}
Expand Down
2 changes: 2 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ struct llama_context {

ggml_context * get_ctx_compute() const;

void mod_n_ctx(uint32_t new_ctx, llama_context_params params, const char* dump_file_path);

uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
Expand Down