Skip to content

make websocket joiners use the transport interface #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/acceptor_test.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ def test_accept():

j = WebsocketsJoiner()
client_base_session = j.join(f"ws://localhost:{port}/ws", "realm1")
client_base_session.ws.close()
client_base_session.close()

thread.join()
server_base_session = result[0]
5 changes: 2 additions & 3 deletions xconn/async_session.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
from asyncio import Future, get_event_loop
from typing import Callable, Union, Awaitable, Any

from websockets.protocol import State
from wampproto import messages, idgen, session

from xconn import types, uris as xconn_uris, exception
@@ -80,7 +79,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
self.wait_task = loop.create_task(self.wait())

async def wait(self):
while self.base_session.ws.state == State.OPEN:
while await self.base_session.is_connected():
try:
data = await self.base_session.receive()
except Exception as e:
@@ -268,7 +267,7 @@ async def leave(self) -> None:
except asyncio.CancelledError:
pass

if self.base_session.ws.state != State.CLOSED:
if self.base_session.is_connected():
await self.base_session.close()

async def ping(self) -> None:
33 changes: 13 additions & 20 deletions xconn/joiner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from websockets.sync.client import connect
from websockets.asyncio.client import connect as async_connect
from wampproto import joiner, serializers, auth
from wampproto.joiner import Joiner

from xconn import types, helpers
from xconn.transports import WebSocketTransport, AsyncWebSocketTransport


class WebsocketsJoiner:
@@ -18,25 +17,22 @@ def __init__(
self._ws_config = ws_config

def join(self, uri: str, realm: str) -> types.BaseSession:
ws = connect(
transport = WebSocketTransport.connect(
uri,
subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)],
open_timeout=self._ws_config.open_timeout,
ping_interval=self._ws_config.ping_interval,
ping_timeout=self._ws_config.ping_timeout,
close_timeout=self._ws_config.close_timeout,
config=self._ws_config,
)

j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator)
ws.send(j.send_hello())
transport.write(j.send_hello())

while True:
data = ws.recv()
data = transport.read()
to_send = j.receive(data)
if to_send is None:
return types.BaseSession(ws, j.get_session_details(), self._serializer)
return types.BaseSession(transport, j.get_session_details(), self._serializer)

ws.send(to_send)
transport.write(to_send)


class AsyncWebsocketsJoiner:
@@ -51,22 +47,19 @@ def __init__(
self._serializer = serializer

async def join(self, uri: str, realm: str) -> types.AsyncBaseSession:
ws = await async_connect(
transport = await AsyncWebSocketTransport.connect(
uri,
subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)],
open_timeout=self._ws_config.open_timeout,
ping_interval=self._ws_config.ping_interval,
ping_timeout=self._ws_config.ping_timeout,
close_timeout=self._ws_config.close_timeout,
config=self._ws_config,
)

j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator)
await ws.send(j.send_hello())
await transport.write(j.send_hello())

while True:
data = await ws.recv()
data = await transport.read()
to_send = j.receive(data)
if to_send is None:
return types.AsyncBaseSession(ws, j.get_session_details(), self._serializer)
return types.AsyncBaseSession(transport, j.get_session_details(), self._serializer)

await ws.send(to_send)
await transport.write(to_send)
3 changes: 1 addition & 2 deletions xconn/session.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
from threading import Thread
from typing import Callable, Any

from websockets.protocol import State
from wampproto import messages, idgen, session, uris

from xconn import types, exception, uris as xconn_uris
@@ -76,7 +75,7 @@ def __init__(self, base_session: types.BaseSession):
thread.start()

def wait(self):
while self.base_session.ws.state == State.OPEN:
while self.base_session.is_connected():
try:
data = self.base_session.receive()
except Exception:
72 changes: 71 additions & 1 deletion xconn/transports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import socket
from asyncio import StreamReader, StreamWriter
from typing import Sequence

from wampproto.transports.rawsocket import (
Handshake,
@@ -9,8 +10,13 @@
SERIALIZER_TYPE_CBOR,
MSG_TYPE_WAMP,
)
from websockets import State, Subprotocol
from websockets.sync.client import connect
from websockets.sync.connection import Connection
from websockets.asyncio.client import connect as async_connect
from websockets.asyncio.client import ClientConnection

from xconn.types import IAsyncTransport, ITransport
from xconn.types import IAsyncTransport, ITransport, WebsocketConfig

# Applies to handshake and message itself.
RAW_SOCKET_HEADER_LENGTH = 4
@@ -114,3 +120,67 @@ async def is_connected(self) -> bool:
return True
except (BrokenPipeError, ConnectionResetError, OSError):
return False


class WebSocketTransport(ITransport):
def __init__(self, websocket: Connection):
super().__init__()
self._websocket = websocket

@staticmethod
def connect(uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig) -> "WebSocketTransport":
ws = connect(
uri,
subprotocols=subprotocols,
open_timeout=config.open_timeout,
ping_interval=config.ping_interval,
ping_timeout=config.ping_timeout,
close_timeout=config.close_timeout,
)

return WebSocketTransport(ws)

def read(self) -> str | bytes:
return self._websocket.recv()

def write(self, data: str | bytes):
self._websocket.send(data)

def close(self):
self._websocket.close()

def is_connected(self) -> bool:
return self._websocket.state == State.OPEN


class AsyncWebSocketTransport(IAsyncTransport):
def __init__(self, websocket: ClientConnection):
super().__init__()
self._websocket = websocket

@staticmethod
async def connect(
uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig
) -> "AsyncWebSocketTransport":
ws = await async_connect(
uri,
subprotocols=subprotocols,
open_timeout=config.open_timeout,
ping_interval=config.ping_interval,
ping_timeout=config.ping_timeout,
close_timeout=config.close_timeout,
)

return AsyncWebSocketTransport(ws)

async def read(self) -> str | bytes:
return await self._websocket.recv()

async def write(self, data: str | bytes):
await self._websocket.send(data)

async def close(self):
await self._websocket.close()

async def is_connected(self) -> bool:
return self._websocket.state == State.OPEN
40 changes: 26 additions & 14 deletions xconn/types.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,7 @@
from typing import Callable, Awaitable

from aiohttp import web
from websockets.sync.connection import Connection
from wampproto import messages, joiner, serializers
from websockets.asyncio.client import ClientConnection


@dataclass
@@ -193,11 +191,16 @@ def receive_message(self) -> messages.Message:
def close(self):
raise NotImplementedError()

def is_connected(self) -> bool:
raise NotImplementedError()


class BaseSession(IBaseSession):
def __init__(self, ws: Connection, session_details: joiner.SessionDetails, serializer: serializers.Serializer):
def __init__(
self, transport: ITransport, session_details: joiner.SessionDetails, serializer: serializers.Serializer
):
super().__init__()
self.ws = ws
self._transport = transport
self.session_details = session_details
self.serializer = serializer

@@ -218,19 +221,22 @@ def authrole(self) -> str:
return self.session_details.authrole

def send(self, data: bytes):
self.ws.send(data)
self._transport.write(data)

def receive(self) -> bytes:
return self.ws.recv()
return self._transport.read()

def send_message(self, msg: messages.Message):
self.ws.send(self.serializer.serialize(msg))
self.send(self.serializer.serialize(msg))

def receive_message(self) -> messages.Message:
return self.serializer.deserialize(self.receive())

def close(self):
self.ws.close()
self._transport.close()

def is_connected(self) -> bool:
return self._transport.is_connected()


class IAsyncBaseSession:
@@ -269,13 +275,16 @@ async def receive_message(self) -> messages.Message:
async def close(self):
raise NotImplementedError()

async def is_connected(self) -> bool:
raise NotImplementedError()


class AsyncBaseSession(IAsyncBaseSession):
def __init__(
self, ws: ClientConnection, session_details: joiner.SessionDetails, serializer: serializers.Serializer
self, transport: IAsyncTransport, session_details: joiner.SessionDetails, serializer: serializers.Serializer
):
super().__init__()
self.ws = ws
self._transport = transport
self.session_details = session_details
self._serializer = serializer

@@ -300,19 +309,22 @@ def serializer(self) -> serializers.Serializer:
return self._serializer

async def send(self, data: bytes):
return await self.ws.send(data)
return await self._transport.write(data)

async def receive(self) -> bytes:
return await self.ws.recv()
return await self._transport.read()

async def send_message(self, msg: messages.Message):
await self.ws.send(self.serializer.serialize(msg))
await self.send(self.serializer.serialize(msg))

async def receive_message(self) -> messages.Message:
return self.serializer.deserialize(await self.receive())

async def close(self):
await self.ws.close()
await self._transport.close()

async def is_connected(self) -> bool:
return await self._transport.is_connected()


class AIOHttpBaseSession(IAsyncBaseSession):