diff --git a/src/core/schema/chat_engines/condense_question/index.ts b/src/core/schema/chat_engines/condense_question/index.ts index 9dd670cc..26361562 100644 --- a/src/core/schema/chat_engines/condense_question/index.ts +++ b/src/core/schema/chat_engines/condense_question/index.ts @@ -8,10 +8,19 @@ export const RetrieverOptionsSchema = z.object({ top_k: z.coerce.number().int().optional() }); +export const GraphRetrieverSearchOptionsSchema = z.object({ + with_degree: z.coerce.boolean().optional(), + depth: z.coerce.number().int().optional(), + include_meta: z.coerce.boolean().optional(), +}); + +export type GraphRetrieverSearchOptions = z.infer; + export const GraphRetrieverOptionsSchema = z.object({ enable: z.coerce.boolean().optional(), reranker: RerankerConfigSchema.optional(), - top_k: z.coerce.number().int().optional() + top_k: z.coerce.number().int().optional(), + search: GraphRetrieverSearchOptionsSchema.optional(), }); export const BaseChatEngineOptionsSchema = z.object({ diff --git a/src/core/services/llamaindex/chating.ts b/src/core/services/llamaindex/chating.ts index 629a134f..9d4cb6b6 100644 --- a/src/core/services/llamaindex/chating.ts +++ b/src/core/services/llamaindex/chating.ts @@ -2,12 +2,20 @@ import {getDb} from '@/core/db'; import {type Chat, listChatMessages, updateChatMessage} from '@/core/repositories/chat'; import {ChatEngineRequiredOptions} from '@/core/repositories/chat_engine'; import {getDocumentsBySourceUris} from "@/core/repositories/document"; +import {GraphRetrieverSearchOptions} from "@/core/schema/chat_engines/condense_question"; import {AppChatService, type ChatOptions, type ChatStreamEvent} from '@/core/services/chating'; import {LlamaindexRetrieverWrapper, LlamaindexRetrieveService} from '@/core/services/llamaindex/retrieving'; import type {RetrieveOptions} from "@/core/services/retrieving"; import {type AppChatStreamSource, AppChatStreamState} from '@/lib/ai/AppChatStream'; import { deduplicateItems } from '@/lib/array-filters'; -import {DocumentChunk, Entity, KnowledgeGraphClient, Relationship, SearchResult} from "@/lib/knowledge-graph/client"; +import { + DocumentChunk, + Entity, + KnowledgeGraphClient, + Relationship, + SearchOptions, + SearchResult +} from "@/lib/knowledge-graph/client"; import {uuidToBin} from '@/lib/kysely'; import {buildEmbedding} from '@/lib/llamaindex/builders/embedding'; import {buildLLM} from "@/lib/llamaindex/builders/llm"; @@ -59,7 +67,11 @@ const DEFAULT_CHAT_ENGINE_OPTIONS: ChatEngineRequiredOptions = { top_k: 5, } as RetrieveOptions, graph_retriever: { - enable: false + enable: false, + search: { + with_degree: false, + depth: 1, + } }, prompts: {}, reverse_context: true, @@ -160,11 +172,15 @@ export class LlamaindexChatService extends AppChatService { }); // Knowledge graph searching. - const result: KGRetrievalResult = await this.searchKnowledgeGraph(kgClient, options.userInput, kgRetrievalSpan); + const result: KGRetrievalResult = await this.searchKnowledgeGraph( + kgClient, + options.userInput, + graphRetrieverConfig.search, + kgRetrievalSpan + ); // Grouping entities and relationships. result.document_relationships = await this.groupDocumentRelationships(result.relationships, result.entities); - if (graphRetrieverConfig.reranker?.provider) { // Knowledge graph reranking. result.document_relationships = await this.rerankDocumentRelationships( @@ -358,15 +374,25 @@ export class LlamaindexChatService extends AppChatService { await this.langfuse?.flushAsync(); } - async searchKnowledgeGraph (kgClient: KnowledgeGraphClient, query: string, trace?: LangfuseTraceClient): Promise { + async searchKnowledgeGraph ( + kgClient: KnowledgeGraphClient, + query: string, + searchOptions: GraphRetrieverSearchOptions = {}, + trace?: LangfuseTraceClient + ): Promise { console.log(`[KG-Retrieving] Start knowledge graph searching for query "${query}".`); const kgSearchSpan = trace?.span({ name: "knowledge-graph-search", input: query, + metadata: searchOptions }); const start = DateTime.now(); - const searchResult = await kgClient.search(query, [], true); + const searchResult = await kgClient.search({ + query, + embedding: [], + ...searchOptions + }); const duration = DateTime.now().diff(start, 'milliseconds').milliseconds; kgSearchSpan?.end({ diff --git a/src/lib/knowledge-graph/client.ts b/src/lib/knowledge-graph/client.ts index 0de9bae3..78353162 100644 --- a/src/lib/knowledge-graph/client.ts +++ b/src/lib/knowledge-graph/client.ts @@ -32,6 +32,14 @@ export interface DocumentInfo { text: string } +export interface SearchOptions { + query: string, + embedding?: number[] + include_meta?: boolean; + depth?: number; + with_degree?: boolean; +} + export class KnowledgeGraphClient { baseURL: string; @@ -43,18 +51,14 @@ export class KnowledgeGraphClient { this.baseURL = baseURL; } - async search(query: string, embedding?: number[], include_meta: boolean = false): Promise { + async search(options?: SearchOptions): Promise { const url = `${this.baseURL}/api/search`; const res = await fetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ - query, - embedding, - include_meta - }) + body: JSON.stringify(options) }); if (!res.ok) { throw new Error(`Failed to call knowledge graph search API: ${res.statusText}`);