From 7262b481f08800680a660375af4e150a499805c0 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Wed, 8 Nov 2023 13:02:32 +0000 Subject: [PATCH] feat: strict typing (#168) - add strict mode mypy, fix typing errors - add absolufy-imports, ditch relative imports - add `type: ignore` comment for decorator usage on a property - add `type: ignore` comments for untyped dependencies --- .github/workflows/pull-request.yml | 12 ++- .pre-commit-config.yaml | 11 +++ flag_engine/engine.py | 14 +-- flag_engine/environments/models.py | 10 +-- flag_engine/features/models.py | 49 ++++++----- flag_engine/identities/models.py | 4 +- flag_engine/identities/traits/types.py | 12 +-- flag_engine/organisations/models.py | 2 +- flag_engine/py.typed | 0 flag_engine/segments/constants.py | 55 +++++------- flag_engine/segments/evaluator.py | 2 +- flag_engine/segments/models.py | 10 +-- flag_engine/utils/hashing.py | 4 +- flag_engine/utils/json/encoders.py | 2 +- flag_engine/utils/semver.py | 13 ++- flag_engine/utils/types.py | 39 +++++++-- mypy.ini | 6 ++ requirements-dev.in | 6 +- requirements-dev.txt | 14 ++- requirements.in | 1 + requirements.txt | 6 +- setup.py | 1 + tests/engine_tests/test_engine.py | 10 ++- tests/unit/conftest.py | 59 ++++++++----- .../test_environments_builders.py | 13 +-- .../environments/test_environments_models.py | 31 ++++--- tests/unit/features/test_features_models.py | 27 ++++-- .../identities/test_identities_builders.py | 12 +-- .../unit/identities/test_identities_models.py | 65 ++++++++++---- tests/unit/organisation/test_models.py | 2 +- .../unit/segments/test_segments_evaluator.py | 2 + tests/unit/test_engine.py | 85 ++++++++++++------- tests/unit/utils/json/test_encoders.py | 2 +- tests/unit/utils/test_utils_datetime.py | 2 +- tests/unit/utils/test_utils_hashing.py | 17 ++-- tox.ini | 6 +- 36 files changed, 390 insertions(+), 216 deletions(-) create mode 100644 flag_engine/py.typed create mode 100644 mypy.ini diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 36f17a58..97744de6 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -40,14 +40,22 @@ jobs: - name: Check Formatting run: black --check . + - name: Check Imports + run: | + git ls-files | grep '\.py$' | xargs absolufy-imports + isort . --check + - name: Check flake8 linting run: flake8 . + - name: Check Typing + run: mypy --strict . + - name: Run Tests run: pytest -p no:warnings - name: Check Coverage uses: 5monkeys/cobertura-action@v13 with: - minimum_coverage: 100 - fail_below_threshold: true + minimum_coverage: 100 + fail_below_threshold: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 097f497a..f71fbfca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,15 @@ repos: + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.5.1 + hooks: + - id: mypy + args: [--strict] + additional_dependencies: + [pydantic, pytest, pytest_mock, types-pytest-lazy-fixture, types-setuptools, semver] + - repo: https://github.com/MarcoGorelli/absolufy-imports + rev: v0.3.1 + hooks: + - id: absolufy-imports - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: diff --git a/flag_engine/engine.py b/flag_engine/engine.py index 92eae648..07334fd4 100644 --- a/flag_engine/engine.py +++ b/flag_engine/engine.py @@ -8,7 +8,9 @@ from flag_engine.utils.exceptions import FeatureStateNotFound -def get_environment_feature_states(environment: EnvironmentModel): +def get_environment_feature_states( + environment: EnvironmentModel, +) -> typing.List[FeatureStateModel]: """ Get a list of feature states for a given environment @@ -19,7 +21,9 @@ def get_environment_feature_states(environment: EnvironmentModel): return environment.feature_states -def get_environment_feature_state(environment: EnvironmentModel, feature_name: str): +def get_environment_feature_state( + environment: EnvironmentModel, feature_name: str +) -> FeatureStateModel: """ Get a specific feature state for a given feature_name in a given environment @@ -38,7 +42,7 @@ def get_environment_feature_state(environment: EnvironmentModel, feature_name: s def get_identity_feature_states( environment: EnvironmentModel, identity: IdentityModel, - override_traits: typing.List[TraitModel] = None, + override_traits: typing.Optional[typing.List[TraitModel]] = None, ) -> typing.List[FeatureStateModel]: """ Get a list of feature states for a given identity in a given environment. @@ -63,8 +67,8 @@ def get_identity_feature_state( environment: EnvironmentModel, identity: IdentityModel, feature_name: str, - override_traits: typing.List[TraitModel] = None, -): + override_traits: typing.Optional[typing.List[TraitModel]] = None, +) -> FeatureStateModel: """ Get a specific feature state for a given identity in a given environment. diff --git a/flag_engine/environments/models.py b/flag_engine/environments/models.py index d0e80179..bba3989b 100644 --- a/flag_engine/environments/models.py +++ b/flag_engine/environments/models.py @@ -19,7 +19,7 @@ class EnvironmentAPIKeyModel(BaseModel): active: bool = True @property - def is_valid(self): + def is_valid(self) -> bool: return self.active and ( not self.expires_at or self.expires_at > utcnow_with_tz() ) @@ -52,7 +52,7 @@ class EnvironmentModel(BaseModel): webhook_config: typing.Optional[WebhookModel] = None - _INTEGRATION_ATTS = [ + _INTEGRATION_ATTRS = [ "amplitude_config", "heap_config", "mixpanel_config", @@ -76,9 +76,9 @@ def integrations_data(self) -> typing.Dict[str, typing.Dict[str, str]]: """ integrations_data = {} - for integration_attr in self._INTEGRATION_ATTS: - integration_config: IntegrationModel = getattr(self, integration_attr, None) - if integration_config: + for integration_attr in self._INTEGRATION_ATTRS: + integration_config: typing.Optional[IntegrationModel] + if integration_config := getattr(self, integration_attr, None): integrations_data[integration_attr] = { "base_url": integration_config.base_url, "api_key": integration_config.api_key, diff --git a/flag_engine/features/models.py b/flag_engine/features/models.py index de213d3d..48bed138 100644 --- a/flag_engine/features/models.py +++ b/flag_engine/features/models.py @@ -2,7 +2,7 @@ import typing import uuid -from annotated_types import Ge, Le +from annotated_types import Ge, Le, SupportsLt from pydantic import UUID4, BaseModel, Field, model_validator from pydantic_collections import BaseCollectionModel from typing_extensions import Annotated @@ -16,16 +16,16 @@ class FeatureModel(BaseModel): name: str type: str - def __eq__(self, other): - return self.id == other.id + def __eq__(self, other: object) -> bool: + return isinstance(other, FeatureModel) and self.id == other.id - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) class MultivariateFeatureOptionModel(BaseModel): value: typing.Any - id: int = None + id: typing.Optional[int] = None class MultivariateFeatureStateValueModel(BaseModel): @@ -40,7 +40,7 @@ class FeatureSegmentModel(BaseModel): class MultivariateFeatureStateValueList( - BaseCollectionModel[MultivariateFeatureStateValueModel] + BaseCollectionModel[MultivariateFeatureStateValueModel] # type: ignore[misc,no-any-unimported] ): @staticmethod def _ensure_correct_percentage_allocations( @@ -75,15 +75,15 @@ def append( class FeatureStateModel(BaseModel, validate_assignment=True): feature: FeatureModel enabled: bool - django_id: int = None - feature_segment: FeatureSegmentModel = None + django_id: typing.Optional[int] = None + feature_segment: typing.Optional[FeatureSegmentModel] = None featurestate_uuid: UUID4 = Field(default_factory=uuid.uuid4) feature_state_value: typing.Any = None multivariate_feature_state_values: MultivariateFeatureStateValueList = Field( default_factory=MultivariateFeatureStateValueList ) - def set_value(self, value: typing.Any): + def set_value(self, value: typing.Any) -> None: self.feature_state_value = value def get_value(self, identity_id: typing.Union[None, int, str] = None) -> typing.Any: @@ -113,18 +113,19 @@ def is_higher_segment_priority(self, other: "FeatureStateModel") -> bool: """ - try: - return ( - getattr( - self.feature_segment, - "priority", - math.inf, + if other_feature_segment := other.feature_segment: + if ( + other_feature_segment_priority := other_feature_segment.priority + ) is not None: + return ( + getattr( + self.feature_segment, + "priority", + math.inf, + ) + < other_feature_segment_priority ) - < other.feature_segment.priority - ) - - except (TypeError, AttributeError): - return False + return False def _get_multivariate_value( self, identity_id: typing.Union[int, str] @@ -138,10 +139,14 @@ def _get_multivariate_value( # the percentage allocations of the multivariate options. This gives us a # way to ensure that the same value is returned every time we use the same # percentage value. - start_percentage = 0 + start_percentage = 0.0 + + def _mv_fs_sort_key(mv_value: MultivariateFeatureStateValueModel) -> SupportsLt: + return mv_value.id or mv_value.mv_fs_value_uuid + for mv_value in sorted( self.multivariate_feature_state_values, - key=lambda v: v.id or v.mv_fs_value_uuid, + key=_mv_fs_sort_key, ): limit = mv_value.percentage_allocation + start_percentage if start_percentage <= percentage_value < limit: diff --git a/flag_engine/identities/models.py b/flag_engine/identities/models.py index 33bd6858..7dcdace2 100644 --- a/flag_engine/identities/models.py +++ b/flag_engine/identities/models.py @@ -11,7 +11,7 @@ from flag_engine.utils.exceptions import DuplicateFeatureState -class IdentityFeaturesList(BaseCollectionModel[FeatureStateModel]): +class IdentityFeaturesList(BaseCollectionModel[FeatureStateModel]): # type: ignore[misc,no-any-unimported] @staticmethod def _ensure_unique_feature_ids( value: typing.MutableSequence[FeatureStateModel], @@ -45,7 +45,7 @@ class IdentityModel(BaseModel): identity_uuid: UUID4 = Field(default_factory=uuid.uuid4) django_id: typing.Optional[int] = None - @computed_field + @computed_field # type: ignore[misc] @property def composite_key(self) -> str: return self.generate_composite_key(self.environment_api_key, self.identifier) diff --git a/flag_engine/identities/traits/types.py b/flag_engine/identities/traits/types.py index fa069fc0..66304458 100644 --- a/flag_engine/identities/traits/types.py +++ b/flag_engine/identities/traits/types.py @@ -1,16 +1,10 @@ import re from decimal import Decimal +from typing import Any, Union, get_args -from typing import Union, Any, get_args -from typing_extensions import TypeGuard - -from pydantic.types import ( - AllowInfNan, - StringConstraints, - StrictBool, -) from pydantic import BeforeValidator -from typing_extensions import Annotated +from pydantic.types import AllowInfNan, StrictBool, StringConstraints +from typing_extensions import Annotated, TypeGuard from flag_engine.identities.traits.constants import TRAIT_STRING_VALUE_MAX_LENGTH diff --git a/flag_engine/organisations/models.py b/flag_engine/organisations/models.py index 51b27b3c..787b3b0a 100644 --- a/flag_engine/organisations/models.py +++ b/flag_engine/organisations/models.py @@ -9,5 +9,5 @@ class OrganisationModel(BaseModel): persist_trait_data: bool @property - def unique_slug(self): + def unique_slug(self) -> str: return str(self.id) + "-" + self.name diff --git a/flag_engine/py.typed b/flag_engine/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/flag_engine/segments/constants.py b/flag_engine/segments/constants.py index 8fc0d1f6..98fddd11 100644 --- a/flag_engine/segments/constants.py +++ b/flag_engine/segments/constants.py @@ -1,39 +1,22 @@ -# Segment Rules -ALL_RULE = "ALL" -ANY_RULE = "ANY" -NONE_RULE = "NONE" +from flag_engine.segments.types import ConditionOperator, RuleType -RULE_TYPES = [ALL_RULE, ANY_RULE, NONE_RULE] +# Segment Rules +ALL_RULE: RuleType = "ALL" +ANY_RULE: RuleType = "ANY" +NONE_RULE: RuleType = "NONE" # Segment Condition Operators -EQUAL = "EQUAL" -GREATER_THAN = "GREATER_THAN" -LESS_THAN = "LESS_THAN" -LESS_THAN_INCLUSIVE = "LESS_THAN_INCLUSIVE" -CONTAINS = "CONTAINS" -GREATER_THAN_INCLUSIVE = "GREATER_THAN_INCLUSIVE" -NOT_CONTAINS = "NOT_CONTAINS" -NOT_EQUAL = "NOT_EQUAL" -REGEX = "REGEX" -PERCENTAGE_SPLIT = "PERCENTAGE_SPLIT" -MODULO = "MODULO" -IS_SET = "IS_SET" -IS_NOT_SET = "IS_NOT_SET" -IN = "IN" - -CONDITION_OPERATORS = [ - EQUAL, - GREATER_THAN, - LESS_THAN, - LESS_THAN_INCLUSIVE, - CONTAINS, - GREATER_THAN_INCLUSIVE, - NOT_CONTAINS, - NOT_EQUAL, - REGEX, - PERCENTAGE_SPLIT, - MODULO, - IS_SET, - IS_NOT_SET, - IN, -] +EQUAL: ConditionOperator = "EQUAL" +GREATER_THAN: ConditionOperator = "GREATER_THAN" +LESS_THAN: ConditionOperator = "LESS_THAN" +LESS_THAN_INCLUSIVE: ConditionOperator = "LESS_THAN_INCLUSIVE" +CONTAINS: ConditionOperator = "CONTAINS" +GREATER_THAN_INCLUSIVE: ConditionOperator = "GREATER_THAN_INCLUSIVE" +NOT_CONTAINS: ConditionOperator = "NOT_CONTAINS" +NOT_EQUAL: ConditionOperator = "NOT_EQUAL" +REGEX: ConditionOperator = "REGEX" +PERCENTAGE_SPLIT: ConditionOperator = "PERCENTAGE_SPLIT" +MODULO: ConditionOperator = "MODULO" +IS_SET: ConditionOperator = "IS_SET" +IS_NOT_SET: ConditionOperator = "IS_NOT_SET" +IN: ConditionOperator = "IN" diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index bffeea5d..13f3a18a 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -176,7 +176,7 @@ def _trait_value_typed( @wraps(func) def inner( segment_value: typing.Optional[str], - trait_value: TraitValue, + trait_value: typing.Union[TraitValue, semver.Version], ) -> bool: with suppress(TypeError, ValueError): if isinstance(trait_value, str) and is_semver(segment_value): diff --git a/flag_engine/segments/models.py b/flag_engine/segments/models.py index e50d15f4..dcfdc11a 100644 --- a/flag_engine/segments/models.py +++ b/flag_engine/segments/models.py @@ -1,7 +1,7 @@ import typing -from typing_extensions import Annotated -from pydantic import BaseModel, Field, BeforeValidator +from pydantic import BaseModel, BeforeValidator, Field +from typing_extensions import Annotated from flag_engine.features.models import FeatureStateModel from flag_engine.segments import constants @@ -22,16 +22,16 @@ class SegmentRuleModel(BaseModel): conditions: typing.List[SegmentConditionModel] = Field(default_factory=list) @staticmethod - def none(iterable: typing.Iterable) -> bool: + def none(iterable: typing.Iterable[object]) -> bool: return not any(iterable) @property - def matching_function(self) -> callable: + def matching_function(self) -> typing.Callable[[typing.Iterable[object]], bool]: return { constants.ANY_RULE: any, constants.ALL_RULE: all, constants.NONE_RULE: SegmentRuleModel.none, - }.get(self.type) + }[self.type] class SegmentModel(BaseModel): diff --git a/flag_engine/utils/hashing.py b/flag_engine/utils/hashing.py index d192c0a6..c4618e1e 100644 --- a/flag_engine/utils/hashing.py +++ b/flag_engine/utils/hashing.py @@ -1,9 +1,11 @@ import hashlib import typing +from flag_engine.utils.types import SupportsStr + def get_hashed_percentage_for_object_ids( - object_ids: typing.Iterable[typing.Any], iterations: int = 1 + object_ids: typing.Iterable[SupportsStr], iterations: int = 1 ) -> float: """ Given a list of object ids, get a floating point number between 0 (inclusive) and diff --git a/flag_engine/utils/json/encoders.py b/flag_engine/utils/json/encoders.py index 9fc73f54..fba8421f 100644 --- a/flag_engine/utils/json/encoders.py +++ b/flag_engine/utils/json/encoders.py @@ -8,7 +8,7 @@ class DecimalEncoder(json.JSONEncoder): int/float(for us) converted to decimal by boto3/dynamodb. """ - def default(self, obj): + def default(self, obj: object) -> object: if isinstance(obj, decimal.Decimal): if obj % 1 == 0: return int(obj) diff --git a/flag_engine/utils/semver.py b/flag_engine/utils/semver.py index 8c36425f..9d9f1c70 100644 --- a/flag_engine/utils/semver.py +++ b/flag_engine/utils/semver.py @@ -1,4 +1,9 @@ -def is_semver(value: str) -> bool: +from typing import Optional + +import semver + + +def is_semver(value: Optional[str]) -> bool: """ Checks if the given string have `:semver` suffix or not >>> is_semver("2.1.41-beta:semver") @@ -7,10 +12,10 @@ def is_semver(value: str) -> bool: False """ - return value[-7:] == ":semver" + return value is not None and value[-7:] == ":semver" -def remove_semver_suffix(value: str) -> str: +def remove_semver_suffix(value: semver.Version) -> str: """ Remove the semver suffix(i.e: last 7 characters) from the given value >>> remove_semver_suffix("2.1.41-beta:semver") @@ -18,4 +23,4 @@ def remove_semver_suffix(value: str) -> str: >>> remove_semver_suffix("2.1.41:semver") '2.1.41' """ - return value[:-7] + return str(value)[:-7] diff --git a/flag_engine/utils/types.py b/flag_engine/utils/types.py index 06c9eee5..f98d6cd9 100644 --- a/flag_engine/utils/types.py +++ b/flag_engine/utils/types.py @@ -1,11 +1,21 @@ import typing +from functools import singledispatch import semver +from flag_engine.identities.traits.types import TraitValue from flag_engine.utils.semver import remove_semver_suffix -def get_casting_function(input_: typing.Any) -> typing.Callable: +class SupportsStr(typing.Protocol): + def __str__(self) -> str: # pragma: no cover + ... + + +@singledispatch +def get_casting_function( + input_: object, +) -> typing.Callable[..., TraitValue]: """ This function returns a callable to cast a value to the same type as input_ >>> assert get_casting_function("a string") == str @@ -13,11 +23,24 @@ def get_casting_function(input_: typing.Any) -> typing.Callable: >>> assert get_casting_function(1.2) == float >>> assert get_casting_function(semver.Version.parse("3.4.5")) == remove_semver_suffix """ + return str + + +@get_casting_function.register +def _(input_: bool) -> typing.Callable[..., bool]: + return lambda v: v not in ("False", "false") + + +@get_casting_function.register +def _(input_: int) -> typing.Callable[..., int]: + return int + + +@get_casting_function.register +def _(input_: float) -> typing.Callable[..., float]: + return float + - type_ = type(input_) - return { - bool: lambda v: v not in ("False", "false"), - int: int, - float: float, - semver.Version: remove_semver_suffix, - }.get(type_, str) +@get_casting_function.register +def _(input_: semver.Version) -> typing.Callable[..., str]: + return remove_semver_suffix diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..1f3216cc --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +plugins = pydantic.mypy +disallow_any_unimported = True + +[mypy-pydantic_collections.*] +ignore_missing_imports = True diff --git a/requirements-dev.in b/requirements-dev.in index ef473841..61f43754 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -5,4 +5,8 @@ isort pytest-mock pytest-lazy-fixture pytest-cov -pip-tools \ No newline at end of file +pip-tools +types-pytest-lazy-fixture +types-setuptools +mypy +absolufy-imports \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 23d2ba42..48f36487 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,8 @@ # # pip-compile --output-file=requirements-dev.txt requirements-dev.in # +absolufy-imports==0.3.1 + # via -r requirements-dev.in black==23.7.0 # via -r requirements-dev.in build==0.10.0 @@ -22,8 +24,12 @@ isort==5.12.0 # via -r requirements-dev.in mccabe==0.7.0 # via flake8 +mypy==1.5.1 + # via -r requirements-dev.in mypy-extensions==1.0.0 - # via black + # via + # black + # mypy packaging==23.1 # via # black @@ -55,6 +61,12 @@ pytest-lazy-fixture==0.6.3 # via -r requirements-dev.in pytest-mock==3.11.1 # via -r requirements-dev.in +types-pytest-lazy-fixture==0.6.3.4 + # via -r requirements-dev.in +types-setuptools==68.2.0.0 + # via -r requirements-dev.in +typing-extensions==4.8.0 + # via mypy wheel==0.41.2 # via pip-tools diff --git a/requirements.in b/requirements.in index ab1a1c4e..cc7f4bab 100644 --- a/requirements.in +++ b/requirements.in @@ -1,3 +1,4 @@ +annotated-types semver pydantic pydantic-collections diff --git a/requirements.txt b/requirements.txt index 93ffd758..516db053 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,9 @@ # pip-compile # annotated-types==0.5.0 - # via pydantic + # via + # -r requirements.in + # pydantic pydantic==2.3.0 # via # -r requirements.in @@ -16,7 +18,7 @@ pydantic-core==2.6.3 # via pydantic semver==3.0.1 # via -r requirements.in -typing-extensions==4.7.1 +typing-extensions==4.8.0 # via # -r requirements.in # pydantic diff --git a/setup.py b/setup.py index 8e3baa47..10647dd7 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ author="Flagsmith", author_email="support@flagsmith.com", packages=find_packages(include=["flag_engine", "flag_engine.*"]), + package_data={"flag_engine": ["py.typed"]}, url="https://github.com/Flagsmith/flagsmith-engine", license="BSD3", description="Flag engine for the Flagsmith API.", diff --git a/tests/engine_tests/test_engine.py b/tests/engine_tests/test_engine.py index e8e03894..0797acbf 100644 --- a/tests/engine_tests/test_engine.py +++ b/tests/engine_tests/test_engine.py @@ -15,7 +15,9 @@ def _extract_test_cases( file_path: Path, -) -> typing.Iterable[typing.Tuple[EnvironmentModel, IdentityModel, dict]]: +) -> typing.Iterable[ + typing.Tuple[EnvironmentModel, IdentityModel, typing.Dict[str, typing.Any]], +]: """ Extract the test cases from the json data file which should be in the following format. @@ -52,7 +54,11 @@ def _extract_test_cases( MODULE_PATH / "engine-test-data/data/environment_n9fbf9h3v4fFgH3U3ngWhb.json" ), ) -def test_engine(environment_model, identity_model, api_response): +def test_engine( + environment_model: EnvironmentModel, + identity_model: IdentityModel, + api_response: typing.Dict[str, typing.Any], +) -> None: # When # we get the feature states from the engine engine_response = get_identity_feature_states(environment_model, identity_model) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4d47d325..5476fe9e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -23,17 +23,20 @@ @pytest.fixture() -def segment_condition_property(): +def segment_condition_property() -> str: return "foo" @pytest.fixture() -def segment_condition_string_value(): +def segment_condition_string_value() -> str: return "bar" @pytest.fixture() -def segment_condition(segment_condition_property, segment_condition_string_value): +def segment_condition( + segment_condition_property: str, + segment_condition_string_value: str, +) -> SegmentConditionModel: return SegmentConditionModel( operator=constants.EQUAL, property_=segment_condition_property, @@ -42,17 +45,17 @@ def segment_condition(segment_condition_property, segment_condition_string_value @pytest.fixture() -def segment_rule(segment_condition): +def segment_rule(segment_condition: SegmentConditionModel) -> SegmentRuleModel: return SegmentRuleModel(type=constants.ALL_RULE, conditions=[segment_condition]) @pytest.fixture() -def segment(segment_rule): +def segment(segment_rule: SegmentRuleModel) -> SegmentModel: return SegmentModel(id=1, name="my_segment", rules=[segment_rule]) @pytest.fixture() -def organisation(): +def organisation() -> OrganisationModel: return OrganisationModel( id=1, name="test Org", @@ -63,7 +66,10 @@ def organisation(): @pytest.fixture() -def project(organisation, segment): +def project( + organisation: OrganisationModel, + segment: SegmentModel, +) -> ProjectModel: return ProjectModel( id=1, name="Test Project", @@ -74,27 +80,31 @@ def project(organisation, segment): @pytest.fixture() -def feature_1(): +def feature_1() -> FeatureModel: return FeatureModel(id=1, name="feature_1", type=STANDARD) @pytest.fixture() -def feature_2(): +def feature_2() -> FeatureModel: return FeatureModel(id=2, name="feature_2", type=STANDARD) @pytest.fixture() -def feature_state_1(feature_1): +def feature_state_1(feature_1: FeatureModel) -> FeatureStateModel: return FeatureStateModel(feature=feature_1, enabled=True) @pytest.fixture() -def feature_state_2(feature_2): +def feature_state_2(feature_2: FeatureModel) -> FeatureStateModel: return FeatureStateModel(feature=feature_2, enabled=True) @pytest.fixture() -def environment(feature_1, feature_2, project): +def environment( + feature_1: FeatureModel, + feature_2: FeatureModel, + project: ProjectModel, +) -> EnvironmentModel: return EnvironmentModel( id=1, api_key="api-key", @@ -107,7 +117,7 @@ def environment(feature_1, feature_2, project): @pytest.fixture() -def identity(environment): +def identity(environment: EnvironmentModel) -> IdentityModel: return IdentityModel( identifier="identity_1", environment_api_key=environment.api_key, @@ -116,14 +126,18 @@ def identity(environment): @pytest.fixture() -def trait_matching_segment(segment_condition): +def trait_matching_segment(segment_condition: SegmentConditionModel) -> TraitModel: return TraitModel( - trait_key=segment_condition.property_, trait_value=segment_condition.value + trait_key=segment_condition.property_, + trait_value=segment_condition.value, ) @pytest.fixture() -def identity_in_segment(trait_matching_segment, environment): +def identity_in_segment( + trait_matching_segment: TraitModel, + environment: EnvironmentModel, +) -> IdentityModel: return IdentityModel( identifier="identity_2", environment_api_key=environment.api_key, @@ -132,7 +146,10 @@ def identity_in_segment(trait_matching_segment, environment): @pytest.fixture() -def segment_override_fs(segment, feature_1): +def segment_override_fs( + segment: SegmentModel, + feature_1: FeatureModel, +) -> FeatureStateModel: fs = FeatureStateModel( django_id=4, feature=feature_1, @@ -143,7 +160,7 @@ def segment_override_fs(segment, feature_1): @pytest.fixture() -def mv_feature_state_value(): +def mv_feature_state_value() -> MultivariateFeatureStateValueModel: return MultivariateFeatureStateValueModel( id=1, multivariate_feature_option=MultivariateFeatureOptionModel( @@ -154,7 +171,11 @@ def mv_feature_state_value(): @pytest.fixture() -def environment_with_segment_override(environment, segment_override_fs, segment): +def environment_with_segment_override( + environment: EnvironmentModel, + segment_override_fs: FeatureStateModel, + segment: SegmentModel, +) -> EnvironmentModel: segment.feature_states.append(segment_override_fs) environment.project.segments.append(segment) return environment diff --git a/tests/unit/environments/test_environments_builders.py b/tests/unit/environments/test_environments_builders.py index 19c486e4..2235b8a2 100644 --- a/tests/unit/environments/test_environments_builders.py +++ b/tests/unit/environments/test_environments_builders.py @@ -11,7 +11,7 @@ from tests.unit.helpers import get_environment_feature_state_for_feature_by_name -def test_build_environment_model(): +def test_build_environment_model() -> None: """Test to exercise the basic fields on the schema.""" # Given webhook_url = "https://my.webhook.com/hook" @@ -51,10 +51,11 @@ def test_build_environment_model(): assert environment_model assert len(environment_model.feature_states) == 1 + assert environment_model.webhook_config assert environment_model.webhook_config.url == webhook_url -def test_build_environment_model_with_name(): +def test_build_environment_model_with_name() -> None: # Given environment_name = "some_environment" environment_dict = { @@ -119,7 +120,9 @@ def test_build_environment_model__project_has_server_key_only_feature_ids__retur ) -def test_get_flags_for_environment_returns_feature_states_for_environment_dictionary(): +def test_get_flags_for_environment_returns_feature_states_for_environment_dictionary() -> ( + None +): # Given # some variables for use later string_value = "foo" @@ -190,7 +193,7 @@ def test_get_flags_for_environment_returns_feature_states_for_environment_dictio ) -def test_build_environment_model_with_multivariate_flag(): +def test_build_environment_model_with_multivariate_flag() -> None: # Given variate_1_value = "value-1" variate_2_value = "value-2" @@ -255,7 +258,7 @@ def test_build_environment_model_with_multivariate_flag(): ) -def test_build_environment_api_key_model(): +def test_build_environment_api_key_model() -> None: # Given environment_key_dict = { "key": "ser.7duQYrsasJXqdGsdaagyfU", diff --git a/tests/unit/environments/test_environments_models.py b/tests/unit/environments/test_environments_models.py index 1b2eb1f1..d98d912e 100644 --- a/tests/unit/environments/test_environments_models.py +++ b/tests/unit/environments/test_environments_models.py @@ -3,10 +3,12 @@ import pytest from flag_engine.environments.integrations.models import IntegrationModel -from flag_engine.environments.models import EnvironmentAPIKeyModel +from flag_engine.environments.models import EnvironmentAPIKeyModel, EnvironmentModel -def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key(): +def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key() -> ( + None +): assert ( EnvironmentAPIKeyModel( id=1, @@ -19,7 +21,9 @@ def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key() ) -def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key_with_expired_date_in_future(): +def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key_with_expired_date_in_future() -> ( + None +): assert ( EnvironmentAPIKeyModel( id=1, @@ -33,7 +37,7 @@ def test_environment_api_key_model_is_valid_is_true_for_non_expired_active_key_w ) -def test_environment_api_key_model_is_valid_is_false_for_expired_active_key(): +def test_environment_api_key_model_is_valid_is_false_for_expired_active_key() -> None: assert ( EnvironmentAPIKeyModel( id=1, @@ -47,7 +51,9 @@ def test_environment_api_key_model_is_valid_is_false_for_expired_active_key(): ) -def test_environment_api_key_model_is_valid_is_false_for_non_expired_inactive_key(): +def test_environment_api_key_model_is_valid_is_false_for_non_expired_inactive_key() -> ( + None +): assert ( EnvironmentAPIKeyModel( id=1, @@ -62,14 +68,14 @@ def test_environment_api_key_model_is_valid_is_false_for_non_expired_inactive_ke def test_environment_integrations_data_returns_empty_dict_when_no_integrations( - environment, -): + environment: EnvironmentModel, +) -> None: assert environment.integrations_data == {} def test_environment_integrations_data_returns_correct_data_when_multiple_integrations( - environment, -): + environment: EnvironmentModel, +) -> None: # Given example_key = "some-key" base_url = "https://some-integration-url" @@ -108,8 +114,11 @@ def test_environment_integrations_data_returns_correct_data_when_multiple_integr ), ) def test_environment_get_hide_disabled_flags( - environment, environment_value, project_value, expected_result -): + environment: EnvironmentModel, + environment_value: bool, + project_value: bool, + expected_result: bool, +) -> None: # Given environment.hide_disabled_flags = environment_value environment.project.hide_disabled_flags = project_value diff --git a/tests/unit/features/test_features_models.py b/tests/unit/features/test_features_models.py index 2e2297fc..5f0d8f09 100644 --- a/tests/unit/features/test_features_models.py +++ b/tests/unit/features/test_features_models.py @@ -2,6 +2,7 @@ import pytest from pydantic import ValidationError +from pytest_mock import MockerFixture from flag_engine.features.constants import STANDARD from flag_engine.features.models import ( @@ -13,12 +14,14 @@ from flag_engine.utils.exceptions import InvalidPercentageAllocation -def test_initializing_feature_state_creates_default_feature_state_uuid(feature_1): +def test_initializing_feature_state_creates_default_feature_state_uuid( + feature_1: FeatureModel, +) -> None: feature_state = FeatureStateModel(django_id=1, feature=feature_1, enabled=True) assert feature_state.featurestate_uuid is not None -def test_initializing_multivariate_feature_state_value_creates_default_uuid(): +def test_initializing_multivariate_feature_state_value_creates_default_uuid() -> None: mv_feature_option = MultivariateFeatureOptionModel(value="value") mv_fs_value_model = MultivariateFeatureStateValueModel( multivariate_feature_option=mv_feature_option, id=1, percentage_allocation=10 @@ -123,7 +126,7 @@ def test_feature_state_model__multivariate_feature_state_values__append__expecte ] -def test_feature_state_get_value_no_mv_values(feature_1): +def test_feature_state_get_value_no_mv_values(feature_1: FeatureModel) -> None: # Given value = "foo" feature_state = FeatureStateModel(django_id=1, feature=feature_1, enabled=True) @@ -149,8 +152,10 @@ def test_feature_state_get_value_no_mv_values(feature_1): ) @mock.patch("flag_engine.features.models.get_hashed_percentage_for_object_ids") def test_feature_state_get_value_mv_values( - mock_get_hashed_percentage, percentage_value, expected_value -): + mock_get_hashed_percentage: mock.Mock, + percentage_value: int, + expected_value: str, +) -> None: # Given # a feature my_feature = FeatureModel(id=1, name="mv_feature", type=STANDARD) @@ -189,8 +194,10 @@ def test_feature_state_get_value_mv_values( def test_get_value_uses_django_id_for_multivariate_value_calculation_if_not_none( - feature_1, mv_feature_state_value, mocker -): + feature_1: FeatureModel, + mv_feature_state_value: MultivariateFeatureStateValueModel, + mocker: MockerFixture, +) -> None: # Given mocked_get_hashed_percentage = mocker.patch( "flag_engine.features.models.get_hashed_percentage_for_object_ids", @@ -212,8 +219,10 @@ def test_get_value_uses_django_id_for_multivariate_value_calculation_if_not_none def test_get_value_uses_featuestate_uuid_for_multivariate_value_calculation_if_django_id_is_not_present( - feature_1, mv_feature_state_value, mocker -): + feature_1: FeatureModel, + mv_feature_state_value: MultivariateFeatureStateValueModel, + mocker: MockerFixture, +) -> None: # Given mocked_get_hashed_percentage = mocker.patch( "flag_engine.features.models.get_hashed_percentage_for_object_ids", diff --git a/tests/unit/identities/test_identities_builders.py b/tests/unit/identities/test_identities_builders.py index bb11559a..e7994486 100644 --- a/tests/unit/identities/test_identities_builders.py +++ b/tests/unit/identities/test_identities_builders.py @@ -4,7 +4,7 @@ from flag_engine.identities.models import IdentityFeaturesList, IdentityModel -def test_build_identity_model_from_dictionary_no_feature_states(): +def test_build_identity_model_from_dictionary_no_feature_states() -> None: # Given identity = { "id": 1, @@ -23,7 +23,9 @@ def test_build_identity_model_from_dictionary_no_feature_states(): assert len(identity_model.identity_traits) == 1 -def test_build_identity_model_from_dictionary_uses_identity_feature_list_for_identity_features(): +def test_build_identity_model_from_dictionary_uses_identity_feature_list_for_identity_features() -> ( + None +): # Given identity_dict = { "id": 1, @@ -50,14 +52,14 @@ def test_build_identity_model_from_dictionary_uses_identity_feature_list_for_ide assert isinstance(identity_model.identity_features, IdentityFeaturesList) -def test_build_build_identity_model_from_dict_creates_identity_uuid(): +def test_build_build_identity_model_from_dict_creates_identity_uuid() -> None: identity_model = build_identity_model( {"identifier": "test_user", "environment_api_key": "some_key"} ) assert identity_model.identity_uuid is not None -def test_build_identity_model_from_dictionary_with_feature_states(): +def test_build_identity_model_from_dictionary_with_feature_states() -> None: # Given identity_dict = { "id": 1, @@ -87,7 +89,7 @@ def test_build_identity_model_from_dictionary_with_feature_states(): assert isinstance(identity_model.identity_features[0], FeatureStateModel) -def test_identity_dict_created_using_model_can_convert_back_to_model(): +def test_identity_dict_created_using_model_can_convert_back_to_model() -> None: # Given identity_model = IdentityModel( environment_api_key="some_key", identifier="test_identifier" diff --git a/tests/unit/identities/test_identities_models.py b/tests/unit/identities/test_identities_models.py index dc6ed536..e8d624d0 100644 --- a/tests/unit/identities/test_identities_models.py +++ b/tests/unit/identities/test_identities_models.py @@ -8,7 +8,7 @@ from flag_engine.utils.exceptions import DuplicateFeatureState -def test_composite_key(): +def test_composite_key() -> None: # Given environment_api_key = "abc123" identifier = "identity" @@ -21,12 +21,12 @@ def test_composite_key(): assert identity_model.composite_key == f"{environment_api_key}_{identifier}" -def test_identiy_model_creates_default_identity_uuid(): +def test_identiy_model_creates_default_identity_uuid() -> None: identity_model = IdentityModel(identifier="test", environment_api_key="some_key") assert identity_model.identity_uuid is not None -def test_generate_composite_key(): +def test_generate_composite_key() -> None: # Given environment_api_key = "abc123" identifier = "identity" @@ -40,7 +40,9 @@ def test_generate_composite_key(): ) -def test_update_traits_remove_traits_with_none_value(identity_in_segment): +def test_update_traits_remove_traits_with_none_value( + identity_in_segment: IdentityModel, +) -> None: # Given trait_key = identity_in_segment.identity_traits[0].trait_key trait_to_remove = TraitModel(trait_key=trait_key, trait_value=None) @@ -55,7 +57,9 @@ def test_update_traits_remove_traits_with_none_value(identity_in_segment): assert traits_updated is True -def test_update_identity_traits_updates_trait_value(identity_in_segment): +def test_update_identity_traits_updates_trait_value( + identity_in_segment: IdentityModel, +) -> None: # Given trait_key = identity_in_segment.identity_traits[0].trait_key trait_value = "updated_trait_value" @@ -73,7 +77,9 @@ def test_update_identity_traits_updates_trait_value(identity_in_segment): assert traits_updated is True -def test_update_traits_adds_new_traits(identity_in_segment): +def test_update_traits_adds_new_traits( + identity_in_segment: IdentityModel, +) -> None: # Given new_trait = TraitModel(trait_key="new_key", trait_value="foobar") @@ -87,7 +93,9 @@ def test_update_traits_adds_new_traits(identity_in_segment): assert traits_updated is True -def test_update_traits_returns_false_if_traits_are_not_updated(identity_in_segment): +def test_update_traits_returns_false_if_traits_are_not_updated( + identity_in_segment: IdentityModel, +) -> None: # Given trait_key = identity_in_segment.identity_traits[0].trait_key trait_value = identity_in_segment.identity_traits[0].trait_value @@ -107,8 +115,9 @@ def test_update_traits_returns_false_if_traits_are_not_updated(identity_in_segme def test_appending_feature_states_raises_duplicate_feature_state_if_fs_for_the_feature_already_exists( - identity, feature_1 -): + identity: IdentityModel, + feature_1: FeatureModel, +) -> None: # Given fs_1 = FeatureStateModel(feature=feature_1, enabled=False) fs_2 = FeatureStateModel(feature=feature_1, enabled=True) @@ -139,7 +148,7 @@ def test_identity_model__identity_features__append__expected_result( ] -def test_append_feature_state(identity, feature_1): +def test_append_feature_state(identity: IdentityModel, feature_1: FeatureModel) -> None: # Given fs_1 = FeatureStateModel(feature=feature_1, enabled=False) # When @@ -149,8 +158,10 @@ def test_append_feature_state(identity, feature_1): def test_prune_features_only_keeps_valid_features( - identity, feature_state_1, feature_state_2 -): + identity: IdentityModel, + feature_state_1: FeatureStateModel, + feature_state_2: FeatureStateModel, +) -> None: # Given identity.identity_features.append(feature_state_1) identity.identity_features.append(feature_state_2) @@ -164,20 +175,44 @@ def test_prune_features_only_keeps_valid_features( assert list(identity.identity_features) == [feature_state_1] -def test_get_hash_key_with_use_identity_composite_key_for_hashing_enabled(identity): +def test_get_hash_key_with_use_identity_composite_key_for_hashing_enabled( + identity: IdentityModel, +) -> None: assert ( identity.get_hash_key(use_identity_composite_key_for_hashing=True) == identity.composite_key ) -def test_get_hash_key_with_use_identity_composite_key_for_hashing_disabled(identity): +def test_get_hash_key_with_use_identity_composite_key_for_hashing_disabled( + identity: IdentityModel, +) -> None: assert ( identity.get_hash_key(use_identity_composite_key_for_hashing=False) == identity.identifier ) +def test_get_hash_key_with_use_mv_v2_evaluation_enabled( + identity: IdentityModel, +) -> None: + # Given + use_mv_v2_evaluations = True + + # When/ Then + assert identity.get_hash_key(use_mv_v2_evaluations) == identity.composite_key + + +def test_get_hash_key_with_use_mv_v2_evaluation_disabled( + identity: IdentityModel, +) -> None: + # Given + use_mv_v2_evaluations = False + + # When/ Then + assert identity.get_hash_key(use_mv_v2_evaluations) == identity.identifier + + @pytest.mark.parametrize( "trait_value, expected_result", [ @@ -206,7 +241,7 @@ def test_trait_model__deserialize__expected_trait_value( assert result.trait_value == expected_result -def test_identity_model__deserialize__handles_nan(): +def test_identity_model__deserialize__handles_nan() -> None: # When result = IdentityModel.model_validate( { diff --git a/tests/unit/organisation/test_models.py b/tests/unit/organisation/test_models.py index adc45fa2..4cfe31db 100644 --- a/tests/unit/organisation/test_models.py +++ b/tests/unit/organisation/test_models.py @@ -1,7 +1,7 @@ from flag_engine.organisations.models import OrganisationModel -def test_unique_slug_property(): +def test_unique_slug_property() -> None: # Given org_id = 1 org_name = "test" diff --git a/tests/unit/segments/test_segments_evaluator.py b/tests/unit/segments/test_segments_evaluator.py index eb08d7c8..3097510d 100644 --- a/tests/unit/segments/test_segments_evaluator.py +++ b/tests/unit/segments/test_segments_evaluator.py @@ -243,6 +243,7 @@ def test_identity_in_segment_is_set_and_is_not_set( (constants.CONTAINS, "bar", "bar", True), (constants.CONTAINS, "bar", "baz", False), (constants.CONTAINS, "bar", 1, False), + (constants.CONTAINS, 1, "1", False), (constants.NOT_CONTAINS, "bar", "b", False), (constants.NOT_CONTAINS, "bar", "bar", False), (constants.NOT_CONTAINS, "bar", "baz", True), @@ -252,6 +253,7 @@ def test_identity_in_segment_is_set_and_is_not_set( (constants.REGEX, 1, r"\d", True), (constants.REGEX, None, r"[a-z]", False), (constants.REGEX, "foo", 12, False), + (constants.REGEX, 1, "1", True), (constants.IN, "foo", "", False), (constants.IN, "foo", "foo,bar", True), (constants.IN, "bar", "foo,bar", True), diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py index a306c37c..dd6b23f3 100644 --- a/tests/unit/test_engine.py +++ b/tests/unit/test_engine.py @@ -9,15 +9,18 @@ from flag_engine.environments.models import EnvironmentModel from flag_engine.features.constants import STANDARD from flag_engine.features.models import FeatureModel, FeatureStateModel -from flag_engine.identities.models import IdentityModel +from flag_engine.identities.models import IdentityFeaturesList, IdentityModel from flag_engine.identities.traits.models import TraitModel +from flag_engine.segments.models import SegmentModel from flag_engine.utils.exceptions import FeatureStateNotFound from tests.unit.helpers import get_environment_feature_state_for_feature def test_identity_get_feature_state_without_any_override( - environment, identity, feature_1 -): + environment: EnvironmentModel, + identity: IdentityModel, + feature_1: FeatureModel, +) -> None: # When feature_state = get_identity_feature_state(environment, identity, feature_1.name) # Then @@ -34,8 +37,11 @@ def test_identity_get_feature_state__nonexistent_feature__raise_expected( def test_identity_get_all_feature_states_no_segments( - feature_1, feature_2, environment, identity -): + feature_1: FeatureModel, + feature_2: FeatureModel, + environment: EnvironmentModel, + identity: IdentityModel, +) -> None: # Given overridden_feature = FeatureModel(id=3, name="overridden_feature", type=STANDARD) @@ -45,9 +51,9 @@ def test_identity_get_all_feature_states_no_segments( ) # but True for the identity - identity.identity_features = [ - FeatureStateModel(django_id=4, feature=overridden_feature, enabled=True) - ] + identity.identity_features = IdentityFeaturesList( + [FeatureStateModel(django_id=4, feature=overridden_feature, enabled=True)] + ) # When all_feature_states = get_identity_feature_states( @@ -81,19 +87,21 @@ def test_identity_get_all_feature_states_no_segments( ), ) def test_get_identity_feature_states_hides_disabled_flags( - environment, - identity, - feature_1, - feature_2, - environment_value, - project_value, - disabled_flag_returned, -): + environment: EnvironmentModel, + identity: IdentityModel, + feature_1: FeatureModel, + feature_2: FeatureModel, + environment_value: bool, + project_value: bool, + disabled_flag_returned: bool, +) -> None: # Given - two identity overrides - identity.identity_features = [ - FeatureStateModel(django_id=1, feature=feature_1, enabled=True), - FeatureStateModel(django_id=2, feature=feature_2, enabled=False), - ] + identity.identity_features = IdentityFeaturesList( + [ + FeatureStateModel(django_id=1, feature=feature_1, enabled=True), + FeatureStateModel(django_id=2, feature=feature_2, enabled=False), + ] + ) environment.hide_disabled_flags = environment_value environment.project.hide_disabled_flags = project_value @@ -108,8 +116,12 @@ def test_get_identity_feature_states_hides_disabled_flags( def test_identity_get_all_feature_states_segments_only( - feature_1, feature_2, environment, segment, identity_in_segment -): + feature_1: FeatureModel, + feature_2: FeatureModel, + environment: EnvironmentModel, + segment: SegmentModel, + identity_in_segment: IdentityModel, +) -> None: # Given # a feature which we can override overridden_feature = FeatureModel(id=3, name="overridden_feature", type=STANDARD) @@ -145,12 +157,12 @@ def test_identity_get_all_feature_states_segments_only( def test_identity_get_all_feature_states_with_traits( - environment_with_segment_override, - identity_in_segment, - identity, - segment_condition_string_value, - segment_condition_property, -): + environment_with_segment_override: EnvironmentModel, + identity_in_segment: IdentityModel, + identity: IdentityModel, + segment_condition_string_value: str, + segment_condition_property: str, +) -> None: # Given trait_models = TraitModel( trait_key=segment_condition_property, trait_value=segment_condition_string_value @@ -167,7 +179,7 @@ def test_identity_get_all_feature_states_with_traits( assert all_feature_states[0].get_value() == "segment_override" -def test_environment_get_all_feature_states(environment): +def test_environment_get_all_feature_states(environment: EnvironmentModel) -> None: # When feature_states = get_environment_feature_states(environment) @@ -187,8 +199,11 @@ def test_environment_get_all_feature_states(environment): ), ) def test_environment_get_feature_states_hide_disabled_flags( - environment, environment_value, project_value, disabled_flag_returned -): + environment: EnvironmentModel, + environment_value: bool, + project_value: bool, + disabled_flag_returned: bool, +) -> None: # Given environment.hide_disabled_flags = environment_value environment.project.hide_disabled_flags = project_value @@ -200,7 +215,9 @@ def test_environment_get_feature_states_hide_disabled_flags( assert len(feature_states) == (2 if disabled_flag_returned else 1) -def test_environment_get_feature_state(environment, feature_1): +def test_environment_get_feature_state( + environment: EnvironmentModel, feature_1: FeatureModel +) -> None: # When feature_state = get_environment_feature_state(environment, feature_1.name) @@ -208,6 +225,8 @@ def test_environment_get_feature_state(environment, feature_1): assert feature_state.feature == feature_1 -def test_environment_get_feature_state_raises_feature_state_not_found(environment): +def test_environment_get_feature_state_raises_feature_state_not_found( + environment: EnvironmentModel, +) -> None: with pytest.raises(FeatureStateNotFound): get_environment_feature_state(environment, "not_a_feature_name") diff --git a/tests/unit/utils/json/test_encoders.py b/tests/unit/utils/json/test_encoders.py index 11909617..b1f1b4c3 100644 --- a/tests/unit/utils/json/test_encoders.py +++ b/tests/unit/utils/json/test_encoders.py @@ -7,7 +7,7 @@ from flag_engine.utils.json.encoders import DecimalEncoder -def test_decimal_encoder_converts_decimal(): +def test_decimal_encoder_converts_decimal() -> None: # Given data = { "int_decimal": Decimal(1), diff --git a/tests/unit/utils/test_utils_datetime.py b/tests/unit/utils/test_utils_datetime.py index cee32b32..932df6b6 100644 --- a/tests/unit/utils/test_utils_datetime.py +++ b/tests/unit/utils/test_utils_datetime.py @@ -3,7 +3,7 @@ from flag_engine.utils.datetime import utcnow_with_tz -def test_utcnow_with_tz_returns_time_with_utc_timezone(): +def test_utcnow_with_tz_returns_time_with_utc_timezone() -> None: # When now = utcnow_with_tz() diff --git a/tests/unit/utils/test_utils_hashing.py b/tests/unit/utils/test_utils_hashing.py index 7a310936..20e3f39d 100644 --- a/tests/unit/utils/test_utils_hashing.py +++ b/tests/unit/utils/test_utils_hashing.py @@ -1,4 +1,5 @@ import itertools +import typing import uuid from unittest import mock @@ -17,8 +18,8 @@ ), ) def test_get_hashed_percentage_for_object_ids_is_number_between_0_inc_and_100_exc( - object_ids, -): + object_ids: typing.List[typing.Union[uuid.UUID, int, str]], +) -> None: assert 100 > get_hashed_percentage_for_object_ids(object_ids) >= 0 @@ -31,7 +32,9 @@ def test_get_hashed_percentage_for_object_ids_is_number_between_0_inc_and_100_ex [str(uuid.uuid4), str(uuid.uuid4())], ), ) -def test_get_hashed_percentage_for_object_ids_is_the_same_each_time(object_ids): +def test_get_hashed_percentage_for_object_ids_is_the_same_each_time( + object_ids: typing.List[typing.Union[uuid.UUID, int, str]], +) -> None: # When result_1 = get_hashed_percentage_for_object_ids(object_ids) result_2 = get_hashed_percentage_for_object_ids(object_ids) @@ -40,7 +43,7 @@ def test_get_hashed_percentage_for_object_ids_is_the_same_each_time(object_ids): assert result_1 == result_2 -def test_percentage_value_is_unique_for_different_identities(): +def test_percentage_value_is_unique_for_different_identities() -> None: # Given first_object_ids = [14, 106] second_object_ids = [53, 200] @@ -53,7 +56,7 @@ def test_percentage_value_is_unique_for_different_identities(): assert result_1 != result_2 -def test_get_hashed_percentage_for_object_ids_should_be_evenly_distributed(): +def test_get_hashed_percentage_for_object_ids_should_be_evenly_distributed() -> None: """ This test checks if the percentage value returned by the helper function returns evenly distributed values. @@ -93,7 +96,7 @@ def test_get_hashed_percentage_for_object_ids_should_be_evenly_distributed(): @mock.patch("flag_engine.utils.hashing.hashlib") -def test_get_hashed_percentage_does_not_return_1(mock_hashlib): +def test_get_hashed_percentage_does_not_return_1(mock_hashlib: mock.Mock) -> None: """ Quite complex test to ensure that the function will never return 1. @@ -113,7 +116,7 @@ def test_get_hashed_percentage_does_not_return_1(mock_hashlib): hash_string_to_return_0 = "270f" hashed_values = [hash_string_to_return_0, hash_string_to_return_1] - def hexdigest_side_effect(): + def hexdigest_side_effect() -> str: return hashed_values.pop() mock_hash = mock.MagicMock() diff --git a/tox.ini b/tox.ini index 7e4f25ba..687dbf64 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py36, py37, py38, py39, black, flake8 +envlist = py38, py39, black, flake8 skip_missing_interpreters = True [testenv] @@ -19,3 +19,7 @@ commands = black --check flag_engine/ tests/ [testenv:flake8] deps = flake8 commands = flake8 tests/ flag_engine/ + +[testenv:mypy] +deps = mypy +commands = mypy . --strict