Skip to content

Commit

Permalink
using entailer to pre-filter the retrieved rules
Browse files Browse the repository at this point in the history
  • Loading branch information
fractalego committed Aug 4, 2024
1 parent 902ad0f commit 94c160e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 6 deletions.
12 changes: 10 additions & 2 deletions todo.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
* Add tqdm to indexing.
* Make it index when wafl start first, not at the first use/login
* apply entailer to rule retrieval:
if more than one rule is retrieved, then the one
that is entailed by the query should be chosen

* the answer from the indexed files should be directed from a rule.
- facts and rules should live at the highest level of the retrieval


/* Add tqdm to indexing.
/* Make it index when wafl start first, not at the first use/login

/* The prior items with timestamps might not be necessary.
/ - Just implement a queue with a fixed size
Expand Down
7 changes: 7 additions & 0 deletions wafl/answerer/answerer_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import List, Tuple

from wafl.answerer.entailer import Entailer
from wafl.exceptions import CloseConversation
from wafl.data_objects.facts import Fact, Sources
from wafl.interface.conversation import Conversation, Utterance
Expand Down Expand Up @@ -160,3 +161,9 @@ def add_dummy_utterances_to_continue_generation(

def add_memories_to_facts(facts: str, memories: List[str]) -> str:
return facts + "\n" + "\n".join(memories)


def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str:
query_text = conversation.get_last_speaker_utterance("user")
rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True)
return rules_as_strings[:num_rules]
10 changes: 6 additions & 4 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
from inspect import getmembers, isfunction
from typing import List

from wafl.answerer.entailer import Entailer
from wafl.answerer.answerer_implementation import (
substitute_memory_in_answer_and_get_memories_if_present,
create_one_liner,
get_facts_with_metadata_from_facts_and_thresholds,
add_dummy_utterances_to_continue_generation,
add_memories_to_facts,
execute_results_in_answer,
create_memory_from_fact_list,
create_memory_from_fact_list, select_best_rules_using_entailer,
)
from wafl.answerer.base_answerer import BaseAnswerer
from wafl.answerer.rule_maker import RuleMaker
from wafl.connectors.clients.llm_chat_client import LLMChatClient
from wafl.data_objects.dataclasses import Query, Answer
from wafl.data_objects.facts import Sources
from wafl.interface.conversation import Conversation
from wafl.simple_text_processing.questions import is_question

Expand All @@ -24,6 +24,7 @@ class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._threshold_for_facts = 0.85
self._client = LLMChatClient(config)
self._entailer = Entailer(config)
self._knowledge = knowledge
self._logger = logger
self._interface = interface
Expand Down Expand Up @@ -108,8 +109,9 @@ async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str:
return memory

async def _get_relevant_rules(self, conversation: Conversation) -> List[str]:
rules = await self._rule_creator.create_from_query(conversation)
for rule in rules:
rules_as_strings = await self._rule_creator.create_from_query(conversation)
rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1)
for rule in rules_as_strings:
if rule not in self._prior_rules:
self._prior_rules.insert(0, rule)
self._prior_rules = self._prior_rules[: self._max_num_rules]
Expand Down
3 changes: 3 additions & 0 deletions wafl/answerer/entailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ def __init__(self, config):
async def left_entails_right(self, lhs: str, rhs: str) -> bool:
prediction = await self.entailer_client.get_entailment_score(lhs, rhs)
return prediction > 0.5

async def get_score(self, lhs: str, rhs: str) -> float:
return await self.entailer_client.get_entailment_score(lhs, rhs)
9 changes: 9 additions & 0 deletions wafl/interface/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]:
if utterance.speaker == speaker
][-n:]

def get_last_speaker_utterance(self, speaker: str) -> str:
if not self.utterances:
return ""

for utterance in reversed(self.utterances):
if utterance.speaker == speaker:
return utterance.text
return ""

def get_first_timestamp(self) -> float:
return self.utterances[0].timestamp if self.utterances else None

Expand Down

0 comments on commit 94c160e

Please sign in to comment.