Skip to content

Commit 8ce773b

Browse files
committed
fix(tag_search): do ordering before applying limit
1 parent 1d7a267 commit 8ce773b

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

src/tagstudio/core/library/alchemy/library.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,41 +1067,61 @@ def search_library(
10671067

10681068
return res
10691069

1070-
def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]:
1070+
def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]:
10711071
"""Return a list of Tag records matching the query."""
1072+
name = name or ""
1073+
name = name.lower()
1074+
1075+
def sort_key(text: str):
1076+
return (not text.startswith(name), text)
1077+
10721078
with Session(self.engine) as session:
1073-
query = select(Tag.id, Tag.name).outerjoin(TagAlias)
1079+
query = select(Tag.id, Tag.name)
10741080

1075-
if limit > 0 and (not name or len(name) == 1):
1081+
if limit > 0 and not name:
10761082
query = query.limit(limit).order_by(func.lower(Tag.name))
10771083

10781084
if name:
10791085
query = query.where(
10801086
or_(
10811087
Tag.name.icontains(name),
10821088
Tag.shorthand.icontains(name),
1083-
TagAlias.name.icontains(name),
10841089
)
10851090
)
10861091

10871092
tags = list(session.execute(query))
1093+
1094+
if name:
1095+
query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name))
1096+
tags.extend(session.execute(query))
1097+
1098+
tags.sort(key=lambda t: sort_key(t[1]))
1099+
seen_ids = set()
1100+
tag_ids = []
1101+
for row in tags:
1102+
id = row[0]
1103+
if id in seen_ids:
1104+
continue
1105+
tag_ids.append(id)
1106+
seen_ids.add(id)
1107+
10881108
logger.info(
10891109
"searching tags",
10901110
search=name,
10911111
limit=limit,
10921112
statement=str(query),
1093-
results=len(tags),
1113+
results=len(tag_ids),
10941114
)
10951115

1096-
tags.sort(key=lambda t: t[1].lower())
10971116
if limit <= 0:
1098-
limit = len(tags)
1099-
tag_ids = [t[0] for t in tags[:limit]]
1117+
limit = len(tag_ids)
1118+
tag_ids = tag_ids[:limit]
11001119

11011120
hierarchy = self.get_tag_hierarchy(tag_ids)
1102-
direct_tags = {hierarchy.pop(id) for id in tag_ids}
1103-
ancestor_tags = set(hierarchy.values())
1104-
return [direct_tags, ancestor_tags]
1121+
direct_tags = [hierarchy.pop(id) for id in tag_ids]
1122+
ancestor_tags = list(hierarchy.values())
1123+
ancestor_tags.sort(key=lambda t: sort_key(t.name))
1124+
return direct_tags, ancestor_tags
11051125

11061126
def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool:
11071127
"""Set the path field of an entry.

src/tagstudio/qt/mixed/tag_search.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -218,32 +218,15 @@ def update_tags(self, query: str | None = None):
218218
self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater()
219219
self.create_button_in_layout = False
220220

221-
# Get results for the search query
222-
query_lower = "" if not query else query.lower()
223221
# Only use the tag limit if it's an actual number (aka not "All Tags")
224222
tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
225-
tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit)
226-
if self.exclude:
227-
tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude}
228-
tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude}
229-
230-
# Sort and prioritize the results
231-
results_0 = list(tag_results[0])
232-
results_0.sort(key=lambda tag: tag.name.lower())
233-
results_1 = list(tag_results[1])
234-
results_1.sort(key=lambda tag: tag.name.lower())
235-
raw_results = list(results_0 + results_1)
236-
priority_results: set[Tag] = set()
237-
all_results: list[Tag] = []
223+
direct_tags, ancestor_tags = self.lib.search_tags(name=query, limit=tag_limit)
238224

239-
if query and query.strip():
240-
for tag in raw_results:
241-
if tag.name.lower().startswith(query_lower):
242-
priority_results.add(tag)
225+
all_results = [t for t in direct_tags if t.id not in self.exclude]
226+
for tag in ancestor_tags:
227+
if tag.id not in self.exclude:
228+
all_results.append(tag)
243229

244-
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
245-
r for r in raw_results if r not in priority_results
246-
]
247230
if tag_limit > 0:
248231
all_results = all_results[:tag_limit]
249232

0 commit comments

Comments
 (0)