diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 2c870dd0..a1da39ae 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -40,7 +40,6 @@ async def stream_handler(stream: INetStream) -> None: class Swarm(Service, INetworkService): - self_id: ID peerstore: IPeerStore upgrader: TransportUpgrader @@ -276,7 +275,9 @@ async def conn_handler(read_write_closer: ReadWriteCloser) -> None: # I/O agnostic, we should change the API. if self.listener_nursery is None: raise SwarmException("swarm instance hasn't been run") - await listener.listen(maddr, self.listener_nursery) + await self.listener_nursery.start( + listener.listen, maddr + ) # type: ignore # Call notifiers since event occurred await self.notify_listen(maddr) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 4d25c254..b57501ec 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -391,7 +391,7 @@ async def heartbeat(self) -> None: await trio.sleep(self.heartbeat_interval) def mesh_heartbeat( - self + self, ) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]: peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list) peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 67e26519..0429c294 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -67,7 +67,7 @@ def security_transport_factory( @asynccontextmanager async def raw_conn_factory( - nursery: trio.Nursery + nursery: trio.Nursery, ) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]: conn_0 = None conn_1 = None @@ -81,7 +81,7 @@ async def tcp_stream_handler(stream: ReadWriteCloser) -> None: tcp_transport = TCP() listener = tcp_transport.create_listener(tcp_stream_handler) - await listener.listen(LISTEN_MADDR, nursery) + await nursery.start(listener.listen, LISTEN_MADDR) listening_maddr = listener.get_addrs()[0] conn_0 = await tcp_transport.dial(listening_maddr) await event.wait() @@ -351,7 +351,7 @@ async def swarm_pair_factory( @asynccontextmanager async def host_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: await connect(hosts[0], hosts[1]) @@ -370,7 +370,7 @@ async def swarm_conn_pair_factory( @asynccontextmanager async def mplex_conn_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[Mplex, Mplex]]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: @@ -382,7 +382,7 @@ async def mplex_conn_pair_factory( @asynccontextmanager async def mplex_stream_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info @@ -398,7 +398,7 @@ async def mplex_stream_pair_factory( @asynccontextmanager async def net_stream_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 5a262b3b..b6e94dee 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None: def create_echo_stream_handler( - ack_prefix: str + ack_prefix: str, ) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index d170d1de..2c6ebea4 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple +from typing import Any, Tuple from multiaddr import Multiaddr import trio @@ -7,7 +7,9 @@ class IListener(ABC): @abstractmethod - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + async def listen( + self, maddr: Multiaddr, task_status: Any = trio.TASK_STATUS_IGNORED + ) -> bool: """ put listener in listening mode and wait for incoming connections. diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 1004e288..40fe5bf9 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,5 +1,5 @@ import logging -from typing import Awaitable, Callable, List, Sequence, Tuple +from typing import Any, Awaitable, Callable, List, Sequence, Tuple from multiaddr import Multiaddr import trio @@ -23,8 +23,9 @@ def __init__(self, handler_function: THandler) -> None: self.listeners = [] self.handler = handler_function - # TODO: Get rid of `nursery`? - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: + async def listen( + self, maddr: Multiaddr, task_status: TaskStatus[Any] = trio.TASK_STATUS_IGNORED + ) -> None: """ put listener in listening mode and wait for incoming connections. @@ -46,13 +47,15 @@ async def handler(stream: trio.SocketStream) -> None: tcp_stream = TrioTCPStream(stream) await self.handler(tcp_stream) - listeners = await nursery.start( - serve_tcp, - handler, - int(maddr.value_for_protocol("tcp")), - maddr.value_for_protocol("ip4"), - ) - self.listeners.extend(listeners) + async with trio.open_nursery() as nursery: + listeners = await nursery.start( + serve_tcp, + handler, + int(maddr.value_for_protocol("tcp")), + maddr.value_for_protocol("ip4"), + ) + task_status.started() + self.listeners.extend(listeners) def get_addrs(self) -> Tuple[Multiaddr, ...]: """ diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index 130b3cc4..ae7699e2 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -17,9 +17,9 @@ async def handler(tcp_stream): listener = transport.create_listener(handler) assert len(listener.get_addrs()) == 0 - await listener.listen(LISTEN_MADDR, nursery) + await nursery.start(listener.listen, LISTEN_MADDR) assert len(listener.get_addrs()) == 1 - await listener.listen(LISTEN_MADDR, nursery) + await nursery.start(listener.listen, LISTEN_MADDR) assert len(listener.get_addrs()) == 2 @@ -41,7 +41,7 @@ async def handler(tcp_stream): await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1")) listener = transport.create_listener(handler) - await listener.listen(LISTEN_MADDR, nursery) + await nursery.start(listener.listen, LISTEN_MADDR) addrs = listener.get_addrs() assert len(addrs) == 1 listen_addr = addrs[0]