Skip to content

Commit

Permalink
Merge pull request #78 from fractalego/llm-webserving
Browse files Browse the repository at this point in the history
Llm webserving
  • Loading branch information
fractalego committed Dec 21, 2023
2 parents 3c3f80c + 19f9b95 commit 973788c
Show file tree
Hide file tree
Showing 23 changed files with 143 additions and 145 deletions.
4 changes: 2 additions & 2 deletions tests/test_closing_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from wafl.events.conversation_events import ConversationEvents
from wafl.exceptions import CloseConversation
from wafl.interface.dummy_interface import DummyInterface
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge

wafl_example = """
rules:
Expand All @@ -24,8 +23,9 @@ def test__thank_you_closes_conversation(self):
]
)
config = Configuration.load_local_config()
config.set_value("rules", wafl_example)
conversation_events = ConversationEvents(
SingleFileKnowledge(config, wafl_example),
config=config,
interface=interface,
)
try:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ def test__listener_accepts_silence_timeout(self):
config = Configuration.load_local_config()
interface = VoiceInterface(config)
self.assertEqual(
interface._listener._timeout, config.get_value("listener_model")["listener_silence_timeout"]
interface._listener._timeout,
config.get_value("listener_model")["listener_silence_timeout"],
)
4 changes: 2 additions & 2 deletions tests/test_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.dummy_interface import DummyInterface
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge

wafl_example = """
facts:
Expand All @@ -21,8 +20,9 @@ def test__facts_are_retrieved(self):
]
)
config = Configuration.load_local_config()
config.set_value("rules", wafl_example)
conversation_events = ConversationEvents(
SingleFileKnowledge(config, wafl_example),
config=config,
interface=interface,
)
asyncio.run(conversation_events.process_next())
Expand Down
9 changes: 4 additions & 5 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.exceptions import CloseConversation
from wafl.interface.dummy_interface import DummyInterface
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge

wafl_example = """
rules:
Expand All @@ -26,8 +24,9 @@ def test__rules_can_be_triggered(self):
]
)
config = Configuration.load_local_config()
config.set_value("rules", wafl_example)
conversation_events = ConversationEvents(
SingleFileKnowledge(config, wafl_example),
config=config,
interface=interface,
)
asyncio.run(conversation_events.process_next())
Expand All @@ -41,11 +40,11 @@ def test__rules_are_not_always_triggered(self):
]
)
config = Configuration.load_local_config()
config.set_value("rules", wafl_example)
conversation_events = ConversationEvents(
SingleFileKnowledge(config, wafl_example),
config=config,
interface=interface,
)
asyncio.run(conversation_events.process_next())
print(interface.get_utterances_list())
unexpected = "bot: the horse is tall"
self.assertNotEqual(unexpected, interface.get_utterances_list()[-1])
19 changes: 8 additions & 11 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from wafl.interface.voice_interface import VoiceInterface
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.dummy_interface import DummyInterface
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge
from wafl.listener.whisper_listener import WhisperListener

_wafl_example = """
Expand All @@ -27,9 +26,8 @@ class TestVoice(TestCase):
def test__activation(self):
interface = DummyInterface(to_utter=["computer", "my name is Jane"])
config = Configuration.load_local_config()
conversation_events = ConversationEvents(
SingleFileKnowledge(config, _wafl_example), interface=interface
)
config.set_value("rules", _wafl_example)
conversation_events = ConversationEvents(config=config, interface=interface)
interface.activate()
asyncio.run(conversation_events.process_next(activation_word="computer"))
asyncio.run(conversation_events.process_next(activation_word="computer"))
Expand All @@ -38,19 +36,15 @@ def test__activation(self):
def test__no_activation(self):
interface = DummyInterface(to_utter=["my name is bob"])
config = Configuration.load_local_config()
conversation_events = ConversationEvents(
SingleFileKnowledge(config, _wafl_example), interface=interface
)
conversation_events = ConversationEvents(config=config, interface=interface)
interface.deactivate()
asyncio.run(conversation_events.process_next(activation_word="computer"))
assert len(interface.get_utterances_list()) == 1

def test__computer_name_is_removed_after_activation(self):
interface = DummyInterface(to_utter=["[computer] computer my name is bob"])
config = Configuration.load_local_config()
conversation_events = ConversationEvents(
SingleFileKnowledge(config, _wafl_example), interface=interface
)
conversation_events = ConversationEvents(config=config, interface=interface)
interface.deactivate()
asyncio.run(conversation_events.process_next(activation_word="computer"))
assert interface.get_utterances_list()[-1].count("computer") == 0
Expand All @@ -77,7 +71,10 @@ def test__random_sounds_are_excluded(self):
def test__voice_interface_receives_config(self):
config = Configuration.load_local_config()
interface = VoiceInterface(config)
assert interface.listener_model_name == config.get_value("listener_model")["local_model"]
assert (
interface.listener_model_name
== config.get_value("listener_model")["local_model"]
)

def test__hotword_listener_activated_using_recording_of_hotword(self):
f = wave.open(os.path.join(_path, "data/computer.wav"), "rb")
Expand Down
19 changes: 12 additions & 7 deletions todo.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
### TODO

* re-train the whisper model using the distilled version
* add system to find out threshold for activation word
* make rules reloadable
* nicer UI?
* New icons
* darker left bar
* update tests
/* re-train the whisper model using the distilled version
/* make rules reloadable
/* nicer UI?
/ * New icons
/ * darker left bar
/* update tests


* lots of duplicates in facts! Avoid that
* use timestamp for facts (or an index in terms of conversation item)
* select only most n recent timestamps
* do not add facts that are already in the list (before cluster_facts)
* update the docs

* redeploy locally and on the server
Expand Down
2 changes: 1 addition & 1 deletion wafl/answerer/base_answerer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class BaseAnswerer:
async def answer(self, query_text, policy):
async def answer(self, query_text: str) -> "Answer":
raise NotImplementedError
86 changes: 35 additions & 51 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def __init__(self, config, knowledge, interface, code_path, logger):
self._max_num_past_utterances = 5
self._max_num_past_utterances_for_facts = 5
self._max_num_past_utterances_for_rules = 0
self._prior_facts = []
self._prior_rules = []
self._prior_facts_with_timestamp = []
self._init_python_module(code_path.replace(".py", ""))
self._max_predictions = 3

Expand All @@ -37,7 +36,6 @@ async def answer(self, query_text):

query = Query.create_from_text(query_text)
rules_texts = await self._get_relevant_rules(query)
facts = await self._get_relevant_facts(query, has_prior_rules=bool(rules_texts))

dialogue = self._interface.get_utterances_list_with_timestamp()[
-self._max_num_past_utterances :
Expand All @@ -54,6 +52,12 @@ async def answer(self, query_text):
last_bot_utterances = get_last_bot_utterances(dialogue_items, num_utterances=3)
last_user_utterance = get_last_user_utterance(dialogue_items)
dialogue_items = [item[1] for item in dialogue_items if item[0] >= start_time]
conversational_timestamp = len(dialogue_items)
facts = await self._get_relevant_facts(
query,
has_prior_rules=bool(rules_texts),
conversational_timestamp=conversational_timestamp,
)
dialogue_items = "\n".join(dialogue_items)

for _ in range(self._max_predictions):
Expand All @@ -74,54 +78,47 @@ async def answer(self, query_text):
continue

if not memories:
if "<execute>" in answer_text:
self._remove_last_rule()

break

facts += "\n" + "\n".join(memories)
self._prior_facts.append("\n".join(memories))
dialogue_items += f"\nbot: {original_answer_text}"

if self._logger:
self._logger.write(f"Answer within dialogue: The answer is {answer_text}")

return Answer.create_from_text(answer_text)

async def _get_relevant_facts(self, query, has_prior_rules):
async def _get_relevant_facts(
self, query, has_prior_rules, conversational_timestamp
):
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
self._prior_facts_with_timestamp = [
item
for item in self._prior_facts_with_timestamp
if item[1]
> conversational_timestamp - self._max_num_past_utterances_for_facts
]
facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold(
query, is_from_user=True, knowledge_name="/", threshold=0.7
query, is_from_user=True, knowledge_name="/", threshold=0.8
)
texts = cluster_facts(facts_and_thresholds)
for text in texts[::-1]:
await self._interface.add_fact(f"The bot remembers: {text}")

if texts:
self._prior_facts = self._prior_facts[
-self._max_num_past_utterances_for_facts :
]
self._prior_facts.append("\n".join(texts))
facts = "\n".join(self._prior_facts)
if facts_and_thresholds:
facts = [item[0].text for item in facts_and_thresholds if item[0].text not in memory]
self._prior_facts_with_timestamp.extend(
(item, conversational_timestamp) for item in facts
)
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])

else:
self._prior_facts = self._prior_facts[-self._max_num_past_utterances :]
if is_question(query.text) and not has_prior_rules:
to_add = [
f"The answer to {query.text} is not in the knowledge base."
memory += (
f"\nThe answer to {query.text} is not in the knowledge base."
"The bot can answer the question while informing the user that the answer was not retrieved"
]
)

elif has_prior_rules:
to_add = [
f"The bot tries to answer {query.text} following the rules from the user."
]
if has_prior_rules:
memory += f"\nThe bot tries to answer {query.text} following the rules from the user."

else:
to_add = []

facts = "\n".join(self._prior_facts + to_add)

return facts
return memory

async def _get_relevant_rules(self, query, max_num_rules=1):
rules = await self._knowledge.ask_for_rule_backward(
Expand All @@ -135,19 +132,10 @@ async def _get_relevant_rules(self, query, max_num_rules=1):
for cause_index, causes in enumerate(rule.causes):
rules_text += f" {cause_index + 1}) {causes.text}\n"

if any(rules_text in item for item in self._prior_rules):
continue

rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")

if rules_texts:
self._prior_rules.append("".join(rules_texts))

self._prior_rules = self._prior_rules[
-self._max_num_past_utterances_for_rules :
]
return "".join(self._prior_rules)
return "\n".join(rules_texts)

def _init_python_module(self, module_name):
self._module = import_module(module_name)
Expand Down Expand Up @@ -190,7 +178,7 @@ async def _run_code(self, to_execute):
break

except NameError as e:
match = re.search(r'\'(\w+)\' is not defined', str(e))
match = re.search(r"\'(\w+)\' is not defined", str(e))
if match:
to_import = match.group(1)
to_execute = f"import {to_import}\n{to_execute}"
Expand All @@ -199,17 +187,13 @@ async def _run_code(self, to_execute):
raise e

except Exception as e:
result = f'Error while executing\n\n"""python\n{to_execute}\n"""\n\n{str(e)}'
result = (
f'Error while executing\n\n"""python\n{to_execute}\n"""\n\n{str(e)}'
)
traceback.print_exc()
break

if not result:
result = f'\n"""python\n{to_execute}\n"""'

return result

def _remove_last_rule(self):
"""
remove the last rule from memory if it was executed during the dialogue
"""
self._prior_rules = self._prior_rules[:-1]
2 changes: 1 addition & 1 deletion wafl/connectors/local/local_entailment_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ async def predict(self, premise: str, hypothesis: str) -> Dict[str, float]:
prediction = torch.softmax(output["logits"], -1)[0]
label_names = ["entailment", "neutral", "contradiction"]
answer = {name: float(pred) for pred, name in zip(prediction, label_names)}
return answer
return answer
6 changes: 5 additions & 1 deletion wafl/connectors/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
def select_best_answer(answers, last_strings):
special_words = last_strings + ["</remember>", "</execute>", "result ="] + ["<execute>", "<remember>", "<execute>", "<remember>"]
special_words = (
last_strings
+ ["</remember>", "</execute>", "result ="]
+ ["<execute>", "<remember>", "<execute>", "<remember>"]
)
return sorted(
answers, key=lambda x: sum([x.count(word) for word in special_words])
)[-1]
2 changes: 1 addition & 1 deletion wafl/entailment/entailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ async def is_neutral(
def _add_presuppositions_to_premise(self, premise):
premise = premise.replace("user says:", "user says to this bot:")
premise = premise.replace("user asks:", "user asks to this bot:")
return premise
return premise
16 changes: 8 additions & 8 deletions wafl/events/conversation_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from wafl.events.answerer_creator import create_answerer
from wafl.simple_text_processing.normalize import normalized
from wafl.config import Configuration
from wafl.events.utils import input_is_valid, remove_text_between_brackets
from wafl.events.utils import input_is_valid, load_knowledge
from wafl.simple_text_processing.questions import is_question
from wafl.exceptions import InterruptTask

Expand All @@ -14,16 +14,13 @@
class ConversationEvents:
def __init__(
self,
knowledge: "BaseKnowledge",
config: "Configuration",
interface: "BaseInterface",
config=None,
logger=None,
):
if not config:
config = Configuration.load_local_config()

self._answerer = create_answerer(config, knowledge, interface, logger)
self._knowledge = knowledge
self._config = config
self._knowledge = load_knowledge(config, logger)
self._answerer = create_answerer(config, self._knowledge, interface, logger)
self._interface = interface
self._logger = logger
self._is_computing = False
Expand Down Expand Up @@ -105,6 +102,9 @@ async def process_next(self, activation_word: str = "") -> bool:
def is_computing(self):
return self._is_computing

def reload_knowledge(self):
self._knowledge = load_knowledge(self._config, self._logger)

def _activation_word_in_text(self, activation_word, text):
if f"[{normalized(activation_word)}]" in normalized(text):
return True
Expand Down
Loading

0 comments on commit 973788c

Please sign in to comment.