Skip to content

Commit

Permalink
Make chatnoir-chat more object oriented and suitable to run in TIRA.
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed Aug 21, 2023
1 parent 93a6812 commit d247fb1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 36 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ To generate text with the ChatNoir Chat API you need to request an API key from
With your API key, you can chat with the cat, like this:

```python
from chatnoir_api.chat import chat
from chatnoir_api.chat import ChatNoirChatClient

api_key: str = "<API_KEY>"
response = chat(api_key, "how are you?")
chat_client = ChatNoirChatClient(api_key="<API_KEY>")
response = chat_client.chat("how are you?")
```

### Retrieve Document Contents
Expand Down
91 changes: 58 additions & 33 deletions chatnoir_api/chat.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,64 @@
from json import dumps
from typing_extensions import Literal
from json import dumps, load
from urllib.parse import urljoin

from requests import post
from pathlib import Path

from chatnoir_api.constants import BASE_URL_CHAT
import os


def default_from_tira_environment(key):
for potential_key in ['inputDataset', 'TIRA_INPUT_DATASET', 'CHATNOIR_CHAT_CONFIGURATION_DIR']:
if not os.environ.get(potential_key, None):
continue

for file_name in ['metadata.json', '.chatnoir-settings.json']:
try:
ret = load(open(Path(os.environ.get(potential_key)) / file_name))
return ret[key]
except:
pass


def default_from_environment(key, default=None):
default_from_tira_env = default_from_tira_environment(key)

if default_from_tira_env:
return default_from_tira_env

return os.environ.get(key, default)


class ChatNoirChatClient():
def __init__(self, api_key=default_from_environment('chatnoir_chat_api_key'), model=default_from_environment('chatnoir_chat_model', 'alpaca-en-7b'), endpoint=default_from_environment('chatnoir_chat_endpoint', BASE_URL_CHAT)):
self.api_key = api_key
self.model = model
self.endpoint = endpoint

if not self.api_key:
raise ValueError('Please provide a proper api_key, got: ' + str(self.api_key))

if not self.model:
raise ValueError('Please provide a proper model, got: ' + str(self.api_key))

if not self.endpoint:
raise ValueError('Please provide a proper model, got: ' + str(self.api_key))

def chat(self, input_sentence: str) -> str:
url = urljoin(BASE_URL_CHAT, f"generate/{self.model}")
data = dumps({"input_sentence": input_sentence})
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Api-Key": self.api_key
}

response = post(url, data=data, headers=headers)
response.raise_for_status()
response_json = response.json()

if "response" not in response_json:
raise ValueError(f"Invalid ChatNoir Chat response: {response_json}")
return response_json["response"]

ModelType = Literal[
"alpaca-en-7b",
"gpt2-base",
"gpt2-large",
"gpt2-xl",
"alpaca-en-7b-retrieve-clueweb22",
"alpaca-en-7b-prompt-retrieve-rewritten-clueweb22",
"gpt2-xl-rewrite-with-clueweb22",
]


def chat(
api_key: str,
input_sentence: str,
model: ModelType = "alpaca-en-7b"
) -> str:
url = urljoin(BASE_URL_CHAT, f"generate/{model}")
data = dumps({
"input_sentence": input_sentence,
})
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Api-Key": api_key
}
response = post(url, data=data, headers=headers)
response.raise_for_status()
response_json = response.json()
if "response" not in response_json:
raise ValueError(f"Invalid ChatNoir Chat response: {response_json}")
return response_json["response"]

0 comments on commit d247fb1

Please sign in to comment.