Skip to content

Commit

Permalink
feat: support bitdeer embedding and chat api (#107)
Browse files Browse the repository at this point in the history
* feat: support bitdeer embedding and chat api

* test
  • Loading branch information
Mini256 committed Apr 16, 2024
1 parent b22deab commit 744d90e
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 22 deletions.
32 changes: 32 additions & 0 deletions src/app/api/test/bitdeer-chat/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import type {ChatResponse} from "@/core/services/chating";
import {getLLM} from "@/lib/llamaindex/converters/llm";
import {baseRegistry} from "@/rag-spec/base";
import {getFlow} from "@/rag-spec/createFlow";
import {NextRequest, NextResponse} from 'next/server';

const flow = await getFlow(baseRegistry);

export async function GET (req: NextRequest) {
const url = new URL(req.url);
const query = url.searchParams.get('query') || 'Why TiDB is the advanced MySQL alternative?';
const llm = getLLM(flow, 'bitdeer', {
model: 'mistral',
apiSecretAccessKey: process.env.BITDEER_API_SECRET_ACCESS_KEY!
})
const res = await llm.chat({
messages: [
{
role: 'system',
content: 'Keep your answers brief, less than 100 words, emoji is fine.'
},
{
role: 'user',
content: query
}
],
stream: false
});
return NextResponse.json(res.message);
}

export const dynamic = 'force-dynamic';
22 changes: 22 additions & 0 deletions src/app/api/test/bitdeer-embedding/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import {getEmbedding} from "@/lib/llamaindex/converters/embedding";
import {BitdeerEmbedding, BitdeerEmbeddingModelType} from "@/lib/llamaindex/embeddings/BitdeerEmbedding";
import {baseRegistry} from "@/rag-spec/base";
import {getFlow} from "@/rag-spec/createFlow";
import {NextRequest, NextResponse} from 'next/server';

const flow = await getFlow(baseRegistry);

export async function GET (req: NextRequest) {
const url = new URL(req.url);
const input = url.searchParams.get('input') || 'I want a database to replace MySQL.';
const bitdeerEmbedding = getEmbedding(flow, 'bitdeer', {
model: BitdeerEmbeddingModelType.MISTRAL_EMBED_LARGE,
apiSecretAccessKey: process.env.BITDEER_API_SECRET_ACCESS_KEY!
});

const embedding = await bitdeerEmbedding.getQueryEmbedding(input);

return NextResponse.json(embedding);
}

export const dynamic = 'force-dynamic';
File renamed without changes.
3 changes: 3 additions & 0 deletions src/lib/llamaindex/converters/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { Flow } from '@/core';
import { rag } from '@/core/interface';
import {BitdeerEmbedding} from "@/lib/llamaindex/embeddings/BitdeerEmbedding";
import { BaseEmbedding, OpenAIEmbedding } from 'llamaindex';
import ExtensionType = rag.ExtensionType;

export function getEmbedding (flow: Flow, provider: string, options: any) {
if (provider === 'openai' || (provider === rag.ExtensionType.Embeddings + '.openai')) {
return new OpenAIEmbedding(options);
} else if (provider === 'bitdeer' || (provider === rag.ExtensionType.Embeddings + '.bitdeer')) {
return new BitdeerEmbedding(options);
}
return fromAppEmbedding(flow.getRequired(ExtensionType.Embeddings, provider).withOptions(options));
}
Expand Down
3 changes: 3 additions & 0 deletions src/lib/llamaindex/converters/llm.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { Flow } from '@/core';
import { rag } from '@/core/interface';
import {Bitdeer} from "@/lib/llamaindex/llm/bitdeer";
import { type LLM, OpenAI } from 'llamaindex';
import ExtensionType = rag.ExtensionType;

export function getLLM (flow: Flow, provider: string, config: any) {
if (provider === 'openai' || provider === rag.ExtensionType.ChatModel + '.openai') {
return new OpenAI(config);
} else if (provider === 'bitdeer' || provider === rag.ExtensionType.ChatModel + '.bitdeer') {
return new Bitdeer(config);
}

return fromAppChatModel(flow.getRequired(ExtensionType.ChatModel, provider).withOptions(config));
Expand Down
60 changes: 60 additions & 0 deletions src/lib/llamaindex/embeddings/BitdeerEmbedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import {BaseEmbedding} from "llamaindex";
import util from "node:util";


export enum BitdeerEmbeddingModelType {
MISTRAL_EMBED_LARGE = "mxbai-embed-large",
}

export class BitdeerEmbedding extends BaseEmbedding {
baseURL: string = "https://www.bitdeer.ai/public/v1";
model: BitdeerEmbeddingModelType;

apiSecretAccessKey: string;
requestTimeout: number = 60 * 1000; // Default is 60 seconds

constructor(init?: Partial<BitdeerEmbedding>) {
super();
this.model = BitdeerEmbeddingModelType.MISTRAL_EMBED_LARGE;
if (typeof init?.apiSecretAccessKey !== "string") {
throw new Error("Bitdeer API secret access key is required.");
}
this.apiSecretAccessKey = init?.apiSecretAccessKey;
}

private async getBitdeerEmbedding(input: string) {
const payload = {
model: this.model,
prompt: input
};
const url = `${this.baseURL}/models/${this.model}/generate`;
const response = await fetch(url, {
body: JSON.stringify(payload),
method: "POST",
// signal: AbortSignal.timeout(this.requestTimeout),
headers: {
"Content-Type": "application/json",
"X-Api-Key": this.apiSecretAccessKey,
},
});

if (!response.ok) {
throw new Error(util.format(
'Failed to call Bitdeer embedding API (status: %d, statusText: %s).',
response.status,
response.statusText
));
}

const raw = await response.json();
return raw.data.embedding;
}

async getTextEmbedding(text: string): Promise<number[]> {
return this.getBitdeerEmbedding(text);
}

async getQueryEmbedding(query: string): Promise<number[]> {
return this.getBitdeerEmbedding(query);
}
}
78 changes: 56 additions & 22 deletions src/lib/llamaindex/llm/bitdeer.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import {BaseLLM} from "llamaindex/llm/base";
import {ok} from "node:assert";
import {
BaseEmbedding, ChatResponse,
ChatResponse,
ChatResponseChunk,
CompletionResponse,
LLM, LLMChatParamsNonStreaming,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming, LLMCompletionParamsNonStreaming, LLMCompletionParamsStreaming,
LLMMetadata
} from "llamaindex";
Expand Down Expand Up @@ -45,16 +46,16 @@ export type BitdeerAdditionalChatOptions = BitdeerLlama2Options;
*
* Website: https://www.bitdeer.com/
*/
export class Bitdeer extends BaseEmbedding implements LLM {
readonly hasStreaming = true;
export class Bitdeer implements BaseLLM<BitdeerAdditionalChatOptions> {
readonly hasStreaming = false;

model: string;
baseURL: string = "https://www.bitdeer.ai/public/v1";
temperature: number = 0.7;
topP: number = 0.9;
contextWindow: number = 4096;
requestTimeout: number = 60 * 1000; // Default is 60 seconds
additionalChatOptions?: Record<string, unknown>;
additionalChatOptions?: BitdeerAdditionalChatOptions;

private apiSecretAccessKey: string;

Expand All @@ -68,10 +69,8 @@ export class Bitdeer extends BaseEmbedding implements LLM {
apiSecretAccessKey: string;
},
) {
super();

if (!init.apiSecretAccessKey) {
throw new Error("API secret access key is required.");
throw new Error("Bitdeer API secret access key is required.");
}

this.model = init.model ?? 'llama2';
Expand Down Expand Up @@ -99,8 +98,55 @@ export class Bitdeer extends BaseEmbedding implements LLM {
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
// Notice: Bitdeer does not support chat API for now.
throw new Error("Method not implemented.");
const { messages, stream } = params;
const payload = {
model: this.model,
messages: messages.map((message) => ({
role: message.role,
content: message.content,
})),
stream: !!stream,
options: {
temperature: this.temperature,
num_ctx: this.contextWindow,
top_p: this.topP,
...this.additionalChatOptions,
},
};

const url = `${this.baseURL}/models/${this.model}/generate`;
const response = await fetch(url, {
body: JSON.stringify(payload),
method: "POST",
headers: {
"Content-Type": "application/json",
"X-Api-Key": this.apiSecretAccessKey,
},
});
if (!stream) {
if (!response.ok) {
throw new Error(util.format(
'Failed to call Bitdeer chat completion API (status: %d, statusText: %s).',
response.status,
response.statusText
));
}

const raw = await response.json();
const { message } = raw.data;
return {
message: {
role: "assistant",
content: message.content,
},
raw,
};
} else {
const stream = response.body;
ok(stream, "stream is null");
ok(stream instanceof ReadableStream, "stream is not readable");
return this.streamChat(stream, messageAccessor);
}
}

private async *streamChat<T>(
Expand Down Expand Up @@ -183,16 +229,4 @@ export class Bitdeer extends BaseEmbedding implements LLM {
}
}

private async getEmbedding(prompt: string): Promise<number[]> {
// Notice: Bitdeer does not support embedding API for now.
throw new Error("Method not implemented.");
}

async getTextEmbedding(text: string): Promise<number[]> {
return this.getEmbedding(text);
}

async getQueryEmbedding(query: string): Promise<number[]> {
return this.getEmbedding(query);
}
}

0 comments on commit 744d90e

Please sign in to comment.