Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generation of Enum field #640

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Features:

* Typing: Add type annotations to `fields <marshmallow_sqlalchemy.fields>`.

Bug fixes:

* Fix auto-generation of `marshmallow.fields.Enum` field from `sqlalchemy.Enum` columns (:issue:`615`).
Thanks :user:`joaquimvl` for reporting.

Other changes:

* Docs: Add more documentation for `marshmallow_sqlalchemy.fields.Related` (:issue:`162`).
Expand Down
5 changes: 2 additions & 3 deletions src/marshmallow_sqlalchemy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,10 @@ def _add_column_kwargs(self, kwargs: dict[str, Any], column: sa.Column) -> None:
else:
kwargs["dump_only"] = True

if hasattr(column.type, "enums") and not kwargs.get("dump_only"):
kwargs["validate"].append(validate.OneOf(choices=column.type.enums))

if hasattr(column.type, "enum_class") and column.type.enum_class is not None:
kwargs["enum"] = column.type.enum_class
elif hasattr(column.type, "enums") and not kwargs.get("dump_only"):
kwargs["validate"].append(validate.OneOf(choices=column.type.enums))

# Add a length validator if a max length is set on the column
# Skip UUID columns
Expand Down
13 changes: 7 additions & 6 deletions tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from marshmallow_sqlalchemy.fields import Related, RelatedList

from .conftest import mapped_column
from .conftest import CourseLevel, mapped_column


def contains_validator(field, v_type):
Expand Down Expand Up @@ -71,17 +71,18 @@ def test_sets_allow_none_for_nullable_fields(self, models):
fields_ = fields_for_model(models.Student)
assert fields_["dob"].allow_none is True

def test_sets_enum_choices(self, models):
def test_enum_with_choices_converted_to_field_with_validator(self, models):
fields_ = fields_for_model(models.Course)
validator = contains_validator(fields_["level"], validate.OneOf)
assert validator
assert list(validator.choices) == ["Primary", "Secondary"]

def test_sets_enum_with_class_choices(self, models):
def test_enum_with_class_converted_to_enum_field(self, models):
fields_ = fields_for_model(models.Course)
validator = contains_validator(fields_["level_with_enum_class"], validate.OneOf)
assert validator
assert list(validator.choices) == ["PRIMARY", "SECONDARY"]
field = fields_["level_with_enum_class"]
assert type(field) is fields.Enum
assert contains_validator(field, validate.OneOf) is False
assert field.enum is CourseLevel

def test_many_to_many_relationship(self, models):
student_fields = fields_for_model(models.Student, include_relationships=True)
Expand Down
Loading