Skip to content

Commit

Permalink
rt
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasavila00 committed May 17, 2024
1 parent 57bccd4 commit 57f6b7a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 0 deletions.
41 changes: 41 additions & 0 deletions examples/client/src/rt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { LmScript } from "@lmscript/client";
import { kitchenSink } from "./tasks/kitchen-sink";
import { RtBackend } from "@lmscript/client/backends/rt";

const bench = async () => {
let promptTokens = 0;
let completionTokens = 0;
const backend = new RtBackend({
url: `http://localhost:1234`,
model: "TheBloke/Mistral-7B-Instruct-v0.2-AWQ",
reportUsage: ({ promptTokens: pt, completionTokens: ct }) => {
promptTokens += pt;
completionTokens += ct;
},
template: "mistral",
});
const model = new LmScript(backend, {
temperature: 0.1,
});
// const { captured } = await model
// .user((c) => c.push("What is 1+1"))
// .assistant(
// (c) => c.gen("answer"),
// ).run();
// console.log(captured);
// throw new Error("not implemented");
const batch = Array.from({ length: 1 }, (_, _i) =>
kitchenSink(model).catch((e) => {
console.error(e);
}),
);

const start = Date.now();
await Promise.all(batch);
const duration = Date.now() - start;
console.log(`Duration: ${duration}ms`);
console.log(`Prompt tokens: ${promptTokens}`);
console.log(`Completion tokens: ${completionTokens}`);
};

bench().catch(console.error);
6 changes: 6 additions & 0 deletions packages/client/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
"import": "./dist/esm/backends/vllm.js",
"default": "./dist/esm/backends/vllm.js"
},
"./backends/rt": {
"types": "./dist/types/backends/rt.d.ts",
"require": "./dist/cjs/backends/rt.js",
"import": "./dist/esm/backends/rt.js",
"default": "./dist/esm/backends/rt.js"
},
"./backends/tgi": {
"types": "./dist/types/backends/tgi.d.ts",
"require": "./dist/cjs/backends/tgi.js",
Expand Down
143 changes: 143 additions & 0 deletions packages/client/src/backends/rt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/**
* This module contains the backend for the vLLM using OpenAI compatible API.
* @module
*/

import { ChatTemplate } from "../chat-template";
import { NOOP } from "../utils";
import {
AbstractBackend,
ClientState,
ExecutionCallbacks,
GenerateTask,
GenerationThread,
ReportUsage,
SelectTask,
} from "./abstract";
import { BaseExecutor } from "./executor";

class RtBackendExecutor extends BaseExecutor {
readonly #url: string;
readonly #model: string;
readonly #reportUsage: ReportUsage;
readonly #auth: string | undefined;

constructor(options: {
url: string;
auth?: string;
model: string;
reportUsage?: ReportUsage;
data: GenerationThread;
callbacks: ExecutionCallbacks;
template: ChatTemplate;
}) {
super(options.data, options.callbacks, options.template);
this.#url = options.url;
this.#model = options.model;
this.#reportUsage = options?.reportUsage ?? NOOP;
this.#auth = options.auth;
}
async #fetchJSON<T>(body: object): Promise<T> {
const headers: Record<string, string> = {
"Content-Type": "application/json",
};
if (this.#auth != null) {
headers["Authorization"] = `Bearer ${this.#auth}`;
}
return this.fetchJSONWithTimeout(`${this.#url}/v1/completions`, {
method: "POST",
headers,
body: JSON.stringify(body),
});
}

override async doGeneration(task: GenerateTask): Promise<string> {
const json = await this.#fetchJSON<any>({
model: this.#model,
prompt: this.state.text,
max_tokens: task.max_tokens,
stop: task.stop,
grammar:
task.regex === undefined
? undefined
: {
type: "regex",
value: task.regex,
},
temperature: this.data.sampling_params.temperature,
top_p: this.data.sampling_params.top_p,
top_k: this.data.sampling_params.top_k,
frequency_penalty: this.data.sampling_params.frequency_penalty,
presence_penalty: this.data.sampling_params.presence_penalty,
});
this.#reportUsage({
promptTokens: json.usage.prompt_tokens,
completionTokens: json.usage.completion_tokens,
});
const captured = json.choices[0].text;
this.state.text += captured;
return captured;
}
override async doSelect(task: SelectTask): Promise<string> {
const json = await this.#fetchJSON<any>({
model: this.#model,
prompt: this.state.text,
grammar: {
type: "regex",
value: `(${task.choices
// .map((it) => {
// // escape regex
// return it.replace(/[-\/\\^$*+?.()|[\]{}]/g, "\\$&");
// })
.join("|")})`,
},
temperature: this.data.sampling_params.temperature,
top_p: this.data.sampling_params.top_p,
top_k: this.data.sampling_params.top_k,
frequency_penalty: this.data.sampling_params.frequency_penalty,
presence_penalty: this.data.sampling_params.presence_penalty,
});
this.#reportUsage({
promptTokens: json.usage.prompt_tokens,
completionTokens: json.usage.completion_tokens,
});
const captured = json.choices[0].text;
this.state.text += captured;
return captured;
}
}

/**
* Backend for the VLLM OpenAI API.
*/
export class RtBackend implements AbstractBackend {
readonly #url: string;
readonly #model: string;
readonly #reportUsage: ReportUsage;
readonly #auth: string | undefined;
readonly #template: ChatTemplate;
constructor(options: {
url: string;
template: ChatTemplate;
auth?: string;
model: string;
reportUsage?: ReportUsage;
}) {
this.#url = options.url;
this.#model = options.model;
this.#reportUsage = options?.reportUsage ?? NOOP;
this.#auth = options.auth;
this.#template = options.template;
}
async executeJSON(data: GenerationThread, callbacks: ExecutionCallbacks): Promise<ClientState> {
return new RtBackendExecutor({
url: this.#url,
model: this.#model,
reportUsage: this.#reportUsage,
auth: this.#auth,
data,
callbacks,
template: this.#template,
}).executeJSON();
}
}

0 comments on commit 57f6b7a

Please sign in to comment.