diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst
index b8b0f9e2..b2d4a730 100644
--- a/documentation/source/configuration.rst
+++ b/documentation/source/configuration.rst
@@ -13,9 +13,11 @@ A typical configuration file looks like this:
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
+ "frontend_port": 8081,
"llm_model": {
"model_host": "localhost",
- "model_port": 8080
+ "model_port": 8080,
+ "temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
@@ -45,7 +47,10 @@ These settings regulate the following:
* "functions" is the file containing the functions that can be used in the rules. The default is "functions.py".
+ * "frontend_port" is the port where the web frontend is running. The default is 8090.
+
* "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080".
+ The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4.
* "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080".
diff --git a/tests/test_rules.py b/tests/test_rules.py
index 496262b3..7be70da6 100644
--- a/tests/test_rules.py
+++ b/tests/test_rules.py
@@ -88,9 +88,8 @@ def test__nested_rules_are_printed_correctly(self):
rule = facts_and_rules["rules"][0]
expected = """
the user wants to know the time
- -output "The time is get_time()"
- -if the time is before 12:00 say "Good morning"
- -if the time is after 12:00 say "Good afternoon"
+ - output "The time is get_time()"
+ - if the time is before 12:00 say "Good morning"
+ - if the time is after 12:00 say "Good afternoon"
""".strip()
self.assertEqual(expected, str(rule).strip())
-
diff --git a/todo.txt b/todo.txt
index 0beacad2..3d726a32 100644
--- a/todo.txt
+++ b/todo.txt
@@ -1,3 +1,17 @@
+* add config file for model names
+ - llm model name
+ - whisper model name
+
+/* add version name
+/* let user decide port for frontend
+/* update docs about port
+/* push new version
+* update pypi with wafl and wafl-llm
+* clean code for llm eval and make it public
+* update huggingface readme
+* read overleaf paper
+
+
* on wafl_llm make it so only some LLMs are supported
* change speaker model with newer one
diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py
index a115d77c..e4b1a160 100644
--- a/wafl/answerer/dialogue_answerer.py
+++ b/wafl/answerer/dialogue_answerer.py
@@ -160,9 +160,19 @@ def _init_python_module(self, module_name):
self._functions = [item[0] for item in getmembers(self._module, isfunction)]
async def _substitute_results_in_answer(self, answer_text):
- matches = re.finditer(r"(.*?)", answer_text, re.DOTALL)
+ matches = re.finditer(r"(.*?)|(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
for match in matches:
to_execute = match.group(1)
+ if not to_execute:
+ continue
+ result = await self._run_code(to_execute)
+ answer_text = answer_text.replace(match.group(0), result)
+
+ matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
+ for match in matches:
+ to_execute = match.group(1)
+ if not to_execute:
+ continue
result = await self._run_code(to_execute)
answer_text = answer_text.replace(match.group(0), result)
@@ -178,6 +188,13 @@ async def _substitute_memory_in_answer_and_get_memories_if_present(
answer_text = answer_text.replace(match.group(0), "")
memories.append(to_execute)
+ matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
+ memories = []
+ for match in matches:
+ to_execute = match.group(1)
+ answer_text = answer_text.replace(match.group(0), "")
+ memories.append(to_execute)
+
return answer_text, memories
async def _run_code(self, to_execute):
@@ -206,7 +223,7 @@ async def _run_code(self, to_execute):
except Exception as e:
result = (
- f'Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}'
+ f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}"
)
traceback.print_exc()
break
diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py
index fb8dd604..f9775d21 100644
--- a/wafl/answerer/rule_maker.py
+++ b/wafl/answerer/rule_maker.py
@@ -1,12 +1,12 @@
class RuleMaker:
def __init__(
self,
- knowledge,
- config,
- interface,
- max_num_rules,
- delete_current_rule,
- max_recursion=3,
+ knowledge: "Knowledge",
+ config: "BaseConfig",
+ interface: "BaseInterface",
+ max_num_rules: int,
+ delete_current_rule: str,
+ max_recursion: int = 3,
):
self._knowledge = knowledge
self._config = config
@@ -26,6 +26,8 @@ async def create_from_query(self, query):
rules_text = rule.get_string_using_template(
"- If {effect} go through the following points:"
)
+ rules_text += f'{rule.indent_str}- After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n'
+
rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")
diff --git a/wafl/connectors/base_llm_connector.py b/wafl/connectors/base_llm_connector.py
index 2c800c20..1bd57a63 100644
--- a/wafl/connectors/base_llm_connector.py
+++ b/wafl/connectors/base_llm_connector.py
@@ -1,7 +1,6 @@
import logging
import re
-
from wafl.connectors.utils import select_best_answer
_system_logger = logging.getLogger(__file__)
@@ -39,11 +38,7 @@ async def generate(self, prompt: str) -> str:
text = prompt
start = len(text)
- while (
- all(item not in text[start:] for item in self._last_strings)
- and len(text) < start + self._max_reply_length
- ):
- text += select_best_answer(await self.predict(text), self._last_strings)
+ text += select_best_answer(await self.predict(text), self._last_strings)
end_set = set()
for item in self._last_strings:
diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py
index 42fa0da7..e9e3c36e 100644
--- a/wafl/connectors/remote/remote_llm_connector.py
+++ b/wafl/connectors/remote/remote_llm_connector.py
@@ -1,7 +1,10 @@
+import json
+
import aiohttp
import asyncio
from wafl.connectors.base_llm_connector import BaseLLMConnector
+from wafl.variables import is_supported
class RemoteLLMConnector(BaseLLMConnector):
@@ -14,6 +17,7 @@ def __init__(self, config, last_strings=None, num_replicas=3):
super().__init__(last_strings)
host = config["model_host"]
port = config["model_port"]
+ self._default_temperature = config["temperature"]
self._server_url = f"https://{host}:{port}/predictions/bot"
self._num_replicas = num_replicas
@@ -28,9 +32,11 @@ def __init__(self, config, last_strings=None, num_replicas=3):
):
raise RuntimeError("Cannot connect a running LLM.")
- async def predict(self, prompt: str, temperature=None, num_tokens=None, num_replicas=None) -> [str]:
+ async def predict(
+ self, prompt: str, temperature=None, num_tokens=None, num_replicas=None
+ ) -> [str]:
if not temperature:
- temperature = 0.5
+ temperature = self._default_temperature
if not num_tokens:
num_tokens = self._num_prediction_tokens
@@ -52,8 +58,11 @@ async def predict(self, prompt: str, temperature=None, num_tokens=None, num_repl
connector=aiohttp.TCPConnector(ssl=False),
) as session:
async with session.post(self._server_url, json=payload) as response:
- answer = await response.text()
- return answer.split("<||>")
+ answer = json.loads(await response.text())
+ status = answer["status"]
+ if status != "success":
+ raise RuntimeError(f"Error in prediction: {answer}")
+ return answer["prediction"].split("<||>")
return [""]
@@ -70,7 +79,14 @@ async def check_connection(self):
conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
) as session:
async with session.post(self._server_url, json=payload) as response:
- await response.text()
+ answer = json.loads(await response.text())
+ wafl_llm_version = answer["version"]
+ print(f"Connected to wafl-llm v{wafl_llm_version}.")
+ if not is_supported(wafl_llm_version):
+ print("This version of wafl-llm is not supported.")
+ print("Please update wafl-llm.")
+ raise aiohttp.client.InvalidURL
+
return True
except aiohttp.client.InvalidURL:
diff --git a/wafl/connectors/remote/remote_whisper_connector.py b/wafl/connectors/remote/remote_whisper_connector.py
index 8a1a8465..d9498a83 100644
--- a/wafl/connectors/remote/remote_whisper_connector.py
+++ b/wafl/connectors/remote/remote_whisper_connector.py
@@ -39,7 +39,9 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]:
data = await response.text()
prediction = json.loads(data)
if "transcription" not in prediction:
- raise RuntimeError("No transcription found in prediction. Is your microphone working?")
+ raise RuntimeError(
+ "No transcription found in prediction. Is your microphone working?"
+ )
transcription = prediction["transcription"]
score = prediction["score"]
diff --git a/wafl/interface/base_interface.py b/wafl/interface/base_interface.py
index 07570ec7..cb54f27c 100644
--- a/wafl/interface/base_interface.py
+++ b/wafl/interface/base_interface.py
@@ -70,5 +70,7 @@ def _decorate_reply(self, text: str) -> str:
return self._decorator.extract(text, self._utterances)
def _insert_utterance(self, speaker, text: str):
- if self._utterances == [] or text != self._utterances[-1][1].replace(f"{speaker}: ", ""):
+ if self._utterances == [] or text != self._utterances[-1][1].replace(
+ f"{speaker}: ", ""
+ ):
self._utterances.append((time.time(), f"{speaker}: {text}"))
diff --git a/wafl/interface/list_interface.py b/wafl/interface/list_interface.py
index 1230da36..c7416d2e 100644
--- a/wafl/interface/list_interface.py
+++ b/wafl/interface/list_interface.py
@@ -13,20 +13,20 @@ def __init__(self, interfaces_list: List[BaseInterface]):
async def output(self, text: str, silent: bool = False):
await asyncio.wait(
[interface.output(text, silent) for interface in self._interfaces_list],
- return_when=asyncio.ALL_COMPLETED
+ return_when=asyncio.ALL_COMPLETED,
)
async def input(self) -> str:
done, pending = await asyncio.wait(
[interface.input() for interface in self._interfaces_list],
- return_when=asyncio.FIRST_COMPLETED
+ return_when=asyncio.FIRST_COMPLETED,
)
return done.pop().result()
async def insert_input(self, text: str):
await asyncio.wait(
[interface.insert_input(text) for interface in self._interfaces_list],
- return_when=asyncio.ALL_COMPLETED
+ return_when=asyncio.ALL_COMPLETED,
)
def bot_has_spoken(self, to_set: bool = None):
@@ -44,11 +44,10 @@ def deactivate(self):
super().deactivate()
self._synchronize_interfaces()
-
def add_hotwords(self, hotwords):
for interface in self._interfaces_list:
interface.add_hotwords(hotwords)
def _synchronize_interfaces(self):
for interface in self._interfaces_list:
- interface._utterances = self._utterances
\ No newline at end of file
+ interface._utterances = self._utterances
diff --git a/wafl/rules.py b/wafl/rules.py
index c54f6efa..fc4837cf 100644
--- a/wafl/rules.py
+++ b/wafl/rules.py
@@ -7,7 +7,7 @@ class Rule:
effect: "Fact"
causes: List["Fact"]
_max_indentation = 3
- _indent_str = " "
+ indent_str = " "
def toJSON(self):
return str(self)
@@ -33,22 +33,21 @@ def _add_clauses(self, rule_str: str) -> str:
return rule_str
-
- def _recursively_add_clauses(self, query: str, depth: int=1) -> str:
- indentation = self._indent_str * depth
+ def _recursively_add_clauses(self, query: str, depth: int = 1) -> str:
+ indentation = self.indent_str * depth
if type(query) == str:
- return f"{indentation}-{query}\n"
+ return f"{indentation}- {query}\n"
if type(query.text) == str:
- return f"{indentation}-{query.text}\n"
+ return f"{indentation}- {query.text}\n"
if depth > self._max_indentation:
return ""
clause = list(query.text.keys())[0]
- rules_text = f"{indentation}-{clause}\n"
+ rules_text = f"{indentation}- {clause}\n"
for clauses in query.text.values():
for cause_index, clause in enumerate(clauses):
rules_text += self._recursively_add_clauses(clause, depth + 1)
- return rules_text
\ No newline at end of file
+ return rules_text
diff --git a/wafl/runners/routes.py b/wafl/runners/routes.py
index 6ea8e691..78b182c0 100644
--- a/wafl/runners/routes.py
+++ b/wafl/runners/routes.py
@@ -77,4 +77,4 @@ def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"):
f"toggle_logs_{conversation_id}",
web_server_loop.toggle_logs,
methods=["POST"],
- )
\ No newline at end of file
+ )
diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py
index 391d151e..7f705db9 100644
--- a/wafl/runners/run_web_and_audio_interface.py
+++ b/wafl/runners/run_web_and_audio_interface.py
@@ -1,4 +1,3 @@
-import asyncio
import random
import sys
import threading
@@ -56,7 +55,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}
- app.run(host="0.0.0.0", port=8889)
+ app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))
if __name__ == "__main__":
diff --git a/wafl/runners/run_web_interface.py b/wafl/runners/run_web_interface.py
index a9087101..ac285396 100644
--- a/wafl/runners/run_web_interface.py
+++ b/wafl/runners/run_web_interface.py
@@ -57,7 +57,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}
- app.run(host="0.0.0.0", port=8889)
+ app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))
if __name__ == "__main__":
diff --git a/wafl/templates/config.json b/wafl/templates/config.json
index 4bb4fa90..0f0eccf3 100644
--- a/wafl/templates/config.json
+++ b/wafl/templates/config.json
@@ -4,9 +4,11 @@
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
+ "frontend_port": 8090,
"llm_model": {
"model_host": "localhost",
- "model_port": 8080
+ "model_port": 8080,
+ "temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
diff --git a/wafl/templates/rules.yaml b/wafl/templates/rules.yaml
index 78eee512..78938a1d 100644
--- a/wafl/templates/rules.yaml
+++ b/wafl/templates/rules.yaml
@@ -45,4 +45,12 @@ rules:
- the user asks something about cities, capitals, countries, buildings, famous people, bars, restaurants, rivers, mountains, lakes, seas, oceans, planets, stars, galaxies:
- say that you are just improvising the answer
- - say what you think answer the question
\ No newline at end of file
+ - say what you think answer the question
+
+ - the user thanks the bot:
+ - The intention of the user is to close the conversation
+ - You must answer the user by writing "close_conversation()"
+
+ - the user says "end this conversation":
+ - The intention of the user is to close the conversation
+ - You must answer the user by writing "close_conversation()"
\ No newline at end of file
diff --git a/wafl/variables.py b/wafl/variables.py
index 84ebf205..1802c460 100644
--- a/wafl/variables.py
+++ b/wafl/variables.py
@@ -1,4 +1,9 @@
def get_variables():
return {
- "version": "0.0.82",
+ "version": "0.0.83",
}
+
+
+def is_supported(wafl_llm_version):
+ supported_versions = ["0.0.82", "0.0.83"]
+ return wafl_llm_version in supported_versions
\ No newline at end of file