From d0f0d5e71258868867a59e757600d8a934b40fae Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 20 Nov 2024 11:10:16 -0600 Subject: [PATCH] Make trie more restrictive --- CHANGELOG.md | 4 ++++ hassil/VERSION | 2 +- hassil/trie.py | 10 ++++------ tests/test_trie.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb06ff9..9b7f137 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 2.0.3 + +- Make trie more restrictive (`two` will not match `t|wo`) + ## 2.0.2 - Require `unicode-rbnf>=2.1` which includes important bugfixes diff --git a/hassil/VERSION b/hassil/VERSION index e9307ca..50ffc5a 100644 --- a/hassil/VERSION +++ b/hassil/VERSION @@ -1 +1 @@ -2.0.2 +2.0.3 diff --git a/hassil/trie.py b/hassil/trie.py index 49b4877..ea35b8d 100644 --- a/hassil/trie.py +++ b/hassil/trie.py @@ -47,9 +47,9 @@ def insert(self, text: str, value: Any) -> None: current_children = current_node.children - def find(self, text: str) -> Iterable[Tuple[int, str, Any]]: + def find(self, text: str, unique: bool = True) -> Iterable[Tuple[int, str, Any]]: """Yield (end_pos, text, value) pairs of all words found in the string.""" - q = deque([(self.roots, 0)]) + q = deque([(self.roots, i) for i in range(len(text))]) visited = set() while q: @@ -60,15 +60,13 @@ def find(self, text: str) -> Iterable[Tuple[int, str, Any]]: current_char = text[current_position] - if current_position < len(text): - q.append((current_children, current_position + 1)) - node = current_children.get(current_char) if (node is not None) and (node.id not in visited): - visited.add(node.id) if node.text is not None: # End is one past the current position + if unique: + visited.add(node.id) yield (current_position + 1, node.text, node.value) if node.children and (current_position < len(text)): diff --git a/tests/test_trie.py b/tests/test_trie.py index 81afa7d..6f039f4 100644 --- a/tests/test_trie.py +++ b/tests/test_trie.py @@ -22,6 +22,19 @@ def test_insert_find() -> None: (45, "twenty two", 22), ] + # Without unique, *[two]* and twenty [two] will return 2 + assert list( + trie.find("set to 1, then *two*, then finally twenty two please!", unique=False) + ) == [ + (8, "1", 1), + (19, "two", 2), + (45, "two", 2), + (45, "twenty two", 22), + ] + + # Test a character in between + assert not list(trie.find("tw|o")) + # Test non-existent value assert not list(trie.find("three"))