Skip to content

Commit

Permalink
Correctly handle postgres arrays of Enums multidimensional arrays (#654)
Browse files Browse the repository at this point in the history
* Correctly handle postgres arrays of Enums multidimensional arrays

* Update changelog

* Fix py39 compat

* Fix types

* Also handle standard sa.ARRAY columns
  • Loading branch information
sloria authored Jan 19, 2025
1 parent 8f92b07 commit d42a076
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

Bug fixes:

* Fix handling of arrays of enums and multidimensional arrays (:issue:`653`).
Thanks :user:`carterjc` for reporting and investigating the fix.
* Fix handling of `sqlalchemy.PickleType` columns (:issue:`394`)
Thanks :user:`Eyon42` for reporting.

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ classifiers = [
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
dependencies = ["marshmallow>=3.18.0", "SQLAlchemy>=1.4.40,<3.0"]
dependencies = [
"marshmallow>=3.18.0",
"SQLAlchemy>=1.4.40,<3.0",
"typing-extensions; python_version < '3.10'",
]

[project.urls]
Changelog = "https://marshmallow-sqlalchemy.readthedocs.io/en/latest/changelog.html"
Expand Down
68 changes: 52 additions & 16 deletions src/marshmallow_sqlalchemy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,46 @@
import inspect
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Union,
cast,
overload,
)

# Remove when dropping Python 3.9
try:
from typing import TypeAlias, TypeGuard
except ImportError:
from typing_extensions import TypeAlias, TypeGuard

import marshmallow as ma
import sqlalchemy as sa
from marshmallow import fields, validate
from sqlalchemy.dialects import mssql, mysql, postgresql
from sqlalchemy.orm import SynonymProperty
from sqlalchemy.types import TypeEngine

from .exceptions import ModelConversionError
from .fields import Related, RelatedList

if TYPE_CHECKING:
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import MapperProperty
from sqlalchemy.types import TypeEngine

PropertyOrColumn = MapperProperty | sa.Column

_FieldClassFactory = Callable[[Any, Any], type[fields.Field]]
_FieldPartial: TypeAlias = Callable[[], fields.Field]
# TODO: Use more specific type for second argument
_FieldClassFactory: TypeAlias = Callable[
["ModelConverter", Any], Union[type[fields.Field], _FieldPartial]
]


def _is_field(value) -> bool:
def _is_field(value: Any) -> TypeGuard[type[fields.Field]]:
return isinstance(value, type) and issubclass(value, fields.Field)


Expand All @@ -49,17 +67,30 @@ def _is_auto_increment(column) -> bool:
return column.table is not None and column is column.table._autoincrement_column


def _postgres_array_factory(converter: ModelConverter, data_type: postgresql.ARRAY):
return functools.partial(
fields.List,
converter._get_field_class_for_data_type(data_type.item_type),
)
def _list_field_factory(
converter: ModelConverter, data_type: postgresql.ARRAY
) -> Callable[[], fields.List]:
FieldClass = converter._get_field_class_for_data_type(data_type.item_type)
inner = FieldClass()
if not data_type.dimensions or data_type.dimensions == 1:
return functools.partial(fields.List, inner)

# For multi-dimensional arrays, nest the Lists
dimensions = data_type.dimensions
for _ in range(dimensions - 1):
inner = fields.List(inner)

return functools.partial(fields.List, inner)


def _enum_field_factory(
converter: ModelConverter, data_type: sa.Enum
) -> type[fields.Field]:
return fields.Enum if data_type.enum_class else fields.Raw
) -> Callable[[], fields.Field]:
return (
functools.partial(fields.Enum, enum=data_type.enum_class)
if data_type.enum_class
else fields.Raw
)


class ModelConverter:
Expand All @@ -72,6 +103,7 @@ class ModelConverter:
] = {
sa.Enum: _enum_field_factory,
sa.JSON: fields.Raw,
sa.ARRAY: _list_field_factory,
sa.PickleType: fields.Raw,
postgresql.BIT: fields.Integer,
postgresql.OID: fields.Integer,
Expand All @@ -82,7 +114,7 @@ class ModelConverter:
postgresql.JSON: fields.Raw,
postgresql.JSONB: fields.Raw,
postgresql.HSTORE: fields.Raw,
postgresql.ARRAY: _postgres_array_factory,
postgresql.ARRAY: _list_field_factory,
postgresql.MONEY: fields.Decimal,
postgresql.DATE: fields.Date,
postgresql.TIME: fields.Time,
Expand Down Expand Up @@ -335,14 +367,18 @@ def _get_field_class_for_column(self, column: sa.Column) -> type[fields.Field]:
def _get_field_class_for_data_type(
self, data_type: TypeEngine
) -> type[fields.Field]:
field_cls = None
field_cls: type[fields.Field] | _FieldPartial | None = None
types = inspect.getmro(type(data_type))
# First search for a field class from self.SQLA_TYPE_MAPPING
for col_type in types:
if col_type in self.SQLA_TYPE_MAPPING:
field_cls = self.SQLA_TYPE_MAPPING[col_type]
if callable(field_cls) and not _is_field(field_cls):
field_cls = cast(_FieldClassFactory, field_cls)(self, data_type)
field_or_factory = self.SQLA_TYPE_MAPPING[col_type]
if _is_field(field_or_factory):
field_cls = field_or_factory
else:
field_cls = cast(_FieldClassFactory, field_or_factory)(
self, data_type
)
break
else:
# Try to find a field class based on the column's python_type
Expand Down
40 changes: 40 additions & 0 deletions tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,46 @@ def test_convert_ARRAY_Integer(self, converter):
inner_field = getattr(field, "inner", getattr(field, "container", None))
assert type(inner_field) is fields.Int

@pytest.mark.parametrize(
"array_property",
(
pytest.param(make_property(sa.ARRAY(sa.Enum(CourseLevel))), id="sa.ARRAY"),
pytest.param(
make_property(postgresql.ARRAY(sa.Enum(CourseLevel))),
id="postgresql.ARRAY",
),
),
)
def test_convert_ARRAY_Enum(self, converter, array_property):
field = converter.property2field(array_property)
assert type(field) is fields.List
inner_field = field.inner
assert type(inner_field) is fields.Enum

@pytest.mark.parametrize(
"array_property",
(
pytest.param(
make_property(sa.ARRAY(sa.Float, dimensions=2)), id="sa.ARRAY"
),
pytest.param(
make_property(postgresql.ARRAY(sa.Float, dimensions=2)),
id="postgresql.ARRAY",
),
),
)
def test_convert_multidimensional_ARRAY(self, converter, array_property):
field = converter.property2field(array_property)
assert type(field) is fields.List
assert type(field.inner) is fields.List
assert type(field.inner.inner) is fields.Float

def test_convert_one_dimensional_ARRAY(self, converter):
prop = make_property(postgresql.ARRAY(sa.Float, dimensions=1))
field = converter.property2field(prop)
assert type(field) is fields.List
assert type(field.inner) is fields.Float

def test_convert_TSVECTOR(self, converter):
prop = make_property(postgresql.TSVECTOR)
with pytest.raises(ModelConversionError):
Expand Down

0 comments on commit d42a076

Please sign in to comment.