From c0cf801b7ff40ab004c37fefe04d6b8362d7193e Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 10:57:30 +0900 Subject: [PATCH 1/7] Add `in_trio_run` and `in_trio_task` --- docs/source/reference-lowlevel.rst | 27 ++++++++++ src/trio/_core/__init__.py | 2 + src/trio/_core/_run.py | 10 +++- src/trio/_core/_tests/test_guest_mode.py | 20 ++++++++ src/trio/_core/_tests/test_instrumentation.py | 49 +++++++++++++++++++ src/trio/_core/_tests/test_run.py | 31 ++++++++++++ src/trio/lowlevel.py | 2 + 7 files changed, 140 insertions(+), 1 deletion(-) diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 46c8b4d48..c72b60606 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -56,6 +56,33 @@ Global statistics .. autoclass:: RunStatistics() +The current Trio context +------------------------ + +There are two different types of contexts in :mod:`trio`. Here are the +semantics presented as a handy table. Choose the right function for +your needs. + ++---------------------------------+-----------------------------------+------------------------------------+ +| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` | ++=================================+===================================+====================================+ +| inside a running async function | `True` | `True` | ++---------------------------------+-----------------------------------+------------------------------------+ +| without a running Trio loop | `False` | `False` | ++---------------------------------+-----------------------------------+------------------------------------+ +| in a guest run's host loop | `True` | `False` | ++---------------------------------+-----------------------------------+------------------------------------+ +| inside an instrument call | depends | depends | ++---------------------------------+-----------------------------------+------------------------------------+ +| :func:`trio.to_thread.run_sync` | `False` | `False` | ++---------------------------------+-----------------------------------+------------------------------------+ +| inside an abort function | `True` | `True` | ++---------------------------------+-----------------------------------+------------------------------------+ + +.. function:: in_trio_run + +.. function:: in_trio_task + The current clock ----------------- diff --git a/src/trio/_core/__init__.py b/src/trio/_core/__init__.py index fdef90292..d21aefb3e 100644 --- a/src/trio/_core/__init__.py +++ b/src/trio/_core/__init__.py @@ -45,6 +45,8 @@ current_task, current_time, current_trio_token, + in_trio_run, + in_trio_task, notify_closing, open_nursery, remove_instrument, diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 0dc3ced5d..8cd646ef0 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2283,7 +2283,7 @@ def setup_runner( # It wouldn't be *hard* to support nested calls to run(), but I can't # think of a single good reason for it, so let's be conservative for # now: - if hasattr(GLOBAL_RUN_CONTEXT, "runner"): + if in_trio_run(): raise RuntimeError("Attempted to call run() from inside a run()") if clock is None: @@ -2952,6 +2952,14 @@ async def checkpoint_if_cancelled() -> None: task._cancel_points += 1 +def in_trio_run() -> bool: + return hasattr(GLOBAL_RUN_CONTEXT, "runner") + + +def in_trio_task() -> bool: + return hasattr(GLOBAL_RUN_CONTEXT, "task") + + if sys.platform == "win32": from ._generated_io_windows import * from ._io_windows import ( diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index b455175f4..81b7a07d8 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -264,6 +264,26 @@ async def synchronize() -> None: sniffio_library.name = None +def test_guest_mode_trio_context_detection() -> None: + def check(thing: bool) -> None: + assert thing + + assert not trio.lowlevel.in_trio_run() + assert not trio.lowlevel.in_trio_task() + + async def trio_main(in_host: InHost) -> None: + for _ in range(2): + assert trio.lowlevel.in_trio_run() + assert trio.lowlevel.in_trio_task() + + in_host(lambda: check(trio.lowlevel.in_trio_run())) + in_host(lambda: check(not trio.lowlevel.in_trio_task())) + + trivial_guest_run(trio_main) + assert not trio.lowlevel.in_trio_run() + assert not trio.lowlevel.in_trio_task() + + def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 diff --git a/src/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py index 220ac9314..7918685bf 100644 --- a/src/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -266,3 +266,52 @@ async def main() -> None: assert "task_exited" not in runner.instruments _core.run(main) + + +def test_instrument_call_trio_context() -> None: + called = set() + + class Instrument(_abc.Instrument): + pass + + hooks = { + # category 1 + "after_io_wait": (True, False), + "before_io_wait": (True, False), + "before_run": (True, False), + # category 2 + "after_run": (False, False), + # category 3 + "before_task_step": (True, True), + "after_task_step": (True, True), + "task_exited": (True, True), + # category 4 + "task_scheduled": (True, None), + "task_spawned": (True, None), + } + for hook, val in hooks.items(): + + def h( + self: Instrument, + *args: object, + hook: str = hook, + val: tuple[bool | None, bool | None] = val, + ) -> None: + fail_str = f"failed in {hook}" + + if val[0] is not None: + assert _core.in_trio_run() == val[0], fail_str + if val[1] is not None: + assert _core.in_trio_task() == val[1], fail_str + called.add(hook) + + setattr(Instrument, hook, h) + + async def main() -> None: + await _core.checkpoint() + + async with _core.open_nursery() as nursery: + nursery.start_soon(_core.checkpoint) + + _core.run(main, instruments=[Instrument()]) + assert called == set(hooks) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 75e5457d7..576b807f9 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -2855,3 +2855,34 @@ def run(self, fn: Callable[[], object]) -> object: with mock.patch("trio._core._run.copy_context", return_value=Context()): assert _count_context_run_tb_frames() == 1 + + +@restore_unraisablehook() +def test_trio_context_detection() -> None: + assert not _core.in_trio_run() + assert not _core.in_trio_task() + + def inner() -> None: + assert _core.in_trio_run() + assert _core.in_trio_task() + + def sync_inner() -> None: + assert not _core.in_trio_run() + assert not _core.in_trio_task() + + def inner_abort(_: object) -> _core.Abort: + assert _core.in_trio_run() + assert _core.in_trio_task() + return _core.Abort.SUCCEEDED + + async def main() -> None: + assert _core.in_trio_run() + assert _core.in_trio_task() + + inner() + + await to_thread_run_sync(sync_inner) + with _core.CancelScope(deadline=_core.current_time() - 1): + await _core.wait_task_rescheduled(inner_abort) + + _core.run(main) diff --git a/src/trio/lowlevel.py b/src/trio/lowlevel.py index 9e385a004..bbeab6af1 100644 --- a/src/trio/lowlevel.py +++ b/src/trio/lowlevel.py @@ -37,6 +37,8 @@ currently_ki_protected as currently_ki_protected, disable_ki_protection as disable_ki_protection, enable_ki_protection as enable_ki_protection, + in_trio_run as in_trio_run, + in_trio_task as in_trio_task, notify_closing as notify_closing, permanently_detach_coroutine_object as permanently_detach_coroutine_object, reattach_detached_coroutine_object as reattach_detached_coroutine_object, From 4463395ac530438b98d2543f1a3d55995a0ce0dc Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 11:09:05 +0900 Subject: [PATCH 2/7] Fixes for CI --- docs/source/reference-lowlevel.rst | 2 ++ newsfragments/2757.feature.rst | 1 + src/trio/_core/_run.py | 8 ++++++++ 3 files changed, 11 insertions(+) create mode 100644 newsfragments/2757.feature.rst diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index c72b60606..4e2b03b2f 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -56,6 +56,8 @@ Global statistics .. autoclass:: RunStatistics() +.. _trio_contexts: + The current Trio context ------------------------ diff --git a/newsfragments/2757.feature.rst b/newsfragments/2757.feature.rst new file mode 100644 index 000000000..317299561 --- /dev/null +++ b/newsfragments/2757.feature.rst @@ -0,0 +1 @@ +Add :func:`trio.lowlevel.in_trio_run` and :func:`trio.lowlevel.in_trio_task` and document the semantics (and differences) thereof. See :ref:`the documentation `. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 8cd646ef0..143ea3f0a 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2953,10 +2953,18 @@ async def checkpoint_if_cancelled() -> None: def in_trio_run() -> bool: + """Check whether we are in a Trio run. + + See also :ref:`the different types of contexts `. + """ return hasattr(GLOBAL_RUN_CONTEXT, "runner") def in_trio_task() -> bool: + """Check whether we are in a Trio task. + + See also :ref:`the different types of contexts `. + """ return hasattr(GLOBAL_RUN_CONTEXT, "task") From 5e372a63dffe1f4f18e8a46cbb783cc863a952c3 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 11:11:13 +0900 Subject: [PATCH 3/7] Use `autofunction` for Sphinx to pull in the docstring --- docs/source/reference-lowlevel.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 4e2b03b2f..fcf04b80d 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -81,9 +81,9 @@ your needs. | inside an abort function | `True` | `True` | +---------------------------------+-----------------------------------+------------------------------------+ -.. function:: in_trio_run +.. autofunction:: in_trio_run -.. function:: in_trio_task +.. autofunction:: in_trio_task The current clock ----------------- From 5a40ef6bd5cc131269e6af18d4002f6b982e8ace Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 11:24:39 +0900 Subject: [PATCH 4/7] Don't have unnecessary branches in tests --- src/trio/_core/_tests/test_instrumentation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py index 7918685bf..af4a625e4 100644 --- a/src/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -295,12 +295,11 @@ def h( self: Instrument, *args: object, hook: str = hook, - val: tuple[bool | None, bool | None] = val, + val: tuple[bool, bool | None] = val, ) -> None: fail_str = f"failed in {hook}" - if val[0] is not None: - assert _core.in_trio_run() == val[0], fail_str + assert _core.in_trio_run() == val[0], fail_str if val[1] is not None: assert _core.in_trio_task() == val[1], fail_str called.add(hook) From ac9e6b4ee2774a7a53095122dcbf3707ca594faa Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 16:23:24 +0900 Subject: [PATCH 5/7] Address PR review Co-authored-by: Joshua Oreman <4316136+oremanj@users.noreply.github.com> --- docs/source/reference-lowlevel.rst | 64 ++++++++++++------- src/trio/_core/_run.py | 9 ++- src/trio/_core/_tests/test_instrumentation.py | 9 ++- 3 files changed, 52 insertions(+), 30 deletions(-) diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index fcf04b80d..8afccb5c1 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -58,28 +58,48 @@ Global statistics .. _trio_contexts: -The current Trio context ------------------------- - -There are two different types of contexts in :mod:`trio`. Here are the -semantics presented as a handy table. Choose the right function for -your needs. - -+---------------------------------+-----------------------------------+------------------------------------+ -| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` | -+=================================+===================================+====================================+ -| inside a running async function | `True` | `True` | -+---------------------------------+-----------------------------------+------------------------------------+ -| without a running Trio loop | `False` | `False` | -+---------------------------------+-----------------------------------+------------------------------------+ -| in a guest run's host loop | `True` | `False` | -+---------------------------------+-----------------------------------+------------------------------------+ -| inside an instrument call | depends | depends | -+---------------------------------+-----------------------------------+------------------------------------+ -| :func:`trio.to_thread.run_sync` | `False` | `False` | -+---------------------------------+-----------------------------------+------------------------------------+ -| inside an abort function | `True` | `True` | -+---------------------------------+-----------------------------------+------------------------------------+ +Checking for Trio +----------------- + +If you want to interact with an active Trio run -- perhaps you need to +know the :func:`~trio.current_time` or the +:func:`~trio.lowlevel.current_task` -- then Trio needs to have certain +state available to it or else you will get a +``RuntimeError("must be called from async context")``. +This requires that you either be: + +* indirectly inside (and on the same thread as) a call to + :func:`trio.run`, for run-level information such as the + :func:`~trio.current_time` or :func:`~trio.lowlevel.current_clock`. + +* indirectly inside a Trio task, for task-level information such as + the :func:`~trio.lowlevel.current_task` or + :func:`~trio.current_effective_deadline`. + +Internally, this state is provided by thread-local variables tracking +the current run and the current task. Sometimes, it's useful to know +in advance whether a call will fail or to have dynamic information for +safeguards against running something inside or outside Trio. To do so, +call :func:`trio.lowlevel.in_trio_run` or +:func:`trio.lowlevel.in_trio_task`, which will provide answers +according to the following table. + + ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` | ++========================================================+===================================+====================================+ +| inside a Trio-flavored async function | `True` | `True` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a thread without an active call to :func:`trio.run` | `False` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a guest run's host loop | `True` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| inside an instrument call | depends | depends | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a thread created by :func:`trio.to_thread.run_sync` | `False` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| inside an abort function | `True` | `True` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ .. autofunction:: in_trio_run diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 143ea3f0a..eedb99644 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2832,8 +2832,9 @@ def unrolled_run( except BaseException as exc: raise TrioInternalError("internal error in Trio - please file a bug!") from exc finally: - GLOBAL_RUN_CONTEXT.__dict__.clear() runner.close() + GLOBAL_RUN_CONTEXT.__dict__.clear() + # Have to do this after runner.close() has disabled KI protection, # because otherwise there's a race where ki_pending could get set # after we check it. @@ -2954,16 +2955,18 @@ async def checkpoint_if_cancelled() -> None: def in_trio_run() -> bool: """Check whether we are in a Trio run. + This returns `True` if and only if :func:`~trio.current_time` will succeed. - See also :ref:`the different types of contexts `. + See also the discussion of differing ways of :ref:`detecting Trio `. """ return hasattr(GLOBAL_RUN_CONTEXT, "runner") def in_trio_task() -> bool: """Check whether we are in a Trio task. + This returns `True` if and only if :func:`~trio.lowlevel.current_task` will succeed. - See also :ref:`the different types of contexts `. + See also the discussion of differing ways of :ref:`detecting Trio `. """ return hasattr(GLOBAL_RUN_CONTEXT, "task") diff --git a/src/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py index af4a625e4..60c54307e 100644 --- a/src/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -275,17 +275,16 @@ class Instrument(_abc.Instrument): pass hooks = { - # category 1 + # not run in task context "after_io_wait": (True, False), "before_io_wait": (True, False), "before_run": (True, False), - # category 2 - "after_run": (False, False), - # category 3 + "after_run": (True, False), + # run in task context "before_task_step": (True, True), "after_task_step": (True, True), "task_exited": (True, True), - # category 4 + # depends "task_scheduled": (True, None), "task_spawned": (True, None), } From 353a8c8f571c712e23b87989a11fa77963fc4e7a Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 16:25:12 +0900 Subject: [PATCH 6/7] Be more assertive about instrument calls --- docs/source/reference-lowlevel.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 8afccb5c1..0b115b415 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -94,7 +94,7 @@ according to the following table. +--------------------------------------------------------+-----------------------------------+------------------------------------+ | in a guest run's host loop | `True` | `False` | +--------------------------------------------------------+-----------------------------------+------------------------------------+ -| inside an instrument call | depends | depends | +| inside an instrument call | `True` | depends | +--------------------------------------------------------+-----------------------------------+------------------------------------+ | in a thread created by :func:`trio.to_thread.run_sync` | `False` | `False` | +--------------------------------------------------------+-----------------------------------+------------------------------------+ From ce97b8a76550f9885d7459ff45c6d556a7540093 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Thu, 30 Jan 2025 17:18:20 +0900 Subject: [PATCH 7/7] Address PR review Co-authored-by: Joshua Oreman <4316136+oremanj@users.noreply.github.com> --- docs/source/reference-lowlevel.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 0b115b415..82bd8537d 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -70,7 +70,8 @@ This requires that you either be: * indirectly inside (and on the same thread as) a call to :func:`trio.run`, for run-level information such as the - :func:`~trio.current_time` or :func:`~trio.lowlevel.current_clock`. + :func:`~trio.current_time` or :func:`~trio.lowlevel.current_clock`; + or * indirectly inside a Trio task, for task-level information such as the :func:`~trio.lowlevel.current_task` or