From 906e560fe2031015022caf33b10e5e3e4cefbc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maik=20Fr=C3=B6be?= Date: Mon, 21 Aug 2023 15:16:05 +0200 Subject: [PATCH] Make chatnoir-chat more object oriented and suitable to run in TIRA. --- chatnoir_api/chat.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/chatnoir_api/chat.py b/chatnoir_api/chat.py index 8335e94..f3e53c7 100644 --- a/chatnoir_api/chat.py +++ b/chatnoir_api/chat.py @@ -14,11 +14,8 @@ def default_from_tira_environment(key): continue for file_name in ['metadata.json', '.chatnoir-settings.json']: - try: - ret = load(open(Path(os.environ.get(potential_key)) / file_name)) - return ret[key] - except: - pass + ret = load(open(Path(os.environ.get(potential_key)) / file_name)) + return ret[key] def default_from_environment(key, default=None): @@ -35,15 +32,18 @@ def __init__(self, api_key=default_from_environment('chatnoir_chat_api_key'), mo self.api_key = api_key self.model = model self.endpoint = endpoint - + if not self.api_key: - raise ValueError('Please provide a proper api_key, got: ' + str(self.api_key)) + raise ValueError('Please provide an api_key, got: ' + + str(self.api_key)) if not self.model: - raise ValueError('Please provide a proper model, got: ' + str(self.api_key)) + raise ValueError('Please provide an model, got: ' + + str(self.model)) if not self.endpoint: - raise ValueError('Please provide a proper model, got: ' + str(self.api_key)) + raise ValueError('Please provide an endpoint, got: ' + + str(self.endpoint)) def chat(self, input_sentence: str) -> str: url = urljoin(BASE_URL_CHAT, f"generate/{self.model}") @@ -59,6 +59,6 @@ def chat(self, input_sentence: str) -> str: response_json = response.json() if "response" not in response_json: - raise ValueError(f"Invalid ChatNoir Chat response: {response_json}") + raise ValueError(f"Invalid ChatNoir response: {response_json}") return response_json["response"]