diff --git a/osbot_utils/type_safe/Type_Safe__Base.py b/osbot_utils/type_safe/Type_Safe__Base.py index de164858..f5997dfa 100644 --- a/osbot_utils/type_safe/Type_Safe__Base.py +++ b/osbot_utils/type_safe/Type_Safe__Base.py @@ -1,6 +1,5 @@ -from typing import get_args, Union, Optional, Any, ForwardRef - -from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache +from typing import get_args, Union, Optional, Any, ForwardRef +from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache EXACT_TYPE_MATCH = (int, float, str, bytes, bool, complex) diff --git a/osbot_utils/type_safe/Type_Safe__Set.py b/osbot_utils/type_safe/Type_Safe__Set.py new file mode 100644 index 00000000..aa49f934 --- /dev/null +++ b/osbot_utils/type_safe/Type_Safe__Set.py @@ -0,0 +1,32 @@ +from osbot_utils.type_safe.Type_Safe__Base import Type_Safe__Base, type_str + +class Type_Safe__Set(Type_Safe__Base, set): + def __init__(self, expected_type, *args): + super().__init__(*args) + self.expected_type = expected_type + + def __repr__(self): + expected_type_name = type_str(self.expected_type) + return f"set[{expected_type_name}] with {len(self)} elements" + + def add(self, item): + try: + self.is_instance_of_type(item, self.expected_type) + except TypeError as e: + raise TypeError(f"In Type_Safe__Set: Invalid type for item: {e}") + super().add(item) + + def json(self): + from osbot_utils.type_safe.Type_Safe import Type_Safe + + result = [] + for item in self: + if isinstance(item, Type_Safe): + result.append(item.json()) + elif isinstance(item, (list, tuple, set)): + result.append([x.json() if isinstance(x, Type_Safe) else x for x in item]) + elif isinstance(item, dict): + result.append({k: v.json() if isinstance(v, Type_Safe) else v for k, v in item.items()}) + else: + result.append(item) + return result \ No newline at end of file diff --git a/osbot_utils/type_safe/shared/Type_Safe__Validation.py b/osbot_utils/type_safe/shared/Type_Safe__Validation.py index 443e2ca4..1d11e1e7 100644 --- a/osbot_utils/type_safe/shared/Type_Safe__Validation.py +++ b/osbot_utils/type_safe/shared/Type_Safe__Validation.py @@ -159,6 +159,9 @@ def check_if__type_matches__obj_annotation__for_attr(self, target, attr_type = annotations.get(attr_name) if attr_type: origin_attr_type = get_origin(attr_type) # to handle when type definition contains a generic + if origin_attr_type is set: + if type(value) is list: + return True # if the attribute is a set and the value is a list, then they are compatible if origin_attr_type is type: # Add handling for Type[T] type_arg = get_args(attr_type)[0] # Get T from Type[T] if type_arg == value: @@ -220,7 +223,7 @@ def validate_type_compatibility(self, target : Any , ) -> None: # Raises ValueError if invalid direct_type_match = type_safe_validation.check_if__type_matches__obj_annotation__for_attr(target, name, value) - union_type_match = type_safe_validation.check_if__type_matches__obj_annotation__for_union_and_annotated(target, name, value) + union_type_match = type_safe_validation.check_if__type_matches__obj_annotation__for_union_and_annotated(target, name, value) is_invalid = (direct_type_match is False and union_type_match is None) or \ (direct_type_match is None and union_type_match is False) or \ diff --git a/osbot_utils/type_safe/steps/Type_Safe__Step__Default_Value.py b/osbot_utils/type_safe/steps/Type_Safe__Step__Default_Value.py index eae5c8eb..d90cb07e 100644 --- a/osbot_utils/type_safe/steps/Type_Safe__Step__Default_Value.py +++ b/osbot_utils/type_safe/steps/Type_Safe__Step__Default_Value.py @@ -3,6 +3,7 @@ import inspect import typing +from osbot_utils.type_safe.Type_Safe__Set import Type_Safe__Set from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache from osbot_utils.utils.Objects import default_value from osbot_utils.type_safe.Type_Safe__List import Type_Safe__List @@ -40,12 +41,17 @@ def default_value(self, _cls, var_type): return set() if origin is set: - return set() # todo: add Type_Safe__Set + item_type = get_args(var_type)[0] + if isinstance(item_type, ForwardRef): + forward_name = item_type.__forward_arg__ + if forward_name == _cls.__name__: + item_type = _cls + return Type_Safe__Set(expected_type=item_type) if var_type is typing.Dict: return {} - if origin is dict: # e.g. Dict[key_type, value_type] + if origin is dict: # e.g. Dict[key_type, value_type] key_type, value_type = get_args(var_type) if isinstance(key_type, ForwardRef): # Handle forward references on key_type --- forward_name = key_type.__forward_arg__ diff --git a/osbot_utils/type_safe/steps/Type_Safe__Step__From_Json.py b/osbot_utils/type_safe/steps/Type_Safe__Step__From_Json.py index 0f5cf3f0..fb12cdd7 100644 --- a/osbot_utils/type_safe/steps/Type_Safe__Step__From_Json.py +++ b/osbot_utils/type_safe/steps/Type_Safe__Step__From_Json.py @@ -6,9 +6,10 @@ from osbot_utils.type_safe.Type_Safe__List import Type_Safe__List from osbot_utils.helpers.Random_Guid import Random_Guid from osbot_utils.helpers.Random_Guid_Short import Random_Guid_Short +from osbot_utils.type_safe.Type_Safe__Set import Type_Safe__Set from osbot_utils.type_safe.shared.Type_Safe__Annotations import type_safe_annotations from osbot_utils.type_safe.shared.Type_Safe__Cache import type_safe_cache -from osbot_utils.type_safe.shared.Type_Safe__Convert import type_safe_convert +from osbot_utils.type_safe.shared.Type_Safe__Convert import type_safe_convert from osbot_utils.utils.Objects import enum_from_value from osbot_utils.helpers.Safe_Id import Safe_Id from osbot_utils.helpers.Timestamp_Now import Timestamp_Now @@ -54,6 +55,19 @@ def deserialize_from_dict(self, _self, data, raise_on_not_found=False): value = self.deserialize_type__using_value(value) elif type_safe_annotations.obj_is_attribute_annotation_of_type(_self, key, dict): # handle the case when the value is a dict value = self.deserialize_dict__using_key_value_annotations(_self, key, value) + elif type_safe_annotations.obj_is_attribute_annotation_of_type(_self, key, set): # handle the case when the value is a list + attribute_annotation = type_safe_annotations.obj_attribute_annotation(_self, key) # get the annotation for this variable + attribute_annotation_args = get_args(attribute_annotation) + if attribute_annotation_args: + expected_type = get_args(attribute_annotation)[0] # get the first arg (which is the type) + type_safe_set = Type_Safe__Set(expected_type) # create a new instance of Type_Safe__List + for item in value: # next we need to convert all items (to make sure they all match the type) + if type(item) is dict: + new_item = expected_type(**item) # create new object + else: + new_item = expected_type(item) + type_safe_set.add(new_item) # and add it to the new type_safe_list obejct + value = type_safe_set # todo: refactor out this create list code, maybe to an deserialize_from_list method elif type_safe_annotations.obj_is_attribute_annotation_of_type(_self, key, list): # handle the case when the value is a list attribute_annotation = type_safe_annotations.obj_attribute_annotation(_self, key) # get the annotation for this variable attribute_annotation_args = get_args(attribute_annotation) @@ -123,7 +137,7 @@ def deserialize_dict__using_key_value_annotations(self, _self, key, value): if type(dict_value) == value_class: # if the value is already the target, then just use it new__dict_value = dict_value - elif issubclass(value_class, Type_Safe): + elif isinstance(value_class, type) and issubclass(value_class, Type_Safe): if 'node_type' in dict_value: value_class = type_safe_convert.get_class_from_class_name(dict_value['node_type']) diff --git a/osbot_utils/utils/Objects.py b/osbot_utils/utils/Objects.py index 15095546..8e8a8027 100644 --- a/osbot_utils/utils/Objects.py +++ b/osbot_utils/utils/Objects.py @@ -281,6 +281,8 @@ def serialize_to_dict(obj): return f"{obj.__module__}.{obj.__name__}" # save the full type name elif isinstance(obj, list) or isinstance(obj, List): return [serialize_to_dict(item) for item in obj] + elif isinstance(obj, set): + return [serialize_to_dict(item) for item in obj] elif isinstance(obj, dict): return {key: serialize_to_dict(value) for key, value in obj.items()} elif hasattr(obj, "__dict__"): diff --git a/tests/unit/type_safe/_bugs/test_Type_Safe__bugs.py b/tests/unit/type_safe/_bugs/test_Type_Safe__bugs.py index a4cb21e8..54dd9dbe 100644 --- a/tests/unit/type_safe/_bugs/test_Type_Safe__bugs.py +++ b/tests/unit/type_safe/_bugs/test_Type_Safe__bugs.py @@ -1,6 +1,6 @@ import sys import pytest -from typing import Optional, Union, Dict +from typing import Optional, Union, Dict, Set from unittest import TestCase from osbot_utils.type_safe.Type_Safe import Type_Safe from osbot_utils.base_classes.Kwargs_To_Self import Kwargs_To_Self diff --git a/tests/unit/type_safe/_regression/test_Type_Safe__regression.py b/tests/unit/type_safe/_regression/test_Type_Safe__regression.py index c2eaa58f..f5eb06b6 100644 --- a/tests/unit/type_safe/_regression/test_Type_Safe__regression.py +++ b/tests/unit/type_safe/_regression/test_Type_Safe__regression.py @@ -2,7 +2,7 @@ import pytest import sys from decimal import Decimal -from typing import Optional, Union, List, Dict, get_origin, Type, ForwardRef, Any +from typing import Optional, Union, List, Dict, get_origin, Type, ForwardRef, Any, Set from unittest import TestCase from unittest.mock import patch from osbot_utils.helpers.Timestamp_Now import Timestamp_Now @@ -14,6 +14,7 @@ from osbot_utils.type_safe.Type_Safe__List import Type_Safe__List from osbot_utils.decorators.methods.cache_on_self import cache_on_self from osbot_utils.helpers.Random_Guid import Random_Guid +from osbot_utils.type_safe.Type_Safe__Set import Type_Safe__Set from osbot_utils.type_safe.shared.Type_Safe__Annotations import type_safe_annotations from osbot_utils.type_safe.validators.Validator__Min import Min from osbot_utils.utils.Json import json_to_str, str_to_json @@ -22,6 +23,44 @@ class test_Type_Safe__regression(TestCase): + def test__regression__roundtrip_set_support(self): + class An_Class(Type_Safe): + an_set_1: set[str] + an_set_2: Set[str] + + an_class = An_Class() + an_class.an_set_1.add ('a') + an_class.an_set_1.add ('b') + an_class.an_set_1.remove('a') + an_class.an_set_2.add ('a') + assert an_class.json() == {'an_set_1': ['b'], 'an_set_2': ['a']} + assert an_class.obj() == __(an_set_1=['b'], an_set_2=['a']) + assert type(an_class.an_set_1) is Type_Safe__Set + assert type(an_class.an_set_2) is Type_Safe__Set + + expected_message = "In Type_Safe__Set: Invalid type for item: Expected 'str', but got 'int'" + + with pytest.raises(TypeError, match=re.escape(expected_message)): + an_class.an_set_1.add(123) # confirms type safety + with pytest.raises(TypeError, match=re.escape(expected_message)): + an_class.an_set_2.add(123) # confirms type safety + + + #expected_message = "Invalid type for attribute 'an_set_1'. Expected 'set[str]' but got ''" + # with pytest.raises(ValueError, match=re.escape(expected_message)): + # An_Class.from_json(an_class.json()) # Fixed BUG: should not have raised an exception + an_class_round_trip = An_Class.from_json(an_class.json()) + + assert an_class_round_trip.an_set_1 == {'b'} # Fixed + assert an_class_round_trip.an_set_2 == {'a'} # Fixed + assert type(an_class_round_trip.an_set_1) is Type_Safe__Set # Fixed: BUG: it should be a set + assert type(an_class_round_trip.an_set_1) is not list # Fixed: BUG: it should not be a list + assert type(an_class_round_trip.an_set_2) is Type_Safe__Set # Fixed: BUG: it should be a set + assert type(an_class_round_trip.an_set_2) is not list # Fixed BUG: it should be a set + + assert an_class_round_trip.json() == an_class.json() # Fixed: + + def test__regression__forward_ref_type(self): class Base__Type(Type_Safe): ref_type: Type['Base__Type'] diff --git a/tests/unit/type_safe/test_Type_Safe.py b/tests/unit/type_safe/test_Type_Safe.py index 8f65f8aa..3d1db28d 100644 --- a/tests/unit/type_safe/test_Type_Safe.py +++ b/tests/unit/type_safe/test_Type_Safe.py @@ -3,7 +3,7 @@ import types import pytest from enum import Enum, auto -from typing import Union, Optional, Type +from typing import Union, Optional, Type, Set from unittest import TestCase from osbot_utils.helpers.Timestamp_Now import Timestamp_Now from osbot_utils.helpers.Guid import Guid @@ -40,6 +40,19 @@ def __init__(self): def an_extra_instance_method(self): pass + def test__set_support(self): + class An_Class(Type_Safe): + an_set_1 : set[str] + an_set_2 : Set[str] + + an_class = An_Class(an_set_1=set(['a', 'b'])) + an_class.an_set_1.remove('a') + an_class.an_set_2.add ('a') + assert an_class.json() == {'an_set_1': ['b'], 'an_set_2': ['a']} + assert an_class.obj () == __(an_set_1=['b'], an_set_2=['a']) + + + def test___cls_kwargs__(self): if sys.version_info < (3, 9): pytest.skip("Skipping test that doesn't work on 3.8 or lower")