-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
57bccd4
commit 57f6b7a
Showing
3 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import { LmScript } from "@lmscript/client"; | ||
import { kitchenSink } from "./tasks/kitchen-sink"; | ||
import { RtBackend } from "@lmscript/client/backends/rt"; | ||
|
||
const bench = async () => { | ||
let promptTokens = 0; | ||
let completionTokens = 0; | ||
const backend = new RtBackend({ | ||
url: `http://localhost:1234`, | ||
model: "TheBloke/Mistral-7B-Instruct-v0.2-AWQ", | ||
reportUsage: ({ promptTokens: pt, completionTokens: ct }) => { | ||
promptTokens += pt; | ||
completionTokens += ct; | ||
}, | ||
template: "mistral", | ||
}); | ||
const model = new LmScript(backend, { | ||
temperature: 0.1, | ||
}); | ||
// const { captured } = await model | ||
// .user((c) => c.push("What is 1+1")) | ||
// .assistant( | ||
// (c) => c.gen("answer"), | ||
// ).run(); | ||
// console.log(captured); | ||
// throw new Error("not implemented"); | ||
const batch = Array.from({ length: 1 }, (_, _i) => | ||
kitchenSink(model).catch((e) => { | ||
console.error(e); | ||
}), | ||
); | ||
|
||
const start = Date.now(); | ||
await Promise.all(batch); | ||
const duration = Date.now() - start; | ||
console.log(`Duration: ${duration}ms`); | ||
console.log(`Prompt tokens: ${promptTokens}`); | ||
console.log(`Completion tokens: ${completionTokens}`); | ||
}; | ||
|
||
bench().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/** | ||
* This module contains the backend for the vLLM using OpenAI compatible API. | ||
* @module | ||
*/ | ||
|
||
import { ChatTemplate } from "../chat-template"; | ||
import { NOOP } from "../utils"; | ||
import { | ||
AbstractBackend, | ||
ClientState, | ||
ExecutionCallbacks, | ||
GenerateTask, | ||
GenerationThread, | ||
ReportUsage, | ||
SelectTask, | ||
} from "./abstract"; | ||
import { BaseExecutor } from "./executor"; | ||
|
||
class RtBackendExecutor extends BaseExecutor { | ||
readonly #url: string; | ||
readonly #model: string; | ||
readonly #reportUsage: ReportUsage; | ||
readonly #auth: string | undefined; | ||
|
||
constructor(options: { | ||
url: string; | ||
auth?: string; | ||
model: string; | ||
reportUsage?: ReportUsage; | ||
data: GenerationThread; | ||
callbacks: ExecutionCallbacks; | ||
template: ChatTemplate; | ||
}) { | ||
super(options.data, options.callbacks, options.template); | ||
this.#url = options.url; | ||
this.#model = options.model; | ||
this.#reportUsage = options?.reportUsage ?? NOOP; | ||
this.#auth = options.auth; | ||
} | ||
async #fetchJSON<T>(body: object): Promise<T> { | ||
const headers: Record<string, string> = { | ||
"Content-Type": "application/json", | ||
}; | ||
if (this.#auth != null) { | ||
headers["Authorization"] = `Bearer ${this.#auth}`; | ||
} | ||
return this.fetchJSONWithTimeout(`${this.#url}/v1/completions`, { | ||
method: "POST", | ||
headers, | ||
body: JSON.stringify(body), | ||
}); | ||
} | ||
|
||
override async doGeneration(task: GenerateTask): Promise<string> { | ||
const json = await this.#fetchJSON<any>({ | ||
model: this.#model, | ||
prompt: this.state.text, | ||
max_tokens: task.max_tokens, | ||
stop: task.stop, | ||
grammar: | ||
task.regex === undefined | ||
? undefined | ||
: { | ||
type: "regex", | ||
value: task.regex, | ||
}, | ||
temperature: this.data.sampling_params.temperature, | ||
top_p: this.data.sampling_params.top_p, | ||
top_k: this.data.sampling_params.top_k, | ||
frequency_penalty: this.data.sampling_params.frequency_penalty, | ||
presence_penalty: this.data.sampling_params.presence_penalty, | ||
}); | ||
this.#reportUsage({ | ||
promptTokens: json.usage.prompt_tokens, | ||
completionTokens: json.usage.completion_tokens, | ||
}); | ||
const captured = json.choices[0].text; | ||
this.state.text += captured; | ||
return captured; | ||
} | ||
override async doSelect(task: SelectTask): Promise<string> { | ||
const json = await this.#fetchJSON<any>({ | ||
model: this.#model, | ||
prompt: this.state.text, | ||
grammar: { | ||
type: "regex", | ||
value: `(${task.choices | ||
// .map((it) => { | ||
// // escape regex | ||
// return it.replace(/[-\/\\^$*+?.()|[\]{}]/g, "\\$&"); | ||
// }) | ||
.join("|")})`, | ||
}, | ||
temperature: this.data.sampling_params.temperature, | ||
top_p: this.data.sampling_params.top_p, | ||
top_k: this.data.sampling_params.top_k, | ||
frequency_penalty: this.data.sampling_params.frequency_penalty, | ||
presence_penalty: this.data.sampling_params.presence_penalty, | ||
}); | ||
this.#reportUsage({ | ||
promptTokens: json.usage.prompt_tokens, | ||
completionTokens: json.usage.completion_tokens, | ||
}); | ||
const captured = json.choices[0].text; | ||
this.state.text += captured; | ||
return captured; | ||
} | ||
} | ||
|
||
/** | ||
* Backend for the VLLM OpenAI API. | ||
*/ | ||
export class RtBackend implements AbstractBackend { | ||
readonly #url: string; | ||
readonly #model: string; | ||
readonly #reportUsage: ReportUsage; | ||
readonly #auth: string | undefined; | ||
readonly #template: ChatTemplate; | ||
constructor(options: { | ||
url: string; | ||
template: ChatTemplate; | ||
auth?: string; | ||
model: string; | ||
reportUsage?: ReportUsage; | ||
}) { | ||
this.#url = options.url; | ||
this.#model = options.model; | ||
this.#reportUsage = options?.reportUsage ?? NOOP; | ||
this.#auth = options.auth; | ||
this.#template = options.template; | ||
} | ||
async executeJSON(data: GenerationThread, callbacks: ExecutionCallbacks): Promise<ClientState> { | ||
return new RtBackendExecutor({ | ||
url: this.#url, | ||
model: this.#model, | ||
reportUsage: this.#reportUsage, | ||
auth: this.#auth, | ||
data, | ||
callbacks, | ||
template: this.#template, | ||
}).executeJSON(); | ||
} | ||
} |