diff --git a/xconn/transports.py b/xconn/transports.py index 951ca06..ea5ad77 100644 --- a/xconn/transports.py +++ b/xconn/transports.py @@ -12,9 +12,9 @@ MSG_TYPE_WAMP, ) from websockets import State, Subprotocol -from websockets.sync.client import connect +from websockets.sync.client import connect, unix_connect from websockets.sync.connection import Connection -from websockets.asyncio.client import connect as async_connect +from websockets.asyncio.client import connect as async_connect, unix_connect as async_unix_connect from websockets.asyncio.client import ClientConnection from xconn.types import IAsyncTransport, ITransport, WebsocketConfig @@ -160,14 +160,27 @@ def __init__(self, websocket: Connection): @staticmethod def connect(uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig) -> "WebSocketTransport": - ws = connect( - uri, - subprotocols=subprotocols, - open_timeout=config.open_timeout, - ping_interval=config.ping_interval, - ping_timeout=config.ping_timeout, - close_timeout=config.close_timeout, - ) + parsed_url = urlparse(uri) + if parsed_url.scheme == "unix+ws": + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(parsed_url.path) + ws = unix_connect( + parsed_url.path, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) + else: + ws = connect( + uri, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) return WebSocketTransport(ws) @@ -196,14 +209,25 @@ def __init__(self, websocket: ClientConnection): async def connect( uri: str, subprotocols: Sequence[Subprotocol], config: WebsocketConfig ) -> "AsyncWebSocketTransport": - ws = await async_connect( - uri, - subprotocols=subprotocols, - open_timeout=config.open_timeout, - ping_interval=config.ping_interval, - ping_timeout=config.ping_timeout, - close_timeout=config.close_timeout, - ) + parsed_url = urlparse(uri) + if parsed_url.scheme == "unix+ws": + ws = await async_unix_connect( + parsed_url.path, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) + else: + ws = await async_connect( + uri, + subprotocols=subprotocols, + open_timeout=config.open_timeout, + ping_interval=config.ping_interval, + ping_timeout=config.ping_timeout, + close_timeout=config.close_timeout, + ) return AsyncWebSocketTransport(ws)