Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add udp #327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions libp2p/network/connection/raw_connection.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import asyncio

from libp2p.transport.stream_interface import IStreamReader, IStreamWriter

from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection


class RawConnection(IRawConnection):
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
reader: IStreamReader
writer: IStreamWriter
initiator: bool

_drain_lock: asyncio.Lock

def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
initiator: bool,
self, reader: IStreamReader, writer: IStreamWriter, initiator: bool
) -> None:
self.reader = reader
self.writer = writer
Expand All @@ -27,16 +26,14 @@ async def write(self, data: bytes) -> None:
"""
Raise `RawConnError` if the underlying connection breaks
"""
try:
self.writer.write(data)
except ConnectionResetError as error:
raise RawConnError(error)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
try:
await self.writer.drain()
await self.writer.write(
data
) # We call it inside the drain lock, because write() calls drain
except ConnectionResetError as error:
raise RawConnError(error)

Expand All @@ -53,5 +50,4 @@ async def read(self, n: int = -1) -> bytes:
raise RawConnError(error)

async def close(self) -> None:
self.writer.close()
await self.writer.wait_closed()
await self.writer.close()
3 changes: 2 additions & 1 deletion libp2p/network/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SecurityUpgradeFailure,
)
from libp2p.transport.listener_interface import IListener
from libp2p.transport.stream_interface import IStreamReader, IStreamWriter
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import StreamHandlerFn
Expand Down Expand Up @@ -174,7 +175,7 @@ async def listen(self, *multiaddrs: Multiaddr) -> bool:
return True

async def conn_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
reader: IStreamReader, writer: IStreamWriter
) -> None:
connection_info = writer.get_extra_info("peername")
# TODO make a proper multiaddr
Expand Down
32 changes: 32 additions & 0 deletions libp2p/transport/stream_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import Any


class IStream(ABC):
"""
This is the common interface for both IStreamReader and IStreamWriter.
In asyncio, just StreamWriter implements 'get_extra_info', however,
it is more intuitive to be able to close it from any of them.

This interface is not intended to be implemented directly by any class.
"""

@abstractmethod
def get_extra_info(self, field: str) -> Any:
pass

@abstractmethod
async def close(self) -> None:
pass


class IStreamReader(IStream):
@abstractmethod
async def read(self, n: int = -1) -> bytes:
pass


class IStreamWriter(IStream):
@abstractmethod
async def write(self, data: bytes) -> None:
pass
22 changes: 19 additions & 3 deletions libp2p/transport/tcp/tcp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import asyncio
from socket import socket
from typing import List
from typing import Awaitable, Callable, List

from multiaddr import Multiaddr

from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.tcp.tcp_stream import TCPStream
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler


# function needed for because asyncio.start_server accepts handlers that receive just TCP Stream,
# instead we use a generic inteface (IStreamReader, IStreamWriter)
def streams_handler_wrapper(
handler: THandler
) -> Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]:
async def wrapper(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
return await handler(*TCPStream.from_asyncio_streams(reader, writer))

return wrapper


class TCPListener(IListener):
multiaddrs: List[Multiaddr]
server = None
Expand All @@ -28,7 +42,7 @@ async def listen(self, maddr: Multiaddr) -> bool:
:return: return True if successful
"""
self.server = await asyncio.start_server(
self.handler,
streams_handler_wrapper(self.handler),
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
)
Expand Down Expand Up @@ -73,7 +87,9 @@ async def dial(self, maddr: Multiaddr) -> IRawConnection:
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)

return RawConnection(reader, writer, True)
stream_reader, stream_writer = TCPStream.from_asyncio_streams(reader, writer)

return RawConnection(stream_reader, stream_writer, True)

def create_listener(self, handler_function: THandler) -> TCPListener:
"""
Expand Down
42 changes: 42 additions & 0 deletions libp2p/transport/tcp/tcp_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from asyncio import StreamReader, StreamWriter
from typing import Any, Tuple

from libp2p.transport.stream_interface import IStream, IStreamReader, IStreamWriter


class TCPStream(IStream):
_write_stream: StreamWriter

def __init__(self, write_stream: StreamWriter):
self._write_stream = write_stream

def get_extra_info(self, field: str) -> Any:
return self._write_stream.get_extra_info(field)

async def close(self) -> None:
if not self._write_stream.is_closing():
self._write_stream.close()
await self._write_stream.wait_closed()

@classmethod
def from_asyncio_streams(
cls, read_stream: StreamReader, write_stream: StreamWriter
) -> Tuple["TCPStreamReader", "TCPStreamWriter"]:
return TCPStreamReader(read_stream, write_stream), TCPStreamWriter(write_stream)


class TCPStreamReader(TCPStream, IStreamReader):
_read_stream: StreamReader

def __init__(self, read_stream: StreamReader, write_stream: StreamWriter):
TCPStream.__init__(self, write_stream)
self._read_stream = read_stream

async def read(self, n: int = -1) -> bytes:
return await self._read_stream.read(n)


class TCPStreamWriter(TCPStream, IStreamWriter):
async def write(self, data: bytes) -> None:
self._write_stream.write(data)
await self._write_stream.drain()
4 changes: 2 additions & 2 deletions libp2p/transport/typing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable, Mapping, Type

from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.stream_interface import IStreamReader, IStreamWriter
from libp2p.typing import TProtocol

THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
THandler = Callable[[IStreamReader, IStreamWriter], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
Empty file.
173 changes: 173 additions & 0 deletions libp2p/transport/udp/udp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import asyncio
from typing import Any, Dict, List, Tuple

from multiaddr import Multiaddr

import asyncio_dgram
from asyncio_dgram.aio import DatagramStream
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.stream_interface import IStreamReader, IStreamWriter
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
from libp2p.transport.udp.udp_stream import (
UDPServerStreamReader,
UDPServerStreamWriter,
UDPStreamReader,
UDPStreamWriter,
)


async def open_datagram_connection(
host: str, port: int
) -> Tuple[IStreamReader, IStreamWriter]:
stream = await asyncio_dgram.connect((host, port))
reader, writer = UDPStreamReader(stream), UDPStreamWriter(stream)
return reader, writer


def start_udp_server(handler: THandler, host: str, port: int) -> "UDPServer":
server = UDPServer(handler, host, port)
asyncio.create_task(server.listen())
return server


class UDPServer:
_host: str
_port: int
_stream: DatagramStream
_handler_queues: Dict[Any, asyncio.Queue[bytes]]
_running_handlers: List[asyncio.Task[Any]]

def __init__(self, handler: THandler, host: str, port: int):
_handler: THandler # in class-level, MyPy raises error, because it treats like a method
self._handler = handler
self._host = host
self._port = port
self._handler_queues = {}
self._running_handlers = []

async def listen(self) -> None:
self._stream = await asyncio_dgram.bind((self._host, self._port))
while True:
data, addr = await self._stream.recv()
await self._deliver_data(addr, data)

async def close(self) -> None:
self._stream.close()
# TODO: Correct way to destroy all handlers?
for task in self._running_handlers:
task.cancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper pattern here according to the asyncio docs is:

task.cancel()
try:
    await task
except asyncio.CancelledError:
    pass

This gives the loop the opportunity to inject the cancellation exception into the task so that it doesn't end up being reported as an unretrieved exception.

Copy link
Author

@aratz-lasa aratz-lasa Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I missed it. And probably there are more bugs, I wrote it without much revising.

await asyncio.gather(task for task in self._running_handlers)

def get_extra_info(self, field: str) -> Any:
return self._stream.__getattribute__(
field
) # it is compatible with some TCP fields

def close_handler_stream(self, addr: Any) -> None:
"""
Method called just by Streams used by handlers
:param addr: socket address is connected to
:return:
"""
if addr in self._running_handlers:
del self._running_handlers[addr]

async def _deliver_data(self, addr: Any, data: bytes) -> None:
if addr in self._handler_queues:
await self._handler_queues[addr].put(data)
else:
queue = await self._init_and_get_new_queue(addr, data)
reader, writer = (
UDPServerStreamReader(self._stream, self, addr, queue),
UDPServerStreamWriter(self._stream, self, addr),
)
await self._handler(reader, writer)

async def _init_and_get_new_queue(
self, addr: Any, data: bytes
) -> asyncio.Queue[bytes]:
queue: asyncio.Queue[bytes]
queue = asyncio.Queue()
self._handler_queues[addr] = queue
await queue.put(data)
return queue


class UDPListener(IListener):
multiaddrs: List[Multiaddr]
server = None

def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = []
self.server = None
self.handler = handler_function

async def listen(self, maddr: Multiaddr) -> bool:
"""
put listener in listening mode and wait for incoming connections
:param maddr: maddr of peer
:return: return True if successful
"""
self.server = start_udp_server(
self.handler,
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("udp"),
)
self.multiaddrs.append(
_multiaddr_from_socketname(self.server.get_extra_info("sockname"))
)
return True

def get_addrs(self) -> List[Multiaddr]:
"""
retrieve list of addresses the listener is listening on
:return: return list of addrs
"""
# TODO check if server is listening
return self.multiaddrs

async def close(self) -> None:
"""
close the listener such that no more connections
can be open on this transport instance
"""
if self.server is None:
return
await self.server.close()
self.server = None


class UDP(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
"""
dial a transport to peer listening on multiaddr
:param maddr: multiaddr of peer
:return: `RawConnection` if successful
:raise OpenConnectionError: raised when failed to open connection
"""
self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp"))

try:
reader, writer = await open_datagram_connection(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)

return RawConnection(reader, writer, True)

def create_listener(self, handler_function: THandler) -> UDPListener:
"""
create listener on transport
:param handler_function: a function called when a new connection is received
that takes a connection as argument which implements interface-connection
:return: a listener object that implements listener_interface.py
"""
return UDPListener(handler_function)


def _multiaddr_from_socketname(socketname: Tuple[str, int]) -> Multiaddr:
return Multiaddr(f"/ip4/{socketname[0]}/udp/{socketname[1]}")
Loading