diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index f193f435..c720dfc8 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,21 +1,20 @@ import asyncio +from libp2p.transport.stream_interface import IStreamReader, IStreamWriter + from .exceptions import RawConnError from .raw_connection_interface import IRawConnection class RawConnection(IRawConnection): - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + reader: IStreamReader + writer: IStreamWriter initiator: bool _drain_lock: asyncio.Lock def __init__( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - initiator: bool, + self, reader: IStreamReader, writer: IStreamWriter, initiator: bool ) -> None: self.reader = reader self.writer = writer @@ -27,16 +26,14 @@ async def write(self, data: bytes) -> None: """ Raise `RawConnError` if the underlying connection breaks """ - try: - self.writer.write(data) - except ConnectionResetError as error: - raise RawConnError(error) # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 # Use a lock to serialize drain() calls. Circumvents this bug: # https://bugs.python.org/issue29930 async with self._drain_lock: try: - await self.writer.drain() + await self.writer.write( + data + ) # We call it inside the drain lock, because write() calls drain except ConnectionResetError as error: raise RawConnError(error) @@ -53,5 +50,4 @@ async def read(self, n: int = -1) -> bytes: raise RawConnError(error) async def close(self) -> None: - self.writer.close() - await self.writer.wait_closed() + await self.writer.close() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index b32e46f9..11439b28 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -16,6 +16,7 @@ SecurityUpgradeFailure, ) from libp2p.transport.listener_interface import IListener +from libp2p.transport.stream_interface import IStreamReader, IStreamWriter from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import StreamHandlerFn @@ -174,7 +175,7 @@ async def listen(self, *multiaddrs: Multiaddr) -> bool: return True async def conn_handler( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: IStreamReader, writer: IStreamWriter ) -> None: connection_info = writer.get_extra_info("peername") # TODO make a proper multiaddr diff --git a/libp2p/transport/stream_interface.py b/libp2p/transport/stream_interface.py new file mode 100644 index 00000000..9b2363a0 --- /dev/null +++ b/libp2p/transport/stream_interface.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class IStream(ABC): + """ + This is the common interface for both IStreamReader and IStreamWriter. + In asyncio, just StreamWriter implements 'get_extra_info', however, + it is more intuitive to be able to close it from any of them. + + This interface is not intended to be implemented directly by any class. + """ + + @abstractmethod + def get_extra_info(self, field: str) -> Any: + pass + + @abstractmethod + async def close(self) -> None: + pass + + +class IStreamReader(IStream): + @abstractmethod + async def read(self, n: int = -1) -> bytes: + pass + + +class IStreamWriter(IStream): + @abstractmethod + async def write(self, data: bytes) -> None: + pass diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 5ee24283..73c9b65d 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,6 +1,6 @@ import asyncio from socket import socket -from typing import List +from typing import Awaitable, Callable, List from multiaddr import Multiaddr @@ -8,10 +8,24 @@ from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener +from libp2p.transport.tcp.tcp_stream import TCPStream from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler +# function needed for because asyncio.start_server accepts handlers that receive just TCP Stream, +# instead we use a generic inteface (IStreamReader, IStreamWriter) +def streams_handler_wrapper( + handler: THandler +) -> Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]: + async def wrapper( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + return await handler(*TCPStream.from_asyncio_streams(reader, writer)) + + return wrapper + + class TCPListener(IListener): multiaddrs: List[Multiaddr] server = None @@ -28,7 +42,7 @@ async def listen(self, maddr: Multiaddr) -> bool: :return: return True if successful """ self.server = await asyncio.start_server( - self.handler, + streams_handler_wrapper(self.handler), maddr.value_for_protocol("ip4"), maddr.value_for_protocol("tcp"), ) @@ -73,7 +87,9 @@ async def dial(self, maddr: Multiaddr) -> IRawConnection: except (ConnectionAbortedError, ConnectionRefusedError) as error: raise OpenConnectionError(error) - return RawConnection(reader, writer, True) + stream_reader, stream_writer = TCPStream.from_asyncio_streams(reader, writer) + + return RawConnection(stream_reader, stream_writer, True) def create_listener(self, handler_function: THandler) -> TCPListener: """ diff --git a/libp2p/transport/tcp/tcp_stream.py b/libp2p/transport/tcp/tcp_stream.py new file mode 100644 index 00000000..1a3e9e0f --- /dev/null +++ b/libp2p/transport/tcp/tcp_stream.py @@ -0,0 +1,42 @@ +from asyncio import StreamReader, StreamWriter +from typing import Any, Tuple + +from libp2p.transport.stream_interface import IStream, IStreamReader, IStreamWriter + + +class TCPStream(IStream): + _write_stream: StreamWriter + + def __init__(self, write_stream: StreamWriter): + self._write_stream = write_stream + + def get_extra_info(self, field: str) -> Any: + return self._write_stream.get_extra_info(field) + + async def close(self) -> None: + if not self._write_stream.is_closing(): + self._write_stream.close() + await self._write_stream.wait_closed() + + @classmethod + def from_asyncio_streams( + cls, read_stream: StreamReader, write_stream: StreamWriter + ) -> Tuple["TCPStreamReader", "TCPStreamWriter"]: + return TCPStreamReader(read_stream, write_stream), TCPStreamWriter(write_stream) + + +class TCPStreamReader(TCPStream, IStreamReader): + _read_stream: StreamReader + + def __init__(self, read_stream: StreamReader, write_stream: StreamWriter): + TCPStream.__init__(self, write_stream) + self._read_stream = read_stream + + async def read(self, n: int = -1) -> bytes: + return await self._read_stream.read(n) + + +class TCPStreamWriter(TCPStream, IStreamWriter): + async def write(self, data: bytes) -> None: + self._write_stream.write(data) + await self._write_stream.drain() diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index f9b31dcb..a6d159f7 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,11 +1,11 @@ -from asyncio import StreamReader, StreamWriter from typing import Awaitable, Callable, Mapping, Type from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.abc import IMuxedConn +from libp2p.transport.stream_interface import IStreamReader, IStreamWriter from libp2p.typing import TProtocol -THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +THandler = Callable[[IStreamReader, IStreamWriter], Awaitable[None]] TSecurityOptions = Mapping[TProtocol, ISecureTransport] TMuxerClass = Type[IMuxedConn] TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/libp2p/transport/udp/__init__.py b/libp2p/transport/udp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/transport/udp/udp.py b/libp2p/transport/udp/udp.py new file mode 100644 index 00000000..f01b2d5b --- /dev/null +++ b/libp2p/transport/udp/udp.py @@ -0,0 +1,173 @@ +import asyncio +from typing import Any, Dict, List, Tuple + +from multiaddr import Multiaddr + +import asyncio_dgram +from asyncio_dgram.aio import DatagramStream +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.transport.exceptions import OpenConnectionError +from libp2p.transport.listener_interface import IListener +from libp2p.transport.stream_interface import IStreamReader, IStreamWriter +from libp2p.transport.transport_interface import ITransport +from libp2p.transport.typing import THandler +from libp2p.transport.udp.udp_stream import ( + UDPServerStreamReader, + UDPServerStreamWriter, + UDPStreamReader, + UDPStreamWriter, +) + + +async def open_datagram_connection( + host: str, port: int +) -> Tuple[IStreamReader, IStreamWriter]: + stream = await asyncio_dgram.connect((host, port)) + reader, writer = UDPStreamReader(stream), UDPStreamWriter(stream) + return reader, writer + + +def start_udp_server(handler: THandler, host: str, port: int) -> "UDPServer": + server = UDPServer(handler, host, port) + asyncio.create_task(server.listen()) + return server + + +class UDPServer: + _host: str + _port: int + _stream: DatagramStream + _handler_queues: Dict[Any, asyncio.Queue[bytes]] + _running_handlers: List[asyncio.Task[Any]] + + def __init__(self, handler: THandler, host: str, port: int): + _handler: THandler # in class-level, MyPy raises error, because it treats like a method + self._handler = handler + self._host = host + self._port = port + self._handler_queues = {} + self._running_handlers = [] + + async def listen(self) -> None: + self._stream = await asyncio_dgram.bind((self._host, self._port)) + while True: + data, addr = await self._stream.recv() + await self._deliver_data(addr, data) + + async def close(self) -> None: + self._stream.close() + # TODO: Correct way to destroy all handlers? + for task in self._running_handlers: + task.cancel() + await asyncio.gather(task for task in self._running_handlers) + + def get_extra_info(self, field: str) -> Any: + return self._stream.__getattribute__( + field + ) # it is compatible with some TCP fields + + def close_handler_stream(self, addr: Any) -> None: + """ + Method called just by Streams used by handlers + :param addr: socket address is connected to + :return: + """ + if addr in self._running_handlers: + del self._running_handlers[addr] + + async def _deliver_data(self, addr: Any, data: bytes) -> None: + if addr in self._handler_queues: + await self._handler_queues[addr].put(data) + else: + queue = await self._init_and_get_new_queue(addr, data) + reader, writer = ( + UDPServerStreamReader(self._stream, self, addr, queue), + UDPServerStreamWriter(self._stream, self, addr), + ) + await self._handler(reader, writer) + + async def _init_and_get_new_queue( + self, addr: Any, data: bytes + ) -> asyncio.Queue[bytes]: + queue: asyncio.Queue[bytes] + queue = asyncio.Queue() + self._handler_queues[addr] = queue + await queue.put(data) + return queue + + +class UDPListener(IListener): + multiaddrs: List[Multiaddr] + server = None + + def __init__(self, handler_function: THandler) -> None: + self.multiaddrs = [] + self.server = None + self.handler = handler_function + + async def listen(self, maddr: Multiaddr) -> bool: + """ + put listener in listening mode and wait for incoming connections + :param maddr: maddr of peer + :return: return True if successful + """ + self.server = start_udp_server( + self.handler, + maddr.value_for_protocol("ip4"), + maddr.value_for_protocol("udp"), + ) + self.multiaddrs.append( + _multiaddr_from_socketname(self.server.get_extra_info("sockname")) + ) + return True + + def get_addrs(self) -> List[Multiaddr]: + """ + retrieve list of addresses the listener is listening on + :return: return list of addrs + """ + # TODO check if server is listening + return self.multiaddrs + + async def close(self) -> None: + """ + close the listener such that no more connections + can be open on this transport instance + """ + if self.server is None: + return + await self.server.close() + self.server = None + + +class UDP(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + """ + dial a transport to peer listening on multiaddr + :param maddr: multiaddr of peer + :return: `RawConnection` if successful + :raise OpenConnectionError: raised when failed to open connection + """ + self.host = maddr.value_for_protocol("ip4") + self.port = int(maddr.value_for_protocol("tcp")) + + try: + reader, writer = await open_datagram_connection(self.host, self.port) + except (ConnectionAbortedError, ConnectionRefusedError) as error: + raise OpenConnectionError(error) + + return RawConnection(reader, writer, True) + + def create_listener(self, handler_function: THandler) -> UDPListener: + """ + create listener on transport + :param handler_function: a function called when a new connection is received + that takes a connection as argument which implements interface-connection + :return: a listener object that implements listener_interface.py + """ + return UDPListener(handler_function) + + +def _multiaddr_from_socketname(socketname: Tuple[str, int]) -> Multiaddr: + return Multiaddr(f"/ip4/{socketname[0]}/udp/{socketname[1]}") diff --git a/libp2p/transport/udp/udp_stream.py b/libp2p/transport/udp/udp_stream.py new file mode 100644 index 00000000..863ec624 --- /dev/null +++ b/libp2p/transport/udp/udp_stream.py @@ -0,0 +1,95 @@ +import asyncio +from io import BytesIO +from typing import Any + +from asyncio_dgram.aio import DatagramStream +from libp2p.transport.stream_interface import IStream, IStreamReader, IStreamWriter +from libp2p.transport.udp.udp import UDPServer + + +class UDPStream(IStream): + _stream: DatagramStream + + def __init__(self, stream: DatagramStream): + self._stream = stream + + def get_extra_info(self, field: str) -> Any: + return self._stream.__getattribute__( + field + ) # it is compatible with some TCP fields + + async def close(self) -> None: + self._stream.close() # it is safe to call close even though it is closed + await asyncio.sleep(-1) + + +class UDPStreamReader(UDPStream, IStreamReader): + _read_stream: BytesIO + + def __init__(self, stream: DatagramStream): + UDPStream.__init__(self, stream) + self._read_stream = BytesIO() + + async def read(self, n: int = -1) -> bytes: + data = self._read_stream.read(n) + while len(data) < n: + await self._fill_read_stream() + data_left = n - len(data) + data += await self.read(data_left) + self._read_stream.seek(0) + return data + + async def _fill_read_stream(self) -> None: + data, addr = await self._stream.recv() + self._read_stream.write(data) + self._read_stream.seek(0) + + +class UDPStreamWriter(UDPStream, IStreamWriter): + _addr: Any + + def __init__(self, stream: DatagramStream, addr: Any = None): + UDPStream.__init__(self, stream) + self._addr = addr + + async def write(self, data: bytes) -> None: + await self._stream.send(data, addr=self._addr) + + +class UDPServerStream: + _server: UDPServer + _addr: Any + + def __init__(self, server: UDPServer, addr: Any): + self._server = server + self._addr = addr + + async def close(self) -> None: + self._server.close_handler_stream(self._addr) + await asyncio.sleep(-1) + + +class UDPServerStreamReader(UDPServerStream, UDPStreamReader): + _queue: asyncio.Queue[bytes] + + def __init__( + self, + stream: DatagramStream, + server: UDPServer, + addr: Any, + queue: asyncio.Queue[bytes], + ): + UDPStreamReader.__init__(self, stream) + UDPServerStream.__init__(self, server, addr) + self._queue = queue + + async def _fill_read_stream(self) -> None: + data = await self._queue.get() + self._read_stream.write(data) + self._read_stream.seek(0) + + +class UDPServerStreamWriter(UDPServerStream, UDPStreamWriter): + def __init__(self, stream: DatagramStream, server: UDPServer, addr: Any): + UDPStreamWriter.__init__(self, stream, addr) + UDPServerStream.__init__(self, server, addr)