diff --git a/src/socketio/asyncio_server.py b/src/socketio/asyncio_server.py index a9f440fa..ccd55fb8 100644 --- a/src/socketio/asyncio_server.py +++ b/src/socketio/asyncio_server.py @@ -40,6 +40,11 @@ class AsyncServer(server.Server): connect handler and your client is confused when it receives events before the connection acceptance. In any other case use the default of ``False``. + :param namespaces: a list of namespaces that are accepted, in addition to + any namespaces for which handlers have been defined. The + default is `['/']`, which always accepts connections to + the default namespace. Set to `'*'` to accept all + namespaces. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -97,11 +102,12 @@ class AsyncServer(server.Server): ``engineio_logger`` is ``False``. """ def __init__(self, client_manager=None, logger=False, json=None, - async_handlers=True, **kwargs): + async_handlers=True, namespaces=None, **kwargs): if client_manager is None: client_manager = asyncio_manager.AsyncManager() super().__init__(client_manager=client_manager, logger=logger, - json=json, async_handlers=async_handlers, **kwargs) + json=json, async_handlers=async_handlers, + namespaces=namespaces, **kwargs) def is_asyncio_based(self): return True @@ -443,7 +449,8 @@ async def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' sid = None - if namespace in self.handlers or namespace in self.namespace_handlers: + if namespace in self.handlers or namespace in self.namespace_handlers \ + or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: await self._send_packet(eio_sid, self.packet_class( diff --git a/src/socketio/server.py b/src/socketio/server.py index b2220db5..0456ed61 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -49,6 +49,11 @@ class Server(object): connect handler and your client is confused when it receives events before the connection acceptance. In any other case use the default of ``False``. + :param namespaces: a list of namespaces that are accepted, in addition to + any namespaces for which handlers have been defined. The + default is `['/']`, which always accepts connections to + the default namespace. Set to `'*'` to accept all + namespaces. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -110,7 +115,7 @@ class Server(object): def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, - **kwargs): + namespaces=None, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -157,6 +162,7 @@ def __init__(self, client_manager=None, logger=False, serializer='default', self.async_handlers = async_handlers self.always_connect = always_connect + self.namespaces = namespaces or ['/'] self.async_mode = self.eio.async_mode @@ -650,7 +656,8 @@ def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' sid = None - if namespace in self.handlers or namespace in self.namespace_handlers: + if namespace in self.handlers or namespace in self.namespace_handlers \ + or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: self._send_packet(eio_sid, self.packet_class( diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 01244ca8..eec531c7 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -425,12 +425,32 @@ def test_handle_connect_async(self, eio): _run(s._handle_eio_message('456', '0')) assert s.manager.initialize.call_count == 1 - def test_handle_connect_with_bad_namespace(self, eio): + def test_handle_connect_with_default_implied_namespaces(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0')) + _run(s._handle_eio_message('123', '0/foo,')) + assert s.manager.is_connected('1', '/') + assert not s.manager.is_connected('2', '/foo') + + def test_handle_connect_with_implied_namespaces(self, eio): + eio.return_value.send = AsyncMock() + s = asyncio_server.AsyncServer(namespaces=['/foo']) + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0')) + _run(s._handle_eio_message('123', '0/foo,')) assert not s.manager.is_connected('1', '/') + assert s.manager.is_connected('1', '/foo') + + def test_handle_connect_with_all_implied_namespaces(self, eio): + eio.return_value.send = AsyncMock() + s = asyncio_server.AsyncServer(namespaces='*') + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0')) + _run(s._handle_eio_message('123', '0/foo,')) + assert s.manager.is_connected('1', '/') + assert s.manager.is_connected('2', '/foo') def test_handle_connect_namespace(self, eio): eio.return_value.send = AsyncMock() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 8730dd8d..583f05b0 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -356,11 +356,29 @@ def test_handle_connect_with_auth_none(self, eio): s._handle_eio_connect('456', 'environ') assert s.manager.initialize.call_count == 1 - def test_handle_connect_with_bad_namespace(self, eio): + def test_handle_connect_with_default_implied_namespaces(self, eio): s = server.Server() s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') + s._handle_eio_message('123', '0/foo,') + assert s.manager.is_connected('1', '/') + assert not s.manager.is_connected('2', '/foo') + + def test_handle_connect_with_implied_namespaces(self, eio): + s = server.Server(namespaces=['/foo']) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + s._handle_eio_message('123', '0/foo,') assert not s.manager.is_connected('1', '/') + assert s.manager.is_connected('1', '/foo') + + def test_handle_connect_with_all_implied_namespaces(self, eio): + s = server.Server(namespaces='*') + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + s._handle_eio_message('123', '0/foo,') + assert s.manager.is_connected('1', '/') + assert s.manager.is_connected('2', '/foo') def test_handle_connect_namespace(self, eio): s = server.Server()