diff --git a/skbase/lookup/_lookup.py b/skbase/lookup/_lookup.py index be274d27..b1ca3751 100644 --- a/skbase/lookup/_lookup.py +++ b/skbase/lookup/_lookup.py @@ -171,7 +171,7 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True): Parameters ---------- obj : BaseObject, an sktime estimator - tag_filter : dict of (str or list of str), default=None + tag_filter : str, list[str] or dict of (str or list of str), default=None subsets the returned estimators as follows: each key/value pair is statement in "and"/conjunction @@ -190,34 +190,23 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True): if tag_filter is None: return True - type_msg = ( - "filter_tags argument of all_objects must be " - "a dict with str or re.Pattern keys, " - "str, or iterable of str, " - "but found" - ) - - if not isinstance(tag_filter, (str, Iterable, dict)): - raise TypeError(f"{type_msg} type {type(tag_filter)}") - - if not hasattr(obj, "get_class_tag"): - return False - - # case: tag_filter is string + # Handle backward compatibility - convert str/list/tuple to dict if isinstance(tag_filter, str): tag_filter = {tag_filter: True} - - # case: tag_filter is iterable of str but not dict - # If a iterable of strings is provided, check that all are in the returned tag_dict - if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict): - if not all(isinstance(t, str) for t in tag_filter): - raise ValueError(f"{type_msg} {tag_filter}") + elif isinstance(tag_filter, (list, tuple)): + # Check if all elements are strings (original error handling) + if not all(isinstance(tag, str) for tag in tag_filter): + raise ValueError("filter_tags") tag_filter = dict.fromkeys(tag_filter, True) + elif not isinstance(tag_filter, dict): + raise TypeError("filter_tags") + + if not hasattr(obj, "get_class_tag"): + return False - # case: tag_filter is dict # check that all keys are str if not all(isinstance(t, str) for t in tag_filter.keys()): - raise ValueError(f"{type_msg} {tag_filter}") + raise ValueError("filter_tags") cond_sat = True diff --git a/skbase/lookup/tests/test_lookup.py b/skbase/lookup/tests/test_lookup.py index 8bec321b..befc0933 100644 --- a/skbase/lookup/tests/test_lookup.py +++ b/skbase/lookup/tests/test_lookup.py @@ -1071,3 +1071,216 @@ def test_all_object_class_lookup_invalid_object_types_raises( object_types=class_filter, class_lookup=class_lookup, ) + + +# ============================================================================== +# ADDITIONAL TESTS FOR EDGE CASES AND ERROR HANDLING +# ============================================================================== + + +def test_all_objects_filter_tags_string_preprocessing(): + """Test all_objects converts string filter_tags to dict correctly.""" + # Test string input conversion + objs_str = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags="A", + ) + + objs_dict = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags={"A": True}, + ) + + # Results should be identical + assert objs_str.equals( + objs_dict + ), "String and dict filter should return same results" + + +def test_all_objects_filter_tags_list_preprocessing(): + """Test all_objects converts list filter_tags to dict correctly.""" + # Test list of strings input conversion + objs_list = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags=["A", "B"], + ) + + objs_dict = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags={"A": True, "B": True}, + ) + + # Results should be identical + assert objs_list.equals( + objs_dict + ), "List and dict filter should return same results" + + +def test_all_objects_filter_tags_tuple_preprocessing(): + """Test all_objects converts tuple filter_tags to dict correctly.""" + # Test tuple of strings input conversion + objs_tuple = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags=("A", "B"), + ) + + objs_dict = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags={"A": True, "B": True}, + ) + + # Results should be identical + assert objs_tuple.equals( + objs_dict + ), "Tuple and dict filter should return same results" + + +def test_get_package_metadata_filter_tags_string_preprocessing(): + """Test get_package_metadata converts string tag_filter to dict correctly.""" + result_str = get_package_metadata( + "skbase", + modules_to_ignore="skbase", + tag_filter="A", + classes_to_exclude=TagAliaserMixin, + ) + + result_dict = get_package_metadata( + "skbase", + modules_to_ignore="skbase", + tag_filter={"A": True}, + classes_to_exclude=TagAliaserMixin, + ) + + # Results should be identical + assert result_str.keys() == result_dict.keys() + + +def test_get_package_metadata_filter_tags_list_preprocessing(): + """Test get_package_metadata converts list tag_filter to dict correctly.""" + result_list = get_package_metadata( + "skbase", + modules_to_ignore="skbase", + tag_filter=["A", "B"], + classes_to_exclude=TagAliaserMixin, + ) + + result_dict = get_package_metadata( + "skbase", + modules_to_ignore="skbase", + tag_filter={"A": True, "B": True}, + classes_to_exclude=TagAliaserMixin, + ) + + # Results should be identical + assert result_list.keys() == result_dict.keys() + + +@pytest.mark.parametrize( + "invalid_filter", + [ + 123, # int + 12.5, # float + object(), # object + ["A", 123], # list with non-string + ("A", 123), # tuple with non-string + ], +) +def test_all_objects_filter_tags_invalid_types_preprocessing(invalid_filter): + """Test that invalid filter_tags types raise TypeError in all_objects.""" + with pytest.raises( + TypeError, match="filter_tags must be a str, list of str, or dict" + ): + all_objects( + package_name="skbase", + filter_tags=invalid_filter, + ) + + +@pytest.mark.parametrize( + "invalid_filter", + [ + 123, # int + 12.5, # float + object(), # object + ["A", 123], # list with non-string + ("A", 123), # tuple with non-string + ], +) +def test_get_package_metadata_filter_tags_invalid_types_preprocessing(invalid_filter): + """Test that invalid tag_filter types raise TypeError in get_package_metadata.""" + with pytest.raises( + TypeError, match="tag_filter must be a str, list of str, or dict" + ): + get_package_metadata( + "skbase", + tag_filter=invalid_filter, + ) + + +def test_all_objects_filter_tags_empty_list(): + """Test all_objects handles empty list filter_tags correctly.""" + objs_empty_list = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags=[], + ) + + objs_empty_dict = all_objects( + package_name="skbase", + return_names=True, + as_dataframe=True, + filter_tags={}, + ) + + # Results should be identical + assert objs_empty_list.equals( + objs_empty_dict + ), "Empty list and empty dict should return same results" + + +def test_filter_by_tags_dict_not_modified(): + """Test that _filter_by_tags doesn't modify the original dict in place.""" + original_filter = {"A": "1"} + original_copy = original_filter.copy() + + # Call all_objects with the filter + all_objects( + package_name="skbase", + filter_tags=original_filter, + ) + + # Original dict should be unchanged + assert ( + original_filter == original_copy + ), "Original filter_tags dict should not be modified" + + +def test_get_package_metadata_filter_tags_dict_copy_behavior(): + """Test that tag_filter dict is copied and not modified in place.""" + original_filter = {"A": "1"} + original_copy = original_filter.copy() + + # Call get_package_metadata with the filter + get_package_metadata( + "skbase", + tag_filter=original_filter, + classes_to_exclude=TagAliaserMixin, + ) + + # Original dict should be unchanged + assert ( + original_filter == original_copy + ), "Original tag_filter dict should not be modified"