diff --git a/include/llama.h b/include/llama.h index 7d5f9d559816d..40bf25dcfd765 100644 --- a/include/llama.h +++ b/include/llama.h @@ -474,6 +474,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params, const char* dump_file_path); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fd64622b8e02d..0a9d9d63fe543 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -406,6 +406,53 @@ ggml_context * llama_context::get_ctx_compute() const { return ctx_compute.get(); } +void dump_state(llama_context *ctx, const char* dump_file_path) { + std::vector state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); + + FILE *fp_write = fopen(dump_file_path, "wb"); + fwrite(state_mem.data(), 1, written, fp_write); + fclose(fp_write); +} + +void load_state(llama_context* ctx, const char* dump_file_path){ + std::vector state_mem; + + FILE * fp_read = fopen(dump_file_path, "rb"); + fseek(fp_read, 0, SEEK_END); + state_mem.resize(ftell(fp_read)); + fseek(fp_read, 0, SEEK_SET); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_state_set_data(ctx, state_mem.data(), state_mem.size())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + + // Free up resources + llama_free(ctx); + llama_free_model(const_cast(&ctx->get_model())); + } +} + +void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path = "dump_state.bin"){ + // Allow only to increase the context size. + if (cparams.n_ctx < new_n_ctx) { + cparams.n_ctx = new_n_ctx; + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + }; + + // Resets the memory and sets it to new memory params with modified cparams + dump_state(this, dump_file_path); // Dump the state here. + memory.reset(model.create_memory(params_mem, cparams)); + load_state(this, dump_file_path); // Load the state. + } + else{ + LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__); + } +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -1929,6 +1976,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) { return ctx->n_ctx(); } +void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path){ + ctx->mod_n_ctx(new_n_ctx, params, dump_file_path); +} + uint32_t llama_n_batch(const llama_context * ctx) { return ctx->n_batch(); } diff --git a/src/llama-context.h b/src/llama-context.h index 5a080e67fcc4b..0b69d8c9d2db8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -34,6 +34,8 @@ struct llama_context { ggml_context * get_ctx_compute() const; + void mod_n_ctx(uint32_t new_ctx, llama_context_params params, const char* dump_file_path); + uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; uint32_t n_batch() const;