diff --git a/tests/config.json b/tests/config.json
index 121d77f0..dfbdcbe8 100644
--- a/tests/config.json
+++ b/tests/config.json
@@ -27,5 +27,9 @@
"text_embedding_model": {
"model_host": "localhost",
"model_port": 8080
+ },
+ "entailer_model": {
+ "model_host": "localhost",
+ "model_port": 8080
}
}
diff --git a/tests/test_entailer.py b/tests/test_entailer.py
new file mode 100644
index 00000000..c1b5a451
--- /dev/null
+++ b/tests/test_entailer.py
@@ -0,0 +1,34 @@
+import asyncio
+import os
+
+from unittest import TestCase
+from wafl.config import Configuration
+from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
+from wafl.connectors.clients.entailer_client import EntailerClient
+
+_path = os.path.dirname(__file__)
+
+
+class TestConnection(TestCase):
+ def test__entailer_connector(self):
+ config = Configuration.load_local_config()
+ connector = RemoteEntailerConnector(config.get_value("entailer_model"))
+ prediction = asyncio.run(
+ connector.predict(
+ "The first contact is a romance novel set in the middle ages.",
+ "The first contact is a science fiction novel about the first contact between humans and aliens.",
+ )
+ )
+ assert prediction["score"] < 0.5
+
+ def test__entailment_client(self):
+
+ config = Configuration.load_local_config()
+ client = EntailerClient(config)
+ prediction = asyncio.run(
+ client.get_entailment_score(
+ "The first contact is a romance novel set in the middle ages.",
+ "The first contact is a science fiction novel about the first contact between humans and aliens.",
+ )
+ )
+ assert prediction < 0.5
diff --git a/tests/test_indexing.py b/tests/test_indexing.py
index 1e64b4a4..0ea51557 100644
--- a/tests/test_indexing.py
+++ b/tests/test_indexing.py
@@ -5,7 +5,7 @@
from unittest import TestCase
from wafl.config import Configuration
-from wafl.dataclasses.dataclasses import Query
+from wafl.data_objects.dataclasses import Query
from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge
_path = os.path.dirname(__file__)
diff --git a/tests/test_voice.py b/tests/test_voice.py
index 6ecd785b..7b90ef83 100644
--- a/tests/test_voice.py
+++ b/tests/test_voice.py
@@ -15,7 +15,7 @@
rules:
- the user's name is Jane:
- - write "I hear you"
+ - reply with "I hear you" and nothing else
""".strip()
_path = os.path.dirname(__file__)
@@ -23,7 +23,7 @@
class TestVoice(TestCase):
def test__activation(self):
- interface = DummyInterface(to_utter=["computer", "my name is Jane"])
+ interface = DummyInterface(to_utter=["computer my name is Jane"])
config = Configuration.load_local_config()
config.set_value("rules", _wafl_example)
conversation_events = ConversationEvents(config=config, interface=interface)
diff --git a/todo.txt b/todo.txt
index 31c05da1..17ec607e 100644
--- a/todo.txt
+++ b/todo.txt
@@ -1,5 +1,23 @@
-* why do I need to re-initialise the retrievers after unpickling the knowledge?
+* apply entailer to rule retrieval:
+ if more than one rule is retrieved, then the one
+ that is entailed by the query should be chosen
+
+* the answer from the indexed files should be directed from a rule.
+ - facts and rules should live at the highest level of the retrieval
+
+
+/* Add tqdm to indexing.
+/* Make it index when wafl start first, not at the first use/login
+
+/* The prior items with timestamps might not be necessary.
+/ - Just implement a queue with a fixed size
+
+* add entailer to wafl_llm
+
+
+/* why do I need to re-initialise the retrievers after unpickling the knowledge?
- maybe you should save the retrievers in the knowledge object separately?
+ - It was gensim that was not serializable. Took it out
/* knowledge cache does not cache the rules or facts
diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py
index 3c8b36e8..9a2c25df 100644
--- a/wafl/answerer/answerer_implementation.py
+++ b/wafl/answerer/answerer_implementation.py
@@ -3,8 +3,9 @@
from typing import List, Tuple
+from wafl.answerer.entailer import Entailer
from wafl.exceptions import CloseConversation
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.interface.conversation import Conversation, Utterance
@@ -113,22 +114,32 @@ async def _run_code(to_execute: str, module, functions) -> str:
return result
-def get_text_from_facts_and_thresholds(
+def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str:
+ text_fact_list = [
+ "\n\n- " + "- " + fact.text + "
"
+ for fact in facts
+ if fact.source == Sources.FROM_TEXT
+ ][:max_num_facts]
+ rule_fact_list = [
+ "\n\n- " + "- " + fact.text + "
"
+ for fact in facts
+ if fact.source in [None, Sources.FROM_RULES]
+ ]
+ return "".join(text_fact_list + rule_fact_list)
+
+
+def get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds: List[Tuple[Fact, float]], memory: str
) -> List[str]:
- text_list = []
+ fact_list = []
for item in facts_and_thresholds:
if item[0].text not in memory:
- text = item[0].text
+ new_fact = item[0].copy()
if item[0].metadata:
- text = (
- f"Metadata for the following text: {str(item[0].metadata)}"
- + "\n"
- + text
- )
- text_list.append(text)
+ new_fact.text = new_fact.text
+ fact_list.append(new_fact)
- return text_list
+ return fact_list
def add_dummy_utterances_to_continue_generation(
@@ -150,3 +161,9 @@ def add_dummy_utterances_to_continue_generation(
def add_memories_to_facts(facts: str, memories: List[str]) -> str:
return facts + "\n" + "\n".join(memories)
+
+
+def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str:
+ query_text = conversation.get_last_speaker_utterance("user")
+ rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True)
+ return rules_as_strings[:num_rules]
diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py
index f12be579..5377233a 100644
--- a/wafl/answerer/dialogue_answerer.py
+++ b/wafl/answerer/dialogue_answerer.py
@@ -1,18 +1,21 @@
from importlib import import_module
from inspect import getmembers, isfunction
-from typing import List, Tuple
+from typing import List
+
+from wafl.answerer.entailer import Entailer
from wafl.answerer.answerer_implementation import (
substitute_memory_in_answer_and_get_memories_if_present,
create_one_liner,
- get_text_from_facts_and_thresholds,
+ get_facts_with_metadata_from_facts_and_thresholds,
add_dummy_utterances_to_continue_generation,
add_memories_to_facts,
execute_results_in_answer,
+ create_memory_from_fact_list, select_best_rules_using_entailer,
)
from wafl.answerer.base_answerer import BaseAnswerer
from wafl.answerer.rule_maker import RuleMaker
from wafl.connectors.clients.llm_chat_client import LLMChatClient
-from wafl.dataclasses.dataclasses import Query, Answer
+from wafl.data_objects.dataclasses import Query, Answer
from wafl.interface.conversation import Conversation
from wafl.simple_text_processing.questions import is_question
@@ -21,13 +24,14 @@ class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._threshold_for_facts = 0.85
self._client = LLMChatClient(config)
+ self._entailer = Entailer(config)
self._knowledge = knowledge
self._logger = logger
self._interface = interface
self._max_num_past_utterances = 5
- self._max_num_past_utterances_for_facts = 5
- self._max_num_past_utterances_for_rules = 2
- self._prior_facts_with_timestamp = []
+ self._max_num_facts = 5
+ self._max_num_rules = 2
+ self._prior_facts = []
self._init_python_module(code_path.replace(".py", ""))
self._prior_rules = []
self._max_predictions = 3
@@ -48,17 +52,15 @@ async def answer(self, query_text: str) -> Answer:
rules_text = await self._get_relevant_rules(conversation)
if not conversation:
conversation = create_one_liner(query_text)
- conversational_timestamp = len(conversation)
- facts = await self._get_relevant_facts(
+ memory = await self._get_relevant_facts(
query,
has_prior_rules=bool(rules_text),
- conversational_timestamp=conversational_timestamp,
)
final_answer_text = ""
for _ in range(self._max_predictions):
original_answer_text = await self._client.get_answer(
- text=facts,
+ text=memory,
rules_text=rules_text,
dialogue=conversation,
)
@@ -82,22 +84,19 @@ async def answer(self, query_text: str) -> Answer:
return Answer.create_from_text(final_answer_text)
- async def _get_relevant_facts(
- self, query: Query, has_prior_rules: bool, conversational_timestamp: int
- ) -> str:
- memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
- self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp(
- conversational_timestamp
- )
+ async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str:
+ memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts)
facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold(
query, is_from_user=True, threshold=self._threshold_for_facts
)
if facts_and_thresholds:
- facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory)
- self._prior_facts_with_timestamp.extend(
- (item, conversational_timestamp) for item in facts
+ facts = get_facts_with_metadata_from_facts_and_thresholds(
+ facts_and_thresholds, memory
+ )
+ self._prior_facts.extend(facts)
+ memory = create_memory_from_fact_list(
+ self._prior_facts, self._max_num_facts
)
- memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
await self._interface.add_fact(f"The bot remembers the facts:\n{memory}")
else:
@@ -110,11 +109,12 @@ async def _get_relevant_facts(
return memory
async def _get_relevant_rules(self, conversation: Conversation) -> List[str]:
- rules = await self._rule_creator.create_from_query(conversation)
- for rule in rules:
+ rules_as_strings = await self._rule_creator.create_from_query(conversation)
+ rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1)
+ for rule in rules_as_strings:
if rule not in self._prior_rules:
self._prior_rules.insert(0, rule)
- self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules]
+ self._prior_rules = self._prior_rules[: self._max_num_rules]
return self._prior_rules
def _init_python_module(self, module_name):
@@ -129,13 +129,3 @@ async def _apply_substitutions(self, original_answer_text):
self._functions,
)
)
-
- def _get_prior_facts_with_timestamp(
- self, conversational_timestamp: int
- ) -> List[Tuple[str, int]]:
- return [
- item
- for item in self._prior_facts_with_timestamp
- if item[1]
- > conversational_timestamp - self._max_num_past_utterances_for_facts
- ]
diff --git a/wafl/answerer/entailer.py b/wafl/answerer/entailer.py
index 54e4e3e2..3f3c2ab9 100644
--- a/wafl/answerer/entailer.py
+++ b/wafl/answerer/entailer.py
@@ -1,41 +1,14 @@
-import os
-import textwrap
-
-from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
-from wafl.connectors.prompt_template import PromptTemplate
-from wafl.interface.conversation import Utterance, Conversation
-
-_path = os.path.dirname(__file__)
+from wafl.connectors.clients.entailer_client import EntailerClient
class Entailer:
def __init__(self, config):
- self._connector = LLMConnectorFactory.get_connector(config)
+ self.entailer_client = EntailerClient(config)
self._config = config
- async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str:
- prompt = await self._get_answer_prompt(lhs, rhs, dialogue)
- result = await self._connector.generate(prompt)
- result = self._clean_result(result)
- return result == "yes"
-
- async def _get_answer_prompt(self, lhs, rhs, dialogue):
- return PromptTemplate(
- system_prompt="",
- conversation=self._get_dialogue_prompt(lhs, rhs, dialogue),
- )
-
- def _clean_result(self, result):
- result = result.replace("", "")
- result = result.split("\n")[0]
- result = result.strip()
- return result.lower()
+ async def left_entails_right(self, lhs: str, rhs: str) -> bool:
+ prediction = await self.entailer_client.get_entailment_score(lhs, rhs)
+ return prediction > 0.5
- def _get_dialogue_prompt(self, dialogue, lhs, rhs):
- text = f"""
-Your task is to determine whether two sentences are similar.
-1) {lhs.lower()}
-2) {rhs.lower()}
-Please answer "yes" if the two sentences are similar or "no" if not:
- """.strip()
- return Conversation([Utterance(speaker="user", text=text)])
+ async def get_score(self, lhs: str, rhs: str) -> float:
+ return await self.entailer_client.get_entailment_score(lhs, rhs)
diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py
index 115dfcfc..7454fe73 100644
--- a/wafl/answerer/rule_maker.py
+++ b/wafl/answerer/rule_maker.py
@@ -1,7 +1,7 @@
from typing import List
-from wafl.dataclasses.dataclasses import Query
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.dataclasses import Query
+from wafl.data_objects.rules import Rule
class RuleMaker:
diff --git a/wafl/changelog.txt b/wafl/changelog.txt
new file mode 100644
index 00000000..bfd13914
--- /dev/null
+++ b/wafl/changelog.txt
@@ -0,0 +1,7 @@
+- version 0.1.3
+* added multi-threaded support for multiple files indexing
+* TODO: ADD support for multiple knowledge bases.
+ It needs to index the rules and the files separately!
+* the interface should show where the facts come from in the web interface
+* add support for wafl studio where you can concatenate actions (and create corresponding yaml files)
+* use <> tags for contactenation
\ No newline at end of file
diff --git a/wafl/command_line.py b/wafl/command_line.py
index 4ebb6643..114f86be 100644
--- a/wafl/command_line.py
+++ b/wafl/command_line.py
@@ -9,6 +9,7 @@
run_testcases,
print_incipit,
download_models,
+ load_indices,
)
from wafl.runners.run_from_actions import run_action
@@ -52,26 +53,31 @@ def process_cli():
elif command == "run":
from wafl.runners.run_web_and_audio_interface import run_app
+ load_indices()
run_app()
remove_preprocessed("/")
elif command == "run-cli":
+ load_indices()
run_from_command_line()
remove_preprocessed("/")
elif command == "run-audio":
from wafl.runners.run_from_audio import run_from_audio
+ load_indices()
run_from_audio()
remove_preprocessed("/")
elif command == "run-server":
from wafl.runners.run_web_interface import run_server_only_app
+ load_indices()
run_server_only_app()
remove_preprocessed("/")
elif command == "run-tests":
+ load_indices()
run_testcases()
remove_preprocessed("/")
diff --git a/wafl/connectors/clients/clients_implementation.py b/wafl/connectors/clients/clients_implementation.py
deleted file mode 100644
index ec463c56..00000000
--- a/wafl/connectors/clients/clients_implementation.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import csv
-import os
-import joblib
-
-from wafl.knowledge.single_file_knowledge import SingleFileKnowledge
-
-_path = os.path.dirname(__file__)
-
-
-async def load_knowledge_from_file(filename, config):
- items_list = []
- with open(os.path.join(_path, "../../data/", filename + ".csv")) as file:
- csvreader = csv.reader(file)
- for row in csvreader:
- items_list.append(row[0].strip())
-
- knowledge = await SingleFileKnowledge.create_from_list(items_list, config)
- joblib.dump(knowledge, os.path.join(_path, f"../../data/{filename}.knowledge"))
- return knowledge
diff --git a/wafl/connectors/clients/entailer_client.py b/wafl/connectors/clients/entailer_client.py
new file mode 100644
index 00000000..a5c189be
--- /dev/null
+++ b/wafl/connectors/clients/entailer_client.py
@@ -0,0 +1,21 @@
+import os
+
+from wafl.connectors.factories.entailer_connector_factory import (
+ EntailerConnectorFactory,
+)
+
+_path = os.path.dirname(__file__)
+
+
+class EntailerClient:
+ def __init__(self, config):
+ self._connector = EntailerConnectorFactory.get_connector(
+ "entailer_model", config
+ )
+ self._config = config
+
+ async def get_entailment_score(self, lhs: str, rhs: str) -> float:
+ prediction = await self._connector.predict(lhs, rhs)
+ if "score" not in prediction:
+ raise ValueError("The Entailment prediction does not contain a score.")
+ return prediction["score"]
diff --git a/wafl/connectors/clients/information_client.py b/wafl/connectors/clients/information_client.py
index 772afb00..533fc902 100644
--- a/wafl/connectors/clients/information_client.py
+++ b/wafl/connectors/clients/information_client.py
@@ -1,9 +1,6 @@
import os
-import textwrap
-from typing import List
from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
-from wafl.connectors.prompt_template import PromptTemplate
_path = os.path.dirname(__file__)
diff --git a/wafl/connectors/factories/entailer_connector_factory.py b/wafl/connectors/factories/entailer_connector_factory.py
new file mode 100644
index 00000000..017fbbdf
--- /dev/null
+++ b/wafl/connectors/factories/entailer_connector_factory.py
@@ -0,0 +1,7 @@
+from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
+
+
+class EntailerConnectorFactory:
+ @staticmethod
+ def get_connector(model_name, config):
+ return RemoteEntailerConnector(config.get_value(model_name))
diff --git a/wafl/connectors/remote/remote_entailer_connector.py b/wafl/connectors/remote/remote_entailer_connector.py
new file mode 100644
index 00000000..f1230c86
--- /dev/null
+++ b/wafl/connectors/remote/remote_entailer_connector.py
@@ -0,0 +1,58 @@
+import aiohttp
+import asyncio
+import json
+import numpy as np
+
+from typing import Dict, List
+
+
+class RemoteEntailerConnector:
+ _max_tries = 3
+
+ def __init__(self, config):
+ host = config["model_host"]
+ port = config["model_port"]
+
+ self._server_url = f"https://{host}:" f"{port}/predictions/entailer"
+ try:
+ loop = asyncio.get_running_loop()
+
+ except RuntimeError:
+ loop = None
+
+ if (not loop or (loop and not loop.is_running())) and not asyncio.run(
+ self.check_connection()
+ ):
+ raise RuntimeError("Cannot connect a running Entailment Model.")
+
+ async def predict(self, lhs: str, rhs: str) -> Dict[str, float]:
+ payload = {"lhs": lhs, "rhs": rhs}
+ for _ in range(self._max_tries):
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ async with session.post(self._server_url, json=payload) as response:
+ data = await response.text()
+ prediction = json.loads(data)
+ score = prediction["score"]
+ return {"score": float(score)}
+
+ return {"score": -1.0}
+
+ async def check_connection(self):
+ payload = {"lhs": "test", "rhs": "test"}
+ try:
+ async with aiohttp.ClientSession(
+ conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ async with session.post(self._server_url, json=payload) as response:
+ await response.text()
+ return True
+
+ except aiohttp.client.InvalidURL:
+ print()
+ print("Is the entailer server running?")
+ print("Please run 'bash start-llm.sh' (see docs for explanation).")
+ print()
+
+ return False
diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py
index a2b3d7f7..ed0af1c8 100644
--- a/wafl/connectors/remote/remote_llm_connector.py
+++ b/wafl/connectors/remote/remote_llm_connector.py
@@ -10,7 +10,7 @@
class RemoteLLMConnector(BaseLLMConnector):
_max_tries = 3
- _max_reply_length = 1024
+ _max_reply_length = 2048
_num_prediction_tokens = 200
_cache = {}
diff --git a/wafl/dataclasses/__init__.py b/wafl/data_objects/__init__.py
similarity index 100%
rename from wafl/dataclasses/__init__.py
rename to wafl/data_objects/__init__.py
diff --git a/wafl/dataclasses/dataclasses.py b/wafl/data_objects/dataclasses.py
similarity index 100%
rename from wafl/dataclasses/dataclasses.py
rename to wafl/data_objects/dataclasses.py
diff --git a/wafl/data_objects/facts.py b/wafl/data_objects/facts.py
new file mode 100644
index 00000000..88926ab0
--- /dev/null
+++ b/wafl/data_objects/facts.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import Union
+
+
+class Sources(Enum):
+ FROM_TEXT = 1
+ FROM_RULES = 2
+
+
+@dataclass
+class Fact:
+ text: Union[str, dict]
+ is_question: bool = False
+ variable: str = None
+ is_interruption: bool = False
+ destination: str = None
+ metadata: Union[str, dict] = None
+ source: Sources = Sources.FROM_RULES
+
+ def toJSON(self):
+ return str(self)
+
+ def copy(self):
+ return Fact(
+ self.text,
+ self.is_question,
+ self.variable,
+ self.is_interruption,
+ self.destination,
+ self.metadata,
+ self.source,
+ )
diff --git a/wafl/dataclasses/rules.py b/wafl/data_objects/rules.py
similarity index 100%
rename from wafl/dataclasses/rules.py
rename to wafl/data_objects/rules.py
diff --git a/wafl/dataclasses/facts.py b/wafl/dataclasses/facts.py
deleted file mode 100644
index 0445adff..00000000
--- a/wafl/dataclasses/facts.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from dataclasses import dataclass
-from typing import Union
-
-
-@dataclass
-class Fact:
- text: Union[str, dict]
- is_question: bool = False
- variable: str = None
- is_interruption: bool = False
- source: str = None
- destination: str = None
- metadata: Union[str, dict] = None
-
- def toJSON(self):
- return str(self)
diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py
index d83d52ce..e9e2d20b 100644
--- a/wafl/events/conversation_events.py
+++ b/wafl/events/conversation_events.py
@@ -19,6 +19,7 @@ def __init__(
config: "Configuration",
interface: "BaseInterface",
logger=None,
+ knowledge=None,
):
self._config = config
try:
@@ -29,6 +30,8 @@ def __init__(
if not loop or not loop.is_running():
self._knowledge = asyncio.run(load_knowledge(config, logger))
+ else:
+ self._knowledge = knowledge
self._answerer = create_answerer(config, self._knowledge, interface, logger)
self._answerer._client._connector._cache = {}
diff --git a/wafl/inference/utils.py b/wafl/inference/utils.py
index 2f25bda1..d270814b 100644
--- a/wafl/inference/utils.py
+++ b/wafl/inference/utils.py
@@ -2,7 +2,7 @@
from typing import List, Dict, Tuple, Any
from fuzzywuzzy import process
-from wafl.dataclasses.dataclasses import Answer
+from wafl.data_objects.dataclasses import Answer
from wafl.simple_text_processing.normalize import normalized
from wafl.simple_text_processing.questions import is_question
diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py
index 68687eb1..c65f1f5b 100644
--- a/wafl/interface/conversation.py
+++ b/wafl/interface/conversation.py
@@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]:
if utterance.speaker == speaker
][-n:]
+ def get_last_speaker_utterance(self, speaker: str) -> str:
+ if not self.utterances:
+ return ""
+
+ for utterance in reversed(self.utterances):
+ if utterance.speaker == speaker:
+ return utterance.text
+ return ""
+
def get_first_timestamp(self) -> float:
return self.utterances[0].timestamp if self.utterances else None
diff --git a/wafl/knowledge/indexing_implementation.py b/wafl/knowledge/indexing_implementation.py
index fc4bdf3b..d4c3a8f8 100644
--- a/wafl/knowledge/indexing_implementation.py
+++ b/wafl/knowledge/indexing_implementation.py
@@ -1,24 +1,50 @@
+import asyncio
import os
-
import joblib
import yaml
+import threading
+from tqdm import tqdm
from wafl.config import Configuration
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge
from wafl.readers.reader_factory import ReaderFactory
+async def add_file_to_knowledge(knowledge, filename):
+ reader = ReaderFactory.get_reader(filename)
+ for chunk in reader.get_chunks(filename):
+ await knowledge.add_fact(chunk)
+
+
async def _add_indices_to_knowledge(knowledge, text):
indices = yaml.safe_load(text)
if "paths" not in indices or not indices["paths"]:
return knowledge
for path in indices["paths"]:
- for root, _, files in os.walk(path):
- for file in files:
- reader = ReaderFactory.get_reader(file)
- for chunk in reader.get_chunks(os.path.join(root, file)):
- await knowledge.add_fact(chunk)
+ print(f"Indexing path: {path}")
+ file_count = sum(len(files) for _, _, files in os.walk(path))
+ with tqdm(total=file_count) as pbar:
+ for root, _, files in os.walk(path):
+ threads = []
+ for file in files:
+ threads.append(
+ threading.Thread(
+ target=asyncio.run,
+ args=(
+ add_file_to_knowledge(
+ knowledge, os.path.join(root, file)
+ ),
+ ),
+ )
+ )
+ num_threads = min(10, len(threads))
+ for i in range(0, len(threads), num_threads):
+ for thread in threads[i : i + num_threads]:
+ thread.start()
+ for thread in threads[i : i + num_threads]:
+ thread.join()
+ pbar.update(num_threads)
return knowledge
@@ -27,10 +53,12 @@ async def load_knowledge(config, logger=None):
if ".yaml" in config.get_value("rules") and not any(
item in config.get_value("rules") for item in [" ", "\n"]
):
+ rules_filename = config.get_value("rules")
with open(config.get_value("rules")) as file:
rules_txt = file.read()
else:
+ rules_filename = None
rules_txt = config.get_value("rules")
index_filename = config.get_value("index")
@@ -41,10 +69,12 @@ async def load_knowledge(config, logger=None):
cache_filename = config.get_value("cache_filename")
if os.path.exists(cache_filename):
- knowledge = joblib.load(cache_filename)
- if knowledge.hash == hash(rules_txt) and os.path.getmtime(
- cache_filename
- ) > os.path.getmtime(index_filename):
+ if (
+ rules_filename
+ and os.path.getmtime(cache_filename) > os.path.getmtime(rules_filename)
+ and os.path.getmtime(cache_filename) > os.path.getmtime(index_filename)
+ ):
+ knowledge = joblib.load(cache_filename)
return knowledge
knowledge = SingleFileKnowledge(config, rules_txt, logger=logger)
diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py
index 8c9c5a25..2fc15c39 100644
--- a/wafl/knowledge/single_file_knowledge.py
+++ b/wafl/knowledge/single_file_knowledge.py
@@ -4,9 +4,10 @@
from typing import List
import nltk
+from tqdm import tqdm
from wafl.config import Configuration
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact
from wafl.knowledge.base_knowledge import BaseKnowledge
from wafl.knowledge.utils import (
text_is_exact_string,
@@ -169,7 +170,8 @@ def get_facts_and_rule_as_text(self):
return text
async def initialize_retrievers(self):
- for index, fact in self._facts_dict.items():
+ print("Initializing fact retrievers")
+ for index, fact in tqdm(self._facts_dict.items()):
if text_is_exact_string(fact.text):
continue
@@ -181,7 +183,8 @@ async def initialize_retrievers(self):
clean_text_for_retrieval(fact.text), index
)
- for index, rule in self._rules_dict.items():
+ print("Initializing rule retrievers")
+ for index, rule in tqdm(self._rules_dict.items()):
if text_is_exact_string(rule.effect.text):
continue
@@ -189,10 +192,6 @@ async def initialize_retrievers(self):
clean_text_for_retrieval(rule.effect.text), index
)
- for index, rule in self._rules_dict.items():
- if not text_is_exact_string(rule.effect.text):
- continue
-
await self._rules_string_retriever.add_text_and_index(
rule.effect.text, index
)
diff --git a/wafl/parsing/line_rules_parser.py b/wafl/parsing/line_rules_parser.py
index 73371f3e..b6bfa3ee 100644
--- a/wafl/parsing/line_rules_parser.py
+++ b/wafl/parsing/line_rules_parser.py
@@ -1,6 +1,6 @@
from wafl.simple_text_processing.questions import is_question
-from wafl.dataclasses.facts import Fact
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.facts import Fact
+from wafl.data_objects.rules import Rule
def parse_rule_from_single_line(text):
diff --git a/wafl/parsing/rules_parser.py b/wafl/parsing/rules_parser.py
index 70d3b5f1..bb813e93 100644
--- a/wafl/parsing/rules_parser.py
+++ b/wafl/parsing/rules_parser.py
@@ -1,7 +1,7 @@
import yaml
-from wafl.dataclasses.facts import Fact
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.facts import Fact
+from wafl.data_objects.rules import Rule
from wafl.simple_text_processing.deixis import from_user_to_bot
diff --git a/wafl/readers/base_reader.py b/wafl/readers/base_reader.py
index ea995601..5f0aaef2 100644
--- a/wafl/readers/base_reader.py
+++ b/wafl/readers/base_reader.py
@@ -1,6 +1,6 @@
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact
class BaseReader:
diff --git a/wafl/readers/pdf_reader.py b/wafl/readers/pdf_reader.py
index 4f610616..dc94f664 100644
--- a/wafl/readers/pdf_reader.py
+++ b/wafl/readers/pdf_reader.py
@@ -2,7 +2,7 @@
from logging import getLogger
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.readers.base_reader import BaseReader
_logger = getLogger(__name__)
@@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]:
Fact(
text=page.get_text(),
metadata={"filename": filename, "page_number": i},
+ source=Sources.FROM_TEXT,
)
for i, page in enumerate(doc)
]
diff --git a/wafl/readers/reader_factory.py b/wafl/readers/reader_factory.py
index 14ccb70c..6fc33bfc 100644
--- a/wafl/readers/reader_factory.py
+++ b/wafl/readers/reader_factory.py
@@ -4,7 +4,7 @@
class ReaderFactory:
_chunk_size = 10000
- _overlap = 100
+ _overlap = 500
_extension_to_reader_dict = {".pdf": PdfReader, ".txt": TextReader}
@staticmethod
@@ -13,7 +13,4 @@ def get_reader(filename):
if extension in filename.lower():
return reader(ReaderFactory._chunk_size, ReaderFactory._overlap)
- ### add pdf reader
- ### add metadata and show in the UI
-
return TextReader(ReaderFactory._chunk_size, ReaderFactory._overlap)
diff --git a/wafl/readers/text_reader.py b/wafl/readers/text_reader.py
index b22c4ffe..8457ee04 100644
--- a/wafl/readers/text_reader.py
+++ b/wafl/readers/text_reader.py
@@ -1,7 +1,7 @@
from logging import getLogger
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.readers.base_reader import BaseReader
_logger = getLogger(__name__)
@@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]:
Fact(
text=chunk,
metadata={"filename": filename, "chunk_number": i},
+ source=Sources.FROM_TEXT,
)
for i, chunk in enumerate(chunks)
]
diff --git a/wafl/run.py b/wafl/run.py
index b0397e84..4138ac48 100644
--- a/wafl/run.py
+++ b/wafl/run.py
@@ -4,6 +4,7 @@
from wafl.exceptions import CloseConversation
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.command_line_interface import CommandLineInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.testcases import ConversationTestCases
from wafl.variables import get_variables
@@ -17,6 +18,12 @@ def print_incipit():
print()
+def load_indices():
+ print("Loading knowledge indices...")
+ config = Configuration.load_local_config()
+ asyncio.run(load_knowledge(config, _logger))
+
+
def run_from_command_line():
interface = CommandLineInterface()
config = Configuration.load_local_config()
diff --git a/wafl/runners/run_from_audio.py b/wafl/runners/run_from_audio.py
index 7b523687..7a1d8f35 100644
--- a/wafl/runners/run_from_audio.py
+++ b/wafl/runners/run_from_audio.py
@@ -1,6 +1,9 @@
+import asyncio
+
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.voice_interface import VoiceInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.handlers.conversation_handler import ConversationHandler
from wafl.scheduler.scheduler import Scheduler
@@ -10,6 +13,7 @@
def run_from_audio():
config = Configuration.load_local_config()
+ asyncio.run(load_knowledge(config, _logger))
interface = VoiceInterface(config)
conversation_events = ConversationEvents(
config=config,
diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py
index 9d5f833c..177f4501 100644
--- a/wafl/runners/run_web_and_audio_interface.py
+++ b/wafl/runners/run_web_and_audio_interface.py
@@ -1,3 +1,4 @@
+import asyncio
import random
import sys
import threading
@@ -6,6 +7,7 @@
from wafl.interface.list_interface import ListInterface
from wafl.interface.voice_interface import VoiceInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.scheduler.scheduler import Scheduler
from wafl.handlers.conversation_handler import ConversationHandler
from wafl.logger.local_file_logger import LocalFileLogger
diff --git a/wafl/runners/run_web_interface.py b/wafl/runners/run_web_interface.py
index b835473c..4eae9ce0 100644
--- a/wafl/runners/run_web_interface.py
+++ b/wafl/runners/run_web_interface.py
@@ -5,6 +5,7 @@
from flask import render_template, redirect
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.scheduler.scheduler import Scheduler
from wafl.handlers.web_handler import WebHandler
from wafl.handlers.conversation_handler import ConversationHandler
diff --git a/wafl/testcases.py b/wafl/testcases.py
index bcc49f4d..cb97af4f 100644
--- a/wafl/testcases.py
+++ b/wafl/testcases.py
@@ -1,4 +1,6 @@
from wafl.answerer.entailer import Entailer
+from wafl.knowledge.indexing_implementation import load_knowledge
+
from wafl.simple_text_processing.deixis import from_user_to_bot, from_bot_to_user
from wafl.exceptions import CloseConversation
from wafl.events.conversation_events import ConversationEvents
@@ -25,8 +27,10 @@ async def test_single_case(self, name):
test_lines = self._testcase_data[name]["lines"]
is_negated = self._testcase_data[name]["negated"]
interface = DummyInterface(user_lines)
- conversation_events = ConversationEvents(self._config, interface=interface)
- await conversation_events._knowledge.initialize_retrievers()
+ knowledge = await load_knowledge(self._config)
+ conversation_events = ConversationEvents(
+ self._config, interface=interface, knowledge=knowledge
+ )
print(self.BLUE_COLOR_START + f"\nRunning test '{name}'." + self.COLOR_END)
continue_conversations = True
@@ -77,9 +81,7 @@ async def _lhs_is_similar_to(self, lhs, rhs, prior_dialogue):
if lhs_name != rhs_name:
return False
- return await self._entailer.left_entails_right(
- lhs, rhs, "\n".join(prior_dialogue)
- )
+ return await self._entailer.left_entails_right(lhs, rhs)
def _apply_deixis(self, line):
name = line.split(":")[0].strip()