diff --git a/tests/config.json b/tests/config.json index 121d77f0..dfbdcbe8 100644 --- a/tests/config.json +++ b/tests/config.json @@ -27,5 +27,9 @@ "text_embedding_model": { "model_host": "localhost", "model_port": 8080 + }, + "entailer_model": { + "model_host": "localhost", + "model_port": 8080 } } diff --git a/tests/test_entailer.py b/tests/test_entailer.py new file mode 100644 index 00000000..c1b5a451 --- /dev/null +++ b/tests/test_entailer.py @@ -0,0 +1,34 @@ +import asyncio +import os + +from unittest import TestCase +from wafl.config import Configuration +from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector +from wafl.connectors.clients.entailer_client import EntailerClient + +_path = os.path.dirname(__file__) + + +class TestConnection(TestCase): + def test__entailer_connector(self): + config = Configuration.load_local_config() + connector = RemoteEntailerConnector(config.get_value("entailer_model")) + prediction = asyncio.run( + connector.predict( + "The first contact is a romance novel set in the middle ages.", + "The first contact is a science fiction novel about the first contact between humans and aliens.", + ) + ) + assert prediction["score"] < 0.5 + + def test__entailment_client(self): + + config = Configuration.load_local_config() + client = EntailerClient(config) + prediction = asyncio.run( + client.get_entailment_score( + "The first contact is a romance novel set in the middle ages.", + "The first contact is a science fiction novel about the first contact between humans and aliens.", + ) + ) + assert prediction < 0.5 diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 1e64b4a4..0ea51557 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -5,7 +5,7 @@ from unittest import TestCase from wafl.config import Configuration -from wafl.dataclasses.dataclasses import Query +from wafl.data_objects.dataclasses import Query from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge _path = os.path.dirname(__file__) diff --git a/tests/test_voice.py b/tests/test_voice.py index 6ecd785b..7b90ef83 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -15,7 +15,7 @@ rules: - the user's name is Jane: - - write "I hear you" + - reply with "I hear you" and nothing else """.strip() _path = os.path.dirname(__file__) @@ -23,7 +23,7 @@ class TestVoice(TestCase): def test__activation(self): - interface = DummyInterface(to_utter=["computer", "my name is Jane"]) + interface = DummyInterface(to_utter=["computer my name is Jane"]) config = Configuration.load_local_config() config.set_value("rules", _wafl_example) conversation_events = ConversationEvents(config=config, interface=interface) diff --git a/todo.txt b/todo.txt index 31c05da1..17ec607e 100644 --- a/todo.txt +++ b/todo.txt @@ -1,5 +1,23 @@ -* why do I need to re-initialise the retrievers after unpickling the knowledge? +* apply entailer to rule retrieval: + if more than one rule is retrieved, then the one + that is entailed by the query should be chosen + +* the answer from the indexed files should be directed from a rule. + - facts and rules should live at the highest level of the retrieval + + +/* Add tqdm to indexing. +/* Make it index when wafl start first, not at the first use/login + +/* The prior items with timestamps might not be necessary. +/ - Just implement a queue with a fixed size + +* add entailer to wafl_llm + + +/* why do I need to re-initialise the retrievers after unpickling the knowledge? - maybe you should save the retrievers in the knowledge object separately? + - It was gensim that was not serializable. Took it out /* knowledge cache does not cache the rules or facts diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py index 3c8b36e8..9a2c25df 100644 --- a/wafl/answerer/answerer_implementation.py +++ b/wafl/answerer/answerer_implementation.py @@ -3,8 +3,9 @@ from typing import List, Tuple +from wafl.answerer.entailer import Entailer from wafl.exceptions import CloseConversation -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.interface.conversation import Conversation, Utterance @@ -113,22 +114,32 @@ async def _run_code(to_execute: str, module, functions) -> str: return result -def get_text_from_facts_and_thresholds( +def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str: + text_fact_list = [ + "\n\n- " + " " + fact.text + " " + for fact in facts + if fact.source == Sources.FROM_TEXT + ][:max_num_facts] + rule_fact_list = [ + "\n\n- " + " " + fact.text + " " + for fact in facts + if fact.source in [None, Sources.FROM_RULES] + ] + return "".join(text_fact_list + rule_fact_list) + + +def get_facts_with_metadata_from_facts_and_thresholds( facts_and_thresholds: List[Tuple[Fact, float]], memory: str ) -> List[str]: - text_list = [] + fact_list = [] for item in facts_and_thresholds: if item[0].text not in memory: - text = item[0].text + new_fact = item[0].copy() if item[0].metadata: - text = ( - f"Metadata for the following text: {str(item[0].metadata)}" - + "\n" - + text - ) - text_list.append(text) + new_fact.text = new_fact.text + fact_list.append(new_fact) - return text_list + return fact_list def add_dummy_utterances_to_continue_generation( @@ -150,3 +161,9 @@ def add_dummy_utterances_to_continue_generation( def add_memories_to_facts(facts: str, memories: List[str]) -> str: return facts + "\n" + "\n".join(memories) + + +def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str: + query_text = conversation.get_last_speaker_utterance("user") + rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True) + return rules_as_strings[:num_rules] diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index f12be579..5377233a 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -1,18 +1,21 @@ from importlib import import_module from inspect import getmembers, isfunction -from typing import List, Tuple +from typing import List + +from wafl.answerer.entailer import Entailer from wafl.answerer.answerer_implementation import ( substitute_memory_in_answer_and_get_memories_if_present, create_one_liner, - get_text_from_facts_and_thresholds, + get_facts_with_metadata_from_facts_and_thresholds, add_dummy_utterances_to_continue_generation, add_memories_to_facts, execute_results_in_answer, + create_memory_from_fact_list, select_best_rules_using_entailer, ) from wafl.answerer.base_answerer import BaseAnswerer from wafl.answerer.rule_maker import RuleMaker from wafl.connectors.clients.llm_chat_client import LLMChatClient -from wafl.dataclasses.dataclasses import Query, Answer +from wafl.data_objects.dataclasses import Query, Answer from wafl.interface.conversation import Conversation from wafl.simple_text_processing.questions import is_question @@ -21,13 +24,14 @@ class DialogueAnswerer(BaseAnswerer): def __init__(self, config, knowledge, interface, code_path, logger): self._threshold_for_facts = 0.85 self._client = LLMChatClient(config) + self._entailer = Entailer(config) self._knowledge = knowledge self._logger = logger self._interface = interface self._max_num_past_utterances = 5 - self._max_num_past_utterances_for_facts = 5 - self._max_num_past_utterances_for_rules = 2 - self._prior_facts_with_timestamp = [] + self._max_num_facts = 5 + self._max_num_rules = 2 + self._prior_facts = [] self._init_python_module(code_path.replace(".py", "")) self._prior_rules = [] self._max_predictions = 3 @@ -48,17 +52,15 @@ async def answer(self, query_text: str) -> Answer: rules_text = await self._get_relevant_rules(conversation) if not conversation: conversation = create_one_liner(query_text) - conversational_timestamp = len(conversation) - facts = await self._get_relevant_facts( + memory = await self._get_relevant_facts( query, has_prior_rules=bool(rules_text), - conversational_timestamp=conversational_timestamp, ) final_answer_text = "" for _ in range(self._max_predictions): original_answer_text = await self._client.get_answer( - text=facts, + text=memory, rules_text=rules_text, dialogue=conversation, ) @@ -82,22 +84,19 @@ async def answer(self, query_text: str) -> Answer: return Answer.create_from_text(final_answer_text) - async def _get_relevant_facts( - self, query: Query, has_prior_rules: bool, conversational_timestamp: int - ) -> str: - memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp]) - self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp( - conversational_timestamp - ) + async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str: + memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts) facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold( query, is_from_user=True, threshold=self._threshold_for_facts ) if facts_and_thresholds: - facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory) - self._prior_facts_with_timestamp.extend( - (item, conversational_timestamp) for item in facts + facts = get_facts_with_metadata_from_facts_and_thresholds( + facts_and_thresholds, memory + ) + self._prior_facts.extend(facts) + memory = create_memory_from_fact_list( + self._prior_facts, self._max_num_facts ) - memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp]) await self._interface.add_fact(f"The bot remembers the facts:\n{memory}") else: @@ -110,11 +109,12 @@ async def _get_relevant_facts( return memory async def _get_relevant_rules(self, conversation: Conversation) -> List[str]: - rules = await self._rule_creator.create_from_query(conversation) - for rule in rules: + rules_as_strings = await self._rule_creator.create_from_query(conversation) + rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1) + for rule in rules_as_strings: if rule not in self._prior_rules: self._prior_rules.insert(0, rule) - self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules] + self._prior_rules = self._prior_rules[: self._max_num_rules] return self._prior_rules def _init_python_module(self, module_name): @@ -129,13 +129,3 @@ async def _apply_substitutions(self, original_answer_text): self._functions, ) ) - - def _get_prior_facts_with_timestamp( - self, conversational_timestamp: int - ) -> List[Tuple[str, int]]: - return [ - item - for item in self._prior_facts_with_timestamp - if item[1] - > conversational_timestamp - self._max_num_past_utterances_for_facts - ] diff --git a/wafl/answerer/entailer.py b/wafl/answerer/entailer.py index 54e4e3e2..3f3c2ab9 100644 --- a/wafl/answerer/entailer.py +++ b/wafl/answerer/entailer.py @@ -1,41 +1,14 @@ -import os -import textwrap - -from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory -from wafl.connectors.prompt_template import PromptTemplate -from wafl.interface.conversation import Utterance, Conversation - -_path = os.path.dirname(__file__) +from wafl.connectors.clients.entailer_client import EntailerClient class Entailer: def __init__(self, config): - self._connector = LLMConnectorFactory.get_connector(config) + self.entailer_client = EntailerClient(config) self._config = config - async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str: - prompt = await self._get_answer_prompt(lhs, rhs, dialogue) - result = await self._connector.generate(prompt) - result = self._clean_result(result) - return result == "yes" - - async def _get_answer_prompt(self, lhs, rhs, dialogue): - return PromptTemplate( - system_prompt="", - conversation=self._get_dialogue_prompt(lhs, rhs, dialogue), - ) - - def _clean_result(self, result): - result = result.replace("", "") - result = result.split("\n")[0] - result = result.strip() - return result.lower() + async def left_entails_right(self, lhs: str, rhs: str) -> bool: + prediction = await self.entailer_client.get_entailment_score(lhs, rhs) + return prediction > 0.5 - def _get_dialogue_prompt(self, dialogue, lhs, rhs): - text = f""" -Your task is to determine whether two sentences are similar. -1) {lhs.lower()} -2) {rhs.lower()} -Please answer "yes" if the two sentences are similar or "no" if not: - """.strip() - return Conversation([Utterance(speaker="user", text=text)]) + async def get_score(self, lhs: str, rhs: str) -> float: + return await self.entailer_client.get_entailment_score(lhs, rhs) diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py index 115dfcfc..7454fe73 100644 --- a/wafl/answerer/rule_maker.py +++ b/wafl/answerer/rule_maker.py @@ -1,7 +1,7 @@ from typing import List -from wafl.dataclasses.dataclasses import Query -from wafl.dataclasses.rules import Rule +from wafl.data_objects.dataclasses import Query +from wafl.data_objects.rules import Rule class RuleMaker: diff --git a/wafl/changelog.txt b/wafl/changelog.txt new file mode 100644 index 00000000..bfd13914 --- /dev/null +++ b/wafl/changelog.txt @@ -0,0 +1,7 @@ +- version 0.1.3 +* added multi-threaded support for multiple files indexing +* TODO: ADD support for multiple knowledge bases. + It needs to index the rules and the files separately! +* the interface should show where the facts come from in the web interface +* add support for wafl studio where you can concatenate actions (and create corresponding yaml files) +* use <> tags for contactenation \ No newline at end of file diff --git a/wafl/command_line.py b/wafl/command_line.py index 4ebb6643..114f86be 100644 --- a/wafl/command_line.py +++ b/wafl/command_line.py @@ -9,6 +9,7 @@ run_testcases, print_incipit, download_models, + load_indices, ) from wafl.runners.run_from_actions import run_action @@ -52,26 +53,31 @@ def process_cli(): elif command == "run": from wafl.runners.run_web_and_audio_interface import run_app + load_indices() run_app() remove_preprocessed("/") elif command == "run-cli": + load_indices() run_from_command_line() remove_preprocessed("/") elif command == "run-audio": from wafl.runners.run_from_audio import run_from_audio + load_indices() run_from_audio() remove_preprocessed("/") elif command == "run-server": from wafl.runners.run_web_interface import run_server_only_app + load_indices() run_server_only_app() remove_preprocessed("/") elif command == "run-tests": + load_indices() run_testcases() remove_preprocessed("/") diff --git a/wafl/connectors/clients/clients_implementation.py b/wafl/connectors/clients/clients_implementation.py deleted file mode 100644 index ec463c56..00000000 --- a/wafl/connectors/clients/clients_implementation.py +++ /dev/null @@ -1,19 +0,0 @@ -import csv -import os -import joblib - -from wafl.knowledge.single_file_knowledge import SingleFileKnowledge - -_path = os.path.dirname(__file__) - - -async def load_knowledge_from_file(filename, config): - items_list = [] - with open(os.path.join(_path, "../../data/", filename + ".csv")) as file: - csvreader = csv.reader(file) - for row in csvreader: - items_list.append(row[0].strip()) - - knowledge = await SingleFileKnowledge.create_from_list(items_list, config) - joblib.dump(knowledge, os.path.join(_path, f"../../data/{filename}.knowledge")) - return knowledge diff --git a/wafl/connectors/clients/entailer_client.py b/wafl/connectors/clients/entailer_client.py new file mode 100644 index 00000000..a5c189be --- /dev/null +++ b/wafl/connectors/clients/entailer_client.py @@ -0,0 +1,21 @@ +import os + +from wafl.connectors.factories.entailer_connector_factory import ( + EntailerConnectorFactory, +) + +_path = os.path.dirname(__file__) + + +class EntailerClient: + def __init__(self, config): + self._connector = EntailerConnectorFactory.get_connector( + "entailer_model", config + ) + self._config = config + + async def get_entailment_score(self, lhs: str, rhs: str) -> float: + prediction = await self._connector.predict(lhs, rhs) + if "score" not in prediction: + raise ValueError("The Entailment prediction does not contain a score.") + return prediction["score"] diff --git a/wafl/connectors/clients/information_client.py b/wafl/connectors/clients/information_client.py index 772afb00..533fc902 100644 --- a/wafl/connectors/clients/information_client.py +++ b/wafl/connectors/clients/information_client.py @@ -1,9 +1,6 @@ import os -import textwrap -from typing import List from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory -from wafl.connectors.prompt_template import PromptTemplate _path = os.path.dirname(__file__) diff --git a/wafl/connectors/factories/entailer_connector_factory.py b/wafl/connectors/factories/entailer_connector_factory.py new file mode 100644 index 00000000..017fbbdf --- /dev/null +++ b/wafl/connectors/factories/entailer_connector_factory.py @@ -0,0 +1,7 @@ +from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector + + +class EntailerConnectorFactory: + @staticmethod + def get_connector(model_name, config): + return RemoteEntailerConnector(config.get_value(model_name)) diff --git a/wafl/connectors/remote/remote_entailer_connector.py b/wafl/connectors/remote/remote_entailer_connector.py new file mode 100644 index 00000000..f1230c86 --- /dev/null +++ b/wafl/connectors/remote/remote_entailer_connector.py @@ -0,0 +1,58 @@ +import aiohttp +import asyncio +import json +import numpy as np + +from typing import Dict, List + + +class RemoteEntailerConnector: + _max_tries = 3 + + def __init__(self, config): + host = config["model_host"] + port = config["model_port"] + + self._server_url = f"https://{host}:" f"{port}/predictions/entailer" + try: + loop = asyncio.get_running_loop() + + except RuntimeError: + loop = None + + if (not loop or (loop and not loop.is_running())) and not asyncio.run( + self.check_connection() + ): + raise RuntimeError("Cannot connect a running Entailment Model.") + + async def predict(self, lhs: str, rhs: str) -> Dict[str, float]: + payload = {"lhs": lhs, "rhs": rhs} + for _ in range(self._max_tries): + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.post(self._server_url, json=payload) as response: + data = await response.text() + prediction = json.loads(data) + score = prediction["score"] + return {"score": float(score)} + + return {"score": -1.0} + + async def check_connection(self): + payload = {"lhs": "test", "rhs": "test"} + try: + async with aiohttp.ClientSession( + conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.post(self._server_url, json=payload) as response: + await response.text() + return True + + except aiohttp.client.InvalidURL: + print() + print("Is the entailer server running?") + print("Please run 'bash start-llm.sh' (see docs for explanation).") + print() + + return False diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index a2b3d7f7..ed0af1c8 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -10,7 +10,7 @@ class RemoteLLMConnector(BaseLLMConnector): _max_tries = 3 - _max_reply_length = 1024 + _max_reply_length = 2048 _num_prediction_tokens = 200 _cache = {} diff --git a/wafl/dataclasses/__init__.py b/wafl/data_objects/__init__.py similarity index 100% rename from wafl/dataclasses/__init__.py rename to wafl/data_objects/__init__.py diff --git a/wafl/dataclasses/dataclasses.py b/wafl/data_objects/dataclasses.py similarity index 100% rename from wafl/dataclasses/dataclasses.py rename to wafl/data_objects/dataclasses.py diff --git a/wafl/data_objects/facts.py b/wafl/data_objects/facts.py new file mode 100644 index 00000000..88926ab0 --- /dev/null +++ b/wafl/data_objects/facts.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Union + + +class Sources(Enum): + FROM_TEXT = 1 + FROM_RULES = 2 + + +@dataclass +class Fact: + text: Union[str, dict] + is_question: bool = False + variable: str = None + is_interruption: bool = False + destination: str = None + metadata: Union[str, dict] = None + source: Sources = Sources.FROM_RULES + + def toJSON(self): + return str(self) + + def copy(self): + return Fact( + self.text, + self.is_question, + self.variable, + self.is_interruption, + self.destination, + self.metadata, + self.source, + ) diff --git a/wafl/dataclasses/rules.py b/wafl/data_objects/rules.py similarity index 100% rename from wafl/dataclasses/rules.py rename to wafl/data_objects/rules.py diff --git a/wafl/dataclasses/facts.py b/wafl/dataclasses/facts.py deleted file mode 100644 index 0445adff..00000000 --- a/wafl/dataclasses/facts.py +++ /dev/null @@ -1,16 +0,0 @@ -from dataclasses import dataclass -from typing import Union - - -@dataclass -class Fact: - text: Union[str, dict] - is_question: bool = False - variable: str = None - is_interruption: bool = False - source: str = None - destination: str = None - metadata: Union[str, dict] = None - - def toJSON(self): - return str(self) diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py index d83d52ce..e9e2d20b 100644 --- a/wafl/events/conversation_events.py +++ b/wafl/events/conversation_events.py @@ -19,6 +19,7 @@ def __init__( config: "Configuration", interface: "BaseInterface", logger=None, + knowledge=None, ): self._config = config try: @@ -29,6 +30,8 @@ def __init__( if not loop or not loop.is_running(): self._knowledge = asyncio.run(load_knowledge(config, logger)) + else: + self._knowledge = knowledge self._answerer = create_answerer(config, self._knowledge, interface, logger) self._answerer._client._connector._cache = {} diff --git a/wafl/inference/utils.py b/wafl/inference/utils.py index 2f25bda1..d270814b 100644 --- a/wafl/inference/utils.py +++ b/wafl/inference/utils.py @@ -2,7 +2,7 @@ from typing import List, Dict, Tuple, Any from fuzzywuzzy import process -from wafl.dataclasses.dataclasses import Answer +from wafl.data_objects.dataclasses import Answer from wafl.simple_text_processing.normalize import normalized from wafl.simple_text_processing.questions import is_question diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py index 68687eb1..c65f1f5b 100644 --- a/wafl/interface/conversation.py +++ b/wafl/interface/conversation.py @@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]: if utterance.speaker == speaker ][-n:] + def get_last_speaker_utterance(self, speaker: str) -> str: + if not self.utterances: + return "" + + for utterance in reversed(self.utterances): + if utterance.speaker == speaker: + return utterance.text + return "" + def get_first_timestamp(self) -> float: return self.utterances[0].timestamp if self.utterances else None diff --git a/wafl/knowledge/indexing_implementation.py b/wafl/knowledge/indexing_implementation.py index fc4bdf3b..d4c3a8f8 100644 --- a/wafl/knowledge/indexing_implementation.py +++ b/wafl/knowledge/indexing_implementation.py @@ -1,24 +1,50 @@ +import asyncio import os - import joblib import yaml +import threading +from tqdm import tqdm from wafl.config import Configuration from wafl.knowledge.single_file_knowledge import SingleFileKnowledge from wafl.readers.reader_factory import ReaderFactory +async def add_file_to_knowledge(knowledge, filename): + reader = ReaderFactory.get_reader(filename) + for chunk in reader.get_chunks(filename): + await knowledge.add_fact(chunk) + + async def _add_indices_to_knowledge(knowledge, text): indices = yaml.safe_load(text) if "paths" not in indices or not indices["paths"]: return knowledge for path in indices["paths"]: - for root, _, files in os.walk(path): - for file in files: - reader = ReaderFactory.get_reader(file) - for chunk in reader.get_chunks(os.path.join(root, file)): - await knowledge.add_fact(chunk) + print(f"Indexing path: {path}") + file_count = sum(len(files) for _, _, files in os.walk(path)) + with tqdm(total=file_count) as pbar: + for root, _, files in os.walk(path): + threads = [] + for file in files: + threads.append( + threading.Thread( + target=asyncio.run, + args=( + add_file_to_knowledge( + knowledge, os.path.join(root, file) + ), + ), + ) + ) + num_threads = min(10, len(threads)) + for i in range(0, len(threads), num_threads): + for thread in threads[i : i + num_threads]: + thread.start() + for thread in threads[i : i + num_threads]: + thread.join() + pbar.update(num_threads) return knowledge @@ -27,10 +53,12 @@ async def load_knowledge(config, logger=None): if ".yaml" in config.get_value("rules") and not any( item in config.get_value("rules") for item in [" ", "\n"] ): + rules_filename = config.get_value("rules") with open(config.get_value("rules")) as file: rules_txt = file.read() else: + rules_filename = None rules_txt = config.get_value("rules") index_filename = config.get_value("index") @@ -41,10 +69,12 @@ async def load_knowledge(config, logger=None): cache_filename = config.get_value("cache_filename") if os.path.exists(cache_filename): - knowledge = joblib.load(cache_filename) - if knowledge.hash == hash(rules_txt) and os.path.getmtime( - cache_filename - ) > os.path.getmtime(index_filename): + if ( + rules_filename + and os.path.getmtime(cache_filename) > os.path.getmtime(rules_filename) + and os.path.getmtime(cache_filename) > os.path.getmtime(index_filename) + ): + knowledge = joblib.load(cache_filename) return knowledge knowledge = SingleFileKnowledge(config, rules_txt, logger=logger) diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py index 8c9c5a25..2fc15c39 100644 --- a/wafl/knowledge/single_file_knowledge.py +++ b/wafl/knowledge/single_file_knowledge.py @@ -4,9 +4,10 @@ from typing import List import nltk +from tqdm import tqdm from wafl.config import Configuration -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact from wafl.knowledge.base_knowledge import BaseKnowledge from wafl.knowledge.utils import ( text_is_exact_string, @@ -169,7 +170,8 @@ def get_facts_and_rule_as_text(self): return text async def initialize_retrievers(self): - for index, fact in self._facts_dict.items(): + print("Initializing fact retrievers") + for index, fact in tqdm(self._facts_dict.items()): if text_is_exact_string(fact.text): continue @@ -181,7 +183,8 @@ async def initialize_retrievers(self): clean_text_for_retrieval(fact.text), index ) - for index, rule in self._rules_dict.items(): + print("Initializing rule retrievers") + for index, rule in tqdm(self._rules_dict.items()): if text_is_exact_string(rule.effect.text): continue @@ -189,10 +192,6 @@ async def initialize_retrievers(self): clean_text_for_retrieval(rule.effect.text), index ) - for index, rule in self._rules_dict.items(): - if not text_is_exact_string(rule.effect.text): - continue - await self._rules_string_retriever.add_text_and_index( rule.effect.text, index ) diff --git a/wafl/parsing/line_rules_parser.py b/wafl/parsing/line_rules_parser.py index 73371f3e..b6bfa3ee 100644 --- a/wafl/parsing/line_rules_parser.py +++ b/wafl/parsing/line_rules_parser.py @@ -1,6 +1,6 @@ from wafl.simple_text_processing.questions import is_question -from wafl.dataclasses.facts import Fact -from wafl.dataclasses.rules import Rule +from wafl.data_objects.facts import Fact +from wafl.data_objects.rules import Rule def parse_rule_from_single_line(text): diff --git a/wafl/parsing/rules_parser.py b/wafl/parsing/rules_parser.py index 70d3b5f1..bb813e93 100644 --- a/wafl/parsing/rules_parser.py +++ b/wafl/parsing/rules_parser.py @@ -1,7 +1,7 @@ import yaml -from wafl.dataclasses.facts import Fact -from wafl.dataclasses.rules import Rule +from wafl.data_objects.facts import Fact +from wafl.data_objects.rules import Rule from wafl.simple_text_processing.deixis import from_user_to_bot diff --git a/wafl/readers/base_reader.py b/wafl/readers/base_reader.py index ea995601..5f0aaef2 100644 --- a/wafl/readers/base_reader.py +++ b/wafl/readers/base_reader.py @@ -1,6 +1,6 @@ from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact class BaseReader: diff --git a/wafl/readers/pdf_reader.py b/wafl/readers/pdf_reader.py index 4f610616..dc94f664 100644 --- a/wafl/readers/pdf_reader.py +++ b/wafl/readers/pdf_reader.py @@ -2,7 +2,7 @@ from logging import getLogger from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.readers.base_reader import BaseReader _logger = getLogger(__name__) @@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]: Fact( text=page.get_text(), metadata={"filename": filename, "page_number": i}, + source=Sources.FROM_TEXT, ) for i, page in enumerate(doc) ] diff --git a/wafl/readers/reader_factory.py b/wafl/readers/reader_factory.py index 14ccb70c..6fc33bfc 100644 --- a/wafl/readers/reader_factory.py +++ b/wafl/readers/reader_factory.py @@ -4,7 +4,7 @@ class ReaderFactory: _chunk_size = 10000 - _overlap = 100 + _overlap = 500 _extension_to_reader_dict = {".pdf": PdfReader, ".txt": TextReader} @staticmethod @@ -13,7 +13,4 @@ def get_reader(filename): if extension in filename.lower(): return reader(ReaderFactory._chunk_size, ReaderFactory._overlap) - ### add pdf reader - ### add metadata and show in the UI - return TextReader(ReaderFactory._chunk_size, ReaderFactory._overlap) diff --git a/wafl/readers/text_reader.py b/wafl/readers/text_reader.py index b22c4ffe..8457ee04 100644 --- a/wafl/readers/text_reader.py +++ b/wafl/readers/text_reader.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.readers.base_reader import BaseReader _logger = getLogger(__name__) @@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]: Fact( text=chunk, metadata={"filename": filename, "chunk_number": i}, + source=Sources.FROM_TEXT, ) for i, chunk in enumerate(chunks) ] diff --git a/wafl/run.py b/wafl/run.py index b0397e84..4138ac48 100644 --- a/wafl/run.py +++ b/wafl/run.py @@ -4,6 +4,7 @@ from wafl.exceptions import CloseConversation from wafl.events.conversation_events import ConversationEvents from wafl.interface.command_line_interface import CommandLineInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.logger.local_file_logger import LocalFileLogger from wafl.testcases import ConversationTestCases from wafl.variables import get_variables @@ -17,6 +18,12 @@ def print_incipit(): print() +def load_indices(): + print("Loading knowledge indices...") + config = Configuration.load_local_config() + asyncio.run(load_knowledge(config, _logger)) + + def run_from_command_line(): interface = CommandLineInterface() config = Configuration.load_local_config() diff --git a/wafl/runners/run_from_audio.py b/wafl/runners/run_from_audio.py index 7b523687..7a1d8f35 100644 --- a/wafl/runners/run_from_audio.py +++ b/wafl/runners/run_from_audio.py @@ -1,6 +1,9 @@ +import asyncio + from wafl.config import Configuration from wafl.events.conversation_events import ConversationEvents from wafl.interface.voice_interface import VoiceInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.logger.local_file_logger import LocalFileLogger from wafl.handlers.conversation_handler import ConversationHandler from wafl.scheduler.scheduler import Scheduler @@ -10,6 +13,7 @@ def run_from_audio(): config = Configuration.load_local_config() + asyncio.run(load_knowledge(config, _logger)) interface = VoiceInterface(config) conversation_events = ConversationEvents( config=config, diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py index 9d5f833c..177f4501 100644 --- a/wafl/runners/run_web_and_audio_interface.py +++ b/wafl/runners/run_web_and_audio_interface.py @@ -1,3 +1,4 @@ +import asyncio import random import sys import threading @@ -6,6 +7,7 @@ from wafl.interface.list_interface import ListInterface from wafl.interface.voice_interface import VoiceInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.scheduler.scheduler import Scheduler from wafl.handlers.conversation_handler import ConversationHandler from wafl.logger.local_file_logger import LocalFileLogger diff --git a/wafl/runners/run_web_interface.py b/wafl/runners/run_web_interface.py index b835473c..4eae9ce0 100644 --- a/wafl/runners/run_web_interface.py +++ b/wafl/runners/run_web_interface.py @@ -5,6 +5,7 @@ from flask import render_template, redirect +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.scheduler.scheduler import Scheduler from wafl.handlers.web_handler import WebHandler from wafl.handlers.conversation_handler import ConversationHandler diff --git a/wafl/testcases.py b/wafl/testcases.py index bcc49f4d..cb97af4f 100644 --- a/wafl/testcases.py +++ b/wafl/testcases.py @@ -1,4 +1,6 @@ from wafl.answerer.entailer import Entailer +from wafl.knowledge.indexing_implementation import load_knowledge + from wafl.simple_text_processing.deixis import from_user_to_bot, from_bot_to_user from wafl.exceptions import CloseConversation from wafl.events.conversation_events import ConversationEvents @@ -25,8 +27,10 @@ async def test_single_case(self, name): test_lines = self._testcase_data[name]["lines"] is_negated = self._testcase_data[name]["negated"] interface = DummyInterface(user_lines) - conversation_events = ConversationEvents(self._config, interface=interface) - await conversation_events._knowledge.initialize_retrievers() + knowledge = await load_knowledge(self._config) + conversation_events = ConversationEvents( + self._config, interface=interface, knowledge=knowledge + ) print(self.BLUE_COLOR_START + f"\nRunning test '{name}'." + self.COLOR_END) continue_conversations = True @@ -77,9 +81,7 @@ async def _lhs_is_similar_to(self, lhs, rhs, prior_dialogue): if lhs_name != rhs_name: return False - return await self._entailer.left_entails_right( - lhs, rhs, "\n".join(prior_dialogue) - ) + return await self._entailer.left_entails_right(lhs, rhs) def _apply_deixis(self, line): name = line.split(":")[0].strip()