diff --git a/tests/test_rules.py b/tests/test_rules.py
index 990f5492..496262b3 100644
--- a/tests/test_rules.py
+++ b/tests/test_rules.py
@@ -5,6 +5,7 @@
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.dummy_interface import DummyInterface
+from wafl.parsing.rules_parser import get_facts_and_rules_from_text
wafl_example = """
rules:
@@ -15,14 +16,12 @@
- reply casually to the conversation"
- the user wants to buy coffee:
- - the bot says the coffee prices
+ - the bot says the coffee prices:
+ - decaf is 1.50
+ - regular is 1.00
+ - espresso is 2.00
- ask for which price range
- tell them the right coffee
-
- - the bot says the coffee prices:
- - decaf is 1.50
- - regular is 1.00
- - espresso is 2.00
"""
@@ -75,3 +74,23 @@ def test__rules_can_nest(self):
self.assertIn("decaf", interface.get_facts_and_timestamp()[0][1])
self.assertIn("regular", interface.get_facts_and_timestamp()[0][1])
self.assertIn("espresso", interface.get_facts_and_timestamp()[0][1])
+
+ def test__nested_rules_are_printed_correctly(self):
+ rule_text = """
+rules:
+ - the user wants to know the time:
+ - output "The time is get_time()":
+ - if the time is before 12:00 say "Good morning"
+ - if the time is after 12:00 say "Good afternoon"
+ """.strip()
+
+ facts_and_rules = get_facts_and_rules_from_text(rule_text)
+ rule = facts_and_rules["rules"][0]
+ expected = """
+the user wants to know the time
+ -output "The time is get_time()"
+ -if the time is before 12:00 say "Good morning"
+ -if the time is after 12:00 say "Good afternoon"
+ """.strip()
+ self.assertEqual(expected, str(rule).strip())
+
diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py
index 7819fab5..fb8dd604 100644
--- a/wafl/answerer/rule_maker.py
+++ b/wafl/answerer/rule_maker.py
@@ -6,7 +6,7 @@ def __init__(
interface,
max_num_rules,
delete_current_rule,
- max_recursion=1,
+ max_recursion=3,
):
self._knowledge = knowledge
self._config = config
@@ -18,39 +18,15 @@ def __init__(
else:
self._max_indentation = config.get_value("max_recursion")
- self._indent_str = " "
-
async def create_from_query(self, query):
rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.92)
rules = rules[: self._max_num_rules]
rules_texts = []
for rule in rules:
- rules_text = f"- If {rule.effect.text} go through the following points:\n"
- for cause_index, cause in enumerate(rule.causes):
- rules_text += f"{self._indent_str}{cause_index + 1}) {cause.text}\n"
- rules_text += await self.recursively_add_rules(cause)
-
- rules_text += f'{self._indent_str}{len(rule.causes) + 1}) After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n'
-
+ rules_text = rule.get_string_using_template(
+ "- If {effect} go through the following points:"
+ )
rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")
return "\n".join(rules_texts)
-
- async def recursively_add_rules(self, query, depth=2):
- if depth > self._max_indentation:
- return ""
-
- rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.95)
- rules = rules[:1]
- rules_texts = []
- for rule in rules:
- rules_text = ""
- for cause_index, causes in enumerate(rule.causes):
- indentation = self._indent_str * depth
- rules_text += f"{indentation}{cause_index + 1}) {causes.text}\n"
- rules_text += await self.recursively_add_rules(causes, depth + 1)
-
- rules_texts.append(rules_text)
-
- return "\n".join(rules_texts)
diff --git a/wafl/facts.py b/wafl/facts.py
index 23db2067..c2a3eec2 100644
--- a/wafl/facts.py
+++ b/wafl/facts.py
@@ -1,9 +1,10 @@
from dataclasses import dataclass
+from typing import Union
@dataclass
class Fact:
- text: str
+ text: Union[str, dict]
is_question: bool = False
variable: str = None
is_interruption: bool = False
diff --git a/wafl/frontend/selector.html b/wafl/frontend/selector.html
index 15ad1ce6..0752552e 100644
--- a/wafl/frontend/selector.html
+++ b/wafl/frontend/selector.html
@@ -2,7 +2,6 @@
WAFL frontend
-
diff --git a/wafl/rules.py b/wafl/rules.py
index 3e3d32f6..c54f6efa 100644
--- a/wafl/rules.py
+++ b/wafl/rules.py
@@ -6,15 +6,24 @@
class Rule:
effect: "Fact"
causes: List["Fact"]
+ _max_indentation = 3
+ _indent_str = " "
def toJSON(self):
return str(self)
+ def get_string_using_template(self, effect_template: str) -> str:
+ rule_str = effect_template.replace("{effect}", self.effect.text) + "\n"
+ return self._add_clauses(rule_str)
+
def __str__(self):
- rule_str = self.effect.text
+ rule_str = self.effect.text + "\n"
+ return self._add_clauses(rule_str)
+
+ def _add_clauses(self, rule_str: str) -> str:
for cause in self.causes:
try:
- rule_str += "\n " + cause.text
+ rule_str += self._recursively_add_clauses(cause)
except TypeError as e:
print(f"Error in rule:'''\n{rule_str}'''")
@@ -23,3 +32,23 @@ def __str__(self):
raise e
return rule_str
+
+
+ def _recursively_add_clauses(self, query: str, depth: int=1) -> str:
+ indentation = self._indent_str * depth
+ if type(query) == str:
+ return f"{indentation}-{query}\n"
+
+ if type(query.text) == str:
+ return f"{indentation}-{query.text}\n"
+
+ if depth > self._max_indentation:
+ return ""
+
+ clause = list(query.text.keys())[0]
+ rules_text = f"{indentation}-{clause}\n"
+ for clauses in query.text.values():
+ for cause_index, clause in enumerate(clauses):
+ rules_text += self._recursively_add_clauses(clause, depth + 1)
+
+ return rules_text
\ No newline at end of file