Skip to content

Commit

Permalink
feat(bitdeer): mock stream mode chat API (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 authored Apr 18, 2024
1 parent 41e6333 commit 7151148
Showing 1 changed file with 79 additions and 45 deletions.
124 changes: 79 additions & 45 deletions src/lib/llamaindex/llm/bitdeer.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import {getEnv} from "@/lib/env";
import {BaseLLM} from "llamaindex/llm/base";
import {ok} from "node:assert";
import {
ChatResponse,
ChatResponseChunk,
CompletionResponse,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming, LLMCompletionParamsNonStreaming, LLMCompletionParamsStreaming,
LLMChatParamsStreaming,
LLMCompletionParamsNonStreaming,
LLMCompletionParamsStreaming,
LLMMetadata
} from "llamaindex";
import {BaseLLM} from "llamaindex/llm/base";
import {ok} from "node:assert";
import * as util from "node:util";

const messageAccessor = (data: any): ChatResponseChunk => {
Expand All @@ -27,8 +29,8 @@ const completionAccessor = (data: any): CompletionResponse => {

export type BitdeerModel = "llama2" | "mistral";

export interface BitdeerLlama2Options extends Record<string, unknown> {
microstat: number;
export interface BitdeerAdditionalChatOptions extends Record<string, unknown> {
mirostat: number;
mirostat_eta: number;
mirostat_tau: number;
num_ctx: number;
Expand All @@ -42,7 +44,20 @@ export interface BitdeerLlama2Options extends Record<string, unknown> {
top_p: number;
}

export type BitdeerAdditionalChatOptions = BitdeerLlama2Options;
export const defaultChatOptions = {
mirostat: 0,
mirostat_eta: 0.1,
mirostat_tau: 5,
num_ctx: 2048,
repeat_last_n: 64,
repeat_penalty: 1.1,
temperature: 0,
num_predict: -1,
seed: 42,
tfs_z: 1,
top_k: 40,
top_p: 0.9
};

/**
* Bitdeer is a cloud computing platform that provides computing power for cryptocurrency mining.
Expand All @@ -52,15 +67,13 @@ export type BitdeerAdditionalChatOptions = BitdeerLlama2Options;
* Website: https://www.bitdeer.com/
*/
export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
readonly hasStreaming = false;

model: string = "llama2";
model: string = "mistral";
baseURL: string = "https://www.bitdeer.ai/public/v1";
temperature: number = 0.7;
temperature: number = 0;
topP: number = 0.9;
contextWindow: number = 4096;
contextWindow: number = 2048;
requestTimeout: number = 60 * 1000; // Default is 60 seconds
additionalChatOptions?: BitdeerAdditionalChatOptions;
additionalChatOptions?: BitdeerAdditionalChatOptions = defaultChatOptions;

private apiSecretAccessKey: string = getEnv("BITDEER_API_SECRET_ACCESS_KEY");

Expand Down Expand Up @@ -101,17 +114,14 @@ export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
const { messages, stream } = params;

if (stream) {
throw new Error("Bitdeer chat completion API does not support streaming mode.");
}

const payload = {
let payload = {
model: this.model,
messages: messages.map((message) => ({
role: message.role,
content: message.content,
})),
stream: !!stream,
// Notice: Bitdeer chat completion API does not support streaming mode.
stream: false,
options: {
temperature: this.temperature,
num_ctx: this.contextWindow,
Expand All @@ -121,24 +131,27 @@ export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
};

const url = `${this.baseURL}/models/${this.model}/generate`;
const response = await fetch(url, {
const res = await fetch(url, {
body: JSON.stringify(payload),
method: "POST",
signal: AbortSignal.timeout(this.requestTimeout),
headers: {
"Content-Type": "application/json",
"X-Api-Key": this.apiSecretAccessKey,
},
});
if (!stream) {
if (!response.ok) {
throw new Error(util.format(
'Failed to call Bitdeer chat completion API (status: %d, statusText: %s).',
response.status,
response.statusText
));
}

const raw = await response.json();
if (!res.ok) {
throw new Error(util.format(
'Failed to call Bitdeer chat completion API (status: %d, statusText: %s, stream: %s).',
res.status,
res.statusText,
stream
));
}

if (!stream) {
const raw = await res.json();
const { message } = raw.data;
return {
message: {
Expand All @@ -148,9 +161,16 @@ export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
raw,
};
} else {
const stream = response.body;
const stream = this.mockResponseStream(async () => {
const raw = await res.json();
return {
message: raw.data.message
};
});

ok(stream, "stream is null");
ok(stream instanceof ReadableStream, "stream is not readable");

return this.streamChat(stream, messageAccessor);
}
}
Expand Down Expand Up @@ -182,21 +202,26 @@ export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
}
}

private mockResponseStream(getResponse: () => Promise<any>) {
return new ReadableStream({
async pull(controller) {
controller.enqueue(Buffer.from(JSON.stringify(await getResponse())));
controller.close();
}
});
}

complete(
params: LLMCompletionParamsStreaming,
params: LLMCompletionParamsStreaming
): Promise<AsyncIterable<CompletionResponse>>;
complete(
params: LLMCompletionParamsNonStreaming,
params: LLMCompletionParamsNonStreaming
): Promise<CompletionResponse>;
async complete(
params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
const { prompt, stream } = params;

if (stream) {
throw new Error("Bitdeer completion API does not support streaming mode.");
}

const payload = {
model: this.model,
prompt: prompt,
Expand All @@ -209,32 +234,41 @@ export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
},
};
const url = `${this.baseURL}/models/${this.model}/generate`;
const response = await fetch(url, {
const res = await fetch(url, {
body: JSON.stringify(payload),
method: "POST",
signal: AbortSignal.timeout(this.requestTimeout),
headers: {
"Content-Type": "application/json",
"X-Api-Key": this.apiSecretAccessKey,
},
});

if (!res.ok) {
throw new Error(util.format(
'Failed to call Bitdeer completion API (status: %d, statusText: %s).',
res.status,
res.statusText
));
}

if (!stream) {
if (!response.ok) {
throw new Error(util.format(
'Failed to call Bitdeer completion API (status: %d, statusText: %s).',
response.status,
response.statusText
));
}
const raw = await response.json();
const raw = await res.json();
return {
text: raw.data.response,
raw,
};
} else {
const stream = response.body;
const stream = this.mockResponseStream(async () => {
const raw = await res.json();
return {
response: raw.data.response
};
});

ok(stream, "stream is null");
ok(stream instanceof ReadableStream, "stream is not readable");

return this.streamChat(stream, completionAccessor);
}
}
Expand Down

0 comments on commit 7151148

Please sign in to comment.