Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions skbase/lookup/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
Parameters
----------
obj : BaseObject, an sktime estimator
tag_filter : dict of (str or list of str), default=None
tag_filter : str, list[str] or dict of (str or list of str), default=None
subsets the returned estimators as follows:
each key/value pair is statement in "and"/conjunction

Expand All @@ -190,34 +190,23 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
if tag_filter is None:
return True

type_msg = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like the function already allowed for iterable of str or str

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, tho! The function already had full support for both single strings and iterables of strings.

I did add redundant but more granular and edge case tests of the existing functionality; already implemented and working.

Should I remove those new tests I added?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like you are still changing the logic itself, making it less general, and removing the useful error message.

"filter_tags argument of all_objects must be "
"a dict with str or re.Pattern keys, "
"str, or iterable of str, "
"but found"
)

if not isinstance(tag_filter, (str, Iterable, dict)):
raise TypeError(f"{type_msg} type {type(tag_filter)}")

if not hasattr(obj, "get_class_tag"):
return False

# case: tag_filter is string
# Handle backward compatibility - convert str/list/tuple to dict
if isinstance(tag_filter, str):
tag_filter = {tag_filter: True}

# case: tag_filter is iterable of str but not dict
# If a iterable of strings is provided, check that all are in the returned tag_dict
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
if not all(isinstance(t, str) for t in tag_filter):
raise ValueError(f"{type_msg} {tag_filter}")
elif isinstance(tag_filter, (list, tuple)):
# Check if all elements are strings (original error handling)
if not all(isinstance(tag, str) for tag in tag_filter):
raise ValueError("filter_tags")
tag_filter = dict.fromkeys(tag_filter, True)
elif not isinstance(tag_filter, dict):
raise TypeError("filter_tags")

if not hasattr(obj, "get_class_tag"):
return False

# case: tag_filter is dict
# check that all keys are str
if not all(isinstance(t, str) for t in tag_filter.keys()):
raise ValueError(f"{type_msg} {tag_filter}")
raise ValueError("filter_tags")

cond_sat = True

Expand Down
213 changes: 213 additions & 0 deletions skbase/lookup/tests/test_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,216 @@ def test_all_object_class_lookup_invalid_object_types_raises(
object_types=class_filter,
class_lookup=class_lookup,
)


# ==============================================================================
# ADDITIONAL TESTS FOR EDGE CASES AND ERROR HANDLING
# ==============================================================================


def test_all_objects_filter_tags_string_preprocessing():
"""Test all_objects converts string filter_tags to dict correctly."""
# Test string input conversion
objs_str = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags="A",
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True},
)

# Results should be identical
assert objs_str.equals(
objs_dict
), "String and dict filter should return same results"


def test_all_objects_filter_tags_list_preprocessing():
"""Test all_objects converts list filter_tags to dict correctly."""
# Test list of strings input conversion
objs_list = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=["A", "B"],
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True, "B": True},
)

# Results should be identical
assert objs_list.equals(
objs_dict
), "List and dict filter should return same results"


def test_all_objects_filter_tags_tuple_preprocessing():
"""Test all_objects converts tuple filter_tags to dict correctly."""
# Test tuple of strings input conversion
objs_tuple = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=("A", "B"),
)

objs_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={"A": True, "B": True},
)

# Results should be identical
assert objs_tuple.equals(
objs_dict
), "Tuple and dict filter should return same results"


def test_get_package_metadata_filter_tags_string_preprocessing():
"""Test get_package_metadata converts string tag_filter to dict correctly."""
result_str = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter="A",
classes_to_exclude=TagAliaserMixin,
)

result_dict = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter={"A": True},
classes_to_exclude=TagAliaserMixin,
)

# Results should be identical
assert result_str.keys() == result_dict.keys()


def test_get_package_metadata_filter_tags_list_preprocessing():
"""Test get_package_metadata converts list tag_filter to dict correctly."""
result_list = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter=["A", "B"],
classes_to_exclude=TagAliaserMixin,
)

result_dict = get_package_metadata(
"skbase",
modules_to_ignore="skbase",
tag_filter={"A": True, "B": True},
classes_to_exclude=TagAliaserMixin,
)

# Results should be identical
assert result_list.keys() == result_dict.keys()


@pytest.mark.parametrize(
"invalid_filter",
[
123, # int
12.5, # float
object(), # object
["A", 123], # list with non-string
("A", 123), # tuple with non-string
],
)
def test_all_objects_filter_tags_invalid_types_preprocessing(invalid_filter):
"""Test that invalid filter_tags types raise TypeError in all_objects."""
with pytest.raises(
TypeError, match="filter_tags must be a str, list of str, or dict"
):
all_objects(
package_name="skbase",
filter_tags=invalid_filter,
)


@pytest.mark.parametrize(
"invalid_filter",
[
123, # int
12.5, # float
object(), # object
["A", 123], # list with non-string
("A", 123), # tuple with non-string
],
)
def test_get_package_metadata_filter_tags_invalid_types_preprocessing(invalid_filter):
"""Test that invalid tag_filter types raise TypeError in get_package_metadata."""
with pytest.raises(
TypeError, match="tag_filter must be a str, list of str, or dict"
):
get_package_metadata(
"skbase",
tag_filter=invalid_filter,
)


def test_all_objects_filter_tags_empty_list():
"""Test all_objects handles empty list filter_tags correctly."""
objs_empty_list = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags=[],
)

objs_empty_dict = all_objects(
package_name="skbase",
return_names=True,
as_dataframe=True,
filter_tags={},
)

# Results should be identical
assert objs_empty_list.equals(
objs_empty_dict
), "Empty list and empty dict should return same results"


def test_filter_by_tags_dict_not_modified():
"""Test that _filter_by_tags doesn't modify the original dict in place."""
original_filter = {"A": "1"}
original_copy = original_filter.copy()

# Call all_objects with the filter
all_objects(
package_name="skbase",
filter_tags=original_filter,
)

# Original dict should be unchanged
assert (
original_filter == original_copy
), "Original filter_tags dict should not be modified"


def test_get_package_metadata_filter_tags_dict_copy_behavior():
"""Test that tag_filter dict is copied and not modified in place."""
original_filter = {"A": "1"}
original_copy = original_filter.copy()

# Call get_package_metadata with the filter
get_package_metadata(
"skbase",
tag_filter=original_filter,
classes_to_exclude=TagAliaserMixin,
)

# Original dict should be unchanged
assert (
original_filter == original_copy
), "Original tag_filter dict should not be modified"