diff --git a/scaler/io/utility.py b/scaler/io/utility.py index 59e02ebb..682909c4 100644 --- a/scaler/io/utility.py +++ b/scaler/io/utility.py @@ -11,6 +11,11 @@ from scaler.protocol.python.message import PROTOCOL from scaler.protocol.python.mixins import Message +try: + from collections.abc import Buffer # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Buffer + def get_scaler_network_backend_from_env(): backend_str = os.environ.get("SCALER_NETWORK_BACKEND", "tcp_zmq") # Default to tcp_zmq @@ -51,7 +56,7 @@ def create_sync_object_storage_connector(*args, **kwargs) -> SyncObjectStorageCo ) -def deserialize(data: bytes) -> Optional[Message]: +def deserialize(data: Buffer) -> Optional[Message]: with _message.Message.from_bytes(data, traversal_limit_in_words=CAPNP_MESSAGE_SIZE_LIMIT) as payload: if not hasattr(payload, payload.which()): logging.error(f"unknown message type: {payload.which()}") diff --git a/scaler/io/ymq/_ymq.pyi b/scaler/io/ymq/_ymq.pyi index f27e9b45..4887589f 100644 --- a/scaler/io/ymq/_ymq.pyi +++ b/scaler/io/ymq/_ymq.pyi @@ -1,13 +1,12 @@ # NOTE: NOT IMPLEMENTATION, TYPE INFORMATION ONLY # This file contains type stubs for the Ymq Python C Extension module -import sys from enum import IntEnum from typing import Callable, Optional, SupportsBytes, Union -if sys.version_info >= (3, 12): - from collections.abc import Buffer -else: - Buffer = object +try: + from collections.abc import Buffer # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Buffer class Bytes(Buffer): data: bytes | None diff --git a/scaler/io/ymq/pymod_ymq/io_socket.h b/scaler/io/ymq/pymod_ymq/io_socket.h index ab02f0cb..d2ec8e35 100644 --- a/scaler/io/ymq/pymod_ymq/io_socket.h +++ b/scaler/io/ymq/pymod_ymq/io_socket.h @@ -38,9 +38,13 @@ extern "C" { static void PyIOSocket_dealloc(PyIOSocket* self) { try { - self->ioContext->removeIOSocket(self->socket); - self->ioContext.~shared_ptr(); - self->socket.~shared_ptr(); + std::shared_ptr ioSocket = std::move(self->socket); + std::shared_ptr ioContext = std::move(self->ioContext); + + // this function blocks so it's important to release the GIL + Py_BEGIN_ALLOW_THREADS; + ioContext->removeIOSocket(ioSocket); + Py_END_ALLOW_THREADS; } catch (...) { PyErr_SetString(PyExc_RuntimeError, "Failed to deallocate IOSocket"); PyErr_WriteUnraisable((PyObject*)self); diff --git a/scaler/io/ymq/tcp_client.cpp b/scaler/io/ymq/tcp_client.cpp index f90d191c..caabdee2 100644 --- a/scaler/io/ymq/tcp_client.cpp +++ b/scaler/io/ymq/tcp_client.cpp @@ -70,6 +70,7 @@ void TcpClient::onCreated() sock->onConnectionCreated(setNoDelay(sockfd), getLocalAddr(sockfd), getRemoteAddr(sockfd), responsibleForRetry); if (_retryTimes == 0) { _onConnectReturn({}); + _onConnectReturn = {}; } return; } @@ -78,6 +79,7 @@ void TcpClient::onCreated() _eventLoopThread->_eventLoop.addFdToLoop(sockfd, EPOLLOUT | EPOLLET, this->_eventManager.get()); if (_retryTimes == 0) { _onConnectReturn(std::unexpected {Error::ErrorCode::InitialConnectFailedWithInProgress}); + _onConnectReturn = {}; } return; } @@ -202,6 +204,7 @@ void TcpClient::onCreated() if (myErrno == ERROR_IO_PENDING) { if (_retryTimes == 0) { _onConnectReturn(std::unexpected {Error::ErrorCode::InitialConnectFailedWithInProgress}); + _onConnectReturn = {}; } return; } @@ -361,6 +364,10 @@ TcpClient::~TcpClient() noexcept if (_retryTimes > 0) { _eventLoopThread->_eventLoop.cancelExecution(_retryIdentifier); } + // TODO: Do we think this is an error? See TcpServer::~TcpServer for detail. + if (_onConnectReturn) { + _onConnectReturn({}); + } } } // namespace ymq diff --git a/scaler/io/ymq/tcp_server.cpp b/scaler/io/ymq/tcp_server.cpp index fc49bdaf..4eae0f32 100644 --- a/scaler/io/ymq/tcp_server.cpp +++ b/scaler/io/ymq/tcp_server.cpp @@ -158,6 +158,7 @@ int TcpServer::createAndBindSocket() ); CloseAndZeroSocket(server_fd); _onBindReturn(std::unexpected(Error {Error::ErrorCode::SetSockOptNonFatalFailure})); + _onBindReturn = {}; return -1; } @@ -257,6 +258,7 @@ void TcpServer::onCreated() #endif // _WIN32 _onBindReturn({}); + _onBindReturn = {}; } void TcpServer::onRead() @@ -404,6 +406,13 @@ TcpServer::~TcpServer() noexcept _eventLoopThread->_eventLoop.removeFdFromLoop(_serverFd); CloseAndZeroSocket(_serverFd); } + // TODO: Do we think this is an error? In extreme cases: + // bindTo(...); + // removeIOSocket(...); + // Below callback may not be called. + if (_onBindReturn) { + _onBindReturn({}); + } } } // namespace ymq diff --git a/tests/pymod_ymq/__init__.py b/tests/pymod_ymq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pymod_ymq/test_pymod_ymq.py b/tests/pymod_ymq/test_pymod_ymq.py new file mode 100644 index 00000000..f1804d52 --- /dev/null +++ b/tests/pymod_ymq/test_pymod_ymq.py @@ -0,0 +1,147 @@ +import asyncio +import unittest +from scaler.io.ymq import ymq +from scaler.io.utility import serialize, deserialize +from scaler.protocol.python.message import TaskCancel +from scaler.utility.identifiers import TaskID + + +class TestPymodYMQ(unittest.IsolatedAsyncioTestCase): + async def test_basic(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + self.assertEqual(binder.identity, "binder") + self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder) + + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + self.assertEqual(connector.identity, "connector") + self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector) + + address = "tcp://127.0.0.1:35793" + await binder.bind(address) + await connector.connect(address) + + await connector.send(ymq.Message(address=None, payload=b"payload")) + msg = await binder.recv() + + assert msg.address is not None + self.assertEqual(msg.address.data, b"connector") + self.assertEqual(msg.payload.data, b"payload") + + async def test_no_address(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + + address = "tcp://127.0.0.1:35794" + await binder.bind(address) + await connector.connect(address) + + with self.assertRaises(ymq.YMQException) as exc: + await binder.send(ymq.Message(address=None, payload=b"payload")) + self.assertEqual(exc.exception.code, ymq.ErrorCode.BinderSendMessageWithNoAddress) + + async def test_routing(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector1 = await ctx.createIOSocket("connector1", ymq.IOSocketType.Connector) + connector2 = await ctx.createIOSocket("connector2", ymq.IOSocketType.Connector) + + address = "tcp://127.0.0.1:35795" + await binder.bind(address) + await connector1.connect(address) + await connector2.connect(address) + + await binder.send(ymq.Message(b"connector2", b"2")) + await binder.send(ymq.Message(b"connector1", b"1")) + + msg1 = await connector1.recv() + self.assertEqual(msg1.payload.data, b"1") + + msg2 = await connector2.recv() + self.assertEqual(msg2.payload.data, b"2") + + async def test_pingpong(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + + address = "tcp://127.0.0.1:35791" + await binder.bind(address) + await connector.connect(address) + + async def binder_routine(binder: ymq.IOSocket, limit: int) -> bool: + i = 0 + while i < limit: + await binder.send(ymq.Message(address=b"connector", payload=f"{i}".encode())) + msg = await binder.recv() + assert msg.payload.data is not None + + recv_i = int(msg.payload.data.decode()) + if recv_i - i > 1: + return False + i = recv_i + 1 + return True + + async def connector_routine(connector: ymq.IOSocket, limit: int) -> bool: + i = 0 + while True: + msg = await connector.recv() + assert msg.payload.data is not None + recv_i = int(msg.payload.data.decode()) + if recv_i - i > 1: + return False + i = recv_i + 1 + await connector.send(ymq.Message(address=None, payload=f"{i}".encode())) + + # when the connector sends `limit - 1`, we're done + if i >= limit - 1: + break + return True + + binder_success, connector_success = await asyncio.gather( + binder_routine(binder, 100), connector_routine(connector, 100) + ) + + if not binder_success: + self.fail("binder failed") + + if not connector_success: + self.fail("connector failed") + + async def test_big_message(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + self.assertEqual(binder.identity, "binder") + self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder) + + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + self.assertEqual(connector.identity, "connector") + self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector) + + address = "tcp://127.0.0.1:35792" + await binder.bind(address) + await connector.connect(address) + + for _ in range(10): + await connector.send(ymq.Message(address=None, payload=b"." * 500_000_000)) + msg = await binder.recv() + + assert msg.address is not None + self.assertEqual(msg.address.data, b"connector") + self.assertEqual(msg.payload.data, b"." * 500_000_000) + + async def test_buffer_interface(self): + msg = TaskCancel.new_msg(TaskID.generate_task_id()) + data = serialize(msg) + + # verify that capnp can deserialize this data + _ = deserialize(data) + + # this creates a copy of the data + copy = ymq.Bytes(data) + + # this should deserialize without creating a copy + # because ymq.Bytes uses the buffer protocol + deserialized: TaskCancel = deserialize(copy) # type: ignore + self.assertEqual(deserialized.task_id, msg.task_id) diff --git a/tests/pymod_ymq/test_types.py b/tests/pymod_ymq/test_types.py new file mode 100644 index 00000000..996110b0 --- /dev/null +++ b/tests/pymod_ymq/test_types.py @@ -0,0 +1,88 @@ +import unittest +from enum import IntEnum +from scaler.io.ymq import ymq +import array + + +class TestTypes(unittest.TestCase): + def test_exception(self): + # type checkers misidentify this as "unnecessary" due to the type hints file + self.assertTrue(issubclass(ymq.YMQException, Exception)) # type: ignore + + exc = ymq.YMQException(ymq.ErrorCode.CoreBug, "oh no") + self.assertEqual(exc.args, (ymq.ErrorCode.CoreBug, "oh no")) + self.assertEqual(exc.code, ymq.ErrorCode.CoreBug) + self.assertEqual(exc.message, "oh no") + + def test_error_code(self): + self.assertTrue(issubclass(ymq.ErrorCode, IntEnum)) # type: ignore + self.assertEqual( + ymq.ErrorCode.ConfigurationError.explanation(), + "An error generated by system call that's likely due to mis-configuration", + ) + + def test_bytes(self): + b = ymq.Bytes(b"data") + self.assertEqual(b.len, len(b)) + self.assertEqual(b.len, 4) + self.assertEqual(b.data, b"data") + + # would raise an exception if ymq.Bytes didn't support the buffer interface + m = memoryview(b) + self.assertTrue(m.obj is b) + self.assertEqual(m.tobytes(), b"data") + + b = ymq.Bytes() + self.assertEqual(b.len, 0) + self.assertTrue(b.data is None) + + b = ymq.Bytes(b"") + self.assertEqual(b.len, 0) + self.assertEqual(b.data, b"") + + b = ymq.Bytes(array.array("B", [115, 99, 97, 108, 101, 114])) + assert b.len == 6 + assert b.data == b"scaler" + + def test_message(self): + m = ymq.Message(b"address", b"payload") + assert m.address is not None + self.assertEqual(m.address.data, b"address") + self.assertEqual(m.payload.data, b"payload") + + m = ymq.Message(address=None, payload=ymq.Bytes(b"scaler")) + self.assertTrue(m.address is None) + self.assertEqual(m.payload.data, b"scaler") + + m = ymq.Message(b"address", payload=b"payload") + assert m.address is not None + self.assertEqual(m.address.data, b"address") + self.assertEqual(m.payload.data, b"payload") + + def test_io_context(self): + ctx = ymq.IOContext() + self.assertEqual(ctx.num_threads, 1) + + ctx = ymq.IOContext(2) + self.assertEqual(ctx.num_threads, 2) + + ctx = ymq.IOContext(num_threads=3) + self.assertEqual(ctx.num_threads, 3) + + # TODO: backporting to 3.8 broke this somehow + # it causes a segmentation fault + # re-enable once fixed + @unittest.skip("causes segmentation fault") + def test_io_socket(self): + # check that we can't create io socket instances directly + self.assertRaises(TypeError, lambda: ymq.IOSocket()) # type: ignore + + def test_io_socket_type(self): + self.assertTrue(issubclass(ymq.IOSocketType, IntEnum)) # type: ignore + + def test_bad_socket_type(self): + ctx = ymq.IOContext() + + # TODO: should the core reject this? + socket = ctx.createIOSocket_sync("identity", ymq.IOSocketType.Uninit) + self.assertEqual(socket.socket_type, ymq.IOSocketType.Uninit)