Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid interfaces #91

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.dummy_interface import DummyInterface
from wafl.parsing.rules_parser import get_facts_and_rules_from_text

wafl_example = """
rules:
Expand All @@ -15,14 +16,12 @@
- reply casually to the conversation"

- the user wants to buy coffee:
- the bot says the coffee prices
- the bot says the coffee prices:
- decaf is 1.50
- regular is 1.00
- espresso is 2.00
- ask for which price range
- tell them the right coffee

- the bot says the coffee prices:
- decaf is 1.50
- regular is 1.00
- espresso is 2.00
"""


Expand Down Expand Up @@ -75,3 +74,23 @@ def test__rules_can_nest(self):
self.assertIn("decaf", interface.get_facts_and_timestamp()[0][1])
self.assertIn("regular", interface.get_facts_and_timestamp()[0][1])
self.assertIn("espresso", interface.get_facts_and_timestamp()[0][1])

def test__nested_rules_are_printed_correctly(self):
rule_text = """
rules:
- the user wants to know the time:
- output "The time is <execute>get_time()</execute>":
- if the time is before 12:00 say "Good morning"
- if the time is after 12:00 say "Good afternoon"
""".strip()

facts_and_rules = get_facts_and_rules_from_text(rule_text)
rule = facts_and_rules["rules"][0]
expected = """
the user wants to know the time
-output "The time is <execute>get_time()</execute>"
-if the time is before 12:00 say "Good morning"
-if the time is after 12:00 say "Good afternoon"
""".strip()
self.assertEqual(expected, str(rule).strip())

32 changes: 4 additions & 28 deletions wafl/answerer/rule_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(
interface,
max_num_rules,
delete_current_rule,
max_recursion=1,
max_recursion=3,
):
self._knowledge = knowledge
self._config = config
Expand All @@ -18,39 +18,15 @@ def __init__(
else:
self._max_indentation = config.get_value("max_recursion")

self._indent_str = " "

async def create_from_query(self, query):
rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.92)
rules = rules[: self._max_num_rules]
rules_texts = []
for rule in rules:
rules_text = f"- If {rule.effect.text} go through the following points:\n"
for cause_index, cause in enumerate(rule.causes):
rules_text += f"{self._indent_str}{cause_index + 1}) {cause.text}\n"
rules_text += await self.recursively_add_rules(cause)

rules_text += f'{self._indent_str}{len(rule.causes) + 1}) After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n'

rules_text = rule.get_string_using_template(
"- If {effect} go through the following points:"
)
rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")

return "\n".join(rules_texts)

async def recursively_add_rules(self, query, depth=2):
if depth > self._max_indentation:
return ""

rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.95)
rules = rules[:1]
rules_texts = []
for rule in rules:
rules_text = ""
for cause_index, causes in enumerate(rule.causes):
indentation = self._indent_str * depth
rules_text += f"{indentation}{cause_index + 1}) {causes.text}\n"
rules_text += await self.recursively_add_rules(causes, depth + 1)

rules_texts.append(rules_text)

return "\n".join(rules_texts)
3 changes: 2 additions & 1 deletion wafl/facts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Union


@dataclass
class Fact:
text: str
text: Union[str, dict]
is_question: bool = False
variable: str = None
is_interruption: bool = False
Expand Down
1 change: 0 additions & 1 deletion wafl/frontend/selector.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
<html lang="">
<head>
<title> WAFL frontend </title>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="https://unpkg.com/[email protected]"></script>
<script src="https://cdn.tailwindcss.com"></script>
<script type="text/javascript" src="/wafl.js"></script>
Expand Down
33 changes: 31 additions & 2 deletions wafl/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@
class Rule:
effect: "Fact"
causes: List["Fact"]
_max_indentation = 3
_indent_str = " "

def toJSON(self):
return str(self)

def get_string_using_template(self, effect_template: str) -> str:
rule_str = effect_template.replace("{effect}", self.effect.text) + "\n"
return self._add_clauses(rule_str)

def __str__(self):
rule_str = self.effect.text
rule_str = self.effect.text + "\n"
return self._add_clauses(rule_str)

def _add_clauses(self, rule_str: str) -> str:
for cause in self.causes:
try:
rule_str += "\n " + cause.text
rule_str += self._recursively_add_clauses(cause)

except TypeError as e:
print(f"Error in rule:'''\n{rule_str}'''")
Expand All @@ -23,3 +32,23 @@ def __str__(self):
raise e

return rule_str


def _recursively_add_clauses(self, query: str, depth: int=1) -> str:
indentation = self._indent_str * depth
if type(query) == str:
return f"{indentation}-{query}\n"

if type(query.text) == str:
return f"{indentation}-{query.text}\n"

if depth > self._max_indentation:
return ""

clause = list(query.text.keys())[0]
rules_text = f"{indentation}-{clause}\n"
for clauses in query.text.values():
for cause_index, clause in enumerate(clauses):
rules_text += self._recursively_add_clauses(clause, depth + 1)

return rules_text
Loading