diff --git a/nat-lab/tests/conftest.py b/nat-lab/tests/conftest.py index 6d7382669..d02b6eb6a 100644 --- a/nat-lab/tests/conftest.py +++ b/nat-lab/tests/conftest.py @@ -9,7 +9,9 @@ from helpers import SetupParameters from interderp_cli import InterDerpClient from itertools import combinations +from typing import Optional from utils.bindings import TelioAdapterType +from utils.connection import DockerConnection from utils.connection_util import ConnectionTag, LAN_ADDR_MAP, new_connection_raw from utils.process import ProcessExecError from utils.router import IPStack @@ -25,10 +27,9 @@ SETUP_CHECK_TIMEOUT_S = 30 SETUP_CHECK_RETRIES = 5 -# pylint: disable=unnecessary-dunder-call -TEST_SCOPE_ASYNC_EXIT_STACK = asyncio.run(AsyncExitStack().__aenter__()) -# pylint: disable=unnecessary-dunder-call -SESSION_ASYNC_EXIT_STACK = asyncio.run(AsyncExitStack().__aenter__()) +RUNNER = asyncio.runners.Runner() +TEST_SCOPE_EXIT_STACK: Optional[AsyncExitStack] = None +SESSION_SCOPE_EXIT_STACK: Optional[AsyncExitStack] = None def _cancel_all_tasks(loop: asyncio.AbstractEventLoop): @@ -208,6 +209,9 @@ async def setup_check_interderp(): ] ] + if not isinstance(connections[0], DockerConnection): + raise Exception("Not docker connection") + async with make_tcpdump(connections): for idx, (server1, server2) in enumerate( combinations( @@ -407,9 +411,13 @@ def pytest_runtest_call(item): if os.environ.get("NATLAB_SAVE_LOGS") is None: return - async def async_code(): + async def async_context(): + global TEST_SCOPE_EXIT_STACK + if not TEST_SCOPE_EXIT_STACK: + TEST_SCOPE_EXIT_STACK = AsyncExitStack() + connections = [ - await TEST_SCOPE_ASYNC_EXIT_STACK.enter_async_context( + await TEST_SCOPE_EXIT_STACK.enter_async_context( new_connection_raw(conn_tag) ) for conn_tag in [ @@ -417,9 +425,10 @@ async def async_code(): ConnectionTag.DOCKER_DNS_SERVER_2, ] ] - await TEST_SCOPE_ASYNC_EXIT_STACK.enter_async_context(make_tcpdump(connections)) - asyncio.run(async_code()) + await TEST_SCOPE_EXIT_STACK.enter_async_context(make_tcpdump(connections)) + + RUNNER.run(async_context()) # pylint: disable=unused-argument @@ -427,14 +436,14 @@ def pytest_runtest_makereport(item, call): if os.environ.get("NATLAB_SAVE_LOGS") is None: return - async def async_code(): - global TEST_SCOPE_ASYNC_EXIT_STACK - await TEST_SCOPE_ASYNC_EXIT_STACK.aclose() - # pylint: disable=unnecessary-dunder-call - TEST_SCOPE_ASYNC_EXIT_STACK = await AsyncExitStack().__aenter__() + async def async_context(): + global TEST_SCOPE_EXIT_STACK + if TEST_SCOPE_EXIT_STACK: + await TEST_SCOPE_EXIT_STACK.aclose() + TEST_SCOPE_EXIT_STACK = None if call.when == "call": - asyncio.run(async_code()) + RUNNER.run(async_context()) # pylint: disable=unused-argument @@ -442,18 +451,22 @@ def pytest_sessionstart(session): if os.environ.get("NATLAB_SAVE_LOGS") is None: return - async def async_code(): + async def async_context(): + global SESSION_SCOPE_EXIT_STACK + if not SESSION_SCOPE_EXIT_STACK: + SESSION_SCOPE_EXIT_STACK = AsyncExitStack() + connections = [ - await SESSION_ASYNC_EXIT_STACK.enter_async_context( + await SESSION_SCOPE_EXIT_STACK.enter_async_context( new_connection_raw(gw_tag) ) for gw_tag in ConnectionTag if "_GW" in gw_tag.name ] - await SESSION_ASYNC_EXIT_STACK.enter_async_context(make_tcpdump(connections)) - if not session.config.option.collectonly: - asyncio.run(async_code()) + await SESSION_SCOPE_EXIT_STACK.enter_async_context(make_tcpdump(connections)) + + RUNNER.run(async_context()) # pylint: disable=unused-argument @@ -461,8 +474,14 @@ def pytest_sessionfinish(session, exitstatus): if os.environ.get("NATLAB_SAVE_LOGS") is None: return + async def async_context(): + global SESSION_SCOPE_EXIT_STACK + if SESSION_SCOPE_EXIT_STACK: + await SESSION_SCOPE_EXIT_STACK.aclose() + SESSION_SCOPE_EXIT_STACK = None + if not session.config.option.collectonly: - asyncio.run(SESSION_ASYNC_EXIT_STACK.aclose()) + RUNNER.run(async_context()) collect_nordderper_logs() collect_dns_server_logs() asyncio.run(collect_kernel_logs(session.items, "after_tests"))