Skip to content

Commit

Permalink
Merge pull request #92 from fractalego/prompt-as-list
Browse files Browse the repository at this point in the history
Prompt as list
  • Loading branch information
fractalego committed May 11, 2024
2 parents 13e3d73 + 556ac14 commit 97a98ac
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 46 deletions.
7 changes: 6 additions & 1 deletion documentation/source/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ A typical configuration file looks like this:
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
"frontend_port": 8081,
"llm_model": {
"model_host": "localhost",
"model_port": 8080
"model_port": 8080,
"temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
Expand Down Expand Up @@ -45,7 +47,10 @@ These settings regulate the following:

* "functions" is the file containing the functions that can be used in the rules. The default is "functions.py".

* "frontend_port" is the port where the web frontend is running. The default is 8090.

* "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080".
The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4.

* "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080".

Expand Down
7 changes: 3 additions & 4 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def test__nested_rules_are_printed_correctly(self):
rule = facts_and_rules["rules"][0]
expected = """
the user wants to know the time
-output "The time is <execute>get_time()</execute>"
-if the time is before 12:00 say "Good morning"
-if the time is after 12:00 say "Good afternoon"
- output "The time is <execute>get_time()</execute>"
- if the time is before 12:00 say "Good morning"
- if the time is after 12:00 say "Good afternoon"
""".strip()
self.assertEqual(expected, str(rule).strip())

14 changes: 14 additions & 0 deletions todo.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
* add config file for model names
- llm model name
- whisper model name

/* add version name
/* let user decide port for frontend
/* update docs about port
/* push new version
* update pypi with wafl and wafl-llm
* clean code for llm eval and make it public
* update huggingface readme
* read overleaf paper


* on wafl_llm make it so only some LLMs are supported
* change speaker model with newer one

Expand Down
21 changes: 19 additions & 2 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,19 @@ def _init_python_module(self, module_name):
self._functions = [item[0] for item in getmembers(self._module, isfunction)]

async def _substitute_results_in_answer(self, answer_text):
matches = re.finditer(r"<execute>(.*?)</execute>", answer_text, re.DOTALL)
matches = re.finditer(r"<execute>(.*?)</execute>|<execute>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
for match in matches:
to_execute = match.group(1)
if not to_execute:
continue
result = await self._run_code(to_execute)
answer_text = answer_text.replace(match.group(0), result)

matches = re.finditer(r"<execute>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
for match in matches:
to_execute = match.group(1)
if not to_execute:
continue
result = await self._run_code(to_execute)
answer_text = answer_text.replace(match.group(0), result)

Expand All @@ -178,6 +188,13 @@ async def _substitute_memory_in_answer_and_get_memories_if_present(
answer_text = answer_text.replace(match.group(0), "")
memories.append(to_execute)

matches = re.finditer(r"<remember>(.*?\))$", answer_text, re.DOTALL|re.MULTILINE)
memories = []
for match in matches:
to_execute = match.group(1)
answer_text = answer_text.replace(match.group(0), "")
memories.append(to_execute)

return answer_text, memories

async def _run_code(self, to_execute):
Expand Down Expand Up @@ -206,7 +223,7 @@ async def _run_code(self, to_execute):

except Exception as e:
result = (
f'Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}'
f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}"
)
traceback.print_exc()
break
Expand Down
14 changes: 8 additions & 6 deletions wafl/answerer/rule_maker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
class RuleMaker:
def __init__(
self,
knowledge,
config,
interface,
max_num_rules,
delete_current_rule,
max_recursion=3,
knowledge: "Knowledge",
config: "BaseConfig",
interface: "BaseInterface",
max_num_rules: int,
delete_current_rule: str,
max_recursion: int = 3,
):
self._knowledge = knowledge
self._config = config
Expand All @@ -26,6 +26,8 @@ async def create_from_query(self, query):
rules_text = rule.get_string_using_template(
"- If {effect} go through the following points:"
)
rules_text += f'{rule.indent_str}- After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n'

rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")

Expand Down
7 changes: 1 addition & 6 deletions wafl/connectors/base_llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import re


from wafl.connectors.utils import select_best_answer

_system_logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -39,11 +38,7 @@ async def generate(self, prompt: str) -> str:

text = prompt
start = len(text)
while (
all(item not in text[start:] for item in self._last_strings)
and len(text) < start + self._max_reply_length
):
text += select_best_answer(await self.predict(text), self._last_strings)
text += select_best_answer(await self.predict(text), self._last_strings)

end_set = set()
for item in self._last_strings:
Expand Down
26 changes: 21 additions & 5 deletions wafl/connectors/remote/remote_llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json

import aiohttp
import asyncio

from wafl.connectors.base_llm_connector import BaseLLMConnector
from wafl.variables import is_supported


class RemoteLLMConnector(BaseLLMConnector):
Expand All @@ -14,6 +17,7 @@ def __init__(self, config, last_strings=None, num_replicas=3):
super().__init__(last_strings)
host = config["model_host"]
port = config["model_port"]
self._default_temperature = config["temperature"]
self._server_url = f"https://{host}:{port}/predictions/bot"
self._num_replicas = num_replicas

Expand All @@ -28,9 +32,11 @@ def __init__(self, config, last_strings=None, num_replicas=3):
):
raise RuntimeError("Cannot connect a running LLM.")

async def predict(self, prompt: str, temperature=None, num_tokens=None, num_replicas=None) -> [str]:
async def predict(
self, prompt: str, temperature=None, num_tokens=None, num_replicas=None
) -> [str]:
if not temperature:
temperature = 0.5
temperature = self._default_temperature

if not num_tokens:
num_tokens = self._num_prediction_tokens
Expand All @@ -52,8 +58,11 @@ async def predict(self, prompt: str, temperature=None, num_tokens=None, num_repl
connector=aiohttp.TCPConnector(ssl=False),
) as session:
async with session.post(self._server_url, json=payload) as response:
answer = await response.text()
return answer.split("<||>")
answer = json.loads(await response.text())
status = answer["status"]
if status != "success":
raise RuntimeError(f"Error in prediction: {answer}")
return answer["prediction"].split("<||>")

return [""]

Expand All @@ -70,7 +79,14 @@ async def check_connection(self):
conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
) as session:
async with session.post(self._server_url, json=payload) as response:
await response.text()
answer = json.loads(await response.text())
wafl_llm_version = answer["version"]
print(f"Connected to wafl-llm v{wafl_llm_version}.")
if not is_supported(wafl_llm_version):
print("This version of wafl-llm is not supported.")
print("Please update wafl-llm.")
raise aiohttp.client.InvalidURL

return True

except aiohttp.client.InvalidURL:
Expand Down
4 changes: 3 additions & 1 deletion wafl/connectors/remote/remote_whisper_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]:
data = await response.text()
prediction = json.loads(data)
if "transcription" not in prediction:
raise RuntimeError("No transcription found in prediction. Is your microphone working?")
raise RuntimeError(
"No transcription found in prediction. Is your microphone working?"
)

transcription = prediction["transcription"]
score = prediction["score"]
Expand Down
4 changes: 3 additions & 1 deletion wafl/interface/base_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ def _decorate_reply(self, text: str) -> str:
return self._decorator.extract(text, self._utterances)

def _insert_utterance(self, speaker, text: str):
if self._utterances == [] or text != self._utterances[-1][1].replace(f"{speaker}: ", ""):
if self._utterances == [] or text != self._utterances[-1][1].replace(
f"{speaker}: ", ""
):
self._utterances.append((time.time(), f"{speaker}: {text}"))
9 changes: 4 additions & 5 deletions wafl/interface/list_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def __init__(self, interfaces_list: List[BaseInterface]):
async def output(self, text: str, silent: bool = False):
await asyncio.wait(
[interface.output(text, silent) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
return_when=asyncio.ALL_COMPLETED,
)

async def input(self) -> str:
done, pending = await asyncio.wait(
[interface.input() for interface in self._interfaces_list],
return_when=asyncio.FIRST_COMPLETED
return_when=asyncio.FIRST_COMPLETED,
)
return done.pop().result()

async def insert_input(self, text: str):
await asyncio.wait(
[interface.insert_input(text) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
return_when=asyncio.ALL_COMPLETED,
)

def bot_has_spoken(self, to_set: bool = None):
Expand All @@ -44,11 +44,10 @@ def deactivate(self):
super().deactivate()
self._synchronize_interfaces()


def add_hotwords(self, hotwords):
for interface in self._interfaces_list:
interface.add_hotwords(hotwords)

def _synchronize_interfaces(self):
for interface in self._interfaces_list:
interface._utterances = self._utterances
interface._utterances = self._utterances
15 changes: 7 additions & 8 deletions wafl/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Rule:
effect: "Fact"
causes: List["Fact"]
_max_indentation = 3
_indent_str = " "
indent_str = " "

def toJSON(self):
return str(self)
Expand All @@ -33,22 +33,21 @@ def _add_clauses(self, rule_str: str) -> str:

return rule_str


def _recursively_add_clauses(self, query: str, depth: int=1) -> str:
indentation = self._indent_str * depth
def _recursively_add_clauses(self, query: str, depth: int = 1) -> str:
indentation = self.indent_str * depth
if type(query) == str:
return f"{indentation}-{query}\n"
return f"{indentation}- {query}\n"

if type(query.text) == str:
return f"{indentation}-{query.text}\n"
return f"{indentation}- {query.text}\n"

if depth > self._max_indentation:
return ""

clause = list(query.text.keys())[0]
rules_text = f"{indentation}-{clause}\n"
rules_text = f"{indentation}- {clause}\n"
for clauses in query.text.values():
for cause_index, clause in enumerate(clauses):
rules_text += self._recursively_add_clauses(clause, depth + 1)

return rules_text
return rules_text
2 changes: 1 addition & 1 deletion wafl/runners/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"):
f"toggle_logs_{conversation_id}",
web_server_loop.toggle_logs,
methods=["POST"],
)
)
3 changes: 1 addition & 2 deletions wafl/runners/run_web_and_audio_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import random
import sys
import threading
Expand Down Expand Up @@ -56,7 +55,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}

app.run(host="0.0.0.0", port=8889)
app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion wafl/runners/run_web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_scheduler_and_webserver_loop(conversation_id):
"web_server_loop": web_loop,
}

app.run(host="0.0.0.0", port=8889)
app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port"))


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion wafl/templates/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"deactivate_sound": true,
"rules": "rules.yaml",
"functions": "functions.py",
"frontend_port": 8090,
"llm_model": {
"model_host": "localhost",
"model_port": 8080
"model_port": 8080,
"temperature": 0.4
},
"listener_model": {
"model_host": "localhost",
Expand Down
10 changes: 9 additions & 1 deletion wafl/templates/rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ rules:

- the user asks something about cities, capitals, countries, buildings, famous people, bars, restaurants, rivers, mountains, lakes, seas, oceans, planets, stars, galaxies:
- say that you are just improvising the answer
- say what you think answer the question
- say what you think answer the question

- the user thanks the bot:
- The intention of the user is to close the conversation
- You must answer the user by writing "<execute>close_conversation()</execute>"

- the user says "end this conversation":
- The intention of the user is to close the conversation
- You must answer the user by writing "<execute>close_conversation()</execute>"
7 changes: 6 additions & 1 deletion wafl/variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
def get_variables():
return {
"version": "0.0.82",
"version": "0.0.83",
}


def is_supported(wafl_llm_version):
supported_versions = ["0.0.82", "0.0.83"]
return wafl_llm_version in supported_versions

0 comments on commit 97a98ac

Please sign in to comment.