Skip to content

Commit 567e6e1

Browse files
committed
Working with async graphql db sessions; Tests broken
1 parent 0d6b31c commit 567e6e1

File tree

15 files changed

+593
-543
lines changed

15 files changed

+593
-543
lines changed

orchestrator/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from orchestrator.api.api_v1.api import api_router
4444
from orchestrator.api.error_handling import ProblemDetailException
4545
from orchestrator.cli.main import app as cli_app
46-
from orchestrator.db import db, init_database
46+
from orchestrator.db import db, init_async_database, init_database
4747
from orchestrator.db.database import DBSessionMiddleware
4848
from orchestrator.db.listeners import monitor_sqlalchemy_queries
4949
from orchestrator.db.loaders import init_model_loaders
@@ -128,6 +128,7 @@ def __init__(
128128
self.include_router(api_router, prefix="/api")
129129

130130
init_database(base_settings)
131+
init_async_database(base_settings)
131132

132133
self.add_middleware(ClearStructlogContextASGIMiddleware)
133134
self.add_middleware(SessionMiddleware, secret_key=base_settings.SESSION_SECRET)

orchestrator/db/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from structlog import get_logger
1616

17+
from orchestrator.db.database import AsyncDatabase, Database, transactional
1718
from orchestrator.db.database import BaseModel as DbBaseModel
18-
from orchestrator.db.database import Database, transactional
1919
from orchestrator.db.models import ( # noqa: F401
2020
EngineSettingsTable,
2121
FixedInputTable,
@@ -42,15 +42,15 @@
4242

4343

4444
class WrappedDatabase:
45-
def __init__(self, wrappee: Database | None = None) -> None:
45+
def __init__(self, wrappee: Database | AsyncDatabase | None = None) -> None:
4646
self.wrapped_database = wrappee
4747

48-
def update(self, wrappee: Database) -> None:
48+
def update(self, wrappee: Database | AsyncDatabase) -> None:
4949
self.wrapped_database = wrappee
5050
logger.info("Database object configured, all methods referencing `db` should work.")
5151

5252
def __getattr__(self, attr: str) -> Any:
53-
if not isinstance(self.wrapped_database, Database):
53+
if not isinstance(self.wrapped_database, Database | AsyncDatabase):
5454
if "_" in attr:
5555
logger.warning("No database configured, but attempting to access class methods")
5656
return None
@@ -72,6 +72,15 @@ def init_database(settings: AppSettings) -> Database:
7272
return db
7373

7474

75+
async_wrapped_db = WrappedDatabase()
76+
async_db = cast(AsyncDatabase, async_wrapped_db)
77+
78+
79+
def init_async_database(settings: AppSettings) -> AsyncDatabase:
80+
async_wrapped_db.update(AsyncDatabase(str(settings.DATABASE_URI)))
81+
return async_db
82+
83+
7584
__all__ = [
7685
"transactional",
7786
"SubscriptionTable",
@@ -94,6 +103,8 @@ def init_database(settings: AppSettings) -> Database:
94103
"UtcTimestampError",
95104
"db",
96105
"init_database",
106+
"async_db",
107+
"init_async_database",
97108
]
98109

99110
ALL_DB_MODELS: list[type[DbBaseModel]] = [

orchestrator/db/database.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# limitations under the License.
1313

1414
from collections.abc import Callable, Generator, Iterator
15-
from contextlib import contextmanager
15+
from contextlib import asynccontextmanager, contextmanager
1616
from contextvars import ContextVar
17-
from typing import Any, ClassVar, cast
17+
from typing import Any, AsyncIterator, ClassVar, Mapping, cast
1818
from uuid import uuid4
1919

2020
import structlog
21+
from pydantic import PostgresDsn
2122
from sqlalchemy import create_engine
2223
from sqlalchemy import inspect as sa_inspect
24+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
2325
from sqlalchemy.ext.declarative import DeclarativeMeta
2426
from sqlalchemy.orm import Query, Session, as_declarative, scoped_session, sessionmaker
2527
from sqlalchemy.sql.schema import MetaData
@@ -171,6 +173,7 @@ def commit(self) -> None:
171173
"json_deserializer": json_loads,
172174
}
173175
SESSION_ARGUMENTS = {"class_": WrappedSession, "autocommit": False, "autoflush": True, "query_cls": SearchQuery}
176+
ASYNC_SESSION_ARGUMENTS = {"class_": AsyncSession, "autocommit": False, "autoflush": True, "query_cls": SearchQuery}
174177

175178

176179
class Database:
@@ -220,10 +223,23 @@ def database_scope(self, **kwargs: Any) -> Generator["Database", None, None]:
220223
self.request_context.reset(token)
221224

222225

226+
class AsyncDatabase:
227+
def __init__(self, db_url: PostgresDsn | str):
228+
self.engine = create_async_engine(str(db_url), **ENGINE_ARGUMENTS)
229+
self.session_factory = async_sessionmaker(bind=self.engine, expire_on_commit=False, **ASYNC_SESSION_ARGUMENTS) # type: ignore[call-overload]
230+
231+
@asynccontextmanager
232+
async def session(self, **kwargs: Mapping) -> AsyncIterator[AsyncSession]:
233+
async with self.session_factory() as session:
234+
yield session
235+
236+
async def dispose(self) -> None:
237+
await self.engine.dispose()
238+
239+
223240
class DBSessionMiddleware(BaseHTTPMiddleware):
224241
def __init__(self, app: ASGIApp, database: Database, commit_on_exit: bool = False):
225242
super().__init__(app)
226-
self.commit_on_exit = commit_on_exit
227243
self.database = database
228244

229245
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:

orchestrator/graphql/resolvers/process.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlalchemy import func, select
1818
from sqlalchemy.orm import selectinload
1919

20-
from orchestrator.db import ProcessTable, db
20+
from orchestrator.db import ProcessTable
2121
from orchestrator.db.filters import Filter
2222
from orchestrator.db.filters.process import PROCESS_TABLE_COLUMN_CLAUSES, filter_processes, process_filter_fields
2323
from orchestrator.db.models import ProcessSubscriptionTable, SubscriptionTable
@@ -42,7 +42,6 @@
4242

4343
logger = structlog.get_logger(__name__)
4444

45-
4645
detailed_props = ("steps", "form", "current_state")
4746
simple_props = tuple([to_lower_camel(key) for key in ProcessType.__annotations__ if key not in detailed_props])
4847

@@ -56,9 +55,10 @@ def _enrich_process(process: ProcessTable, with_details: bool = False) -> Proces
5655

5756

5857
async def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessType | None:
58+
session = info.context.db_session
5959
query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info)
6060
stmt = select(ProcessTable).options(*query_loaders).where(ProcessTable.process_id == process_id)
61-
if process := db.session.scalar(stmt):
61+
if process := await session.scalar(stmt):
6262
is_detailed = _is_process_detailed(info)
6363
return ProcessType.from_pydantic(_enrich_process(process, is_detailed))
6464
return None
@@ -72,6 +72,7 @@ async def resolve_processes(
7272
after: int = 0,
7373
query: str | None = None,
7474
) -> Connection[ProcessType]:
75+
session = info.context.db_session
7576
_error_handler = create_resolver_error_handler(info)
7677
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
7778
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -98,7 +99,7 @@ async def resolve_processes(
9899
stmt = select_stmt
99100

100101
stmt = sort_processes(stmt, pydantic_sort_by, _error_handler)
101-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
102+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
102103
stmt = apply_range_to_statement(stmt, after, after + first + 1)
103104

104105
graphql_processes = []

orchestrator/graphql/resolvers/product.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import structlog
22
from sqlalchemy import func, select
33

4-
from orchestrator.db import db
54
from orchestrator.db.filters import Filter
65
from orchestrator.db.filters.product import PRODUCT_TABLE_COLUMN_CLAUSES, filter_products, product_filter_fields
76
from orchestrator.db.models import ProductTable
@@ -28,6 +27,7 @@ async def resolve_products(
2827
query: str | None = None,
2928
) -> Connection[ProductType]:
3029
_error_handler = create_resolver_error_handler(info)
30+
session = info.context.db_session
3131

3232
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
3333
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -49,7 +49,7 @@ async def resolve_products(
4949
stmt = select_stmt
5050

5151
stmt = sort_products(stmt, pydantic_sort_by, _error_handler)
52-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
52+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
5353
stmt = apply_range_to_statement(stmt, after, after + first + 1)
5454

5555
graphql_products = []

orchestrator/graphql/resolvers/product_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import structlog
22
from sqlalchemy import func, select
33

4-
from orchestrator.db import db
54
from orchestrator.db.filters import Filter
65
from orchestrator.db.filters.product_block import (
76
PRODUCT_BLOCK_TABLE_COLUMN_CLAUSES,
@@ -32,6 +31,7 @@ async def resolve_product_blocks(
3231
query: str | None = None,
3332
) -> Connection[ProductBlock]:
3433
_error_handler = create_resolver_error_handler(info)
34+
session = info.context.db_session
3535

3636
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
3737
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -55,7 +55,7 @@ async def resolve_product_blocks(
5555
stmt = select_stmt
5656

5757
stmt = sort_product_blocks(stmt, pydantic_sort_by, _error_handler)
58-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
58+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
5959
stmt = apply_range_to_statement(stmt, after, after + first + 1)
6060

6161
graphql_product_blocks = []

orchestrator/graphql/resolvers/resource_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import structlog
22
from sqlalchemy import func, select
33

4-
from orchestrator.db import db
54
from orchestrator.db.filters import Filter
65
from orchestrator.db.filters.resource_type import (
76
RESOURCE_TYPE_TABLE_COLUMN_CLAUSES,
@@ -32,6 +31,7 @@ async def resolve_resource_types(
3231
query: str | None = None,
3332
) -> Connection[ResourceType]:
3433
_error_handler = create_resolver_error_handler(info)
34+
session = info.context.db_session
3535

3636
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
3737
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -54,7 +54,7 @@ async def resolve_resource_types(
5454
stmt = select_stmt
5555

5656
stmt = sort_resource_types(stmt, pydantic_sort_by, _error_handler)
57-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
57+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
5858
stmt = apply_range_to_statement(stmt, after, after + first + 1)
5959

6060
graphql_resource_types = []

orchestrator/graphql/resolvers/subscription.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic
2222

2323
from nwastdlib.asyncio import gather_nice
24-
from orchestrator.db import ProductTable, SubscriptionTable, db
24+
from orchestrator.db import ProductTable, SubscriptionTable
2525
from orchestrator.db.filters import Filter
2626
from orchestrator.db.filters.subscription import (
2727
filter_by_query_string,
@@ -99,9 +99,10 @@ async def format_subscription(info: OrchestratorInfo, subscription: Subscription
9999

100100

101101
async def resolve_subscription(info: OrchestratorInfo, id: UUID) -> SubscriptionInterface | None:
102+
session = info.context.db_session
102103
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == id)
103104

104-
if subscription := db.session.scalar(stmt):
105+
if subscription := await session.scalar(stmt):
105106
return await format_subscription(info, subscription)
106107
return None
107108

@@ -115,6 +116,7 @@ async def resolve_subscriptions(
115116
query: str | None = None,
116117
) -> Connection[SubscriptionInterface]:
117118
_error_handler = create_resolver_error_handler(info)
119+
session = info.context.db_session
118120

119121
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
120122
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -141,12 +143,13 @@ async def resolve_subscriptions(
141143
stmt = filter_by_query_string(stmt, query)
142144

143145
stmt = cast(Select, sort_subscriptions(stmt, pydantic_sort_by, _error_handler))
144-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
146+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
145147
stmt = apply_range_to_statement(stmt, after, after + first + 1)
146148

147149
graphql_subscriptions: list[SubscriptionInterface] = []
148150
if is_querying_page_data(info):
149-
subscriptions = db.session.scalars(stmt).all()
151+
query_result = await session.scalars(stmt)
152+
subscriptions = query_result.all()
150153
graphql_subscriptions = list(await gather_nice((format_subscription(info, p) for p in subscriptions))) # type: ignore
151154
logger.info("Resolve subscriptions", filter_by=filter_by, total=total)
152155

orchestrator/graphql/resolvers/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import structlog
22
from sqlalchemy import func, select
33

4-
from orchestrator.db import db
54
from orchestrator.db.filters import Filter
65
from orchestrator.db.filters.workflow import WORKFLOW_TABLE_COLUMN_CLAUSES, filter_workflows, workflow_filter_fields
76
from orchestrator.db.models import WorkflowTable
@@ -28,6 +27,7 @@ async def resolve_workflows(
2827
query: str | None = None,
2928
) -> Connection[Workflow]:
3029
_error_handler = create_resolver_error_handler(info)
30+
session = info.context.db_session
3131

3232
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
3333
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
@@ -49,7 +49,7 @@ async def resolve_workflows(
4949
stmt = select_stmt
5050

5151
stmt = sort_workflows(stmt, pydantic_sort_by, _error_handler)
52-
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
52+
total = await session.scalar(select(func.count()).select_from(stmt.subquery()))
5353

5454
stmt = apply_range_to_statement(stmt, after, after + first + 1)
5555

orchestrator/graphql/schema.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Callable, Iterable
1414
from http import HTTPStatus
1515
from pathlib import Path
16-
from typing import Any, Coroutine, Protocol
16+
from typing import Any, AsyncIterator, Coroutine, Protocol
1717

1818
import strawberry
1919
import structlog
@@ -30,6 +30,7 @@
3030
from nwastdlib.graphql.extensions.error_handler_extension import ErrorHandlerExtension, ErrorType
3131
from oauth2_lib.fastapi import AuthManager
3232
from oauth2_lib.strawberry import authenticated_field
33+
from orchestrator.db import async_db
3334
from orchestrator.domain.base import SubscriptionModel
3435
from orchestrator.graphql.autoregistration import create_subscription_strawberry_type, register_domain_models
3536
from orchestrator.graphql.extensions.model_cache import ModelCacheExtension
@@ -156,11 +157,15 @@ def default_context_getter(
156157
auth_manager: AuthManager,
157158
graphql_models: StrawberryModelType,
158159
broadcast_thread: ProcessDataBroadcastThread | None = None,
159-
) -> Callable[[], Coroutine[Any, Any, OrchestratorContext]]:
160-
async def context_getter() -> OrchestratorContext:
161-
return OrchestratorContext(
162-
auth_manager=auth_manager, graphql_models=graphql_models, broadcast_thread=broadcast_thread
163-
)
160+
) -> Callable[[], AsyncIterator[OrchestratorContext]]:
161+
async def context_getter() -> AsyncIterator[OrchestratorContext]:
162+
async with async_db.session() as session:
163+
yield OrchestratorContext(
164+
auth_manager=auth_manager,
165+
graphql_models=graphql_models,
166+
broadcast_thread=broadcast_thread,
167+
db_session=session,
168+
)
164169

165170
return context_getter
166171

0 commit comments

Comments
 (0)