diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 99af0672..b48327f7 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -617,30 +617,22 @@ async def _handle_ack(self, eio_sid, namespace, id, data): async def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - handler = None - if event in self.handlers[namespace]: - handler = self.handlers[namespace][event] - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - handler = self.handlers[namespace]['*'] - args = (event, *args) - if handler: - if asyncio.iscoroutinefunction(handler): - try: - ret = await handler(*args) - except asyncio.CancelledError: # pragma: no cover - ret = None - else: - ret = handler(*args) - return ret + handler, args = self._get_event_handler(event, namespace, *args) + if handler: + if asyncio.iscoroutinefunction(handler): + try: + ret = await handler(*args) + except asyncio.CancelledError: # pragma: no cover + ret = None else: - return self.not_handled - - # or else, forward the event to a namepsace handler if one exists - elif namespace in self.namespace_handlers: # pragma: no branch - return await self.namespace_handlers[namespace].trigger_event( - event, *args) + ret = handler(*args) + return ret + # or else, forward the event to a namespace handler if one exists + handler, args = self._get_namespace_handler(namespace, *args) + if handler: + return await handler.trigger_event(event, *args) + else: + return self.not_handled async def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index f8c90003..88e59f25 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -196,6 +196,54 @@ def get_environ(self, sid, namespace=None): eio_sid = self.manager.eio_sid_from_sid(sid, namespace or '/') return self.environ.get(eio_sid) + def _get_event_handler(self, event, namespace, *args): + """Return the appropriate application event handler. + + Resolution priority: + - self.handlers[namespace][event] + - self.handlers[namespace]["*"] + - self.handlers["*"][event] + - self.handlers["*"]["*"] + """ + handler = None + if namespace in self.handlers and \ + event in self.handlers[namespace]: + handler = self.handlers[namespace][event] + elif namespace in self.handlers and \ + event not in self.reserved_events and \ + '*' in self.handlers[namespace]: + handler = self.handlers[namespace]['*'] + args = (event, *args) + elif '*' in self.handlers and \ + event in self.handlers['*']: + handler = self.handlers['*'][event] + args = (namespace, *args) + elif '*' in self.handlers and \ + event not in self.reserved_events and \ + '*' in self.handlers['*']: + handler = self.handlers['*']['*'] + args = (event, namespace, *args) + else: + handler = None + return handler, args + + def _get_namespace_handler(self, namespace, *args): + """Return the appropriate application event handler. + + Resolution priority: + - self.namespace_handlers[namespace] + - self.namespace_handlers["*"] + """ + handler = None + if namespace in self.namespace_handlers: + handler = self.namespace_handlers[namespace] + elif '*' in self.namespace_handlers: + handler = self.namespace_handlers['*'] + args = (namespace, *args) + else: + handler = None + return handler, args + def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/server.py b/src/socketio/server.py index 20813374..afc1cb81 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -604,19 +604,15 @@ def _handle_ack(self, eio_sid, namespace, id, data): def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - if event in self.handlers[namespace]: - return self.handlers[namespace][event](*args) - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - return self.handlers[namespace]['*'](event, *args) - else: - return self.not_handled - + handler, args = self._get_event_handler(event, namespace, *args) + if handler: + return handler(*args) # or else, forward the event to a namespace handler if one exists - elif namespace in self.namespace_handlers: # pragma: no branch - return self.namespace_handlers[namespace].trigger_event( - event, *args) + handler, args = self._get_namespace_handler(namespace, *args) + if handler: + return handler.trigger_event(event, *args) + else: + return self.not_handled def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 2f84b5ff..3a79cdea 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -621,6 +621,30 @@ def test_handle_event_with_namespace(self, eio): catchall_handler.assert_called_once_with( 'my message', sid, 'a', 'b', 'c') + def test_handle_event_with_catchall_namespace(self, eio): + eio.return_value.send = AsyncMock() + s = async_server.AsyncServer(async_handlers=False) + sid_foo = _run(s.manager.connect('123', '/foo')) + sid_bar = _run(s.manager.connect('123', '/bar')) + msg_foo_handler = mock.MagicMock() + msg_star_handler = mock.MagicMock() + star_foo_handler = mock.MagicMock() + star_star_handler = mock.MagicMock() + s.on('msg', msg_foo_handler, namespace='/foo') + s.on('msg', msg_star_handler, namespace='*') + s.on('*', star_foo_handler, namespace='/foo') + s.on('*', star_star_handler, namespace='*') + _run(s._handle_eio_message('123', '2/foo,["msg","a","b"]')) + _run(s._handle_eio_message('123', '2/bar,["msg","a","b"]')) + _run(s._handle_eio_message('123', '2/foo,["my message","a","b","c"]')) + _run(s._handle_eio_message('123', '2/bar,["my message","a","b","c"]')) + msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') + msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') + star_foo_handler.assert_called_once_with( + 'my message', sid_foo, 'a', 'b', 'c') + star_star_handler.assert_called_once_with( + 'my message', '/bar', sid_bar, 'a', 'b', 'c') + def test_handle_event_with_disconnected_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) @@ -904,6 +928,40 @@ async def on_baz(self, sid, data1, data2): _run(s.disconnect('1', '/foo')) assert result['result'] == ('disconnect', '1') + def test_catchall_namespace_handler(self, eio): + eio.return_value.send = AsyncMock() + result = {} + + class MyNamespace(async_namespace.AsyncNamespace): + def on_connect(self, ns, sid, environ): + result['result'] = (sid, ns, environ) + + async def on_disconnect(self, ns, sid): + result['result'] = ('disconnect', sid, ns) + + async def on_foo(self, ns, sid, data): + result['result'] = (sid, ns, data) + + def on_bar(self, ns, sid): + result['result'] = 'bar' + ns + + async def on_baz(self, ns, sid, data1, data2): + result['result'] = (ns, data1, data2) + + s = async_server.AsyncServer(async_handlers=False, namespaces='*') + s.register_namespace(MyNamespace('*')) + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo,')) + assert result['result'] == ('1', '/foo', 'environ') + _run(s._handle_eio_message('123', '2/foo,["foo","a"]')) + assert result['result'] == ('1', '/foo', 'a') + _run(s._handle_eio_message('123', '2/foo,["bar"]')) + assert result['result'] == 'bar/foo' + _run(s._handle_eio_message('123', '2/foo,["baz","a","b"]')) + assert result['result'] == ('/foo', 'a', 'b') + _run(s.disconnect('1', '/foo')) + assert result['result'] == ('disconnect', '1', '/foo') + def test_bad_namespace_handler(self, eio): class Dummy(object): pass diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 08c59ac8..7dc7d21a 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -574,6 +574,29 @@ def test_handle_event_with_namespace(self, eio): catchall_handler.assert_called_once_with( 'my message', '1', 'a', 'b', 'c') + def test_handle_event_with_catchall_namespace(self, eio): + s = server.Server(async_handlers=False) + sid_foo = s.manager.connect('123', '/foo') + sid_bar = s.manager.connect('123', '/bar') + msg_foo_handler = mock.MagicMock() + msg_star_handler = mock.MagicMock() + star_foo_handler = mock.MagicMock() + star_star_handler = mock.MagicMock() + s.on('msg', msg_foo_handler, namespace='/foo') + s.on('msg', msg_star_handler, namespace='*') + s.on('*', star_foo_handler, namespace='/foo') + s.on('*', star_star_handler, namespace='*') + s._handle_eio_message('123', '2/foo,["msg","a","b"]') + s._handle_eio_message('123', '2/bar,["msg","a","b"]') + s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') + s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') + msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') + msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') + star_foo_handler.assert_called_once_with( + 'my message', sid_foo, 'a', 'b', 'c') + star_star_handler.assert_called_once_with( + 'my message', '/bar', sid_bar, 'a', 'b', 'c') + def test_handle_event_with_disconnected_namespace(self, eio): s = server.Server(async_handlers=False) s.manager.connect('123', '/foo') @@ -815,6 +838,39 @@ def on_baz(self, sid, data1, data2): s.disconnect('1', '/foo') assert result['result'] == ('disconnect', '1') + def test_catchall_namespace_handler(self, eio): + result = {} + + class MyNamespace(namespace.Namespace): + def on_connect(self, ns, sid, environ): + result['result'] = (sid, ns, environ) + + def on_disconnect(self, ns, sid): + result['result'] = ('disconnect', sid, ns) + + def on_foo(self, ns, sid, data): + result['result'] = (sid, ns, data) + + def on_bar(self, ns, sid): + result['result'] = 'bar' + ns + + def on_baz(self, ns, sid, data1, data2): + result['result'] = (ns, data1, data2) + + s = server.Server(async_handlers=False, namespaces='*') + s.register_namespace(MyNamespace('*')) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo,') + assert result['result'] == ('1', '/foo', 'environ') + s._handle_eio_message('123', '2/foo,["foo","a"]') + assert result['result'] == ('1', '/foo', 'a') + s._handle_eio_message('123', '2/foo,["bar"]') + assert result['result'] == 'bar/foo' + s._handle_eio_message('123', '2/foo,["baz","a","b"]') + assert result['result'] == ('/foo', 'a', 'b') + s.disconnect('1', '/foo') + assert result['result'] == ('disconnect', '1', '/foo') + def test_bad_namespace_handler(self, eio): class Dummy(object): pass