Skip to content

Commit

Permalink
feat: metadata-filter support custom llm (#148)
Browse files Browse the repository at this point in the history
* feat: metadata-filter support custom llm

* fix
  • Loading branch information
Mini256 committed Jun 2, 2024
1 parent 446afb1 commit 3fcb6c2
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/app/api/test/reranker/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const POST = defineHandler({
body
}) => {
const { query, config, llmConfig, top_k } = body;
const llm = llmConfig ? await buildLLM(llmConfig) : undefined;
const llm = llmConfig ? buildLLM(llmConfig) : undefined;
const serviceContext = serviceContextFromDefaults({
llm: llm
})
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/v1/indexes/[name]/retrieve/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export const POST = defineHandler({

const flow = await getFlow(baseRegistry);
const serviceContext = serviceContextFromDefaults({
llm: await buildLLM(llmConfig),
llm: buildLLM(llmConfig),
embedModel: await buildEmbedding(index.config.embedding),
});

Expand Down
2 changes: 1 addition & 1 deletion src/core/services/llamaindex/chating.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ export class LlamaindexChatService extends AppChatService {
});

// Service context.
const llm = await buildLLM(llmConfig!, trace);
const llm = buildLLM(llmConfig!, trace);
const promptHelper = new PromptHelper(llm.metadata.contextWindow);
const embedModel = await buildEmbedding(this.index.config.embedding);
const serviceContext = serviceContextFromDefaults({
Expand Down
2 changes: 1 addition & 1 deletion src/core/services/llamaindex/indexing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export class LlamaindexIndexProvider extends DocumentIndexProvider {
});

// Select and config the llm for indexing (metadata extractor).
const llm = await buildLLM(index.config.llm);
const llm = buildLLM(index.config.llm);
llm.metadata.model = index.config.llm.options?.model!;

// Select and config the embedding (important and immutable)
Expand Down
2 changes: 1 addition & 1 deletion src/lib/llamaindex/builders/indices.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export async function createVectorStoreIndex (id: number) {
dimensions: DEFAULT_TIDB_VECTOR_DIMENSIONS,
}),
serviceContext: serviceContextFromDefaults({
llm: await buildLLM(index.config.llm),
llm: buildLLM(index.config.llm),
embedModel: await buildEmbedding(index.config.embedding),
}),
});
Expand Down
2 changes: 1 addition & 1 deletion src/lib/llamaindex/builders/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {LangfuseTraceClient} from "langfuse";
import {OpenAI, Ollama} from "llamaindex";
import {Bitdeer} from "@/lib/llamaindex/llm/bitdeer";

export async function buildLLM ({ provider, options}: LLMConfig, trace?: LangfuseTraceClient) {
export function buildLLM ({ provider, options}: LLMConfig, trace?: LangfuseTraceClient) {
let baseLLM;
switch (provider) {
case LLMProvider.OPENAI:
Expand Down
15 changes: 10 additions & 5 deletions src/lib/llamaindex/builders/metadata-filter.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import {buildLLM} from "@/lib/llamaindex/builders/llm";
import {MetadataFilterConfig} from "@/lib/llamaindex/config/metadata-filter";
import {
MetadataPostFilter
} from "@/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter";
import {ServiceContext} from 'llamaindex';

export function buildMetadataFilter (serviceContext: ServiceContext, { provider, options }: MetadataFilterConfig) {
switch (provider) {
export function buildMetadataFilter (serviceContext: ServiceContext, config: MetadataFilterConfig) {
switch (config.provider) {
case 'default':
let llm = serviceContext.llm;
if (config.options?.llm) {
llm = buildLLM(config.options.llm);
}
return new MetadataPostFilter({
...options,
serviceContext,
llm,

});
default:
throw new Error(`Unknown metadata filter provider: ${provider}`)
throw new Error(`Unknown metadata filter provider: ${config.provider}`)
}
}
2 changes: 2 additions & 0 deletions src/lib/llamaindex/config/metadata-filter.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import {LLMConfigSchema} from "@/lib/llamaindex/config/llm";
import {z} from "zod";

export const metadataFilterSchema = z.object({
Expand Down Expand Up @@ -25,6 +26,7 @@ export enum MetadataFilterProvider {
}

export const DefaultMetadataFilterOptions = z.object({
llm: LLMConfigSchema.optional(),
metadata_fields: z.array(metadataFieldSchema).optional(),
filters: z.array(metadataFilterSchema).optional()
});
Expand Down
40 changes: 24 additions & 16 deletions src/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {MetadataField, MetadataFieldFilter} from "@/lib/llamaindex/config/metadata-filter";
import {BaseNodePostprocessor, NodeWithScore, ServiceContext, serviceContextFromDefaults} from "llamaindex";
import {BaseLLM} from "llamaindex/llm/base";
import {DateTime} from "luxon";

export const defaultMetadataFilterChoicePrompt = ({metadataFields, query}: {
Expand Down Expand Up @@ -48,41 +49,49 @@ export type MetadataFilterChoicePrompt = typeof defaultMetadataFilterChoicePromp
export type MetadataPostFilterOptions = Partial<MetadataPostFilter>;

export class MetadataPostFilter implements BaseNodePostprocessor {
serviceContext: ServiceContext = serviceContextFromDefaults();
metadataFilterChoicePrompt: MetadataFilterChoicePrompt = defaultMetadataFilterChoicePrompt;

serviceContext: ServiceContext;
/**
* The llm model used for filters generating.
*/
llm: BaseLLM;
/**
* The prompt for metadata filter choice.
*/
metadataFilterChoicePrompt: MetadataFilterChoicePrompt;
/**
* The definition of metadata fields.
*/
metadata_fields: MetadataField[] = [];
metadata_fields: MetadataField[];
/**
* Provide the filters to apply to the search.
*/
filters: MetadataFieldFilter[] | null = null;
filters: MetadataFieldFilter[] | null;

constructor(init?: MetadataPostFilterOptions) {
Object.assign(this, init);
this.serviceContext = init?.serviceContext ?? serviceContextFromDefaults();
this.llm = init?.llm ?? this.serviceContext.llm;
this.metadata_fields = init?.metadata_fields ?? [];
this.filters = init?.filters ?? null;
this.metadataFilterChoicePrompt = init?.metadataFilterChoicePrompt || defaultMetadataFilterChoicePrompt;
}

async postprocessNodes(nodes: NodeWithScore[], query: string): Promise<NodeWithScore[]> {
let filters;
if (this.filters) {
filters = this.filters;
console.info('Apply provided filters:', filters);
console.info('[Metadata Filter] Provided filters: ', filters);
} else {
const start = DateTime.now();
filters = await this.generateFilters(query);
const end = DateTime.now();
console.info('Generate filters took:', end.diff(start).as('seconds'), 's');
console.info('Apply generated filters:', filters);
const duration = DateTime.now().diff(start).as('milliseconds')
console.info(`[Metadata Filter] Generate filters (took: ${duration} ms): `, filters);
}

console.log('Nodes before filter:', nodes.length, 'nodes');
let filteredNodes = await this.filterNodes(nodes, filters);
console.log('Nodes after filter:', filteredNodes.length, 'nodes');
const filteredNodes = await this.filterNodes(nodes, filters);
console.log(`[Metadata Filter] Applied provided/generated filter (before: ${nodes.length} nodes, after: ${filteredNodes.length} nodes).`);

if (filteredNodes.length === 0) {
console.warn('No nodes left after filtering, fallback to using all nodes.');
console.warn('[Metadata Filter] No nodes left after filtering, fallback to using all nodes.');
return nodes;
}

Expand All @@ -91,12 +100,11 @@ export class MetadataPostFilter implements BaseNodePostprocessor {

async generateFilters(query: string): Promise<MetadataFieldFilter[]> {
try {
const llm = this.serviceContext.llm;
const prompt = this.metadataFilterChoicePrompt({
metadataFields: this.metadata_fields,
query
});
const raw = await llm.chat({
const raw = await this.llm.chat({
messages: [
{
role: 'system',
Expand Down

0 comments on commit 3fcb6c2

Please sign in to comment.