diff --git a/src/loz.ts b/src/loz.ts index 6f12bee..68af7d6 100644 --- a/src/loz.ts +++ b/src/loz.ts @@ -40,7 +40,7 @@ export class Loz { config: Config = new Config(); git: Git = new Git(); - constructor(llmAPI?: string) { + constructor() { this.defaultSettings = { model: DEFAULT_OPENAI_MODEL, prompt: "", @@ -283,6 +283,19 @@ export class Loz { } } + public async setAPI(api: string, model?: string): Promise { + if (api === "ollama" || api === "openai") { + this.config.set("api", api); + } + + if (model) { + this.config.set("model", model); + } + + this.config.save(); + await this.initLLMfromConfig(); + } + public async handlePrompt(prompt: string): Promise { const systemPrompt = "Decide if the following prompt can be translated into Linux commands. " + diff --git a/test/a.loz.test.ts b/test/a.loz.test.ts index 5bf7037..7fcb4f6 100644 --- a/test/a.loz.test.ts +++ b/test/a.loz.test.ts @@ -46,11 +46,13 @@ describe("Test OpenAI API", () => { if (GITHUB_ACTIONS === false) { describe("Loz.ollama", () => { it("should return true", async () => { - let loz = new Loz("ollama"); + let loz = new Loz(); await loz.init(); + await loz.setAPI("ollama", "llama2"); + expect((loz as any).checkAPI()).to.equal("ollama"); const completion = await loz.completeUserPrompt("1+1="); - expect(completion.content).to.equal("2"); + expect(completion.content).contains("2"); }); }); }