Skip to content

Commit

Permalink
feat: rerank entity and relationship (#137)
Browse files Browse the repository at this point in the history
* feat: rerank entity and relationship

* tracing more detail

* rename

* make kg reranker configurable
  • Loading branch information
Mini256 authored May 17, 2024
1 parent 9a0d506 commit e680846
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 72 deletions.
3 changes: 3 additions & 0 deletions src/components/chat/message-annotation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ export function MessageAnnotation ({ group }: { group: ConversationMessageGroupP
case 'CREATING':
text = assistantAnnotation.stateMessage || 'Preparing chat...';
break;
case 'KG_RETRIEVING':
text = assistantAnnotation.stateMessage || 'Retrieving knowledge...';
break;
case 'SEARCHING':
text = assistantAnnotation.stateMessage || 'Searching...';
break;
Expand Down
2 changes: 2 additions & 0 deletions src/core/schema/chat_engines/condense_question/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export const RetrieverOptionsSchema = z.object({

export const GraphRetrieverOptionsSchema = z.object({
enable: z.coerce.boolean().optional(),
reranker: RerankerConfigSchema.optional(),
top_k: z.coerce.number().int().optional()
});

export const BaseChatEngineOptionsSchema = z.object({
Expand Down
2 changes: 0 additions & 2 deletions src/core/services/knowledge-graph/indexing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import {KnowledgeGraphClient} from "@/lib/knowledge-graph/client";
import {fromFlowReaders} from "@/lib/llamaindex/builders/reader";
import {baseRegistry} from "@/rag-spec/base";
import {getFlow} from "@/rag-spec/createFlow";
import {SentenceSplitter, SimpleNodeParser} from "llamaindex";


export class KnowledgeGraphIndexProvider extends DocumentIndexProvider {
async process (task: DocumentIndexTask, document: Document, index: Index, mutableInfo: DocumentIndexTaskInfo): Promise<DocumentIndexTaskResult> {
Expand Down
250 changes: 215 additions & 35 deletions src/core/services/llamaindex/chating.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ import {AppChatService, type ChatOptions, type ChatStreamEvent} from '@/core/ser
import {LlamaindexRetrieverWrapper, LlamaindexRetrieveService} from '@/core/services/llamaindex/retrieving';
import type {RetrieveOptions} from "@/core/services/retrieving";
import {type AppChatStreamSource, AppChatStreamState} from '@/lib/ai/AppChatStream';
import {KnowledgeGraphClient} from "@/lib/knowledge-graph/client";
import {DocumentChunk, Entity, KnowledgeGraphClient, Relationship, SearchResult} from "@/lib/knowledge-graph/client";
import {uuidToBin} from '@/lib/kysely';
import {buildEmbedding} from '@/lib/llamaindex/builders/embedding';
import {buildLLM} from "@/lib/llamaindex/builders/llm";
import {buildReranker} from "@/lib/llamaindex/builders/reranker";
import {LLMConfig, LLMProvider} from "@/lib/llamaindex/config/llm";
import {RerankerProvider} from "@/lib/llamaindex/config/reranker";
import {ManagedAsyncIterable} from '@/lib/ManagedAsyncIterable';
import {LangfuseTraceClient} from "langfuse";
import {Liquid} from 'liquidjs';
import {
CompactAndRefine,
Expand All @@ -24,15 +27,26 @@ import {
RetrieverQueryEngine,
serviceContextFromDefaults,
SimplePrompt,
PromptHelper
PromptHelper, ServiceContext, TextNode
} from 'llamaindex';
import {DateTime} from 'luxon';
import {randomUUID, UUID} from 'node:crypto';
import {UUID} from 'node:crypto';

interface SourceWithNodeId extends AppChatStreamSource {
id: string;
}

interface DocumentRelationship {
docId: string;
relationships: Relationship[];
entities: Entity[];
relevance_score?: number;
}

interface KGRetrievalResult extends SearchResult {
document_relationships?: DocumentRelationship[];
}

const DEFAULT_CHAT_ENGINE_OPTIONS = {
index_id: 0,
llm: {
Expand All @@ -52,11 +66,11 @@ const DEFAULT_CHAT_ENGINE_OPTIONS = {
export class LlamaindexChatService extends AppChatService {
private liquid = new Liquid();

getPrompt<Tmpl extends SimplePrompt> (template: string | undefined, fallback: Tmpl, partialContext: Record<string, any> = {}): (ctx: Parameters<Tmpl>[0]) => string {
getPrompt<Tmpl extends SimplePrompt> (template: string | undefined, fallback: Tmpl, partialContext?: Record<string, any>): (ctx: Parameters<Tmpl>[0]) => string {
if (!template) return fallback;
const tmpl = this.liquid.parse(template);
return context => this.liquid.renderSync(tmpl, {
...partialContext,
...partialContext ?? {},
...context
});
}
Expand All @@ -79,8 +93,16 @@ export class LlamaindexChatService extends AppChatService {
// Init tracing.
const trace = this.langfuse?.trace({
name: 'chatting',
input: options,
metadata: engineOptions,
input: {
history: options.history,
userInput: options.userInput
},
metadata: {
chat_id: chat.id,
chat_slug: chat.url_key,
chat_engine_type: chat.engine,
chat_engine_options: engineOptions,
},
});

yield {
Expand Down Expand Up @@ -113,48 +135,74 @@ export class LlamaindexChatService extends AppChatService {
// Document sources.
const allSources = new Map<string, AppChatStreamSource>();

// Knowledge graph retriever.
let additionalContext: Record<string, any> = {};
// Build knowledge graph based retriever.
// TODO: refactor this part to KnowledgeGraphRetrieveService in the services/knowledge-graph/retrieving.ts
let kgContext: Record<string, any> | undefined;
if (graphRetrieverConfig?.enable) {
console.log('[KG-Retrieving] Start knowledge graph retrieving ...');

yield {
status: AppChatStreamState.SEARCHING,
status: AppChatStreamState.KG_RETRIEVING,
traceURL: trace?.getTraceUrl(),
sources: Array.from(allSources.values()),
statusMessage: 'Start graph RAG searching ...',
statusMessage: 'Start knowledge graph searching ...',
content: '',
};

const kgClient = new KnowledgeGraphClient();
const kgSearchSpan = trace?.span({
name: "knowledge-graph-retrieval",
input: {
userInput: options.userInput,
},
});
additionalContext = await kgClient.search(options.userInput);
kgSearchSpan?.end({
output: additionalContext,
const kgRetrievalSpan = trace?.span({
name: 'knowledge-graph-retrieval',
input: options.userInput,
metadata: graphRetrieverConfig
});

// Found the document name by link.
const sourceLinks = (additionalContext['chunks'] ?? []).map((chunk: any) => chunk.link);
const documents = await getDocumentsBySourceUris(sourceLinks);
const documentMap = new Map(documents.map(document => [document.source_uri, document]));
for (let sourceLink of sourceLinks) {
const document = documentMap.get(sourceLink);
allSources.set(randomUUID(), { title: document?.name || 'Document from Graph RAG', uri: sourceLink });
// Knowledge graph searching.
const result: KGRetrievalResult = await this.searchKnowledgeGraph(kgClient, options.userInput, kgRetrievalSpan);

// Grouping entities and relationships.
result.document_relationships = await this.groupDocumentRelationships(result.relationships, result.entities);

if (graphRetrieverConfig.reranker?.provider) {
// Knowledge graph reranking.
result.document_relationships = await this.rerankDocumentRelationships(
result.document_relationships,
options.userInput,
retrieverConfig.top_k,
serviceContext,
kgRetrievalSpan
);

// Flatten relationships and entities.
result.relationships = result.document_relationships.map(dr => dr.relationships).flat();
result.entities = result.document_relationships.map(dr => dr.entities).flat();
}

kgContext = result;
// Notice: Limit to avoid exceeding the maximum size of the trace.
kgRetrievalSpan?.end({
output: {
entities: result.entities,
relationships: result.relationships,
chunks: result.chunks
}
});

yield {
status: AppChatStreamState.SEARCHING,
status: AppChatStreamState.KG_RETRIEVING,
traceURL: trace?.getTraceUrl(),
sources: Array.from(allSources.values()),
statusMessage: 'Graph RAG searching completed.',
statusMessage: 'Knowledge graph retrieving completed.',
content: '',
};

console.log('[KG-Retrieving] Finish knowledge graph retrieving.');

// Append sources from knowledge graph retrieving.
const links = result.chunks.map((chunk: DocumentChunk) => chunk.link);
await this.appendSourceByLinks(allSources, links);
}

// Build Retriever.
// Build vector search based retriever.
// FIXME: This method only support a single retrieve call currently.
const retrieveService = new LlamaindexRetrieveService({
flow: this.flow,
Expand Down Expand Up @@ -199,7 +247,7 @@ export class LlamaindexChatService extends AppChatService {
onRetrieved: async (id, chunks) => {
try {
const chunkIds = chunks.map(chunk => chunk.document_chunk_node_id);
await this.appendSource(allSources, chunkIds);
await this.appendSourceByChunkIds(allSources, chunkIds);
next({
done: false,
value: {
Expand Down Expand Up @@ -227,8 +275,8 @@ export class LlamaindexChatService extends AppChatService {

// Build Query Engine.
const { textQa, refine } = prompts;
const textQaPrompt = this.getPrompt(textQa, defaultTextQaPrompt, additionalContext);
const refinePrompt = this.getPrompt(refine, defaultRefinePrompt, additionalContext);
const textQaPrompt = this.getPrompt(textQa, defaultTextQaPrompt, kgContext);
const refinePrompt = this.getPrompt(refine, defaultRefinePrompt, kgContext);
const responseBuilder = new CompactAndRefine(serviceContext, textQaPrompt, refinePrompt);
const responseSynthesizer = new ResponseSynthesizer({
serviceContext,
Expand All @@ -243,7 +291,7 @@ export class LlamaindexChatService extends AppChatService {
content: message.content,
additionalKwargs: {},
}));
const condenseMessagePrompt = this.getPrompt(prompts?.condenseQuestion, defaultCondenseQuestionPrompt, additionalContext);
const condenseMessagePrompt = this.getPrompt(prompts?.condenseQuestion, defaultCondenseQuestionPrompt, kgContext);
const chatEngine = new CondenseQuestionChatEngine({
serviceContext,
queryEngine,
Expand Down Expand Up @@ -297,6 +345,123 @@ export class LlamaindexChatService extends AppChatService {
await this.langfuse?.flushAsync();
}

async searchKnowledgeGraph (kgClient: KnowledgeGraphClient, query: string, trace?: LangfuseTraceClient): Promise<SearchResult> {
console.log(`[KG-Retrieving] Start knowledge graph searching for query "${query}".`);
const kgSearchSpan = trace?.span({
name: "knowledge-graph-search",
input: query,
});

const start = DateTime.now();
const searchResult = await kgClient.search(query, [], true);
const duration = DateTime.now().diff(start, 'milliseconds').milliseconds;

kgSearchSpan?.end({
output: searchResult,
});
console.log(`[KG-Retrieving] Finish knowledge graph searching, take ${duration} ms.`);
return searchResult;
}

async groupDocumentRelationships (
relationships: Relationship[] = [],
entities: Entity[] = []
) {
const entityMap = new Map(entities.map(entity => [entity.id, entity]));
const documentRelationshipsMap: Map<string, {
docId: string,
relationships: Map<number, Relationship>
entities: Map<number, Entity>
}> = new Map();

for (let relationship of relationships) {
const docId = relationship?.meta?.doc_id || 'default';
if (!documentRelationshipsMap.has(docId)) {
documentRelationshipsMap.set(docId, {
docId,
relationships: new Map<number, Relationship>(),
entities: new Map<number, Entity>()
});
}

const relGroup = documentRelationshipsMap.get(docId)!;
relGroup?.relationships.set(relationship.id, relationship);

const sourceEntity = entityMap.get(relationship.source_entity_id);
if (sourceEntity) {
relGroup?.entities.set(relationship.source_entity_id, sourceEntity);
}

const targetEntity = entityMap.get(relationship.target_entity_id);
if (targetEntity) {
relGroup?.entities.set(relationship.target_entity_id, targetEntity);
}

documentRelationshipsMap.set(docId, relGroup);
}

const documentRelationships = Array.from(documentRelationshipsMap.values());
return documentRelationships.map(relGroup => ({
docId: relGroup.docId,
relationships: Array.from(relGroup.relationships.values()),
entities: Array.from(relGroup.entities.values())
}));
}

async rerankDocumentRelationships (
documentRelationships: DocumentRelationship[],
query: string,
topK: number = 10,
serviceContext: ServiceContext,
trace?: LangfuseTraceClient
): Promise<DocumentRelationship[]> {
console.log(`[KG-Retrieving] Start knowledge graph reranking for query "${query}".`, { documentRelationship: documentRelationships.length, topK: topK });

// Build reranker.
const reranker = await buildReranker(serviceContext, { provider: RerankerProvider.JINAAI }, topK);

// Transform document relationships to TextNode.
const docIdRelationshipsMap = new Map(documentRelationships.map(dr => [dr.docId, dr]));
const nodes = documentRelationships.map(dr => ({
node: new TextNode({
id_: dr.docId,
text: `
Document URL: ${dr.docId}
Relationships extract from the document:
${dr.relationships.map(rel => rel.description).join('\n')}
Entities extract from the document:
${dr.entities.map(ent => `- ${ent.name}: ${ent.description}`).join('\n')}
`,
})
}));

const rerankSpan = trace?.span({
name: 'knowledge-graph-rerank',
input: {
query,
document_relationships_size: documentRelationships.length
}
});

// Reranking.
const start = DateTime.now();
const nodesWithScore = await reranker.postprocessNodes(nodes, query);
const result = nodesWithScore.map((nodeWithScore, index, total) => ({
...docIdRelationshipsMap.get(nodeWithScore.node.id_)!,
relevance_score: nodeWithScore.score ?? (total.length - index + topK * 10),
}));
const duration = DateTime.now().diff(start, 'milliseconds').milliseconds;

rerankSpan?.end({
output: result
});

console.log(`[KG-Retrieving] Finish knowledge graph reranking, take ${duration} ms.`);
return result;
}

async getSourcesByChunkIds (chunkIds: string[]): Promise<SourceWithNodeId[]> {
if (chunkIds.length === 0) {
return [];
Expand All @@ -317,7 +482,22 @@ export class LlamaindexChatService extends AppChatService {
}));
}

async appendSource (sources: Map<string, AppChatStreamSource>, chunkIds?: string[]) {
async appendSourceByLinks(sources: Map<string, AppChatStreamSource>, links: string[]) {
const documents = await getDocumentsBySourceUris(links);
const linkDocumentMap = new Map(documents.map(document => [document.source_uri, document]));

for (let link of links) {
const document = linkDocumentMap.get(link);
sources.set(link, {
title: document?.name || 'Document from Graph RAG',
uri: link
});
}

return sources;
}

async appendSourceByChunkIds (sources: Map<string, AppChatStreamSource>, chunkIds?: string[]) {
if (!Array.isArray(chunkIds) || chunkIds.length === 0) {
return;
}
Expand Down
Loading

0 comments on commit e680846

Please sign in to comment.