Skip to content

Commit

Permalink
More type coverage (#634)
Browse files Browse the repository at this point in the history
* More type coverage

* Fix typing for model param/opt; use mapped_columns in tests

* Add overloads to handle instance argument

* Fix compat with sqlalchemy 1.4

* Update changelog

* Update import in recipes
  • Loading branch information
sloria authored Jan 9, 2025
1 parent 38f8230 commit ffe1f19
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 69 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Changelog

Features:

* Typing: Improve type coverage (:pr:`631`).
* Typing: Improve type coverage (:pr:`631`, :pr:`632`, :pr:`634`).

Other changes:

Expand Down
16 changes: 13 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@ Declare your models
.. code-block:: python
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker, relationship, backref
from sqlalchemy.orm import (
DeclarativeBase,
backref,
relationship,
scoped_session,
sessionmaker,
)
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
engine = sa.create_engine("sqlite:///:memory:")
session = scoped_session(sessionmaker(bind=engine))
Base = declarative_base()
class Base(DeclarativeBase):
pass
class Author(Base):
Expand Down
16 changes: 13 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@ Declare your models
.. code-block:: python
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker, relationship, backref
from sqlalchemy.orm import (
DeclarativeBase,
backref,
relationship,
scoped_session,
sessionmaker,
)
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
engine = sa.create_engine("sqlite:///:memory:")
session = scoped_session(sessionmaker(bind=engine))
Base = declarative_base()
class Base(DeclarativeBase):
pass
class Author(Base):
Expand Down
3 changes: 1 addition & 2 deletions docs/recipes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ An example of then using this:
.. code-block:: python
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy import event
from sqlalchemy.orm import mapper
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ warn_unreachable = true
warn_unused_ignores = true
warn_redundant_casts = true
no_implicit_optional = true

[[tool.mypy.overrides]]
module = "tests.*"
check_untyped_defs = true
83 changes: 74 additions & 9 deletions src/marshmallow_sqlalchemy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload

import marshmallow as ma
import sqlalchemy as sa
Expand All @@ -17,9 +17,10 @@

if TYPE_CHECKING:
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import MapperProperty
from sqlalchemy.types import TypeEngine

PropertyOrColumn = sa.orm.MapperProperty | sa.Column
PropertyOrColumn = MapperProperty | sa.Column

_FieldClassFactory = Callable[[Any, Any], type[fields.Field]]

Expand Down Expand Up @@ -122,7 +123,7 @@ class ModelConverter:
}
DIRECTION_MAPPING = {"MANYTOONE": False, "MANYTOMANY": True, "ONETOMANY": True}

def __init__(self, schema_cls: ma.Schema | None = None):
def __init__(self, schema_cls: type[ma.Schema] | None = None):
self.schema_cls = schema_cls

@property
Expand All @@ -134,7 +135,7 @@ def type_mapping(self) -> dict[type, type[fields.Field]]:

def fields_for_model(
self,
model: DeclarativeMeta,
model: type[DeclarativeMeta],
*,
include_fk: bool = False,
include_relationships: bool = False,
Expand All @@ -146,7 +147,7 @@ def fields_for_model(
result = dict_cls()
base_fields = base_fields or {}

for prop in sa.inspect(model).attrs:
for prop in sa.inspect(model).attrs: # type: ignore[union-attr]
key = self._get_field_name(prop)
if self._should_exclude_field(prop, fields=fields, exclude=exclude):
# Allow marshmallow to validate and exclude the field key.
Expand Down Expand Up @@ -197,9 +198,29 @@ def fields_for_table(
result[key] = field
return result

@overload
def property2field(
self,
prop,
prop: MapperProperty,
*,
instance: Literal[True] = ...,
field_class: type[fields.Field] | None = ...,
**kwargs,
) -> fields.Field: ...

@overload
def property2field(
self,
prop: MapperProperty,
*,
instance: Literal[False] = ...,
field_class: type[fields.Field] | None = ...,
**kwargs,
) -> type[fields.Field]: ...

def property2field(
self,
prop: MapperProperty,
*,
instance: bool = True,
field_class: type[fields.Field] | None = None,
Expand Down Expand Up @@ -228,6 +249,16 @@ def property2field(
ret = RelatedList(ret, **related_list_kwargs)
return ret

@overload
def column2field(
self, column, *, instance: Literal[True] = ..., **kwargs
) -> fields.Field: ...

@overload
def column2field(
self, column, *, instance: Literal[False] = ..., **kwargs
) -> type[fields.Field]: ...

def column2field(
self, column, *, instance: bool = True, **kwargs
) -> fields.Field | type[fields.Field]:
Expand All @@ -239,8 +270,36 @@ def column2field(
_field_update_kwargs(field_class, field_kwargs, kwargs)
return field_class(**field_kwargs)

@overload
def field_for(
self,
model: type[DeclarativeMeta],
property_name: str,
*,
instance: Literal[True] = ...,
field_class: type[fields.Field] | None = ...,
**kwargs,
) -> fields.Field: ...

@overload
def field_for(
self, model: DeclarativeMeta, property_name: str, **kwargs
self,
model: type[DeclarativeMeta],
property_name: str,
*,
instance: Literal[False] = ...,
field_class: type[fields.Field] | None = None,
**kwargs,
) -> type[fields.Field]: ...

def field_for(
self,
model: type[DeclarativeMeta],
property_name: str,
*,
instance: bool = True,
field_class: type[fields.Field] | None = None,
**kwargs,
) -> fields.Field | type[fields.Field]:
target_model = model
prop_name = property_name
Expand All @@ -250,8 +309,14 @@ def field_for(
target_model = attr.target_class
prop_name = attr.value_attr
remote_with_local_multiplicity = attr.local_attr.prop.uselist
prop = sa.inspect(target_model).attrs.get(prop_name)
converted_prop = self.property2field(prop, **kwargs)
prop: MapperProperty = sa.inspect(target_model).attrs.get(prop_name) # type: ignore[union-attr]
converted_prop = self.property2field(
prop,
# To satisfy type checking, need to pass a literal bool
instance=True if instance else False,
field_class=field_class,
**kwargs,
)
if remote_with_local_multiplicity:
related_list_kwargs = _field_update_kwargs(
RelatedList, self.get_base_kwargs(), kwargs
Expand Down
55 changes: 44 additions & 11 deletions src/marshmallow_sqlalchemy/load_instance_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,59 @@
Users should not need to use this module directly.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

import marshmallow as ma
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm.exc import ObjectDeletedError

from .fields import get_primary_keys

if TYPE_CHECKING:
from sqlalchemy.orm import Session

_ModelType = TypeVar("_ModelType", bound=DeclarativeMeta)


class LoadInstanceMixin:
class Opts:
model: type[DeclarativeMeta] | None
sqla_session: Session | None
load_instance: bool
transient: bool

def __init__(self, meta, *args, **kwargs):
super().__init__(meta, *args, **kwargs)
self.model = getattr(meta, "model", None)
self.sqla_session = getattr(meta, "sqla_session", None)
self.load_instance = getattr(meta, "load_instance", False)
self.transient = getattr(meta, "transient", False)

class Schema:
class Schema(Generic[_ModelType]):
opts: LoadInstanceMixin.Opts
instance: _ModelType | None
_session: Session | None
_transient: bool | None
_load_instance: bool

@property
def session(self):
def session(self) -> Session | None:
return self._session or self.opts.sqla_session

@session.setter
def session(self, session):
def session(self, session: Session) -> None:
self._session = session

@property
def transient(self):
def transient(self) -> bool:
if self._transient is not None:
return self._transient
return self.opts.transient

@transient.setter
def transient(self, transient):
def transient(self, transient: bool) -> None:
self._transient = transient

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -64,7 +86,7 @@ def get_instance(self, data):
return None

@ma.post_load
def make_instance(self, data, **kwargs):
def make_instance(self, data, **kwargs) -> _ModelType:
"""Deserialize data to an instance of the model if self.load_instance is True.
Update an existing row if specified in `self.instance` or loaded by primary
Expand All @@ -80,12 +102,21 @@ def make_instance(self, data, **kwargs):
setattr(instance, key, value)
return instance
kwargs, association_attrs = self._split_model_kwargs_association(data)
instance = self.opts.model(**kwargs)
ModelClass = cast(DeclarativeMeta, self.opts.model)
instance = ModelClass(**kwargs)
for attr, value in association_attrs.items():
setattr(instance, attr, value)
return instance

def load(self, data, *, session=None, instance=None, transient=False, **kwargs):
def load(
self,
data,
*,
session: Session | None = None,
instance: _ModelType | None = None,
transient: bool = False,
**kwargs,
) -> Any:
"""Deserialize data to internal representation.
:param session: Optional SQLAlchemy session.
Expand All @@ -98,15 +129,17 @@ def load(self, data, *, session=None, instance=None, transient=False, **kwargs):
raise ValueError("Deserialization requires a session")
self.instance = instance or self.instance
try:
return super().load(data, **kwargs)
return cast(ma.Schema, super()).load(data, **kwargs)
finally:
self.instance = None

def validate(self, data, *, session=None, **kwargs):
def validate(
self, data, *, session: Session | None = None, **kwargs
) -> dict[str, list[str]]:
self._session = session or self._session
if not (self.transient or self.session):
raise ValueError("Validation requires a session")
return super().validate(data, **kwargs)
return cast(ma.Schema, super()).validate(data, **kwargs)

def _split_model_kwargs_association(self, data):
"""Split serialized attrs to ensure association proxies are passed separately.
Expand Down
Loading

0 comments on commit ffe1f19

Please sign in to comment.