From b197f6d01a4ec53e8c13d500c210e061e640dd72 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 10 Dec 2024 08:54:47 -0800 Subject: [PATCH 1/2] Add support for nested action scopes --- chia/_tests/util/test_action_scope.py | 22 ++++++++++++++---- chia/util/action_scope.py | 33 ++++++++++++++++++--------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/chia/_tests/util/test_action_scope.py b/chia/_tests/util/test_action_scope.py index 789c34f503d0..66f4deeed0fa 100644 --- a/chia/_tests/util/test_action_scope.py +++ b/chia/_tests/util/test_action_scope.py @@ -135,10 +135,22 @@ async def callback2(interface: StateInterface[TestSideEffects]) -> None: raise RuntimeError("This should prevent the callbacks from being called") -# TODO: add support, change this test to test it and add a test for nested transactionality @pytest.mark.anyio -async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None: - async with action_scope.use(): - with pytest.raises(RuntimeError, match="cannot currently support nested transactions"): +async def test_nested_use(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None: + async with action_scope.use() as interface: + with pytest.raises(RuntimeError, match="Must pass `current_interface` when doing nested transactions"): async with action_scope.use(): - pass + raise NotImplementedError("Should not get here") # pragma: no cover + + assert interface.side_effects.buf == b"" + async with action_scope.use(interface) as nested_interface: + nested_interface.side_effects.buf = b"foo" + + assert interface.side_effects.buf == b"foo" + + with pytest.raises(RuntimeError, match="deliberate raise"): + async with action_scope.use(interface) as nested_interface: + nested_interface.side_effects.buf = b"bar" + raise RuntimeError("deliberate raise") + + assert interface.side_effects.buf == b"foo" diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py index 5d48fb4ee964..612edfd486b0 100644 --- a/chia/util/action_scope.py +++ b/chia/util/action_scope.py @@ -3,7 +3,7 @@ import contextlib from collections.abc import AsyncIterator, Awaitable from dataclasses import dataclass, field -from typing import Callable, Generic, Optional, Protocol, TypeVar, final +from typing import Any, Callable, Generic, Optional, Protocol, TypeVar, final import aiosqlite @@ -18,7 +18,9 @@ async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceM yield # type: ignore[misc] @contextlib.asynccontextmanager - async def use(self) -> AsyncIterator[None]: # pragma: no cover + async def use( + self, current_interface: Optional[StateInterface[Any]] = None + ) -> AsyncIterator[None]: # pragma: no cover # yield included to make this a generator as expected by @contextlib.asynccontextmanager yield @@ -52,18 +54,25 @@ async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceM yield self @contextlib.asynccontextmanager - async def use(self) -> AsyncIterator[None]: - if self._active_writer is not None: - raise RuntimeError("SQLiteResourceManager cannot currently support nested transactions") + async def use(self, current_interface: Optional[StateInterface[Any]] = None) -> AsyncIterator[None]: async with self._db.writer() as conn: + if self._active_writer is not None and current_interface is None: + raise RuntimeError("Must pass `current_interface` when doing nested transactions") + previous_writer = self._active_writer self._active_writer = conn try: yield finally: - self._active_writer = None - - async def get_resource(self, resource_type: type[_T_SideEffects]) -> _T_SideEffects: - row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects") + self._active_writer = previous_writer + if previous_writer is not None and current_interface is not None: + current_interface.side_effects = await self.get_resource(type(current_interface.side_effects), conn) + + async def get_resource( + self, resource_type: type[_T_SideEffects], _active_writer: Optional[aiosqlite.Connection] = None + ) -> _T_SideEffects: + row = await execute_fetchone( + self.get_active_writer() if _active_writer is None else _active_writer, "SELECT total FROM side_effects" + ) assert row is not None side_effects = resource_type.from_bytes(row[0]) return side_effects @@ -140,8 +149,10 @@ async def new_scope( self._final_side_effects = interface.side_effects @contextlib.asynccontextmanager - async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]: - async with self._resource_manager.use(): + async def use( + self, current_interface: Optional[StateInterface[_T_SideEffects]] = None, _callbacks_allowed: bool = True + ) -> AsyncIterator[StateInterface[_T_SideEffects]]: + async with self._resource_manager.use(current_interface): side_effects = await self._resource_manager.get_resource(self._side_effects_format) interface = StateInterface(side_effects, _callbacks_allowed, self._callback) From 216fc44de29745f491385881b01993dce8ed494a Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 10 Dec 2024 10:06:06 -0800 Subject: [PATCH 2/2] Make sure parent edits are propogated to children --- chia/_tests/util/test_action_scope.py | 2 ++ chia/util/action_scope.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/chia/_tests/util/test_action_scope.py b/chia/_tests/util/test_action_scope.py index 66f4deeed0fa..99be79ca5c5d 100644 --- a/chia/_tests/util/test_action_scope.py +++ b/chia/_tests/util/test_action_scope.py @@ -143,7 +143,9 @@ async def test_nested_use(action_scope: ActionScope[TestSideEffects, TestConfig] raise NotImplementedError("Should not get here") # pragma: no cover assert interface.side_effects.buf == b"" + interface.side_effects.buf = b"qat" async with action_scope.use(interface) as nested_interface: + assert nested_interface.side_effects.buf == b"qat" nested_interface.side_effects.buf = b"foo" assert interface.side_effects.buf == b"foo" diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py index 612edfd486b0..398dcab24225 100644 --- a/chia/util/action_scope.py +++ b/chia/util/action_scope.py @@ -58,6 +58,8 @@ async def use(self, current_interface: Optional[StateInterface[Any]] = None) -> async with self._db.writer() as conn: if self._active_writer is not None and current_interface is None: raise RuntimeError("Must pass `current_interface` when doing nested transactions") + elif current_interface is not None: + await self.save_resource(current_interface.side_effects) previous_writer = self._active_writer self._active_writer = conn try: