Skip to content

Commit cccede2

Browse files
feat(py): add retriever reference support maching JS SDK (#3936)
Co-authored-by: Mengqin Shen <[email protected]>
1 parent 06a3907 commit cccede2

File tree

5 files changed

+476
-67
lines changed

5 files changed

+476
-67
lines changed

py/packages/genkit/src/genkit/ai/_aio.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class while customizing it with any plugins.
3737
ModelMiddleware,
3838
)
3939
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
40+
from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef
4041
from genkit.core.action import ActionRunContext
4142
from genkit.core.action.types import ActionKind
4243
from genkit.core.typing import EmbedRequest, EmbedResponse
@@ -294,63 +295,75 @@ def generate_stream(
294295

295296
return stream, stream.closed
296297

297-
async def embed(
298+
async def retrieve(
298299
self,
299-
embedder: str | EmbedderRef | None = None,
300-
documents: list[Document] | None = None,
300+
retriever: str | RetrieverRef | None = None,
301+
query: str | DocumentData | None = None,
301302
options: dict[str, Any] | None = None,
302-
) -> EmbedResponse:
303-
embedder_name: str
304-
embedder_config: dict[str, Any] = {}
305-
"""Calculates embeddings for documents.
303+
) -> RetrieverResponse:
304+
"""Retrieves documents based on query.
306305
307306
Args:
308-
embedder: Optional embedder model name to use.
309-
documents: Texts to embed.
310-
options: embedding options
307+
retriever: Optional retriever name or reference to use.
308+
query: Text query or a DocumentData containing query text.
309+
options: retriever options
311310
312311
Returns:
313-
The generated response with embeddings.
312+
The generated response with documents.
314313
"""
315-
if isinstance(embedder, EmbedderRef):
316-
embedder_name = embedder.name
317-
embedder_config = embedder.config or {}
318-
if embedder.version:
319-
embedder_config['version'] = embedder.version # Handle version from ref
320-
elif isinstance(embedder, str):
321-
embedder_name = embedder
314+
retriever_name: str
315+
retriever_config: dict[str, Any] = {}
316+
317+
if isinstance(retriever, RetrieverRef):
318+
retriever_name = retriever.name
319+
retriever_config = retriever.config or {}
320+
if retriever.version:
321+
retriever_config['version'] = retriever.version
322+
elif isinstance(retriever, str):
323+
retriever_name = retriever
322324
else:
323-
# Handle case where embedder is None
324-
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')
325+
raise ValueError('Retriever must be specified as a string name or a RetrieverRef.')
325326

326-
# Merge options passed to embed() with config from EmbedderRef
327-
final_options = {**(embedder_config or {}), **(options or {})}
328-
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)
327+
if isinstance(query, str):
328+
query = Document.from_text(query)
329329

330-
return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
330+
final_options = {**(retriever_config or {}), **(options or {})}
331331

332-
async def retrieve(
332+
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever_name)
333+
334+
return (await retrieve_action.arun(RetrieverRequest(query=query, options=final_options))).response
335+
336+
async def index(
333337
self,
334-
retriever: str | None = None,
335-
query: str | DocumentData | None = None,
338+
indexer: str | IndexerRef | None = None,
339+
documents: list[Document] | None = None,
336340
options: dict[str, Any] | None = None,
337-
) -> RetrieverResponse:
338-
"""Retrieves documents based on query.
341+
) -> None:
342+
"""Indexes documents.
339343
340344
Args:
341-
retriever: Optional retriever name to use.
342-
query: Text query or a DocumentData containing query text.
343-
options: retriever options
344-
345-
Returns:
346-
The generated response with embeddings.
345+
indexer: Optional indexer name or reference to use.
346+
documents: Documents to index.
347+
options: indexer options
347348
"""
348-
if isinstance(query, str):
349-
query = Document.from_text(query)
349+
indexer_name: str
350+
indexer_config: dict[str, Any] = {}
351+
352+
if isinstance(indexer, IndexerRef):
353+
indexer_name = indexer.name
354+
indexer_config = indexer.config or {}
355+
if indexer.version:
356+
indexer_config['version'] = indexer.version
357+
elif isinstance(indexer, str):
358+
indexer_name = indexer
359+
else:
360+
raise ValueError('Indexer must be specified as a string name or an IndexerRef.')
361+
362+
final_options = {**(indexer_config or {}), **(options or {})}
350363

351-
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever)
364+
index_action = self.registry.lookup_action(ActionKind.INDEXER, indexer_name)
352365

353-
return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response
366+
await index_action.arun(IndexerRequest(documents=documents, options=final_options))
354367

355368
async def embed(
356369
self,

py/packages/genkit/src/genkit/ai/_registry.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from genkit.blocks.formats.types import FormatDef
5353
from genkit.blocks.model import ModelFn, ModelMiddleware
5454
from genkit.blocks.prompt import define_prompt
55-
from genkit.blocks.retriever import RetrieverFn
55+
from genkit.blocks.retriever import IndexerFn, RetrieverFn
5656
from genkit.blocks.tools import ToolRunContext
5757
from genkit.codec import dump_dict
5858
from genkit.core.action import Action
@@ -278,6 +278,40 @@ def define_retriever(
278278
description=retriever_description,
279279
)
280280

281+
def define_indexer(
282+
self,
283+
name: str,
284+
fn: IndexerFn,
285+
config_schema: BaseModel | dict[str, Any] | None = None,
286+
metadata: dict[str, Any] | None = None,
287+
description: str | None = None,
288+
) -> Callable[[Callable], Callable]:
289+
"""Define an indexer action.
290+
291+
Args:
292+
name: Name of the indexer.
293+
fn: Function implementing the indexer behavior.
294+
config_schema: Optional schema for indexer configuration.
295+
metadata: Optional metadata for the indexer.
296+
description: Optional description for the indexer.
297+
"""
298+
indexer_meta = metadata if metadata else {}
299+
if 'indexer' not in indexer_meta:
300+
indexer_meta['indexer'] = {}
301+
if 'label' not in indexer_meta['indexer'] or not indexer_meta['indexer']['label']:
302+
indexer_meta['indexer']['label'] = name
303+
if config_schema:
304+
indexer_meta['indexer']['customOptions'] = to_json_schema(config_schema)
305+
306+
indexer_description = get_func_description(fn, description)
307+
return self.registry.register_action(
308+
name=name,
309+
kind=ActionKind.INDEXER,
310+
fn=fn,
311+
metadata=indexer_meta,
312+
description=indexer_description,
313+
)
314+
281315
def define_evaluator(
282316
self,
283317
name: str,

py/packages/genkit/src/genkit/blocks/retriever.py

Lines changed: 205 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@
2323
"""
2424

2525
from collections.abc import Callable
26-
from typing import Generic, TypeVar
26+
from typing import Any, Generic, TypeVar
27+
28+
from pydantic import BaseModel, ConfigDict, Field
2729

2830
from genkit.blocks.document import Document
29-
from genkit.core.typing import RetrieverResponse
31+
from genkit.core.action import ActionMetadata
32+
from genkit.core.action.types import ActionKind
33+
from genkit.core.schema import to_json_schema
34+
from genkit.core.typing import DocumentData, RetrieverResponse
3035

3136
T = TypeVar('T')
3237
# type RetrieverFn[T] = Callable[[Document, T], RetrieverResponse]
@@ -39,3 +44,201 @@ def __init__(
3944
retriever_fn: RetrieverFn[T],
4045
):
4146
self.retriever_fn = retriever_fn
47+
48+
49+
class RetrieverRequest(BaseModel):
50+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
51+
52+
query: DocumentData
53+
options: Any | None = None
54+
55+
56+
class RetrieverSupports(BaseModel):
57+
"""Retriever capability support."""
58+
59+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
60+
61+
media: bool | None = None
62+
63+
64+
class RetrieverInfo(BaseModel):
65+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
66+
67+
label: str | None = None
68+
supports: RetrieverSupports | None = None
69+
70+
71+
class RetrieverOptions(BaseModel):
72+
"""Configuration options for a retriever."""
73+
74+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
75+
76+
config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
77+
label: str | None = None
78+
supports: RetrieverSupports | None = None
79+
80+
81+
class RetrieverRef(BaseModel):
82+
"""Reference to a retriever with configuration."""
83+
84+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
85+
86+
name: str
87+
config: Any | None = None
88+
version: str | None = None
89+
info: RetrieverInfo | None = None
90+
91+
92+
def retriever_action_metadata(
93+
name: str,
94+
options: RetrieverOptions | None = None,
95+
) -> ActionMetadata:
96+
"""Creates action metadata for a retriever."""
97+
options = options if options is not None else RetrieverOptions()
98+
retriever_metadata_dict = {'retriever': {}}
99+
100+
if options.label:
101+
retriever_metadata_dict['retriever']['label'] = options.label
102+
103+
if options.supports:
104+
retriever_metadata_dict['retriever']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
105+
106+
retriever_metadata_dict['retriever']['customOptions'] = options.config_schema if options.config_schema else None
107+
108+
return ActionMetadata(
109+
kind=ActionKind.RETRIEVER,
110+
name=name,
111+
input_json_schema=to_json_schema(RetrieverRequest),
112+
output_json_schema=to_json_schema(RetrieverResponse),
113+
metadata=retriever_metadata_dict,
114+
)
115+
116+
117+
def create_retriever_ref(
118+
name: str,
119+
config: dict[str, Any] | None = None,
120+
version: str | None = None,
121+
info: RetrieverInfo | None = None,
122+
) -> RetrieverRef:
123+
"""Creates a RetrieverRef instance."""
124+
return RetrieverRef(name=name, config=config, version=version, info=info)
125+
126+
127+
class IndexerRequest(BaseModel):
128+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
129+
130+
documents: list[DocumentData]
131+
options: Any | None = None
132+
133+
134+
class IndexerInfo(BaseModel):
135+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
136+
137+
label: str | None = None
138+
supports: RetrieverSupports | None = None
139+
140+
141+
class IndexerOptions(BaseModel):
142+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
143+
144+
config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
145+
label: str | None = None
146+
supports: RetrieverSupports | None = None
147+
148+
149+
class IndexerRef(BaseModel):
150+
"""Reference to an indexer with configuration."""
151+
152+
model_config = ConfigDict(extra='forbid', populate_by_name=True)
153+
154+
name: str
155+
config: Any | None = None
156+
version: str | None = None
157+
info: IndexerInfo | None = None
158+
159+
160+
def indexer_action_metadata(
161+
name: str,
162+
options: IndexerOptions | None = None,
163+
) -> ActionMetadata:
164+
"""Creates action metadata for an indexer."""
165+
options = options if options is not None else IndexerOptions()
166+
indexer_metadata_dict = {'indexer': {}}
167+
168+
if options.label:
169+
indexer_metadata_dict['indexer']['label'] = options.label
170+
171+
if options.supports:
172+
indexer_metadata_dict['indexer']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
173+
174+
indexer_metadata_dict['indexer']['customOptions'] = options.config_schema if options.config_schema else None
175+
176+
return ActionMetadata(
177+
kind=ActionKind.INDEXER,
178+
name=name,
179+
input_json_schema=to_json_schema(IndexerRequest),
180+
output_json_schema=to_json_schema(None),
181+
metadata=indexer_metadata_dict,
182+
)
183+
184+
185+
def create_indexer_ref(
186+
name: str,
187+
config: dict[str, Any] | None = None,
188+
version: str | None = None,
189+
info: IndexerInfo | None = None,
190+
) -> IndexerRef:
191+
"""Creates a IndexerRef instance."""
192+
return IndexerRef(name=name, config=config, version=version, info=info)
193+
194+
195+
def define_retriever(
196+
registry: Any,
197+
name: str,
198+
fn: RetrieverFn,
199+
options: RetrieverOptions | None = None,
200+
) -> None:
201+
"""Defines and registers a retriever action."""
202+
metadata = retriever_action_metadata(name, options)
203+
204+
async def wrapper(
205+
request: RetrieverRequest,
206+
ctx: Any,
207+
) -> RetrieverResponse:
208+
return await fn(request.query, request.options)
209+
210+
registry.register_action(
211+
kind=ActionKind.RETRIEVER,
212+
name=name,
213+
fn=wrapper,
214+
metadata=metadata.metadata,
215+
span_metadata=metadata.metadata,
216+
)
217+
218+
219+
IndexerFn = Callable[[list[Document], T], None]
220+
221+
222+
def define_indexer(
223+
registry: Any,
224+
name: str,
225+
fn: IndexerFn,
226+
options: IndexerOptions | None = None,
227+
) -> None:
228+
"""Defines and registers an indexer action."""
229+
metadata = indexer_action_metadata(name, options)
230+
231+
async def wrapper(
232+
request: IndexerRequest,
233+
ctx: Any,
234+
) -> None:
235+
docs = [Document.from_data(d) for d in request.documents]
236+
await fn(docs, request.options)
237+
238+
registry.register_action(
239+
kind=ActionKind.INDEXER,
240+
name=name,
241+
fn=wrapper,
242+
metadata=metadata.metadata,
243+
span_metadata=metadata.metadata,
244+
)

0 commit comments

Comments
 (0)