Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIA-2041] Add support for nested action scopes #19013

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions chia/_tests/util/test_action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,24 @@ async def callback2(interface: StateInterface[TestSideEffects]) -> None:
raise RuntimeError("This should prevent the callbacks from being called")


# TODO: add support, change this test to test it and add a test for nested transactionality
@pytest.mark.anyio
async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
async with action_scope.use():
with pytest.raises(RuntimeError, match="cannot currently support nested transactions"):
async def test_nested_use(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
async with action_scope.use() as interface:
with pytest.raises(RuntimeError, match="Must pass `current_interface` when doing nested transactions"):
async with action_scope.use():
pass
raise NotImplementedError("Should not get here") # pragma: no cover

assert interface.side_effects.buf == b""
interface.side_effects.buf = b"qat"
async with action_scope.use(interface) as nested_interface:
assert nested_interface.side_effects.buf == b"qat"
nested_interface.side_effects.buf = b"foo"

assert interface.side_effects.buf == b"foo"

with pytest.raises(RuntimeError, match="deliberate raise"):
async with action_scope.use(interface) as nested_interface:
nested_interface.side_effects.buf = b"bar"
raise RuntimeError("deliberate raise")

assert interface.side_effects.buf == b"foo"
35 changes: 24 additions & 11 deletions chia/util/action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
from collections.abc import AsyncIterator, Awaitable
from dataclasses import dataclass, field
from typing import Callable, Generic, Optional, Protocol, TypeVar, final
from typing import Any, Callable, Generic, Optional, Protocol, TypeVar, final

import aiosqlite

Expand All @@ -18,7 +18,9 @@ async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceM
yield # type: ignore[misc]

@contextlib.asynccontextmanager
async def use(self) -> AsyncIterator[None]: # pragma: no cover
async def use(
self, current_interface: Optional[StateInterface[Any]] = None
) -> AsyncIterator[None]: # pragma: no cover
# yield included to make this a generator as expected by @contextlib.asynccontextmanager
yield

Expand Down Expand Up @@ -52,18 +54,27 @@ async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceM
yield self

@contextlib.asynccontextmanager
async def use(self) -> AsyncIterator[None]:
if self._active_writer is not None:
raise RuntimeError("SQLiteResourceManager cannot currently support nested transactions")
async def use(self, current_interface: Optional[StateInterface[Any]] = None) -> AsyncIterator[None]:
async with self._db.writer() as conn:
if self._active_writer is not None and current_interface is None:
raise RuntimeError("Must pass `current_interface` when doing nested transactions")
elif current_interface is not None:
await self.save_resource(current_interface.side_effects)
previous_writer = self._active_writer
self._active_writer = conn
try:
yield
finally:
self._active_writer = None

async def get_resource(self, resource_type: type[_T_SideEffects]) -> _T_SideEffects:
row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects")
self._active_writer = previous_writer
if previous_writer is not None and current_interface is not None:
current_interface.side_effects = await self.get_resource(type(current_interface.side_effects), conn)

async def get_resource(
self, resource_type: type[_T_SideEffects], _active_writer: Optional[aiosqlite.Connection] = None
) -> _T_SideEffects:
row = await execute_fetchone(
self.get_active_writer() if _active_writer is None else _active_writer, "SELECT total FROM side_effects"
)
assert row is not None
side_effects = resource_type.from_bytes(row[0])
return side_effects
Expand Down Expand Up @@ -140,8 +151,10 @@ async def new_scope(
self._final_side_effects = interface.side_effects

@contextlib.asynccontextmanager
async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]:
async with self._resource_manager.use():
async def use(
self, current_interface: Optional[StateInterface[_T_SideEffects]] = None, _callbacks_allowed: bool = True
) -> AsyncIterator[StateInterface[_T_SideEffects]]:
async with self._resource_manager.use(current_interface):
side_effects = await self._resource_manager.get_resource(self._side_effects_format)
interface = StateInterface(side_effects, _callbacks_allowed, self._callback)

Expand Down
Loading