Skip to content

Commit

Permalink
feat: support stream and chat_engine options for chats api (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Jun 3, 2024
1 parent 6ea82c0 commit 71a14ce
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/app/api/v1/chat_engines/[id]/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {CreateChatEngineOptionsSchema} from "@/core/schema/chat_engines";
import {
deleteChatEngine,
getChatEngine,
getChatEngineById,
updateChatEngine
} from '@/core/repositories/chat_engine';
import { defineHandler } from '@/lib/next/handler';
Expand All @@ -18,7 +18,7 @@ export const GET = defineHandler({
}, async ({
params,
}) => {
const engine = await getChatEngine(params.id);
const engine = await getChatEngineById(params.id);
if (!engine) {
notFound();
}
Expand Down
29 changes: 18 additions & 11 deletions src/app/api/v1/chats/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {type Chat, createChat, getChatByUrlKey, listChats} from '@/core/repositories/chat';
import {getChatEngineConfig} from '@/core/repositories/chat_engine';
import {getChatEngineByNameOrDefault} from '@/core/repositories/chat_engine';
import {getIndexByNameOrThrow} from '@/core/repositories/index_';
import {LlamaindexChatService} from '@/core/services/llamaindex/chating';
import {toPageRequest} from '@/lib/database';
Expand All @@ -18,12 +18,13 @@ const ChatRequest = z.object({
role: z.string(),
}).array(),
sessionId: z.string().optional(),
// TODO: using `title` instead of.
name: z.string().optional(),
index: z.string().optional(),
// TODO: using engine name instead.
engine: z.number().int().optional(),
chat_engine: z.string().optional(),
regenerate: z.boolean().optional(),
messageId: z.coerce.number().int().optional(),
stream: z.boolean().default(true),
});

const DEFAULT_CHAT_TITLE = 'Untitled';
Expand All @@ -41,7 +42,7 @@ export const POST = defineHandler({
messages,
} = body;

const [engineId, engine, engineOptions] = await getChatEngineConfig(body.engine);
const engine = await getChatEngineByNameOrDefault(body.chat_engine);

// TODO: need refactor, it is too complex now
// For chat page, create a chat and return the session ID (url_key) first.
Expand All @@ -58,9 +59,10 @@ export const POST = defineHandler({
}

return await createChat({
engine,
engine_id: engineId,
engine_options: JSON.stringify(engineOptions),
engine: engine.engine,
engine_id: engine.id,
engine_name: engine.name,
engine_options: JSON.stringify(engine.engine_options),
created_at: new Date(),
created_by: userId,
title: title,
Expand All @@ -72,8 +74,9 @@ export const POST = defineHandler({
let sessionId = body.sessionId;
if (!sessionId) {
chat = await createChat({
engine,
engine_options: JSON.stringify(engineOptions),
engine: engine.engine,
engine_name: engine.name,
engine_options: JSON.stringify(engine.engine_options),
created_at: new Date(),
created_by: userId,
title: body.name ?? body.messages.findLast(message => message.role === 'user')?.content ?? DEFAULT_CHAT_TITLE,
Expand All @@ -100,9 +103,13 @@ export const POST = defineHandler({
}

const lastUserMessage = messages.findLast(m => m.role === 'user')?.content ?? '';
const chatStream = await chatService.chat(sessionId, userId, lastUserMessage, body.regenerate ?? false);
const chatResult = await chatService.chat(sessionId, userId, lastUserMessage, body.regenerate ?? false, body.stream);

return chatStream.toResponse();
if (body.stream) {
return chatResult.toResponse();
} else {
return chatResult;
}
});

export const GET = defineHandler({
Expand Down
27 changes: 20 additions & 7 deletions src/app/api/v1/indexes/[name]/retrieve/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import {getChatEngineConfig} from "@/core/repositories/chat_engine";
import {
ChatEngine,
getChatEngineByIdOrDefault,
getChatEngineByNameOrDefault,
getDefaultChatEngine
} from "@/core/repositories/chat_engine";
import {getIndexByName} from '@/core/repositories/index_';
import {LlamaindexRetrieveService} from '@/core/services/llamaindex/retrieving';
import {retrieveOptionsSchema} from '@/core/services/retrieving';
Expand Down Expand Up @@ -26,28 +31,36 @@ export const POST = defineHandler({
notFound();
}

const [engineId, engine, engineOptions] = await getChatEngineConfig(body.engine);
let engine: ChatEngine;
if (body.chat_engine) {
engine = await getChatEngineByNameOrDefault(body.chat_engine);
} else if (body.engine) {
// TODO: remove it after migration is finished.
engine = await getChatEngineByIdOrDefault(body.engine)
} else {
engine = await getDefaultChatEngine();
}
const { engine_options } = engine;

const {
llm: llmConfig = {
provider: LLMProvider.OPENAI,
config: {}
},
} = engineOptions;
} = engine_options;

const flow = await getFlow(baseRegistry);
const serviceContext = serviceContextFromDefaults({
llm: buildLLM(llmConfig),
embedModel: await buildEmbedding(index.config.embedding),
});

const retrieveService = new LlamaindexRetrieveService({
metadata_filter: engineOptions.metadata_filter,
reranker: engineOptions.reranker,
metadata_filter: engine_options.metadata_filter,
reranker: engine_options.reranker,
flow,
index,
serviceContext
});

const result = await retrieveService.retrieve(body);

return await retrieveService.extendResultDetails(result);
Expand Down
1 change: 1 addition & 0 deletions src/core/db/schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export interface Chat {
deleted_by: string | null;
engine: string;
engine_id: number | null;
engine_name: string | null;
engine_options: Json;
id: Generated<number>;
title: string;
Expand Down
42 changes: 31 additions & 11 deletions src/core/repositories/chat_engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export type ChatEngine = Rewrite<Selectable<DBv1['chat_engine']>, { engine: Chat
export type CreateChatEngine = Rewrite<Insertable<DBv1['chat_engine']>, { engine: ChatEngineProvider, engine_options: ChatEngineOptions }>;
export type UpdateChatEngine = Rewrite<Updateable<DBv1['chat_engine']>, { engine: ChatEngineProvider, engine_options: ChatEngineOptions }>;

export async function getChatEngine (id: number) {
export async function getChatEngineById (id: number) {
return await getDb()
.selectFrom('chat_engine')
.selectAll()
Expand All @@ -24,6 +24,15 @@ export async function getChatEngine (id: number) {
.executeTakeFirst();
}

export async function getChatEngineByName (name: string) {
return await getDb()
.selectFrom('chat_engine')
.selectAll()
.$castTo<ChatEngine>()
.where('name', '=', name)
.executeTakeFirst();
}

export async function getChatEngineNameByID (id?: number | null) {
if (!id) return undefined;
const res = await getDb()
Expand All @@ -43,16 +52,27 @@ export async function getDefaultChatEngine () {
.executeTakeFirstOrThrow();
}

export async function getChatEngineConfig (engineConfigId?: number): Promise<[number, string, ChatEngineOptions]> {
if (engineConfigId) {
const config = await getChatEngine(engineConfigId);
if (!config) {
throw CHAT_ENGINE_NOT_FOUND_ERROR.format(engineConfigId);
export async function getChatEngineByIdOrDefault (engineId?: number): Promise<ChatEngine> {
if (engineId) {
const engine = await getChatEngineById(engineId);
if (!engine) {
throw CHAT_ENGINE_NOT_FOUND_ERROR.format(engineId);
}
return engine;
} else {
return await getDefaultChatEngine();
}
}

export async function getChatEngineByNameOrDefault (engineName?: string): Promise<ChatEngine> {
if (engineName) {
const engine = await getChatEngineByName(engineName);
if (!engine) {
throw CHAT_ENGINE_NOT_FOUND_ERROR.format(engineName);
}
return [config.id, config.engine, config.engine_options];
return engine;
} else {
const config = await getDefaultChatEngine();
return [config.id, config.engine, config.engine_options];
return await getDefaultChatEngine();
}
}

Expand All @@ -70,7 +90,7 @@ export async function createChatEngine (create: CreateChatEngine) {
.values({ ...create, engine_options: JSON.stringify(create.engine_options) })
.executeTakeFirstOrThrow();

return (await getChatEngine(Number(insertId)))!;
return (await getChatEngineById(Number(insertId)))!;
}

export async function updateChatEngine (id: number, update: UpdateChatEngine) {
Expand All @@ -83,7 +103,7 @@ export async function updateChatEngine (id: number, update: UpdateChatEngine) {

export async function deleteChatEngine (id: number) {
await tx(async () => {
const chatEngine = await getChatEngine(id);
const chatEngine = await getChatEngineById(id);
if (!chatEngine) {
notFound();
}
Expand Down
60 changes: 48 additions & 12 deletions src/core/services/chating.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,70 @@ export type ChatStreamEvent = {
error?: unknown;
}

export interface ChatNonStreamingResult {
traceURL: string;
content: string;
sources: AppChatStreamSource[];
state: AppChatStreamState;
}

export abstract class AppChatService extends AppIndexBaseService {

async chat (sessionId: string, userId: string, userInput: string, regenerating: boolean) {
async chat (sessionId: string, userId: string, userInput: string, regenerating: boolean, stream: boolean = true): Promise<any> {
const { chat, history } = await this.getSessionInfo(sessionId, userId);
const respondMessage = await this.startChat(chat, history, userInput, regenerating);

return new AppChatStream(sessionId, respondMessage.id, async controller => {
if (stream) {
return new AppChatStream(sessionId, respondMessage.id, async controller => {
try {
let content = '';
let retrieveIds = new Set<number>();
for await (const chunk of this.run(chat, { userInput, history, userId, respondMessage })) {
controller.appendText(chunk.content, chunk.status === AppChatStreamState.CREATING /* force sends an empty text chunk first, to avoid a dependency BUG */);
controller.setChatState(chunk.status, chunk.statusMessage);
controller.setTraceURL(chunk.traceURL);
controller.setSources(chunk.sources);
content += chunk.content;
if (chunk.retrieveId) {
retrieveIds.add(chunk.retrieveId);
}
}
controller.setChatState(AppChatStreamState.FINISHED);
await this.finishChat(respondMessage, content, retrieveIds);
} catch (error) {
controller.setChatState(AppChatStreamState.ERROR, getErrorMessage(error));
await this.terminateChat(respondMessage, error);
return Promise.reject(error);
}
});
} else {
let chatResult: ChatNonStreamingResult = {
traceURL: '',
content: '',
sources: [],
state: AppChatStreamState.CREATING,
};
try {
let content = '';
let retrieveIds = new Set<number>();
for await (const chunk of this.run(chat, { userInput, history, userId, respondMessage })) {
controller.appendText(chunk.content, chunk.status === AppChatStreamState.CREATING /* force sends an empty text chunk first, to avoid a dependency BUG */);
controller.setChatState(chunk.status, chunk.statusMessage);
controller.setTraceURL(chunk.traceURL);
controller.setSources(chunk.sources);
content += chunk.content;
chatResult.content += chunk.content;
chatResult.sources = chunk.sources;
if (chunk.retrieveId) {
retrieveIds.add(chunk.retrieveId);
}
if (chunk.traceURL && chunk.traceURL.length > 0) {
chatResult.traceURL = chunk.traceURL;
}
}
controller.setChatState(AppChatStreamState.FINISHED);
await this.finishChat(respondMessage, content, retrieveIds);
chatResult.state = AppChatStreamState.FINISHED;
await this.finishChat(respondMessage, chatResult.content, retrieveIds);
return chatResult;
} catch (error) {
controller.setChatState(AppChatStreamState.ERROR, getErrorMessage(error));
chatResult.state = AppChatStreamState.ERROR;
await this.terminateChat(respondMessage, error);
return Promise.reject(error);
}
});
}
}

async deleteHistoryFromMessage (chat: Chat, messageId: number) {
Expand Down
6 changes: 4 additions & 2 deletions src/core/services/retrieving.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import z from "zod";

export const retrieveOptionsSchema = z.object({
query: z.string().min(1),
// TODO: using engine name instead.
// TODO: using chat engine name instead.
// @Deprecated
engine: z.number().optional(),
chat_engine: z.string().optional(),
filters: z.array(metadataFilterSchema).optional(),
search_top_k: z.number().int().optional(),
top_k: z.number().int().optional(),
Expand Down Expand Up @@ -106,7 +108,7 @@ export abstract class AppRetrieveService extends AppIndexBaseService {
callbacks?.onRetrieved(retrieve.id, results);

if (options.reversed) {
// use cloned array to avoid affecting langfuse output tracing.
// use a cloned array to avoid affecting langfuse output tracing.
return Array.from(results).reverse();
} else {
return results;
Expand Down

0 comments on commit 71a14ce

Please sign in to comment.