Skip to content

Commit

Permalink
Make chatnoir-chat more object oriented and suitable to run in TIRA.
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed Aug 21, 2023
1 parent 906e560 commit 190d03f
Showing 1 changed file with 35 additions and 22 deletions.
57 changes: 35 additions & 22 deletions chatnoir_api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,55 @@


def default_from_tira_environment(key):
for potential_key in ['inputDataset', 'TIRA_INPUT_DATASET', 'CHATNOIR_CHAT_CONFIGURATION_DIR']:
ENVIRONMENT_LOOKUP_KEYS = ['inputDataset', 'TIRA_INPUT_DATASET',
'CHATNOIR_CHAT_CONFIGURATION_DIR']

for potential_key in ENVIRONMENT_LOOKUP_KEYS:
if not os.environ.get(potential_key, None):
continue

for file_name in ['metadata.json', '.chatnoir-settings.json']:
ret = load(open(Path(os.environ.get(potential_key)) / file_name))
return ret[key]
return ret[key], f'Configuration file "{file_name}"'


def default_from_environment(key, default=None):
def default_config(key, default=None):
default_from_tira_env = default_from_tira_environment(key)

if default_from_tira_env:
return default_from_tira_env

return os.environ.get(key, default)


class ChatNoirChatClient():
def __init__(self, api_key=default_from_environment('chatnoir_chat_api_key'), model=default_from_environment('chatnoir_chat_model', 'alpaca-en-7b'), endpoint=default_from_environment('chatnoir_chat_endpoint', BASE_URL_CHAT)):
self.api_key = api_key
self.model = model
self.endpoint = endpoint
return os.environ.get(key, default), "Environment variable"

if not self.api_key:
raise ValueError('Please provide an api_key, got: '
+ str(self.api_key))
default_api_key =

if not self.model:
raise ValueError('Please provide an model, got: '
+ str(self.model))

if not self.endpoint:
raise ValueError('Please provide an endpoint, got: '
+ str(self.endpoint))
class ChatNoirChatClient():
def __init__(self,
api_key=default_config('chatnoir_chat_api_key'),
model=default_config('chatnoir_chat_model', 'alpaca-en-7b'),
endpoint=default_config('chatnoir_chat_endpoint', BASE_URL_CHAT)):

if type(api_key) == tuple:
print(f"ChatNoir Chat uses API key from {api_key[1]}")
self.api_key = api_key[0]
else:
print(f"ChatNoir Chat uses API key from from parameters")
self.api_key = api_key

if type(self.model) == tuple:
print(f"ChatNoir Chat uses model '{model[0]}' from {self.model[1]}")
self.model = model[0]
else:
print(f"ChatNoir Chat uses model '{model}' from from parameters")
self.model = model

if type(self.endpoint) == tuple:
print(f"ChatNoir Chat uses endpoint '{endpoint[0]}' from {self.endpoint[1]}")
self.endpoint = endpoint[0]
else:
print(f"ChatNoir Chat uses model '{endpoint}' from from parameters")
self.endpoint = endpoint

def chat(self, input_sentence: str) -> str:
url = urljoin(BASE_URL_CHAT, f"generate/{self.model}")
Expand All @@ -53,12 +67,11 @@ def chat(self, input_sentence: str) -> str:
"Accept": "application/json",
"Api-Key": self.api_key
}

response = post(url, data=data, headers=headers)
response.raise_for_status()
response_json = response.json()

if "response" not in response_json:
raise ValueError(f"Invalid ChatNoir response: {response_json}")
return response_json["response"]

0 comments on commit 190d03f

Please sign in to comment.