From 9886bab813e19938970a76c33ab7b4a1fa4b7925 Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Fri, 10 May 2024 16:38:09 +0100 Subject: [PATCH 1/7] adding at the end of the steps --- tests/test_rules.py | 7 +++---- wafl/answerer/dialogue_answerer.py | 2 +- wafl/answerer/rule_maker.py | 14 ++++++++------ wafl/connectors/remote/remote_llm_connector.py | 4 +++- .../connectors/remote/remote_whisper_connector.py | 4 +++- wafl/interface/base_interface.py | 4 +++- wafl/interface/list_interface.py | 9 ++++----- wafl/rules.py | 15 +++++++-------- wafl/runners/routes.py | 2 +- 9 files changed, 33 insertions(+), 28 deletions(-) diff --git a/tests/test_rules.py b/tests/test_rules.py index 496262b3..7be70da6 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -88,9 +88,8 @@ def test__nested_rules_are_printed_correctly(self): rule = facts_and_rules["rules"][0] expected = """ the user wants to know the time - -output "The time is get_time()" - -if the time is before 12:00 say "Good morning" - -if the time is after 12:00 say "Good afternoon" + - output "The time is get_time()" + - if the time is before 12:00 say "Good morning" + - if the time is after 12:00 say "Good afternoon" """.strip() self.assertEqual(expected, str(rule).strip()) - diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index a115d77c..8ca45467 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -206,7 +206,7 @@ async def _run_code(self, to_execute): except Exception as e: result = ( - f'Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}' + f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}" ) traceback.print_exc() break diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py index fb8dd604..f9775d21 100644 --- a/wafl/answerer/rule_maker.py +++ b/wafl/answerer/rule_maker.py @@ -1,12 +1,12 @@ class RuleMaker: def __init__( self, - knowledge, - config, - interface, - max_num_rules, - delete_current_rule, - max_recursion=3, + knowledge: "Knowledge", + config: "BaseConfig", + interface: "BaseInterface", + max_num_rules: int, + delete_current_rule: str, + max_recursion: int = 3, ): self._knowledge = knowledge self._config = config @@ -26,6 +26,8 @@ async def create_from_query(self, query): rules_text = rule.get_string_using_template( "- If {effect} go through the following points:" ) + rules_text += f'{rule.indent_str}- After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n' + rules_texts.append(rules_text) await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}") diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index 42fa0da7..7edbd1f7 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -28,7 +28,9 @@ def __init__(self, config, last_strings=None, num_replicas=3): ): raise RuntimeError("Cannot connect a running LLM.") - async def predict(self, prompt: str, temperature=None, num_tokens=None, num_replicas=None) -> [str]: + async def predict( + self, prompt: str, temperature=None, num_tokens=None, num_replicas=None + ) -> [str]: if not temperature: temperature = 0.5 diff --git a/wafl/connectors/remote/remote_whisper_connector.py b/wafl/connectors/remote/remote_whisper_connector.py index 8a1a8465..d9498a83 100644 --- a/wafl/connectors/remote/remote_whisper_connector.py +++ b/wafl/connectors/remote/remote_whisper_connector.py @@ -39,7 +39,9 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]: data = await response.text() prediction = json.loads(data) if "transcription" not in prediction: - raise RuntimeError("No transcription found in prediction. Is your microphone working?") + raise RuntimeError( + "No transcription found in prediction. Is your microphone working?" + ) transcription = prediction["transcription"] score = prediction["score"] diff --git a/wafl/interface/base_interface.py b/wafl/interface/base_interface.py index 07570ec7..cb54f27c 100644 --- a/wafl/interface/base_interface.py +++ b/wafl/interface/base_interface.py @@ -70,5 +70,7 @@ def _decorate_reply(self, text: str) -> str: return self._decorator.extract(text, self._utterances) def _insert_utterance(self, speaker, text: str): - if self._utterances == [] or text != self._utterances[-1][1].replace(f"{speaker}: ", ""): + if self._utterances == [] or text != self._utterances[-1][1].replace( + f"{speaker}: ", "" + ): self._utterances.append((time.time(), f"{speaker}: {text}")) diff --git a/wafl/interface/list_interface.py b/wafl/interface/list_interface.py index 1230da36..c7416d2e 100644 --- a/wafl/interface/list_interface.py +++ b/wafl/interface/list_interface.py @@ -13,20 +13,20 @@ def __init__(self, interfaces_list: List[BaseInterface]): async def output(self, text: str, silent: bool = False): await asyncio.wait( [interface.output(text, silent) for interface in self._interfaces_list], - return_when=asyncio.ALL_COMPLETED + return_when=asyncio.ALL_COMPLETED, ) async def input(self) -> str: done, pending = await asyncio.wait( [interface.input() for interface in self._interfaces_list], - return_when=asyncio.FIRST_COMPLETED + return_when=asyncio.FIRST_COMPLETED, ) return done.pop().result() async def insert_input(self, text: str): await asyncio.wait( [interface.insert_input(text) for interface in self._interfaces_list], - return_when=asyncio.ALL_COMPLETED + return_when=asyncio.ALL_COMPLETED, ) def bot_has_spoken(self, to_set: bool = None): @@ -44,11 +44,10 @@ def deactivate(self): super().deactivate() self._synchronize_interfaces() - def add_hotwords(self, hotwords): for interface in self._interfaces_list: interface.add_hotwords(hotwords) def _synchronize_interfaces(self): for interface in self._interfaces_list: - interface._utterances = self._utterances \ No newline at end of file + interface._utterances = self._utterances diff --git a/wafl/rules.py b/wafl/rules.py index c54f6efa..fc4837cf 100644 --- a/wafl/rules.py +++ b/wafl/rules.py @@ -7,7 +7,7 @@ class Rule: effect: "Fact" causes: List["Fact"] _max_indentation = 3 - _indent_str = " " + indent_str = " " def toJSON(self): return str(self) @@ -33,22 +33,21 @@ def _add_clauses(self, rule_str: str) -> str: return rule_str - - def _recursively_add_clauses(self, query: str, depth: int=1) -> str: - indentation = self._indent_str * depth + def _recursively_add_clauses(self, query: str, depth: int = 1) -> str: + indentation = self.indent_str * depth if type(query) == str: - return f"{indentation}-{query}\n" + return f"{indentation}- {query}\n" if type(query.text) == str: - return f"{indentation}-{query.text}\n" + return f"{indentation}- {query.text}\n" if depth > self._max_indentation: return "" clause = list(query.text.keys())[0] - rules_text = f"{indentation}-{clause}\n" + rules_text = f"{indentation}- {clause}\n" for clauses in query.text.values(): for cause_index, clause in enumerate(clauses): rules_text += self._recursively_add_clauses(clause, depth + 1) - return rules_text \ No newline at end of file + return rules_text diff --git a/wafl/runners/routes.py b/wafl/runners/routes.py index 6ea8e691..78b182c0 100644 --- a/wafl/runners/routes.py +++ b/wafl/runners/routes.py @@ -77,4 +77,4 @@ def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"): f"toggle_logs_{conversation_id}", web_server_loop.toggle_logs, methods=["POST"], - ) \ No newline at end of file + ) From def04738af8f77b20e28e4bdefa19fd1fad27eef Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 11:26:14 +0100 Subject: [PATCH 2/7] integrating with vllm --- wafl/answerer/dialogue_answerer.py | 12 +++++++++++- wafl/connectors/base_llm_connector.py | 6 +----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index 8ca45467..e253807f 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -160,9 +160,19 @@ def _init_python_module(self, module_name): self._functions = [item[0] for item in getmembers(self._module, isfunction)] async def _substitute_results_in_answer(self, answer_text): - matches = re.finditer(r"(.*?)", answer_text, re.DOTALL) + matches = re.finditer(r"(.*?)|(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) for match in matches: to_execute = match.group(1) + if not to_execute: + continue + result = await self._run_code(to_execute) + answer_text = answer_text.replace(match.group(0), result) + + matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) + for match in matches: + to_execute = match.group(1) + if not to_execute: + continue result = await self._run_code(to_execute) answer_text = answer_text.replace(match.group(0), result) diff --git a/wafl/connectors/base_llm_connector.py b/wafl/connectors/base_llm_connector.py index 2c800c20..61a699a5 100644 --- a/wafl/connectors/base_llm_connector.py +++ b/wafl/connectors/base_llm_connector.py @@ -39,11 +39,7 @@ async def generate(self, prompt: str) -> str: text = prompt start = len(text) - while ( - all(item not in text[start:] for item in self._last_strings) - and len(text) < start + self._max_reply_length - ): - text += select_best_answer(await self.predict(text), self._last_strings) + text += select_best_answer(await self.predict(text), self._last_strings) end_set = set() for item in self._last_strings: From c594775c809beb769d3dd42d1785bd0e46a7ccf3 Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 11:31:26 +0100 Subject: [PATCH 3/7] added new pattern for tags --- wafl/answerer/dialogue_answerer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index e253807f..e4b1a160 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -188,6 +188,13 @@ async def _substitute_memory_in_answer_and_get_memories_if_present( answer_text = answer_text.replace(match.group(0), "") memories.append(to_execute) + matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) + memories = [] + for match in matches: + to_execute = match.group(1) + answer_text = answer_text.replace(match.group(0), "") + memories.append(to_execute) + return answer_text, memories async def _run_code(self, to_execute): From 4c4b099ac1e9128bff967dd732154d90e9d6d630 Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 11:57:30 +0100 Subject: [PATCH 4/7] added version check on wafl-llm from the client side --- wafl/connectors/base_llm_connector.py | 1 - .../connectors/remote/remote_llm_connector.py | 21 +++++++++++++++---- wafl/variables.py | 5 +++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/wafl/connectors/base_llm_connector.py b/wafl/connectors/base_llm_connector.py index 61a699a5..1bd57a63 100644 --- a/wafl/connectors/base_llm_connector.py +++ b/wafl/connectors/base_llm_connector.py @@ -1,7 +1,6 @@ import logging import re - from wafl.connectors.utils import select_best_answer _system_logger = logging.getLogger(__file__) diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index 7edbd1f7..3df0fe8b 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -1,7 +1,10 @@ +import json + import aiohttp import asyncio from wafl.connectors.base_llm_connector import BaseLLMConnector +from wafl.variables import is_supported class RemoteLLMConnector(BaseLLMConnector): @@ -32,7 +35,7 @@ async def predict( self, prompt: str, temperature=None, num_tokens=None, num_replicas=None ) -> [str]: if not temperature: - temperature = 0.5 + temperature = 0.1 if not num_tokens: num_tokens = self._num_prediction_tokens @@ -54,8 +57,11 @@ async def predict( connector=aiohttp.TCPConnector(ssl=False), ) as session: async with session.post(self._server_url, json=payload) as response: - answer = await response.text() - return answer.split("<||>") + answer = json.loads(await response.text()) + status = answer["status"] + if status != "success": + raise RuntimeError(f"Error in prediction: {answer}") + return answer["prediction"].split("<||>") return [""] @@ -72,7 +78,14 @@ async def check_connection(self): conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False) ) as session: async with session.post(self._server_url, json=payload) as response: - await response.text() + answer = json.loads(await response.text()) + wafl_llm_version = answer["version"] + print(f"Connected to wafl-llm v{wafl_llm_version}.") + if not is_supported(wafl_llm_version): + print("This version of wafl-llm is not supported.") + print("Please update wafl-llm.") + raise aiohttp.client.InvalidURL + return True except aiohttp.client.InvalidURL: diff --git a/wafl/variables.py b/wafl/variables.py index 84ebf205..f38a524f 100644 --- a/wafl/variables.py +++ b/wafl/variables.py @@ -2,3 +2,8 @@ def get_variables(): return { "version": "0.0.82", } + + +def is_supported(wafl_llm_version): + supported_versions = ["0.0.82"] + return wafl_llm_version in supported_versions \ No newline at end of file From 119c99bee71be3f420eb07c37f6b410d6d448bab Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 12:06:21 +0100 Subject: [PATCH 5/7] added port config for frontent --- documentation/source/configuration.rst | 3 +++ wafl/runners/run_web_and_audio_interface.py | 3 +-- wafl/runners/run_web_interface.py | 2 +- wafl/templates/config.json | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst index b8b0f9e2..32b93f96 100644 --- a/documentation/source/configuration.rst +++ b/documentation/source/configuration.rst @@ -13,6 +13,7 @@ A typical configuration file looks like this: "deactivate_sound": true, "rules": "rules.yaml", "functions": "functions.py", + "frontend_port": 8081, "llm_model": { "model_host": "localhost", "model_port": 8080 @@ -45,6 +46,8 @@ These settings regulate the following: * "functions" is the file containing the functions that can be used in the rules. The default is "functions.py". + * "frontend_port" is the port where the web frontend is running. The default is 8090. + * "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080". * "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080". diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py index 391d151e..7f705db9 100644 --- a/wafl/runners/run_web_and_audio_interface.py +++ b/wafl/runners/run_web_and_audio_interface.py @@ -1,4 +1,3 @@ -import asyncio import random import sys import threading @@ -56,7 +55,7 @@ def create_scheduler_and_webserver_loop(conversation_id): "web_server_loop": web_loop, } - app.run(host="0.0.0.0", port=8889) + app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port")) if __name__ == "__main__": diff --git a/wafl/runners/run_web_interface.py b/wafl/runners/run_web_interface.py index a9087101..ac285396 100644 --- a/wafl/runners/run_web_interface.py +++ b/wafl/runners/run_web_interface.py @@ -57,7 +57,7 @@ def create_scheduler_and_webserver_loop(conversation_id): "web_server_loop": web_loop, } - app.run(host="0.0.0.0", port=8889) + app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port")) if __name__ == "__main__": diff --git a/wafl/templates/config.json b/wafl/templates/config.json index 4bb4fa90..67d9dcfc 100644 --- a/wafl/templates/config.json +++ b/wafl/templates/config.json @@ -4,6 +4,7 @@ "deactivate_sound": true, "rules": "rules.yaml", "functions": "functions.py", + "frontend_port": 8090, "llm_model": { "model_host": "localhost", "model_port": 8080 From 0bafd5e4310693450edc7ffa8320c39011c0d68c Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 12:15:02 +0100 Subject: [PATCH 6/7] added temperature config --- documentation/source/configuration.rst | 4 +++- wafl/connectors/remote/remote_llm_connector.py | 3 ++- wafl/templates/config.json | 3 ++- wafl/templates/rules.yaml | 10 +++++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst index 32b93f96..b2d4a730 100644 --- a/documentation/source/configuration.rst +++ b/documentation/source/configuration.rst @@ -16,7 +16,8 @@ A typical configuration file looks like this: "frontend_port": 8081, "llm_model": { "model_host": "localhost", - "model_port": 8080 + "model_port": 8080, + "temperature": 0.4 }, "listener_model": { "model_host": "localhost", @@ -49,6 +50,7 @@ These settings regulate the following: * "frontend_port" is the port where the web frontend is running. The default is 8090. * "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080". + The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4. * "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080". diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index 3df0fe8b..e9e3c36e 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -17,6 +17,7 @@ def __init__(self, config, last_strings=None, num_replicas=3): super().__init__(last_strings) host = config["model_host"] port = config["model_port"] + self._default_temperature = config["temperature"] self._server_url = f"https://{host}:{port}/predictions/bot" self._num_replicas = num_replicas @@ -35,7 +36,7 @@ async def predict( self, prompt: str, temperature=None, num_tokens=None, num_replicas=None ) -> [str]: if not temperature: - temperature = 0.1 + temperature = self._default_temperature if not num_tokens: num_tokens = self._num_prediction_tokens diff --git a/wafl/templates/config.json b/wafl/templates/config.json index 67d9dcfc..0f0eccf3 100644 --- a/wafl/templates/config.json +++ b/wafl/templates/config.json @@ -7,7 +7,8 @@ "frontend_port": 8090, "llm_model": { "model_host": "localhost", - "model_port": 8080 + "model_port": 8080, + "temperature": 0.4 }, "listener_model": { "model_host": "localhost", diff --git a/wafl/templates/rules.yaml b/wafl/templates/rules.yaml index 78eee512..78938a1d 100644 --- a/wafl/templates/rules.yaml +++ b/wafl/templates/rules.yaml @@ -45,4 +45,12 @@ rules: - the user asks something about cities, capitals, countries, buildings, famous people, bars, restaurants, rivers, mountains, lakes, seas, oceans, planets, stars, galaxies: - say that you are just improvising the answer - - say what you think answer the question \ No newline at end of file + - say what you think answer the question + + - the user thanks the bot: + - The intention of the user is to close the conversation + - You must answer the user by writing "close_conversation()" + + - the user says "end this conversation": + - The intention of the user is to close the conversation + - You must answer the user by writing "close_conversation()" \ No newline at end of file From 556ac1455cbe940a0a5e215faddfe0f21e936fb9 Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sat, 11 May 2024 12:16:40 +0100 Subject: [PATCH 7/7] changed version --- todo.txt | 14 ++++++++++++++ wafl/variables.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/todo.txt b/todo.txt index 0beacad2..3d726a32 100644 --- a/todo.txt +++ b/todo.txt @@ -1,3 +1,17 @@ +* add config file for model names + - llm model name + - whisper model name + +/* add version name +/* let user decide port for frontend +/* update docs about port +/* push new version +* update pypi with wafl and wafl-llm +* clean code for llm eval and make it public +* update huggingface readme +* read overleaf paper + + * on wafl_llm make it so only some LLMs are supported * change speaker model with newer one diff --git a/wafl/variables.py b/wafl/variables.py index f38a524f..1802c460 100644 --- a/wafl/variables.py +++ b/wafl/variables.py @@ -1,9 +1,9 @@ def get_variables(): return { - "version": "0.0.82", + "version": "0.0.83", } def is_supported(wafl_llm_version): - supported_versions = ["0.0.82"] + supported_versions = ["0.0.82", "0.0.83"] return wafl_llm_version in supported_versions \ No newline at end of file