-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
5_advanced_retrieval.py
60 lines (48 loc) · 2.01 KB
/
5_advanced_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Example from Advanced Usage docs.
This example shows how kani's function calling functionality can be used to retrieve information from an external
database, e.g. Wikipedia.
"""
import json
import os
from typing import Annotated
import httpx
from kani import AIParam, Kani, ai_function, chat_in_terminal
from kani.engines.openai import OpenAIEngine
api_key = os.getenv("OPENAI_API_KEY")
engine = OpenAIEngine(api_key, model="gpt-4o-mini")
class WikipediaRetrievalKani(Kani):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wikipedia_client = httpx.AsyncClient(base_url="https://en.wikipedia.org/w/api.php", follow_redirects=True)
@ai_function()
async def wikipedia(
self,
title: Annotated[str, AIParam(desc='The article title on Wikipedia, e.g. "Train_station".')],
):
"""Get additional information about a topic from Wikipedia."""
# https://en.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&titles=Train&explaintext=1&formatversion=2
resp = await self.wikipedia_client.get(
"/",
params={
"action": "query",
"format": "json",
"prop": "extracts",
"titles": title,
"explaintext": 1,
"formatversion": 2,
},
)
data = resp.json()
page = data["query"]["pages"][0]
if extract := page.get("extract"):
return extract
return f"The page {title!r} does not exist on Wikipedia."
@ai_function()
async def search(self, query: str):
"""Find titles of Wikipedia articles similar to the given query."""
# https://en.wikipedia.org/w/api.php?action=opensearch&format=json&search=Train
resp = await self.wikipedia_client.get("/", params={"action": "opensearch", "format": "json", "search": query})
return json.dumps(resp.json()[1])
ai = WikipediaRetrievalKani(engine)
if __name__ == "__main__":
chat_in_terminal(ai)