Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mixin for int enum field custom #1781

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Changelog
0.22
====

0.22.1
------

Added
^^^^^
- Add IntEnumFieldMixin to make it easy to Custom IntEnumField (#1781)

0.22.0
------
Fixed
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ docs: deps
_style:
isort -src $(checkfiles)
black $(checkfiles)
style: _style deps
style: deps _style

build: deps
rm -fR dist/
Expand Down
26 changes: 23 additions & 3 deletions examples/enum_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum, IntEnum
from typing import Type, cast

from tortoise import Tortoise, fields, run_async
from tortoise.fields.data import IntEnumFieldMixin, IntEnumType, IntField
from tortoise.models import Model


Expand All @@ -16,18 +18,36 @@ class Currency(str, Enum):
USD = "USD"


class Protocol(IntEnum):
A = 10000
B = 80000 # >32767, beyond the value range of fields.IntEnumField


class IntEnumInstance(IntEnumFieldMixin, IntField):
pass


def IntEnumField(enum_type: Type[IntEnumType], **kwargs) -> IntEnumType:
return cast(IntEnumType, IntEnumInstance(enum_type, **kwargs))


class EnumFields(Model):
service: Service = fields.IntEnumField(Service)
currency: Currency = fields.CharEnumField(Currency, default=Currency.HUF)
# When each value of the enum_type is between [-32768, 32767], use the fields.IntEnumField
service: Service = fields.IntEnumField(Service)
# Else, you can use a custom Field
protocol: Protocol = IntEnumField(Protocol)


async def run():
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
await Tortoise.generate_schemas()

obj0 = await EnumFields.create(service=Service.python_programming, currency=Currency.USD)
obj0 = await EnumFields.create(
service=Service.python_programming, currency=Currency.USD, protocol=Protocol.A
)
# also you can use valid int and str value directly
await EnumFields.create(service=1, currency="USD")
await EnumFields.create(service=1, currency="USD", protocol=Protocol.B.value)

try:
# invalid enum value will raise ValueError
Expand Down
14 changes: 13 additions & 1 deletion tests/fields/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import IntEnum

from examples.enum_fields import IntEnumField as CustomIntEnumField
from tests import testmodels
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError, IntegrityError
from tortoise.fields import CharEnumField, IntEnumField
from tortoise.fields import CharEnumField, IntEnumField, IntField


class BadIntEnum1(IntEnum):
Expand All @@ -24,6 +25,12 @@ class BadIntEnumIfGenerated(IntEnum):
system_administration = 3


class BadCustomIntEnum(IntEnum):
python_programming = IntField.VALUE_RANGE[1] + 1
database_design = 2
system_administration = 3


class TestIntEnumFields(test.TestCase):
async def test_empty(self):
with self.assertRaises(IntegrityError):
Expand Down Expand Up @@ -118,6 +125,11 @@ def test_manual_description(self):
fld = IntEnumField(testmodels.Service, description="foo")
self.assertEqual(fld.description, "foo")

def test_custom_int_enum_field(self):
assert CustomIntEnumField(BadIntEnum1)
with self.assertRaises(ConfigurationError):
CustomIntEnumField(BadCustomIntEnum)


class TestCharEnumFields(test.TestCase):
async def test_create(self):
Expand Down
Loading