Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor channel search functionality in tdm_loader #28

Merged
merged 10 commits into from
Mar 28, 2024
45 changes: 21 additions & 24 deletions tdm_loader/tdm_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os
import zipfile
import re
from functools import cache

from xml.etree import ElementTree
import warnings
Expand Down Expand Up @@ -168,6 +169,11 @@ def _get_usi_from_txt(txt):
return []
return re.findall(r'id\("(.+?)"\)', txt)

@cache
def _get_channels(self, group_id):
group = self._xml_chgs[group_id]
return {v: i for i, v in enumerate(re.findall(r'id\("(.+?)"\)', group.find("channels").text))}

def channel_group_search(self, search_term):
"""Returns a list of channel group names that contain ``search term``.
Results are independent of case and spaces in the channel name.
Expand Down Expand Up @@ -228,31 +234,22 @@ def channel_search(self, search_term):
"""
search_term = str(search_term).upper().replace(" ", "")

ind_chg_ch = []
for j in range(len(self._xml_chgs)):
chs = self._channels_xml(j)
matched_channels = []
channel_group_ids = {v: i for i, v in enumerate(x.get("id") for x in self._xml_chgs)}

if search_term == "":
found_terms = [
ch.findtext("name") for ch in chs if ch.findtext("name") is None
]
else:
found_terms = [
ch.findtext("name")
for ch in chs
if ch.findtext("name") is not None
and ch.findtext("name")
.upper()
.replace(" ", "")
.find(str(search_term))
>= 0
]

for name in found_terms:
i = [ch.findtext("name") for ch in chs].index(name)
ind_chg_ch.append((name, j, i))

return ind_chg_ch
for channel in self._root.findall(".//tdm_channel"):
channel_name = channel.find("name").text
if channel_name:
group_uri = re.findall(r'id\("(.+?)"\)', channel.find("group").text)
group_id = channel_group_ids.get(group_uri[0])
channels = self._get_channels(group_id)

channel_id = channels.get(channel.get("id"))

if channel_name.upper().replace(" ", "").find(search_term) >= 0:
matched_channels.append((channel_name, group_id, channel_id))

return matched_channels

def channel(self, channel_group, channel, occurrence=0, ch_occurrence=0):
"""Returns a data channel by its channel group and channel index.
Expand Down
4 changes: 3 additions & 1 deletion tdm_loader/tests/test_non_zip_tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def test_channel_search(tdm_file):
("Float as Float", 0, 1),
]

assert tdm_file.channel_search("") == []
assert tdm_file.channel_search("") == [('Float_4_Integers', 0, 0),
('Float as Float', 0, 1),
('Integer32_with_max_min', 0, 2)]


# pylint: disable=redefined-outer-name
Expand Down
Loading