diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 1625b13e..c3b8591e 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -187,15 +187,18 @@ def rooms(self, sid, namespace=None): namespace = namespace or '/' return self.manager.get_rooms(sid, namespace) - def transport(self, sid): + def transport(self, sid, namespace=None): """Return the name of the transport used by the client. The two possible values returned by this function are ``'polling'`` and ``'websocket'``. :param sid: The session of the client. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. """ - return self.eio.transport(sid) + eio_sid = self.manager.eio_sid_from_sid(sid, namespace or '/') + return self.eio.transport(eio_sid) def get_environ(self, sid, namespace=None): """Return the WSGI environ dictionary for a client. diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 256545eb..471e562a 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -307,9 +307,9 @@ def test_transport(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer() s.eio.transport = mock.MagicMock(return_value='polling') - _run(s._handle_eio_connect('foo', 'environ')) - assert s.transport('foo') == 'polling' - s.eio.transport.assert_called_once_with('foo') + sid_foo = _run(s.manager.connect('123', '/foo')) + assert s.transport(sid_foo, '/foo') == 'polling' + s.eio.transport.assert_called_once_with('123') def test_handle_connect(self, eio): eio.return_value.send = AsyncMock() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 8f3a356b..b568c683 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -292,9 +292,9 @@ def test_send_eio_packet(self, eio): def test_transport(self, eio): s = server.Server() s.eio.transport = mock.MagicMock(return_value='polling') - s._handle_eio_connect('foo', 'environ') - assert s.transport('foo') == 'polling' - s.eio.transport.assert_called_once_with('foo') + sid_foo = s.manager.connect('123', '/foo') + assert s.transport(sid_foo, '/foo') == 'polling' + s.eio.transport.assert_called_once_with('123') def test_handle_connect(self, eio): s = server.Server()