diff --git a/todo.txt b/todo.txt index 49ae4284..74dca50c 100644 --- a/todo.txt +++ b/todo.txt @@ -1,3 +1,19 @@ +wafl: +- create indices + +training: +- retrain phi3 +- add tokens and to the training data +- add some prior conversation to the training data, taken from other examples +- add more unused rules in the prompt + + + +* after a rule is deleted, you should also prune the conversation above. + The system can get confused if the conversation becomes too long + re-train the system with prior conversations before calling the rule + + * substitute utterances in base_interface with the conversation class * add config file for model names diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py new file mode 100644 index 00000000..3a7b99a0 --- /dev/null +++ b/wafl/answerer/answerer_implementation.py @@ -0,0 +1,140 @@ +import re +import traceback + +from typing import List, Tuple + +from wafl.exceptions import CloseConversation +from wafl.facts import Fact +from wafl.interface.conversation import Conversation, Utterance + + +def is_executable(text: str) -> bool: + return "" in text + + +def create_one_liner(query_text): + return Conversation( + [ + Utterance( + query_text, + "user", + ) + ] + ) + + +async def substitute_memory_in_answer_and_get_memories_if_present( + answer_text: str, +) -> Tuple[str, List[str]]: + matches = re.finditer( + r"(.*?)|(.*?)$", + answer_text, + re.DOTALL | re.MULTILINE, + ) + memories = [] + for match in matches: + to_substitute = match.group(1) + if not to_substitute: + continue + answer_text = answer_text.replace(match.group(0), "[Output in memory]") + memories.append(to_substitute) + + answer_text = answer_text.replace("
", "\n") + matches = re.finditer(r"(.*?)$", answer_text, re.DOTALL | re.MULTILINE) + memories = [] + for match in matches: + to_substitute = match.group(1) + if not to_substitute: + continue + answer_text = answer_text.replace(match.group(0), "[Output in memory]") + memories.append(to_substitute) + + return answer_text, memories + + +async def execute_results_in_answer(answer_text: str, module, functions) -> str: + matches = re.finditer( + r"(.*?)|(.*?\))$", + answer_text, + re.DOTALL | re.MULTILINE, + ) + for match in matches: + to_execute = match.group(1) + if not to_execute: + continue + result = await _run_code(to_execute, module, functions) + answer_text = answer_text.replace(match.group(0), result) + + matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL | re.MULTILINE) + for match in matches: + to_execute = match.group(1) + if not to_execute: + continue + result = await _run_code(to_execute, module, functions) + answer_text = answer_text.replace(match.group(0), result) + + return answer_text + + +async def _run_code(to_execute: str, module, functions) -> str: + result = None + for _ in range(3): + try: + if any(item + "(" in to_execute for item in functions): + result = eval(f"module.{to_execute}") + break + + else: + ldict = {} + exec(to_execute, globals(), ldict) + if "result" in ldict: + result = str(ldict["result"]) + break + + except NameError as e: + match = re.search(r"\'(\w+)\' is not defined", str(e)) + if match: + to_import = match.group(1) + to_execute = f"import {to_import}\n{to_execute}" + + except CloseConversation as e: + raise e + + except Exception as e: + result = ( + f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}" + ) + traceback.print_exc() + break + + if not result: + result = f"\n```python\n{to_execute}\n```" + + return result + + +def get_text_from_facts_and_thresholds( + facts_and_thresholds: List[Tuple[Fact, float]], memory: str +) -> List[str]: + return [item[0].text for item in facts_and_thresholds if item[0].text not in memory] + + +def add_dummy_utterances_to_continue_generation( + conversation: Conversation, answer_text: str +): + conversation.add_utterance( + Utterance( + answer_text, + "bot", + ) + ) + conversation.add_utterance( + Utterance( + "Continue", + "user", + ) + ) + + +def add_memories_to_facts(facts: str, memories: List[str]) -> str: + return facts + "\n" + "\n".join(memories) diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index 0ee73def..37b042b4 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -1,12 +1,18 @@ -import re -import traceback - from importlib import import_module from inspect import getmembers, isfunction +from typing import List, Tuple +from wafl.answerer.answerer_implementation import ( + is_executable, + substitute_memory_in_answer_and_get_memories_if_present, + create_one_liner, + get_text_from_facts_and_thresholds, + add_dummy_utterances_to_continue_generation, + add_memories_to_facts, + execute_results_in_answer, +) from wafl.answerer.base_answerer import BaseAnswerer from wafl.answerer.rule_maker import RuleMaker from wafl.connectors.clients.llm_chitchat_answer_client import LLMChitChatAnswerClient -from wafl.exceptions import CloseConversation from wafl.extractors.dataclasses import Query, Answer from wafl.interface.conversation import Conversation, Utterance from wafl.simple_text_processing.questions import is_question @@ -14,6 +20,7 @@ class DialogueAnswerer(BaseAnswerer): def __init__(self, config, knowledge, interface, code_path, logger): + self._threshold_for_facts = 0.85 self._delete_current_rule = "[delete_rule]" self._client = LLMChitChatAnswerClient(config) self._knowledge = knowledge @@ -21,7 +28,7 @@ def __init__(self, config, knowledge, interface, code_path, logger): self._interface = interface self._max_num_past_utterances = 5 self._max_num_past_utterances_for_facts = 5 - self._max_num_past_utterances_for_rules = 0 + self._max_num_past_utterances_for_rules = 2 self._prior_facts_with_timestamp = [] self._init_python_module(code_path.replace(".py", "")) self._prior_rules = [] @@ -34,30 +41,20 @@ def __init__(self, config, knowledge, interface, code_path, logger): delete_current_rule=self._delete_current_rule, ) - async def answer(self, query_text): + async def answer(self, query_text: str) -> Answer: if self._logger: self._logger.write(f"Dialogue Answerer: the query is {query_text}") - query = Query.create_from_text("The user says: " + query_text) - rules_text = await self._get_relevant_rules(query) conversation = self._interface.get_utterances_list_with_timestamp().get_last_n( self._max_num_past_utterances ) + rules_text = await self._get_relevant_rules(conversation) if not conversation: - conversation = Conversation( - [ - Utterance( - query_text, - "user", - ) - ] - ) - + conversation = create_one_liner(query_text) last_bot_utterances = conversation.get_last_speaker_utterances("bot", 3) last_user_utterance = conversation.get_last_speaker_utterances("user", 1) if not last_user_utterance: last_user_utterance = query_text - conversational_timestamp = len(conversation) facts = await self._get_relevant_facts( query, @@ -73,21 +70,13 @@ async def answer(self, query_text): dialogue=conversation, ) await self._interface.add_fact(f"The bot predicts: {original_answer_text}") - ( - answer_text, - memories, - ) = await self._substitute_memory_in_answer_and_get_memories_if_present( - await self._substitute_results_in_answer(original_answer_text) + answer_text, memories = await self._apply_substitutions( + original_answer_text ) - if answer_text in last_bot_utterances: - conversation = Conversation( - [ - Utterance( - last_user_utterance[-1], - "user", - ) - ] - ) + if answer_text in last_bot_utterances and not is_executable( + original_answer_text + ): + conversation = create_one_liner(last_user_utterance[-1]) continue if self._delete_current_rule in answer_text: @@ -99,20 +88,8 @@ async def answer(self, query_text): if not memories: break - facts += "\n" + "\n".join(memories) - - conversation.add_utterance( - Utterance( - answer_text, - "bot", - ) - ) - conversation.add_utterance( - Utterance( - "Continue", - "user", - ) - ) + facts = add_memories_to_facts(facts, memories) + add_dummy_utterances_to_continue_generation(conversation, answer_text) if self._logger: self._logger.write( @@ -122,24 +99,17 @@ async def answer(self, query_text): return Answer.create_from_text(final_answer_text) async def _get_relevant_facts( - self, query, has_prior_rules, conversational_timestamp - ): + self, query: Query, has_prior_rules: bool, conversational_timestamp: int + ) -> str: memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp]) - self._prior_facts_with_timestamp = [ - item - for item in self._prior_facts_with_timestamp - if item[1] - > conversational_timestamp - self._max_num_past_utterances_for_facts - ] + self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp( + conversational_timestamp + ) facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold( - query, is_from_user=True, threshold=0.85 + query, is_from_user=True, threshold=self._threshold_for_facts ) if facts_and_thresholds: - facts = [ - item[0].text - for item in facts_and_thresholds - if item[0].text not in memory - ] + facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory) self._prior_facts_with_timestamp.extend( (item, conversational_timestamp) for item in facts ) @@ -155,104 +125,33 @@ async def _get_relevant_facts( return memory - async def _get_relevant_rules(self, query): - rules = await self._rule_creator.create_from_query(query) + async def _get_relevant_rules(self, conversation: Conversation) -> List[str]: + rules = await self._rule_creator.create_from_query(conversation) for rule in rules: if rule not in self._prior_rules: - self._prior_rules.append(rule) + self._prior_rules.insert(0, rule) + self._prior_rules = self._prior_rules[:self._max_num_past_utterances_for_rules] return self._prior_rules def _init_python_module(self, module_name): self._module = import_module(module_name) self._functions = [item[0] for item in getmembers(self._module, isfunction)] - async def _substitute_results_in_answer(self, answer_text): - matches = re.finditer( - r"(.*?)|(.*?\))$", - answer_text, - re.DOTALL | re.MULTILINE, - ) - for match in matches: - to_execute = match.group(1) - if not to_execute: - continue - result = await self._run_code(to_execute) - answer_text = answer_text.replace(match.group(0), result) - - matches = re.finditer( - r"(.*?\))$", answer_text, re.DOTALL | re.MULTILINE - ) - for match in matches: - to_execute = match.group(1) - if not to_execute: - continue - result = await self._run_code(to_execute) - answer_text = answer_text.replace(match.group(0), result) - - return answer_text - - async def _substitute_memory_in_answer_and_get_memories_if_present( - self, answer_text - ): - matches = re.finditer( - r"(.*?)|(.*?)$", - answer_text, - re.DOTALL | re.MULTILINE, - ) - memories = [] - for match in matches: - to_substitute = match.group(1) - if not to_substitute: - continue - answer_text = answer_text.replace(match.group(0), "[Output in memory]") - memories.append(to_substitute) - - answer_text = answer_text.replace("
", "\n") - matches = re.finditer( - r"(.*?)$", answer_text, re.DOTALL | re.MULTILINE + async def _apply_substitutions(self, original_answer_text): + return await substitute_memory_in_answer_and_get_memories_if_present( + await execute_results_in_answer( + original_answer_text, + self._module, + self._functions, + ) ) - memories = [] - for match in matches: - to_substitute = match.group(1) - if not to_substitute: - continue - answer_text = answer_text.replace(match.group(0), "[Output in memory]") - memories.append(to_substitute) - - return answer_text, memories - - async def _run_code(self, to_execute): - result = None - for _ in range(3): - try: - if any(item + "(" in to_execute for item in self._functions): - result = eval(f"self._module.{to_execute}") - break - - else: - ldict = {} - exec(to_execute, globals(), ldict) - if "result" in ldict: - result = str(ldict["result"]) - break - - except NameError as e: - match = re.search(r"\'(\w+)\' is not defined", str(e)) - if match: - to_import = match.group(1) - to_execute = f"import {to_import}\n{to_execute}" - except CloseConversation as e: - raise e - - except Exception as e: - result = ( - f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}" - ) - traceback.print_exc() - break - - if not result: - result = f"\n```python\n{to_execute}\n```" - - return result + def _get_prior_facts_with_timestamp( + self, conversational_timestamp: int + ) -> List[Tuple[str, int]]: + return [ + item + for item in self._prior_facts_with_timestamp + if item[1] + > conversational_timestamp - self._max_num_past_utterances_for_facts + ] diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py index d181b6a5..e6fd87df 100644 --- a/wafl/answerer/rule_maker.py +++ b/wafl/answerer/rule_maker.py @@ -1,5 +1,8 @@ from typing import List +from wafl.extractors.dataclasses import Query +from wafl.rules import Rule + class RuleMaker: def __init__( @@ -21,9 +24,8 @@ def __init__( else: self._max_indentation = config.get_value("max_recursion") - async def create_from_query(self, query: "Query") -> List[str]: - rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.92) - rules = rules[: self._max_num_rules] + async def create_from_query(self, conversation: "Conversation") -> List[str]: + rules = await self._get_rules(conversation) rules_texts = [] for rule in rules: rules_text = rule.get_string_using_template( @@ -35,3 +37,18 @@ async def create_from_query(self, query: "Query") -> List[str]: await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}") return rules_texts + + async def _get_rules(self, conversation: "Conversation") -> List[Rule]: + utterances = conversation.get_last_speaker_utterances("user", 1) + rules = await self._knowledge.ask_for_rule_backward( + Query.create_from_list(utterances), threshold=0.9 + ) + + utterances = conversation.get_last_speaker_utterances("user", 2) + rules.extend( + await self._knowledge.ask_for_rule_backward( + Query.create_from_list(utterances), threshold=0.8 + ) + ) + + return rules[: self._max_num_rules] diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py index d18e9a9b..0ef99469 100644 --- a/wafl/events/conversation_events.py +++ b/wafl/events/conversation_events.py @@ -101,12 +101,13 @@ async def process_next(self, activation_word: str = "") -> bool: return False + async def reload_knowledge(self): + self._knowledge = load_knowledge(self._config, self._logger) + await self._knowledge.initialize_retrievers() + def is_computing(self): return self._is_computing - def reload_knowledge(self): - self._knowledge = load_knowledge(self._config, self._logger) - def reset_discourse_memory(self): self._answerer = create_answerer( self._config, self._knowledge, self._interface, self._logger diff --git a/wafl/extractors/dataclasses.py b/wafl/extractors/dataclasses.py index 17830601..24b967fa 100644 --- a/wafl/extractors/dataclasses.py +++ b/wafl/extractors/dataclasses.py @@ -16,6 +16,13 @@ def is_neutral(self) -> bool: def create_from_text(text): return Query(text=text, is_question=is_question(text)) + @staticmethod + def create_from_list(utterances): + text = "\n".join(utterances) + return Query( + text=text, is_question=any(is_question(item) for item in utterances) + ) + @dataclass class Answer: diff --git a/wafl/frontend/index.html b/wafl/frontend/index.html index 8137ca99..f2ba40e9 100644 --- a/wafl/frontend/index.html +++ b/wafl/frontend/index.html @@ -73,9 +73,7 @@ -
- +
     
WAFL frontend - -
- Creating a new instance. This may take a few seconds... + hx-trigger="load" +
+ + Creating a new instance. This may take a few seconds... +
\ No newline at end of file diff --git a/wafl/frontend/wafl.css b/wafl/frontend/wafl.css index a3a69e78..7c6c4959 100644 --- a/wafl/frontend/wafl.css +++ b/wafl/frontend/wafl.css @@ -1,5 +1,5 @@ body { - background: #f0f8ff; + background: #fffffa; color: black; font-family: monospace; font-size: 25px; @@ -30,7 +30,6 @@ pre { } pre .dialogue { - width: 50vw; height: 90vh; display: flex; padding: 10px; @@ -57,11 +56,6 @@ pre .logs { border-top: 1px solid #4e4a4a; } -#banner { - width: 100%; - padding: 10px 10px; -} - #default-sidebar a{ cursor: pointer; text-align: left; diff --git a/wafl/interface/base_interface.py b/wafl/interface/base_interface.py index dc73f964..1c785651 100644 --- a/wafl/interface/base_interface.py +++ b/wafl/interface/base_interface.py @@ -81,7 +81,10 @@ def _decorate_reply(self, text: str) -> str: return self._decorator.extract(text, self._utterances) def _insert_utterance(self, speaker, text: str): - text = re.sub(r"\[.*?\]", "", text) + clean_text = re.sub(r"\[.*?]", "", text) + if not clean_text.strip(): + clean_text = text + self._utterances.add_utterance( - Utterance(text=text, speaker=speaker, timestamp=time.time()) + Utterance(text=clean_text, speaker=speaker, timestamp=time.time()) ) diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py index e3a0112e..68687eb1 100644 --- a/wafl/interface/conversation.py +++ b/wafl/interface/conversation.py @@ -82,8 +82,24 @@ def insert_utterance(self, new_utterance: Utterance): self.utterances = new_utterances - def get_last_n(self, n: int) -> "Conversation": - return Conversation(self.utterances[-n:]) if self.utterances else Conversation() + def get_last_n(self, n: int, stop_at_string: str = None) -> "Conversation": + if not self.utterances: + return Conversation() + + utterances_to_return: List[Utterance] = [] + stop_at_next = False + for utterance in reversed(self.utterances): + utterances_to_return.append(utterance) + if stop_at_next: + break + if stop_at_string and stop_at_string in utterance.text: + stop_at_next = True + continue + + if len(utterances_to_return) == n: + break + utterances_to_return.reverse() + return Conversation(utterances_to_return) def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]: if not self.utterances: diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py index e2cc0621..747c5a15 100644 --- a/wafl/knowledge/single_file_knowledge.py +++ b/wafl/knowledge/single_file_knowledge.py @@ -59,7 +59,7 @@ def __init__(self, config, rules_text=None, logger=None): loop = None if not loop or not loop.is_running(): - asyncio.run(self._initialize_retrievers()) + asyncio.run(self.initialize_retrievers()) async def add(self, text): fact_index = f"F{len(self._facts_dict)}" @@ -158,7 +158,7 @@ def get_facts_and_rule_as_text(self): return text - async def _initialize_retrievers(self): + async def initialize_retrievers(self): for index, fact in self._facts_dict.items(): if text_is_exact_string(fact.text): continue diff --git a/wafl/scheduler/web_handler.py b/wafl/scheduler/web_handler.py index f28b6d82..064a6d83 100644 --- a/wafl/scheduler/web_handler.py +++ b/wafl/scheduler/web_handler.py @@ -45,7 +45,7 @@ async def reset_conversation(self): self._interface.reset_history() self._interface.deactivate() self._interface.activate() - self._conversation_events.reload_knowledge() + await self._conversation_events.reload_knowledge() self._conversation_events.reset_discourse_memory() await self._interface.output("Hello. How may I help you?") conversation = await self._messages_creator.get_messages_window() diff --git a/wafl/testcases.py b/wafl/testcases.py index 017ad0f8..bcc49f4d 100644 --- a/wafl/testcases.py +++ b/wafl/testcases.py @@ -26,7 +26,7 @@ async def test_single_case(self, name): is_negated = self._testcase_data[name]["negated"] interface = DummyInterface(user_lines) conversation_events = ConversationEvents(self._config, interface=interface) - await conversation_events._knowledge._initialize_retrievers() + await conversation_events._knowledge.initialize_retrievers() print(self.BLUE_COLOR_START + f"\nRunning test '{name}'." + self.COLOR_END) continue_conversations = True