Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt as list #92

Merged
merged 7 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion documentation/source/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ A typical configuration file looks like this:
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
"frontend_port": 8081,
"llm_model": {
"model_host": "localhost",
"model_port": 8080
"model_port": 8080,
"temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
Expand Down Expand Up @@ -45,7 +47,10 @@ These settings regulate the following:

* "functions" is the file containing the functions that can be used in the rules. The default is "functions.py".

* "frontend_port" is the port where the web frontend is running. The default is 8090.

* "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080".
The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4.

* "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080".

Expand Down
7 changes: 3 additions & 4 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def test__nested_rules_are_printed_correctly(self):
rule = facts_and_rules["rules"][0]
expected = """
the user wants to know the time
-output "The time is <execute>get_time()</execute>"
-if the time is before 12:00 say "Good morning"
-if the time is after 12:00 say "Good afternoon"
- output "The time is <execute>get_time()</execute>"
- if the time is before 12:00 say "Good morning"
- if the time is after 12:00 say "Good afternoon"
""".strip()
self.assertEqual(expected, str(rule).strip())

14 changes: 14 additions & 0 deletions todo.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
* add config file for model names
- llm model name
- whisper model name

/* add version name
/* let user decide port for frontend
/* update docs about port
/* push new version
* update pypi with wafl and wafl-llm
* clean code for llm eval and make it public
* update huggingface readme
* read overleaf paper


* on wafl_llm make it so only some LLMs are supported
* change speaker model with newer one

Expand Down
21 changes: 19 additions & 2 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,19 @@ def _init_python_module(self, module_name):
self._functions = [item[0] for item in getmembers(self._module, isfunction)]

async def _substitute_results_in_answer(self, answer_text):
matches = re.finditer(r"<execute>(.*?)</execute>", answer_text, re.DOTALL)
matches = re.finditer(r"<execute>(.*?)</execute>|<execute>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
for match in matches:
to_execute = match.group(1)
if not to_execute:
continue
result = await self._run_code(to_execute)
answer_text = answer_text.replace(match.group(0), result)

matches = re.finditer(r"<execute>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
for match in matches:
to_execute = match.group(1)
if not to_execute:
continue
result = await self._run_code(to_execute)
answer_text = answer_text.replace(match.group(0), result)

Expand All @@ -178,6 +188,13 @@ async def _substitute_memory_in_answer_and_get_memories_if_present(
answer_text = answer_text.replace(match.group(0), "")
memories.append(to_execute)

matches = re.finditer(r"<remember>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
memories = []
for match in matches:
to_execute = match.group(1)
answer_text = answer_text.replace(match.group(0), "")
memories.append(to_execute)

return answer_text, memories

async def _run_code(self, to_execute):
Expand Down Expand Up @@ -206,7 +223,7 @@ async def _run_code(self, to_execute):

except Exception as e:
result = (
f'Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}'
f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}"
)
traceback.print_exc()
break
Expand Down
14 changes: 8 additions & 6 deletions wafl/answerer/rule_maker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
class RuleMaker:
def __init__(
self,
knowledge,
config,
interface,
max_num_rules,
delete_current_rule,
max_recursion=3,
knowledge: "Knowledge",
config: "BaseConfig",
interface: "BaseInterface",
max_num_rules: int,
delete_current_rule: str,
max_recursion: int = 3,
):
self._knowledge = knowledge
self._config = config
Expand All @@ -26,6 +26,8 @@ async def create_from_query(self, query):
rules_text = rule.get_string_using_template(
"- If {effect} go through the following points:"
)
rules_text += f'{rule.indent_str}- After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n'

rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")

Expand Down
7 changes: 1 addition & 6 deletions wafl/connectors/base_llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import re


from wafl.connectors.utils import select_best_answer

_system_logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -39,11 +38,7 @@ async def generate(self, prompt: str) -> str:

text = prompt
start = len(text)
while (
all(item not in text[start:] for item in self._last_strings)
and len(text) < start + self._max_reply_length
):
text += select_best_answer(await self.predict(text), self._last_strings)
text += select_best_answer(await self.predict(text), self._last_strings)

end_set = set()
for item in self._last_strings:
Expand Down
26 changes: 21 additions & 5 deletions wafl/connectors/remote/remote_llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json

import aiohttp
import asyncio

from wafl.connectors.base_llm_connector import BaseLLMConnector
from wafl.variables import is_supported


class RemoteLLMConnector(BaseLLMConnector):
Expand All @@ -14,6 +17,7 @@ def __init__(self, config, last_strings=None, num_replicas=3):
super().__init__(last_strings)
host = config["model_host"]
port = config["model_port"]
self._default_temperature = config["temperature"]
self._server_url = f"https://{host}:{port}/predictions/bot"
self._num_replicas = num_replicas

Expand All @@ -28,9 +32,11 @@ def __init__(self, config, last_strings=None, num_replicas=3):
):
raise RuntimeError("Cannot connect a running LLM.")

async def predict(self, prompt: str, temperature=None, num_tokens=None, num_replicas=None) -> [str]:
async def predict(
self, prompt: str, temperature=None, num_tokens=None, num_replicas=None
) -> [str]:
if not temperature:
temperature = 0.5
temperature = self._default_temperature

if not num_tokens:
num_tokens = self._num_prediction_tokens
Expand All @@ -52,8 +58,11 @@ async def predict(self, prompt: str, temperature=None, num_tokens=None, num_repl
connector=aiohttp.TCPConnector(ssl=False),
) as session:
async with session.post(self._server_url, json=payload) as response:
answer = await response.text()
return answer.split("<||>")
answer = json.loads(await response.text())
status = answer["status"]
if status != "success":
raise RuntimeError(f"Error in prediction: {answer}")
return answer["prediction"].split("<||>")

return [""]

Expand All @@ -70,7 +79,14 @@ async def check_connection(self):
conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
) as session:
async with session.post(self._server_url, json=payload) as response:
await response.text()
answer = json.loads(await response.text())
wafl_llm_version = answer["version"]
print(f"Connected to wafl-llm v{wafl_llm_version}.")
if not is_supported(wafl_llm_version):
print("This version of wafl-llm is not supported.")
print("Please update wafl-llm.")
raise aiohttp.client.InvalidURL

return True

except aiohttp.client.InvalidURL:
Expand Down
4 changes: 3 additions & 1 deletion wafl/connectors/remote/remote_whisper_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]:
data = await response.text()
prediction = json.loads(data)
if "transcription" not in prediction:
raise RuntimeError("No transcription found in prediction. Is your microphone working?")
raise RuntimeError(
"No transcription found in prediction. Is your microphone working?"
)

transcription = prediction["transcription"]
score = prediction["score"]
Expand Down
4 changes: 3 additions & 1 deletion wafl/interface/base_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ def _decorate_reply(self, text: str) -> str:
return self._decorator.extract(text, self._utterances)

def _insert_utterance(self, speaker, text: str):
if self._utterances == [] or text != self._utterances[-1][1].replace(f"{speaker}: ", ""):
if self._utterances == [] or text != self._utterances[-1][1].replace(
f"{speaker}: ", ""
):
self._utterances.append((time.time(), f"{speaker}: {text}"))
9 changes: 4 additions & 5 deletions wafl/interface/list_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def __init__(self, interfaces_list: List[BaseInterface]):
async def output(self, text: str, silent: bool = False):
await asyncio.wait(
[interface.output(text, silent) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
return_when=asyncio.ALL_COMPLETED,
)

async def input(self) -> str:
done, pending = await asyncio.wait(
[interface.input() for interface in self._interfaces_list],
return_when=asyncio.FIRST_COMPLETED
return_when=asyncio.FIRST_COMPLETED,
)
return done.pop().result()

async def insert_input(self, text: str):
await asyncio.wait(
[interface.insert_input(text) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
return_when=asyncio.ALL_COMPLETED,
)

def bot_has_spoken(self, to_set: bool = None):
Expand All @@ -44,11 +44,10 @@ def deactivate(self):
super().deactivate()
self._synchronize_interfaces()


def add_hotwords(self, hotwords):
for interface in self._interfaces_list:
interface.add_hotwords(hotwords)

def _synchronize_interfaces(self):
for interface in self._interfaces_list:
interface._utterances = self._utterances
interface._utterances = self._utterances
15 changes: 7 additions & 8 deletions wafl/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Rule:
effect: "Fact"
causes: List["Fact"]
_max_indentation = 3
_indent_str = " "
indent_str = " "

def toJSON(self):
return str(self)
Expand All @@ -33,22 +33,21 @@ def _add_clauses(self, rule_str: str) -> str:

return rule_str


def _recursively_add_clauses(self, query: str, depth: int=1) -> str:
indentation = self._indent_str * depth
def _recursively_add_clauses(self, query: str, depth: int = 1) -> str:
indentation = self.indent_str * depth
if type(query) == str:
return f"{indentation}-{query}\n"
return f"{indentation}- {query}\n"

if type(query.text) == str:
return f"{indentation}-{query.text}\n"
return f"{indentation}- {query.text}\n"

if depth > self._max_indentation:
return ""

clause = list(query.text.keys())[0]
rules_text = f"{indentation}-{clause}\n"
rules_text = f"{indentation}- {clause}\n"
for clauses in query.text.values():
for cause_index, clause in enumerate(clauses):
rules_text += self._recursively_add_clauses(clause, depth + 1)

return rules_text
return rules_text
2 changes: 1 addition & 1 deletion wafl/runners/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"):
f"toggle_logs_{conversation_id}",
web_server_loop.toggle_logs,
methods=["POST"],
)
)
3 changes: 1 addition & 2 deletions wafl/runners/run_web_and_audio_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import random
import sys
import threading
Expand Down Expand Up @@ -56,7 +55,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}

app.run(host="0.0.0.0", port=8889)
app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion wafl/runners/run_web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}

app.run(host="0.0.0.0", port=8889)
app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion wafl/templates/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
"frontend_port": 8090,
"llm_model": {
"model_host": "localhost",
"model_port": 8080
"model_port": 8080,
"temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
Expand Down
10 changes: 9 additions & 1 deletion wafl/templates/rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ rules:

- the user asks something about cities, capitals, countries, buildings, famous people, bars, restaurants, rivers, mountains, lakes, seas, oceans, planets, stars, galaxies:
- say that you are just improvising the answer
- say what you think answer the question
- say what you think answer the question

- the user thanks the bot:
- The intention of the user is to close the conversation
- You must answer the user by writing "<execute>close_conversation()</execute>"

- the user says "end this conversation":
- The intention of the user is to close the conversation
- You must answer the user by writing "<execute>close_conversation()</execute>"
7 changes: 6 additions & 1 deletion wafl/variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
def get_variables():
return {
"version": "0.0.82",
"version": "0.0.83",
}


def is_supported(wafl_llm_version):
supported_versions = ["0.0.82", "0.0.83"]
return wafl_llm_version in supported_versions
Loading