|
12 | 12 | # limitations under the License. |
13 | 13 |
|
14 | 14 | from collections.abc import Callable, Generator, Iterator |
15 | | -from contextlib import contextmanager |
| 15 | +from contextlib import asynccontextmanager, contextmanager |
16 | 16 | from contextvars import ContextVar |
17 | | -from typing import Any, ClassVar, cast |
| 17 | +from typing import Any, AsyncIterator, ClassVar, Mapping, cast |
18 | 18 | from uuid import uuid4 |
19 | 19 |
|
20 | 20 | import structlog |
| 21 | +from pydantic import PostgresDsn |
21 | 22 | from sqlalchemy import create_engine |
22 | 23 | from sqlalchemy import inspect as sa_inspect |
| 24 | +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
23 | 25 | from sqlalchemy.ext.declarative import DeclarativeMeta |
24 | 26 | from sqlalchemy.orm import Query, Session, as_declarative, scoped_session, sessionmaker |
25 | 27 | from sqlalchemy.sql.schema import MetaData |
@@ -171,6 +173,7 @@ def commit(self) -> None: |
171 | 173 | "json_deserializer": json_loads, |
172 | 174 | } |
173 | 175 | SESSION_ARGUMENTS = {"class_": WrappedSession, "autocommit": False, "autoflush": True, "query_cls": SearchQuery} |
| 176 | +ASYNC_SESSION_ARGUMENTS = {"class_": AsyncSession, "autocommit": False, "autoflush": True, "query_cls": SearchQuery} |
174 | 177 |
|
175 | 178 |
|
176 | 179 | class Database: |
@@ -220,10 +223,23 @@ def database_scope(self, **kwargs: Any) -> Generator["Database", None, None]: |
220 | 223 | self.request_context.reset(token) |
221 | 224 |
|
222 | 225 |
|
| 226 | +class AsyncDatabase: |
| 227 | + def __init__(self, db_url: PostgresDsn | str): |
| 228 | + self.engine = create_async_engine(str(db_url), **ENGINE_ARGUMENTS) |
| 229 | + self.session_factory = async_sessionmaker(bind=self.engine, expire_on_commit=False, **ASYNC_SESSION_ARGUMENTS) # type: ignore[call-overload] |
| 230 | + |
| 231 | + @asynccontextmanager |
| 232 | + async def session(self, **kwargs: Mapping) -> AsyncIterator[AsyncSession]: |
| 233 | + async with self.session_factory() as session: |
| 234 | + yield session |
| 235 | + |
| 236 | + async def dispose(self) -> None: |
| 237 | + await self.engine.dispose() |
| 238 | + |
| 239 | + |
223 | 240 | class DBSessionMiddleware(BaseHTTPMiddleware): |
224 | 241 | def __init__(self, app: ASGIApp, database: Database, commit_on_exit: bool = False): |
225 | 242 | super().__init__(app) |
226 | | - self.commit_on_exit = commit_on_exit |
227 | 243 | self.database = database |
228 | 244 |
|
229 | 245 | async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: |
|
0 commit comments