Skip to content

Commit

Permalink
feat: adapte to non-streaming mode LLM (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Apr 17, 2024
1 parent e7411cc commit 2e67c91
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 86 deletions.
177 changes: 112 additions & 65 deletions src/core/services/llamaindex/chating.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,43 @@ import { LlamaindexRetrieverWrapper, LlamaindexRetrieveService } from '@/core/se
import { type Chat, listChatMessages } from '@/core/repositories/chat';
import type { ChatEngineOptions } from '@/core/repositories/chat_engine';
import { getDb } from '@/core/db';
import {SortedSet} from "@/lib/collection";
import { uuidToBin } from '@/lib/kysely';
import { getEmbedding } from '@/lib/llamaindex/converters/embedding';
import { getLLM } from '@/lib/llamaindex/converters/llm';
import {defaultRefinePrompt, defaultTextQaPrompt} from "@/lib/llamaindex/prompts/defaultPrompts";
import { Liquid } from 'liquidjs';
import { CompactAndRefine, CondenseQuestionChatEngine, defaultCondenseQuestionPrompt, defaultRefinePrompt, defaultTextQaPrompt, MetadataMode, type Response, ResponseSynthesizer, RetrieverQueryEngine, serviceContextFromDefaults } from 'llamaindex';
import {
CompactAndRefine,
CondenseQuestionChatEngine, defaultCondenseQuestionPrompt,
MetadataMode,
type Response,
ResponseSynthesizer,
RetrieverQueryEngine,
serviceContextFromDefaults,
SimplePrompt,
} from 'llamaindex';
import {DateTime} from "luxon";
import type { UUID } from 'node:crypto';

interface Source {
title: string;
uri: string;
}

interface SourceWithNodeId extends Source {
id: string;
}

export class LlamaindexChatService extends AppChatService {
private liquid = new Liquid();

getPrompt<Tmpl extends (ctx: any) => string> (template: string | undefined, fallback: Tmpl): (ctx: Parameters<Tmpl>[0]) => string {
getPrompt<Tmpl extends SimplePrompt> (template: string | undefined, fallback: Tmpl): (ctx: Parameters<Tmpl>[0]) => string {
if (!template) return fallback;
const tmpl = this.liquid.parse(template);
return context => this.liquid.renderSync(tmpl, context);
}

getTextQAPrompt (): typeof defaultTextQaPrompt {
return ({ query, context }) => `Context information is below.
---------------------
${context}
---------------------
Given the context information and not prior knowledge, answer the query use markdown format. Add links reference to original contexts if necessary.
Query: ${query}
Answer:`;
}

getRefinePrompt (): typeof defaultRefinePrompt {
return ({ context, query, existingAnswer }: any) =>
`The original query is as follows: ${query}
We have provided an existing answer: ${existingAnswer}
We have the opportunity to refine the existing answer (only if needed) with some more context below.
------------
${context}
------------
Given the new context, refine the original answer to better answer the query. If the context isn't useful, return the original answer.
Use markdown format to answer. Add links reference to contexts if necessary.
Refined Answer:`;
}

protected async* run (chat: Chat, options: ChatOptions): AsyncGenerator<ChatResponse> {
const {
llm: {
Expand All @@ -67,7 +65,7 @@ Refined Answer:`;
embedModel: getEmbedding(this.flow, this.index.config.embedding.provider, this.index.config.embedding.config),
});

// build queryEngine
// Build Retriever.
const retrieveService = new LlamaindexRetrieveService({
reranker,
flow: this.flow,
Expand All @@ -85,93 +83,142 @@ Refined Answer:`;
},
});

const responseBuilder = new CompactAndRefine(
serviceContext,
this.getPrompt(textQa, this.getTextQAPrompt()),
this.getPrompt(refine, this.getRefinePrompt()),
);
// Build Query Engine.
const textQaPrompt = this.getPrompt(textQa, defaultTextQaPrompt);
const refinePrompt = this.getPrompt(refine, defaultRefinePrompt);
const responseBuilder = new CompactAndRefine(serviceContext, textQaPrompt, refinePrompt);
const responseSynthesizer = new ResponseSynthesizer({
serviceContext,
responseBuilder,
metadataMode: MetadataMode.LLM,
});
const queryEngine = new RetrieverQueryEngine(retriever, responseSynthesizer);

// build chatHistory
// Build ChatHistory.
const history = await listChatMessages(chat.id);
const chatHistory = history.map(message => ({
role: message.role as any,
content: message.content,
additionalKwargs: {},
}));

// Build ChatEngine.
const stream = llmConfig?.stream ?? true;
const condenseMessagePrompt = this.getPrompt(condenseQuestion, defaultCondenseQuestionPrompt);
const chatEngine = new CondenseQuestionChatEngine({
queryEngine,
chatHistory,
serviceContext,
condenseMessagePrompt: this.getPrompt(condenseQuestion, defaultCondenseQuestionPrompt),
condenseMessagePrompt
});

const stream = chatEngine.chat({
stream: true,
// Chatting with LLM via ChatEngine.
console.log(`Start chatting for chat <${chat.id}>.`, { stream });
const start = DateTime.now();
const responses = await chatEngine.chat({
stream: stream,
chatHistory: options.history.map(message => ({
role: message.role as any,
content: message.content,
additionalKwargs: {},
})),
message: options.userInput,
});
const end = DateTime.now();
const duration = end.diff(start, 'seconds').seconds;
console.log(`Finished chatting for chat <${chat.id}>, take ${duration} seconds.`, { stream });



let last: Response | undefined;
let message = '';
let sources: Map<string, { title: string, uri: string }> = new Map();

const getSources = async (chunkIds?: string[]) => {
if (!chunkIds?.length) {
return [];
}
const idsToFetch = chunkIds.filter(id => !sources.has(id));
if (idsToFetch.length > 0) {
const results = await getDb().selectFrom(`llamaindex_document_chunk_node_${this.index.name}`)
.innerJoin('document', 'document.id', `llamaindex_document_chunk_node_${this.index.name}.document_id`)
.select(eb => eb.fn('bin_to_uuid', [`llamaindex_document_chunk_node_${this.index.name}.id`]).as('id'))
.select('document.name')
.select('document.source_uri')
.where(`llamaindex_document_chunk_node_${this.index.name}.id`, 'in', idsToFetch.map(id => uuidToBin(id as UUID)))
.execute();

results.forEach(result => sources.set(result.id, { title: result.name, uri: result.source_uri }));
}

return chunkIds.map(id => sources.get(id)).filter(Boolean) as { title: string, uri: string }[];
};
// Notice: Some LLM API doesn't support streaming mode.
if (!stream) {
const res = responses as unknown as Response;
const chunkIds = this.getChunkIdsFromResponse(res) ?? [];
const sources = await this.getSourcesByChunkIds(chunkIds);
yield {
content: res.response,
status: 'generating',
sources: sources,
retrieveId: lastRetrieveId,
};
yield {
content: '',
status: 'finished',
sources: sources,
retrieveId: lastRetrieveId,
};
return;
}

// TODO: yield states
// yield {
// sources: [],
// status: 'retrieving',
// content: '',
// }
for await (const response of await stream) {
message += response.response;

let lastResponse: Response | undefined;
const sources = new SortedSet<string, Source>();
for await (const res of responses) {
await this.appendSource(sources, this.getChunkIdsFromResponse(res));
yield {
content: response.response,
content: res.response,
status: 'generating',
sources: await getSources(response.sourceNodes?.map(node => node.id_)),
sources: sources.asList(),
retrieveId: lastRetrieveId,
};
last = response;
lastResponse = res;
}
if (last) {

if (lastResponse) {
await this.appendSource(sources, this.getChunkIdsFromResponse(lastResponse));
yield {
content: '',
status: 'finished',
sources: await getSources(last.sourceNodes?.map(node => node.id_)),
sources: sources.asList(),
retrieveId: lastRetrieveId,
};
} else {
// use `return` instead of `await` to avoid goto catch block.
throw new Error('No response from LLM');
}
}

getChunkIdsFromResponse(res: Response) {
return res.sourceNodes?.map(node => node.id_);
}

async getSourcesByChunkIds(chunkIds: string[]): Promise<SourceWithNodeId[]> {
if (chunkIds.length === 0) {
return [];
}

const results = await getDb().selectFrom(`llamaindex_document_chunk_node_${this.index.name}`)
.innerJoin('document', 'document.id', `llamaindex_document_chunk_node_${this.index.name}.document_id`)
.select(eb => eb.fn('bin_to_uuid', [`llamaindex_document_chunk_node_${this.index.name}.id`]).as('id'))
.select('document.name')
.select('document.source_uri')
.where(`llamaindex_document_chunk_node_${this.index.name}.id`, 'in', chunkIds.map(id => uuidToBin(id as UUID)))
.execute();

return results.map(result => ({
id: result.id,
title: result.name,
uri: result.source_uri,
}));
}

async appendSource(sources: SortedSet<string, Source>, chunkIds?: string[]) {
if (!Array.isArray(chunkIds) || chunkIds.length === 0) {
return;
}

const idsToFetch = chunkIds.filter(id => !sources.has(id));
const sourcesWithId = await this.getSourcesByChunkIds(idsToFetch);

for (let source of sourcesWithId) {
const { id, title, uri } = source;
sources.add(id, { title, uri });
}
}
}
19 changes: 16 additions & 3 deletions src/core/services/llamaindex/retrieving.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { cosineDistance } from '@/lib/kysely';
import {getReranker} from "@/lib/llamaindex/converters/reranker";
import { type BaseRetriever, NodeRelationship, type NodeWithScore, ObjectType, type RetrieveParams, type ServiceContext, TextNode } from 'llamaindex';
import type { RelatedNodeInfo, RelatedNodeType } from 'llamaindex/Node';
import {DateTime} from "luxon";
import type { UUID } from 'node:crypto';

export class LlamaindexRetrieveService extends AppRetrieveService {
Expand All @@ -20,7 +21,12 @@ export class LlamaindexRetrieveService extends AppRetrieveService {

await this.startSearch(retrieve);

console.log(`Start embedding searching for query "${text}".`, { top_k })
const searchStart = DateTime.now();
let chunks = await this.search(queryEmbedding, search_top_k);
const searchEnd = DateTime.now();
const searchDuration = searchEnd.diff(searchStart, 'milliseconds').milliseconds;
console.log(`Finish embedding searching, take ${searchDuration} ms, found ${chunks.length} chunks.`, { top_k });

// Could support more filters here
if (filters.namespaces && filters.namespaces.length > 0) {
Expand All @@ -34,7 +40,14 @@ export class LlamaindexRetrieveService extends AppRetrieveService {

await this.startRerank(retrieve);

return this.rerank(chunks, text, top_k, this.rerankerOptions);
console.log(`Start reranking for query "${text}".`, { top_k });
const rerankStart = DateTime.now();
const rerankedResult = await this.rerank(chunks, text, top_k, this.rerankerOptions);
const rerankEnd = DateTime.now();
const rerankDuration = rerankEnd.diff(rerankStart, 'milliseconds').milliseconds;
console.log(`Finish reranking, take ${rerankDuration} ms.`);

return rerankedResult;
}

asRetriever (options: Omit<RetrieveOptions, 'text'>, serviceContext: ServiceContext, callbacks: RetrieveCallbacks): BaseRetriever {
Expand Down Expand Up @@ -164,9 +177,9 @@ export class LlamaindexRetrieverWrapper implements BaseRetriever {
async retrieve (params: RetrieveParams): Promise<NodeWithScore[]> {
const chunks = await this.retrieveService.retrieve({ ...this.options, text: params.query }, this.callbacks);

const detaildChunks = await this.retrieveService.extendResultDetails(chunks);
const detailedChunks = await this.retrieveService.extendResultDetails(chunks);

return detaildChunks.map(chunk => {
return detailedChunks.map(chunk => {
return {
node: new TextNode({
id_: chunk.document_chunk_node_id,
Expand Down
19 changes: 19 additions & 0 deletions src/lib/collection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export class SortedSet<Key extends string | number | symbol, Element> {
private keySet = new Set<Key>();
private elements: Element[] = [];

has(key: Key) {
return this.keySet.has(key);
}

add(key: Key, element: Element) {
if (!this.keySet.has(key)) {
this.keySet.add(key);
this.elements.push(element);
}
}

asList(): Element[] {
return this.elements;
}
}
Loading

0 comments on commit 2e67c91

Please sign in to comment.