Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mixin for int enum field custom #1781

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Fixed
- Fix unable to use ManyToManyField if OneToOneField passed as Primary Key (#1783)
- Fix sorting by Term (e.g. RawSQL) (#1788)

Added
^^^^^
- Add IntEnumFieldMixin to make it easy to Custom IntEnumField (#1781)

0.22.0
------
Fixed
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ docs: deps
_style:
isort -src $(checkfiles)
black $(checkfiles)
style: _style deps
style: deps _style

build: deps
rm -fR dist/
Expand Down
26 changes: 23 additions & 3 deletions examples/enum_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum, IntEnum
from typing import Type, cast

from tortoise import Tortoise, fields, run_async
from tortoise.fields.data import IntEnumFieldMixin, IntEnumType, IntField
from tortoise.models import Model


Expand All @@ -16,18 +18,36 @@ class Currency(str, Enum):
USD = "USD"


class Protocol(IntEnum):
A = 10000
B = 80000 # >32767, beyond the 'le' value in fields.IntEnumFieldInstance.constraints


class IntEnumFieldInstance(IntEnumFieldMixin, IntField):
pass


def IntEnumField(enum_type: Type[IntEnumType], **kwargs) -> IntEnumType:
return cast(IntEnumType, IntEnumFieldInstance(enum_type, **kwargs))


class EnumFields(Model):
service: Service = fields.IntEnumField(Service)
currency: Currency = fields.CharEnumField(Currency, default=Currency.HUF)
# When each value of the enum_type is between [-32768, 32767], use the fields.IntEnumField
service: Service = fields.IntEnumField(Service)
# Else, you can use a custom Field
protocol: Protocol = IntEnumField(Protocol)


async def run():
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
await Tortoise.generate_schemas()

obj0 = await EnumFields.create(service=Service.python_programming, currency=Currency.USD)
obj0 = await EnumFields.create(
service=Service.python_programming, currency=Currency.USD, protocol=Protocol.A
)
# also you can use valid int and str value directly
await EnumFields.create(service=1, currency="USD")
await EnumFields.create(service=1, currency="USD", protocol=Protocol.B.value)

try:
# invalid enum value will raise ValueError
Expand Down
12 changes: 12 additions & 0 deletions tests/fields/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import IntEnum

from examples.enum_fields import IntEnumField as CustomIntEnumField
from tests import testmodels
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError, IntegrityError
Expand All @@ -24,6 +25,12 @@ class BadIntEnumIfGenerated(IntEnum):
system_administration = 3


class BadCustomIntEnum(IntEnum):
python_programming = 2147483648
database_design = 2
system_administration = 3


class TestIntEnumFields(test.TestCase):
async def test_empty(self):
with self.assertRaises(IntegrityError):
Expand Down Expand Up @@ -118,6 +125,11 @@ def test_manual_description(self):
fld = IntEnumField(testmodels.Service, description="foo")
self.assertEqual(fld.description, "foo")

def test_custom_int_enum_field(self):
assert CustomIntEnumField(BadIntEnum1)
with self.assertRaises(ConfigurationError):
CustomIntEnumField(BadCustomIntEnum)


class TestCharEnumFields(test.TestCase):
async def test_create(self):
Expand Down
32 changes: 22 additions & 10 deletions tortoise/fields/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import datetime
import functools
import json
import warnings
from decimal import Decimal
from enum import Enum, IntEnum
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union, cast
from uuid import UUID, uuid4

from pypika import functions
Expand Down Expand Up @@ -668,7 +670,10 @@ class _db_mssql:
SQL_TYPE = "VARBINARY(MAX)"


class IntEnumFieldInstance(SmallIntField):
class IntEnumFieldMixin:
validate: Callable
constraints: dict[str, int]

def __init__(
self,
enum_type: Type[IntEnum],
Expand All @@ -677,27 +682,30 @@ def __init__(
**kwargs: Any,
) -> None:
# Validate values
minimum = 1 if generated else -32768
minimum = 1 if generated else self.constraints["ge"]
maximum = self.constraints["le"]
for item in enum_type:
try:
value = int(item.value)
except ValueError:
raise ConfigurationError("IntEnumField only supports integer enums!")
if not minimum <= value < 32768:
if not minimum <= value <= maximum:
# To extend value range, see: https://tortoise.github.io/examples/basic.html#enumeration-fields
raise ConfigurationError(
"The valid range of IntEnumField's values is {}..32767!".format(minimum)
f"The valid range of IntEnumField's values is {minimum}..{maximum}!"
)

# Automatic description for the field if not specified by the user
if description is None:
description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[:2048]

super().__init__(description=description, **kwargs)
super().__init__(
description=description, generated=generated, **kwargs
) # type:ignore[call-arg]
self.enum_type = enum_type

def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]:
value = self.enum_type(value) if value is not None else None
return value
return self.enum_type(value) if value is not None else None

def to_db_value(
self, value: Union[IntEnum, None, int], instance: "Union[Type[Model], Model]"
Expand All @@ -713,6 +721,10 @@ def to_db_value(
IntEnumType = TypeVar("IntEnumType", bound=IntEnum)


class IntEnumFieldInstance(IntEnumFieldMixin, SmallIntField):
pass


def IntEnumField(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really confusing that IntEnumField actually returns SmallIntField :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but mostly, IntEnumField used in the following case:

class Gender(IntEnum):
      male = 0
      female = 1
      other = 2

class Users(Model):
    gender = fields.IntEnumField(Gender)

And I wish it to be small int in db, not int.

enum_type: Type[IntEnumType],
description: Optional[str] = None,
Expand All @@ -735,7 +747,7 @@ def IntEnumField(
of "name: value" pairs.

"""
return IntEnumFieldInstance(enum_type, description, **kwargs) # type: ignore
return cast("IntEnumType", IntEnumFieldInstance(enum_type, description, **kwargs))


class CharEnumFieldInstance(CharField):
Expand Down Expand Up @@ -805,4 +817,4 @@ def CharEnumField(

"""

return CharEnumFieldInstance(enum_type, description, max_length, **kwargs) # type: ignore
return cast("CharEnumType", CharEnumFieldInstance(enum_type, description, max_length, **kwargs))