Skip to content

Commit

Permalink
feat: optimize type adapter call & multi-config fixes (#229)
Browse files Browse the repository at this point in the history
* feat: optimize type adapter call

* feat: attempt to reuse TypeAdapters

* feat: ensure additional keys are unique

* feat: ensure additional keys are unique

* fix: reduce duplication

* feat: remove unused placeholder
  • Loading branch information
cofin committed Jul 8, 2024
1 parent 045997a commit 9463ae2
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 172 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/provinzkraut/unasyncd
rev: "v0.7.2"
rev: "v0.7.3"
hooks:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.5.0"
rev: "v0.5.1"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
8 changes: 6 additions & 2 deletions advanced_alchemy/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,12 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
This is a listener that will update ``created_at`` and ``updated_at`` columns on record modification.
Disable if you plan to bring your own update mechanism for these columns"""
_KEY_REGISTRY: ClassVar[set[str]] = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of the class."""
_SESSION_SCOPE_KEY_REGISTRY: ClassVar[set[str]] = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of session scope keys in the class."""
_ENGINE_APP_STATE_KEY_REGISTRY: ClassVar[set[str]] = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of engine app state keys in the class."""
_SESSIONMAKER_APP_STATE_KEY_REGISTRY: ClassVar[set[str]] = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of sessionmaker state keys in the class."""

def __post_init__(self) -> None:
if self.connection_string is not None and self.engine_instance is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,23 @@ class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
The configuration options are documented in the SQLAlchemy documentation.
"""

def _ensure_unique_session_scope_key(self, key: str, _iter: int = 0) -> str:
if key in self.__class__._KEY_REGISTRY: # noqa: SLF001
def _ensure_unique(self, registry_name: str, key: str, new_key: str | None = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
key = self._ensure_unique_session_scope_key(f"{key}_{_iter}", _iter)
return key
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key

def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique_session_scope_key(self.session_scope_key)
self.__class__._KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
Expand Down Expand Up @@ -258,3 +266,11 @@ def create_app_state_items(self) -> dict[str, Any]:
self.engine_app_state_key: self.get_engine(),
self.session_maker_app_state_key: self.create_session_maker(),
}

def update_app_state(self, app: Litestar) -> None:
"""Set the app state with engine and session.
Args:
app: The ``Litestar`` instance.
"""
app.state.update(self.create_app_state_items())
20 changes: 14 additions & 6 deletions advanced_alchemy/extensions/litestar/plugins/init/config/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,23 @@ class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
The configuration options are documented in the SQLAlchemy documentation.
"""

def _ensure_unique_session_scope_key(self, key: str, _iter: int = 0) -> str:
if key in self.__class__._KEY_REGISTRY: # noqa: SLF001
def _ensure_unique(self, registry_name: str, key: str, new_key: str | None = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
key = self._ensure_unique_session_scope_key(f"{key}_{_iter}", _iter)
return key
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key

def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique_session_scope_key(self.session_scope_key)
self.__class__._KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
Expand Down
7 changes: 3 additions & 4 deletions advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,9 @@ def append_to_lambda_statement(
model: type[ModelT],
) -> StatementLambdaElement:
field = self._get_instrumented_attr(model, self.field_name)
if self.sort_order == "desc":
statement += lambda s: s.order_by(field.desc()) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
return statement
statement += lambda s: s.order_by(field.asc()) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
fragment = field.desc() if self.sort_order == "desc" else field.asc()
statement += lambda s: s.order_by(fragment) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]

return statement


Expand Down
12 changes: 8 additions & 4 deletions advanced_alchemy/service/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TYPE_CHECKING,
Any,
Callable,
List,
Sequence,
cast,
overload,
Expand All @@ -28,8 +29,8 @@
BaseModel,
ModelDTOT,
Struct,
TypeAdapter,
convert,
get_type_adapter,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -193,7 +194,7 @@ def to_schema(
return OffsetPagination[ModelDTOT](
items=convert(
obj=data,
type=Sequence[schema_type], # type: ignore[valid-type]
type=List[schema_type], # type: ignore[valid-type]
from_attributes=True,
dec_hook=partial(
_default_msgspec_deserializer,
Expand All @@ -209,12 +210,15 @@ def to_schema(

if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel):
if not isinstance(data, Sequence):
return cast("ModelDTOT", TypeAdapter(schema_type).validate_python(data, from_attributes=True)) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
return cast(
"ModelDTOT",
get_type_adapter(schema_type).validate_python(data, from_attributes=True),
) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
limit_offset = find_filter(LimitOffset, filters=filters)
total = total if total else len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelDTOT](
items=TypeAdapter(Sequence[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
items=get_type_adapter(List[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
Expand Down
37 changes: 34 additions & 3 deletions advanced_alchemy/service/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from functools import lru_cache
from typing import (
Any,
ClassVar,
Expand All @@ -20,11 +21,12 @@
runtime_checkable,
)

from typing_extensions import TypeAlias, TypeGuard
from typing_extensions import Annotated, TypeAlias, TypeGuard

from advanced_alchemy.filters import StatementFilter # noqa: TCH001
from advanced_alchemy.repository.typing import ModelT

T = TypeVar("T") # pragma: nocover
try:
from pydantic import BaseModel # pyright: ignore[reportAssignmentType]
from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportUnusedImport, reportAssignmentType]
Expand All @@ -42,8 +44,6 @@ def model_dump(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Placeholder"""
return {}

T = TypeVar("T") # pragma: nocover

class TypeAdapter(Generic[T]): # type: ignore[no-redef] # pragma: nocover
"""Placeholder Implementation"""

Expand All @@ -56,6 +56,34 @@ def validate_python(self, data: Any, *args: Any, **kwargs: Any) -> T: # pragma:

PYDANTIC_INSTALLED: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003

try:
# this is from pydantic 2.8. We should check for it before using it.
from pydantic import FailFast # pyright: ignore[reportAssignmentType]

PYDANTIC_USE_FAILFAST: Final[bool] = True
except ImportError:

class FailFast: # type: ignore[no-redef] # pragma: nocover
"""Placeholder Implementation for FailFast"""

def __init__(self, *args: Any, **kwargs: Any) -> None: # pragma: nocover
"""Init"""

def __call__(self, *args: Any, **kwargs: Any) -> None: # pragma: nocover
"""Placeholder"""

PYDANTIC_USE_FAILFAST: Final[bool] = False # type: ignore # pyright: ignore[reportConstantRedefinition,reportGeneralTypeIssues] # noqa: PGH003


@lru_cache(typed=True)
def get_type_adapter(f: type[T]) -> TypeAdapter[T]:
"""Caches and returns a pydantic type adapter"""
if PYDANTIC_USE_FAILFAST:
return TypeAdapter(
Annotated[f, FailFast()], # type: ignore[operator]
)
return TypeAdapter(f)


try:
from msgspec import UNSET, Struct, UnsetType, convert # pyright: ignore[reportAssignmentType,reportUnusedImport]
Expand Down Expand Up @@ -147,8 +175,11 @@ def schema_dump(
"PydanticOrMsgspecT",
"PYDANTIC_INSTALLED",
"MSGSPEC_INSTALLED",
"PYDANTIC_USE_FAILFAST",
"BaseModel",
"TypeAdapter",
"get_type_adapter",
"FailFast",
"Struct",
"convert",
"UNSET",
Expand Down
Loading

0 comments on commit 9463ae2

Please sign in to comment.