Skip to content

Commit

Permalink
Merge pull request #128 from fractalego/indexing-with-large-corpus
Browse files Browse the repository at this point in the history
Indexing with large corpus
  • Loading branch information
fractalego committed Aug 4, 2024
2 parents 09df1e3 + 94c160e commit 64f365d
Show file tree
Hide file tree
Showing 38 changed files with 344 additions and 158 deletions.
4 changes: 4 additions & 0 deletions tests/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,9 @@
"text_embedding_model": {
"model_host": "localhost",
"model_port": 8080
},
"entailer_model": {
"model_host": "localhost",
"model_port": 8080
}
}
34 changes: 34 additions & 0 deletions tests/test_entailer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import os

from unittest import TestCase
from wafl.config import Configuration
from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
from wafl.connectors.clients.entailer_client import EntailerClient

_path = os.path.dirname(__file__)


class TestConnection(TestCase):
def test__entailer_connector(self):
config = Configuration.load_local_config()
connector = RemoteEntailerConnector(config.get_value("entailer_model"))
prediction = asyncio.run(
connector.predict(
"The first contact is a romance novel set in the middle ages.",
"The first contact is a science fiction novel about the first contact between humans and aliens.",
)
)
assert prediction["score"] < 0.5

def test__entailment_client(self):

config = Configuration.load_local_config()
client = EntailerClient(config)
prediction = asyncio.run(
client.get_entailment_score(
"The first contact is a romance novel set in the middle ages.",
"The first contact is a science fiction novel about the first contact between humans and aliens.",
)
)
assert prediction < 0.5
2 changes: 1 addition & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest import TestCase

from wafl.config import Configuration
from wafl.dataclasses.dataclasses import Query
from wafl.data_objects.dataclasses import Query
from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge

_path = os.path.dirname(__file__)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
rules:
- the user's name is Jane:
- write "I hear you"
- reply with "I hear you" and nothing else
""".strip()

_path = os.path.dirname(__file__)


class TestVoice(TestCase):
def test__activation(self):
interface = DummyInterface(to_utter=["computer", "my name is Jane"])
interface = DummyInterface(to_utter=["computer my name is Jane"])
config = Configuration.load_local_config()
config.set_value("rules", _wafl_example)
conversation_events = ConversationEvents(config=config, interface=interface)
Expand Down
20 changes: 19 additions & 1 deletion todo.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
* why do I need to re-initialise the retrievers after unpickling the knowledge?
* apply entailer to rule retrieval:
if more than one rule is retrieved, then the one
that is entailed by the query should be chosen

* the answer from the indexed files should be directed from a rule.
- facts and rules should live at the highest level of the retrieval


/* Add tqdm to indexing.
/* Make it index when wafl start first, not at the first use/login

/* The prior items with timestamps might not be necessary.
/ - Just implement a queue with a fixed size

* add entailer to wafl_llm


/* why do I need to re-initialise the retrievers after unpickling the knowledge?
- maybe you should save the retrievers in the knowledge object separately?
- It was gensim that was not serializable. Took it out

/* knowledge cache does not cache the rules or facts

Expand Down
39 changes: 28 additions & 11 deletions wafl/answerer/answerer_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from typing import List, Tuple

from wafl.answerer.entailer import Entailer
from wafl.exceptions import CloseConversation
from wafl.dataclasses.facts import Fact
from wafl.data_objects.facts import Fact, Sources
from wafl.interface.conversation import Conversation, Utterance


Expand Down Expand Up @@ -113,22 +114,32 @@ async def _run_code(to_execute: str, module, functions) -> str:
return result


def get_text_from_facts_and_thresholds(
def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str:
text_fact_list = [
"\n\n- " + "<item> " + fact.text + " </item>"
for fact in facts
if fact.source == Sources.FROM_TEXT
][:max_num_facts]
rule_fact_list = [
"\n\n- " + "<item> " + fact.text + " </item>"
for fact in facts
if fact.source in [None, Sources.FROM_RULES]
]
return "".join(text_fact_list + rule_fact_list)


def get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds: List[Tuple[Fact, float]], memory: str
) -> List[str]:
text_list = []
fact_list = []
for item in facts_and_thresholds:
if item[0].text not in memory:
text = item[0].text
new_fact = item[0].copy()
if item[0].metadata:
text = (
f"Metadata for the following text: {str(item[0].metadata)}"
+ "\n"
+ text
)
text_list.append(text)
new_fact.text = new_fact.text
fact_list.append(new_fact)

return text_list
return fact_list


def add_dummy_utterances_to_continue_generation(
Expand All @@ -150,3 +161,9 @@ def add_dummy_utterances_to_continue_generation(

def add_memories_to_facts(facts: str, memories: List[str]) -> str:
return facts + "\n" + "\n".join(memories)


def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str:
query_text = conversation.get_last_speaker_utterance("user")
rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True)
return rules_as_strings[:num_rules]
58 changes: 24 additions & 34 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from importlib import import_module
from inspect import getmembers, isfunction
from typing import List, Tuple
from typing import List

from wafl.answerer.entailer import Entailer
from wafl.answerer.answerer_implementation import (
substitute_memory_in_answer_and_get_memories_if_present,
create_one_liner,
get_text_from_facts_and_thresholds,
get_facts_with_metadata_from_facts_and_thresholds,
add_dummy_utterances_to_continue_generation,
add_memories_to_facts,
execute_results_in_answer,
create_memory_from_fact_list, select_best_rules_using_entailer,
)
from wafl.answerer.base_answerer import BaseAnswerer
from wafl.answerer.rule_maker import RuleMaker
from wafl.connectors.clients.llm_chat_client import LLMChatClient
from wafl.dataclasses.dataclasses import Query, Answer
from wafl.data_objects.dataclasses import Query, Answer
from wafl.interface.conversation import Conversation
from wafl.simple_text_processing.questions import is_question

Expand All @@ -21,13 +24,14 @@ class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._threshold_for_facts = 0.85
self._client = LLMChatClient(config)
self._entailer = Entailer(config)
self._knowledge = knowledge
self._logger = logger
self._interface = interface
self._max_num_past_utterances = 5
self._max_num_past_utterances_for_facts = 5
self._max_num_past_utterances_for_rules = 2
self._prior_facts_with_timestamp = []
self._max_num_facts = 5
self._max_num_rules = 2
self._prior_facts = []
self._init_python_module(code_path.replace(".py", ""))
self._prior_rules = []
self._max_predictions = 3
Expand All @@ -48,17 +52,15 @@ async def answer(self, query_text: str) -> Answer:
rules_text = await self._get_relevant_rules(conversation)
if not conversation:
conversation = create_one_liner(query_text)
conversational_timestamp = len(conversation)
facts = await self._get_relevant_facts(
memory = await self._get_relevant_facts(
query,
has_prior_rules=bool(rules_text),
conversational_timestamp=conversational_timestamp,
)

final_answer_text = ""
for _ in range(self._max_predictions):
original_answer_text = await self._client.get_answer(
text=facts,
text=memory,
rules_text=rules_text,
dialogue=conversation,
)
Expand All @@ -82,22 +84,19 @@ async def answer(self, query_text: str) -> Answer:

return Answer.create_from_text(final_answer_text)

async def _get_relevant_facts(
self, query: Query, has_prior_rules: bool, conversational_timestamp: int
) -> str:
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp(
conversational_timestamp
)
async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str:
memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts)
facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold(
query, is_from_user=True, threshold=self._threshold_for_facts
)
if facts_and_thresholds:
facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory)
self._prior_facts_with_timestamp.extend(
(item, conversational_timestamp) for item in facts
facts = get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds, memory
)
self._prior_facts.extend(facts)
memory = create_memory_from_fact_list(
self._prior_facts, self._max_num_facts
)
memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
await self._interface.add_fact(f"The bot remembers the facts:\n{memory}")

else:
Expand All @@ -110,11 +109,12 @@ async def _get_relevant_facts(
return memory

async def _get_relevant_rules(self, conversation: Conversation) -> List[str]:
rules = await self._rule_creator.create_from_query(conversation)
for rule in rules:
rules_as_strings = await self._rule_creator.create_from_query(conversation)
rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1)
for rule in rules_as_strings:
if rule not in self._prior_rules:
self._prior_rules.insert(0, rule)
self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules]
self._prior_rules = self._prior_rules[: self._max_num_rules]
return self._prior_rules

def _init_python_module(self, module_name):
Expand All @@ -129,13 +129,3 @@ async def _apply_substitutions(self, original_answer_text):
self._functions,
)
)

def _get_prior_facts_with_timestamp(
self, conversational_timestamp: int
) -> List[Tuple[str, int]]:
return [
item
for item in self._prior_facts_with_timestamp
if item[1]
> conversational_timestamp - self._max_num_past_utterances_for_facts
]
41 changes: 7 additions & 34 deletions wafl/answerer/entailer.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,14 @@
import os
import textwrap

from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
from wafl.connectors.prompt_template import PromptTemplate
from wafl.interface.conversation import Utterance, Conversation

_path = os.path.dirname(__file__)
from wafl.connectors.clients.entailer_client import EntailerClient


class Entailer:
def __init__(self, config):
self._connector = LLMConnectorFactory.get_connector(config)
self.entailer_client = EntailerClient(config)
self._config = config

async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str:
prompt = await self._get_answer_prompt(lhs, rhs, dialogue)
result = await self._connector.generate(prompt)
result = self._clean_result(result)
return result == "yes"

async def _get_answer_prompt(self, lhs, rhs, dialogue):
return PromptTemplate(
system_prompt="",
conversation=self._get_dialogue_prompt(lhs, rhs, dialogue),
)

def _clean_result(self, result):
result = result.replace("</task>", "")
result = result.split("\n")[0]
result = result.strip()
return result.lower()
async def left_entails_right(self, lhs: str, rhs: str) -> bool:
prediction = await self.entailer_client.get_entailment_score(lhs, rhs)
return prediction > 0.5

def _get_dialogue_prompt(self, dialogue, lhs, rhs):
text = f"""
Your task is to determine whether two sentences are similar.
1) {lhs.lower()}
2) {rhs.lower()}
Please answer "yes" if the two sentences are similar or "no" if not:
""".strip()
return Conversation([Utterance(speaker="user", text=text)])
async def get_score(self, lhs: str, rhs: str) -> float:
return await self.entailer_client.get_entailment_score(lhs, rhs)
4 changes: 2 additions & 2 deletions wafl/answerer/rule_maker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from wafl.dataclasses.dataclasses import Query
from wafl.dataclasses.rules import Rule
from wafl.data_objects.dataclasses import Query
from wafl.data_objects.rules import Rule


class RuleMaker:
Expand Down
7 changes: 7 additions & 0 deletions wafl/changelog.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- version 0.1.3
* added multi-threaded support for multiple files indexing
* TODO: ADD support for multiple knowledge bases.
It needs to index the rules and the files separately!
* the interface should show where the facts come from in the web interface
* add support for wafl studio where you can concatenate actions (and create corresponding yaml files)
* use <> tags for contactenation
6 changes: 6 additions & 0 deletions wafl/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
run_testcases,
print_incipit,
download_models,
load_indices,
)
from wafl.runners.run_from_actions import run_action

Expand Down Expand Up @@ -52,26 +53,31 @@ def process_cli():
elif command == "run":
from wafl.runners.run_web_and_audio_interface import run_app

load_indices()
run_app()
remove_preprocessed("/")

elif command == "run-cli":
load_indices()
run_from_command_line()
remove_preprocessed("/")

elif command == "run-audio":
from wafl.runners.run_from_audio import run_from_audio

load_indices()
run_from_audio()
remove_preprocessed("/")

elif command == "run-server":
from wafl.runners.run_web_interface import run_server_only_app

load_indices()
run_server_only_app()
remove_preprocessed("/")

elif command == "run-tests":
load_indices()
run_testcases()
remove_preprocessed("/")

Expand Down
Loading

0 comments on commit 64f365d

Please sign in to comment.