diff --git a/tests/acceptor_test.py b/tests/acceptor_test.py index c4907c7..056d976 100644 --- a/tests/acceptor_test.py +++ b/tests/acceptor_test.py @@ -28,7 +28,7 @@ def test_accept(): j = WebsocketsJoiner() client_base_session = j.join(f"ws://localhost:{port}/ws", "realm1") - client_base_session.ws.close() + client_base_session.close() thread.join() server_base_session = result[0] diff --git a/xconn/async_session.py b/xconn/async_session.py index 5119447..bb5d8c3 100644 --- a/xconn/async_session.py +++ b/xconn/async_session.py @@ -4,7 +4,6 @@ from asyncio import Future, get_event_loop from typing import Callable, Union, Awaitable, Any -from websockets.protocol import State from wampproto import messages, idgen, session from xconn import types, uris as xconn_uris, exception @@ -80,7 +79,7 @@ def __init__(self, base_session: types.IAsyncBaseSession): self.wait_task = loop.create_task(self.wait()) async def wait(self): - while self.base_session.ws.state == State.OPEN: + while await self.base_session.is_connected(): try: data = await self.base_session.receive() except Exception as e: @@ -268,7 +267,7 @@ async def leave(self) -> None: except asyncio.CancelledError: pass - if self.base_session.ws.state != State.CLOSED: + if self.base_session.is_connected(): await self.base_session.close() async def ping(self) -> None: diff --git a/xconn/joiner.py b/xconn/joiner.py index 0c90f65..7f176ba 100644 --- a/xconn/joiner.py +++ b/xconn/joiner.py @@ -1,9 +1,8 @@ -from websockets.sync.client import connect -from websockets.asyncio.client import connect as async_connect from wampproto import joiner, serializers, auth from wampproto.joiner import Joiner from xconn import types, helpers +from xconn.transports import WebSocketTransport, AsyncWebSocketTransport class WebsocketsJoiner: @@ -18,25 +17,22 @@ def __init__( self._ws_config = ws_config def join(self, uri: str, realm: str) -> types.BaseSession: - ws = connect( + transport = WebSocketTransport.connect( uri, subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)], - open_timeout=self._ws_config.open_timeout, - ping_interval=self._ws_config.ping_interval, - ping_timeout=self._ws_config.ping_timeout, - close_timeout=self._ws_config.close_timeout, + config=self._ws_config, ) j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator) - ws.send(j.send_hello()) + transport.write(j.send_hello()) while True: - data = ws.recv() + data = transport.read() to_send = j.receive(data) if to_send is None: - return types.BaseSession(ws, j.get_session_details(), self._serializer) + return types.BaseSession(transport, j.get_session_details(), self._serializer) - ws.send(to_send) + transport.write(to_send) class AsyncWebsocketsJoiner: @@ -51,22 +47,19 @@ def __init__( self._serializer = serializer async def join(self, uri: str, realm: str) -> types.AsyncBaseSession: - ws = await async_connect( + transport = await AsyncWebSocketTransport.connect( uri, subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)], - open_timeout=self._ws_config.open_timeout, - ping_interval=self._ws_config.ping_interval, - ping_timeout=self._ws_config.ping_timeout, - close_timeout=self._ws_config.close_timeout, + config=self._ws_config, ) j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator) - await ws.send(j.send_hello()) + await transport.write(j.send_hello()) while True: - data = await ws.recv() + data = await transport.read() to_send = j.receive(data) if to_send is None: - return types.AsyncBaseSession(ws, j.get_session_details(), self._serializer) + return types.AsyncBaseSession(transport, j.get_session_details(), self._serializer) - await ws.send(to_send) + await transport.write(to_send) diff --git a/xconn/session.py b/xconn/session.py index 657b736..d468d1f 100644 --- a/xconn/session.py +++ b/xconn/session.py @@ -3,7 +3,6 @@ from threading import Thread from typing import Callable, Any -from websockets.protocol import State from wampproto import messages, idgen, session, uris from xconn import types, exception, uris as xconn_uris @@ -76,7 +75,7 @@ def __init__(self, base_session: types.BaseSession): thread.start() def wait(self): - while self.base_session.ws.state == State.OPEN: + while self.base_session.is_connected(): try: data = self.base_session.receive() except Exception: diff --git a/xconn/transports.py b/xconn/transports.py index 2b750d0..2448ed3 100644 --- a/xconn/transports.py +++ b/xconn/transports.py @@ -1,6 +1,7 @@ import asyncio import socket from asyncio import StreamReader, StreamWriter +from typing import Sequence from wampproto.transports.rawsocket import ( Handshake, @@ -9,8 +10,13 @@ SERIALIZER_TYPE_CBOR, MSG_TYPE_WAMP, ) +from websockets import State, Subprotocol +from websockets.sync.client import connect +from websockets.sync.connection import Connection +from websockets.asyncio.client import connect as async_connect +from websockets.asyncio.client import ClientConnection -from xconn.types import IAsyncTransport, ITransport +from xconn.types import IAsyncTransport, ITransport, WebsocketConfig # Applies to handshake and message itself. RAW_SOCKET_HEADER_LENGTH = 4 @@ -114,3 +120,67 @@ async def is_connected(self) -> bool: return True except (BrokenPipeError, ConnectionResetError, OSError): return False + + +class WebSocketTransport(ITransport): + def __init__(self, websocket: Connection): + super().__init__() + self._websocket = websocket + + @staticmethod + def connect(uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig) -> "WebSocketTransport": + ws = connect( + uri, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) + + return WebSocketTransport(ws) + + def read(self) -> str | bytes: + return self._websocket.recv() + + def write(self, data: str | bytes): + self._websocket.send(data) + + def close(self): + self._websocket.close() + + def is_connected(self) -> bool: + return self._websocket.state == State.OPEN + + +class AsyncWebSocketTransport(IAsyncTransport): + def __init__(self, websocket: ClientConnection): + super().__init__() + self._websocket = websocket + + @staticmethod + async def connect( + uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig + ) -> "AsyncWebSocketTransport": + ws = await async_connect( + uri, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) + + return AsyncWebSocketTransport(ws) + + async def read(self) -> str | bytes: + return await self._websocket.recv() + + async def write(self, data: str | bytes): + await self._websocket.send(data) + + async def close(self): + await self._websocket.close() + + async def is_connected(self) -> bool: + return self._websocket.state == State.OPEN diff --git a/xconn/types.py b/xconn/types.py index a031305..906577e 100644 --- a/xconn/types.py +++ b/xconn/types.py @@ -10,9 +10,7 @@ from typing import Callable, Awaitable from aiohttp import web -from websockets.sync.connection import Connection from wampproto import messages, joiner, serializers -from websockets.asyncio.client import ClientConnection @dataclass @@ -193,11 +191,16 @@ def receive_message(self) -> messages.Message: def close(self): raise NotImplementedError() + def is_connected(self) -> bool: + raise NotImplementedError() + class BaseSession(IBaseSession): - def __init__(self, ws: Connection, session_details: joiner.SessionDetails, serializer: serializers.Serializer): + def __init__( + self, transport: ITransport, session_details: joiner.SessionDetails, serializer: serializers.Serializer + ): super().__init__() - self.ws = ws + self._transport = transport self.session_details = session_details self.serializer = serializer @@ -218,19 +221,22 @@ def authrole(self) -> str: return self.session_details.authrole def send(self, data: bytes): - self.ws.send(data) + self._transport.write(data) def receive(self) -> bytes: - return self.ws.recv() + return self._transport.read() def send_message(self, msg: messages.Message): - self.ws.send(self.serializer.serialize(msg)) + self.send(self.serializer.serialize(msg)) def receive_message(self) -> messages.Message: return self.serializer.deserialize(self.receive()) def close(self): - self.ws.close() + self._transport.close() + + def is_connected(self) -> bool: + return self._transport.is_connected() class IAsyncBaseSession: @@ -269,13 +275,16 @@ async def receive_message(self) -> messages.Message: async def close(self): raise NotImplementedError() + async def is_connected(self) -> bool: + raise NotImplementedError() + class AsyncBaseSession(IAsyncBaseSession): def __init__( - self, ws: ClientConnection, session_details: joiner.SessionDetails, serializer: serializers.Serializer + self, transport: IAsyncTransport, session_details: joiner.SessionDetails, serializer: serializers.Serializer ): super().__init__() - self.ws = ws + self._transport = transport self.session_details = session_details self._serializer = serializer @@ -300,19 +309,22 @@ def serializer(self) -> serializers.Serializer: return self._serializer async def send(self, data: bytes): - return await self.ws.send(data) + return await self._transport.write(data) async def receive(self) -> bytes: - return await self.ws.recv() + return await self._transport.read() async def send_message(self, msg: messages.Message): - await self.ws.send(self.serializer.serialize(msg)) + await self.send(self.serializer.serialize(msg)) async def receive_message(self) -> messages.Message: return self.serializer.deserialize(await self.receive()) async def close(self): - await self.ws.close() + await self._transport.close() + + async def is_connected(self) -> bool: + return await self._transport.is_connected() class AIOHttpBaseSession(IAsyncBaseSession):