Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fb23fc6
Add YMQ Python Tests
magniloquency Oct 10, 2025
e904792
Fix no_address test
magniloquency Oct 10, 2025
a47f50e
Remove interrupted exception test
magniloquency Oct 10, 2025
a3b5d4f
Fix typing, lint
magniloquency Oct 10, 2025
bb81105
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 15, 2025
584fb01
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 16, 2025
931a9a6
Fix linter
magniloquency Oct 16, 2025
7b1f497
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 17, 2025
a5933e2
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 17, 2025
f0448bb
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 17, 2025
442b090
refactor addresses
magniloquency Oct 20, 2025
f385550
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 20, 2025
7ca26e2
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 20, 2025
b83529d
Fix GIL
magniloquency Oct 21, 2025
9ae0a6a
Apply gxu's patch
magniloquency Oct 21, 2025
b5ee3e3
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 21, 2025
0c8aed2
delete assert
magniloquency Oct 21, 2025
897c5a5
Merge branch 'ymq-pymod-tests-3' of https://github.com/magniloquency/…
magniloquency Oct 21, 2025
40846e8
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 21, 2025
885aa3f
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 21, 2025
9668062
remove print include
magniloquency Oct 21, 2025
d515c6d
remove cassert include
magniloquency Oct 21, 2025
d457b03
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 22, 2025
670e34b
Merge branch 'main' into ymq-pymod-tests-3
magniloquency Oct 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scaler/io/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from scaler.protocol.python.message import PROTOCOL
from scaler.protocol.python.mixins import Message

try:
from collections.abc import Buffer # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Buffer


def get_scaler_network_backend_from_env():
backend_str = os.environ.get("SCALER_NETWORK_BACKEND", "tcp_zmq") # Default to tcp_zmq
Expand Down Expand Up @@ -51,7 +56,7 @@ def create_sync_object_storage_connector(*args, **kwargs) -> SyncObjectStorageCo
)


def deserialize(data: bytes) -> Optional[Message]:
def deserialize(data: Buffer) -> Optional[Message]:
with _message.Message.from_bytes(data, traversal_limit_in_words=CAPNP_MESSAGE_SIZE_LIMIT) as payload:
if not hasattr(payload, payload.which()):
logging.error(f"unknown message type: {payload.which()}")
Expand Down
9 changes: 4 additions & 5 deletions scaler/io/ymq/_ymq.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# NOTE: NOT IMPLEMENTATION, TYPE INFORMATION ONLY
# This file contains type stubs for the Ymq Python C Extension module
import sys
from enum import IntEnum
from typing import Callable, Optional, SupportsBytes, Union

if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
Buffer = object
try:
from collections.abc import Buffer # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Buffer

class Bytes(Buffer):
data: bytes | None
Expand Down
10 changes: 7 additions & 3 deletions scaler/io/ymq/pymod_ymq/io_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ extern "C" {
static void PyIOSocket_dealloc(PyIOSocket* self)
{
try {
self->ioContext->removeIOSocket(self->socket);
self->ioContext.~shared_ptr();
self->socket.~shared_ptr();
std::shared_ptr<IOSocket> ioSocket = std::move(self->socket);
std::shared_ptr<IOContext> ioContext = std::move(self->ioContext);

// this function blocks so it's important to release the GIL
Py_BEGIN_ALLOW_THREADS;
ioContext->removeIOSocket(ioSocket);
Py_END_ALLOW_THREADS;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Failed to deallocate IOSocket");
PyErr_WriteUnraisable((PyObject*)self);
Expand Down
7 changes: 7 additions & 0 deletions scaler/io/ymq/tcp_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void TcpClient::onCreated()
sock->onConnectionCreated(setNoDelay(sockfd), getLocalAddr(sockfd), getRemoteAddr(sockfd), responsibleForRetry);
if (_retryTimes == 0) {
_onConnectReturn({});
_onConnectReturn = {};
}
return;
}
Expand All @@ -78,6 +79,7 @@ void TcpClient::onCreated()
_eventLoopThread->_eventLoop.addFdToLoop(sockfd, EPOLLOUT | EPOLLET, this->_eventManager.get());
if (_retryTimes == 0) {
_onConnectReturn(std::unexpected {Error::ErrorCode::InitialConnectFailedWithInProgress});
_onConnectReturn = {};
}
return;
}
Expand Down Expand Up @@ -202,6 +204,7 @@ void TcpClient::onCreated()
if (myErrno == ERROR_IO_PENDING) {
if (_retryTimes == 0) {
_onConnectReturn(std::unexpected {Error::ErrorCode::InitialConnectFailedWithInProgress});
_onConnectReturn = {};
}
return;
}
Expand Down Expand Up @@ -361,6 +364,10 @@ TcpClient::~TcpClient() noexcept
if (_retryTimes > 0) {
_eventLoopThread->_eventLoop.cancelExecution(_retryIdentifier);
}
// TODO: Do we think this is an error? See TcpServer::~TcpServer for detail.
if (_onConnectReturn) {
_onConnectReturn({});
}
}

} // namespace ymq
Expand Down
9 changes: 9 additions & 0 deletions scaler/io/ymq/tcp_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ int TcpServer::createAndBindSocket()
);
CloseAndZeroSocket(server_fd);
_onBindReturn(std::unexpected(Error {Error::ErrorCode::SetSockOptNonFatalFailure}));
_onBindReturn = {};
return -1;
}

Expand Down Expand Up @@ -257,6 +258,7 @@ void TcpServer::onCreated()
#endif // _WIN32

_onBindReturn({});
_onBindReturn = {};
}

void TcpServer::onRead()
Expand Down Expand Up @@ -404,6 +406,13 @@ TcpServer::~TcpServer() noexcept
_eventLoopThread->_eventLoop.removeFdFromLoop(_serverFd);
CloseAndZeroSocket(_serverFd);
}
// TODO: Do we think this is an error? In extreme cases:
// bindTo(...);
// removeIOSocket(...);
// Below callback may not be called.
if (_onBindReturn) {
_onBindReturn({});
}
}

} // namespace ymq
Expand Down
Empty file added tests/pymod_ymq/__init__.py
Empty file.
147 changes: 147 additions & 0 deletions tests/pymod_ymq/test_pymod_ymq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
import unittest
from scaler.io.ymq import ymq
from scaler.io.utility import serialize, deserialize
from scaler.protocol.python.message import TaskCancel
from scaler.utility.identifiers import TaskID


class TestPymodYMQ(unittest.IsolatedAsyncioTestCase):
async def test_basic(self):
ctx = ymq.IOContext()
binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder)
self.assertEqual(binder.identity, "binder")
self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder)

connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector)
self.assertEqual(connector.identity, "connector")
self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector)

address = "tcp://127.0.0.1:35793"
await binder.bind(address)
await connector.connect(address)

await connector.send(ymq.Message(address=None, payload=b"payload"))
msg = await binder.recv()

assert msg.address is not None
self.assertEqual(msg.address.data, b"connector")
self.assertEqual(msg.payload.data, b"payload")

async def test_no_address(self):
ctx = ymq.IOContext()
binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder)
connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector)

address = "tcp://127.0.0.1:35794"
await binder.bind(address)
await connector.connect(address)

with self.assertRaises(ymq.YMQException) as exc:
await binder.send(ymq.Message(address=None, payload=b"payload"))
self.assertEqual(exc.exception.code, ymq.ErrorCode.BinderSendMessageWithNoAddress)

async def test_routing(self):
ctx = ymq.IOContext()
binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder)
connector1 = await ctx.createIOSocket("connector1", ymq.IOSocketType.Connector)
connector2 = await ctx.createIOSocket("connector2", ymq.IOSocketType.Connector)

address = "tcp://127.0.0.1:35795"
await binder.bind(address)
await connector1.connect(address)
await connector2.connect(address)

await binder.send(ymq.Message(b"connector2", b"2"))
await binder.send(ymq.Message(b"connector1", b"1"))

msg1 = await connector1.recv()
self.assertEqual(msg1.payload.data, b"1")

msg2 = await connector2.recv()
self.assertEqual(msg2.payload.data, b"2")

async def test_pingpong(self):
ctx = ymq.IOContext()
binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder)
connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector)

address = "tcp://127.0.0.1:35791"
await binder.bind(address)
await connector.connect(address)

async def binder_routine(binder: ymq.IOSocket, limit: int) -> bool:
i = 0
while i < limit:
await binder.send(ymq.Message(address=b"connector", payload=f"{i}".encode()))
msg = await binder.recv()
assert msg.payload.data is not None

recv_i = int(msg.payload.data.decode())
if recv_i - i > 1:
return False
i = recv_i + 1
return True

async def connector_routine(connector: ymq.IOSocket, limit: int) -> bool:
i = 0
while True:
msg = await connector.recv()
assert msg.payload.data is not None
recv_i = int(msg.payload.data.decode())
if recv_i - i > 1:
return False
i = recv_i + 1
await connector.send(ymq.Message(address=None, payload=f"{i}".encode()))

# when the connector sends `limit - 1`, we're done
if i >= limit - 1:
break
return True

binder_success, connector_success = await asyncio.gather(
binder_routine(binder, 100), connector_routine(connector, 100)
)

if not binder_success:
self.fail("binder failed")

if not connector_success:
self.fail("connector failed")

async def test_big_message(self):
ctx = ymq.IOContext()
binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder)
self.assertEqual(binder.identity, "binder")
self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder)

connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector)
self.assertEqual(connector.identity, "connector")
self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector)

address = "tcp://127.0.0.1:35792"
await binder.bind(address)
await connector.connect(address)

for _ in range(10):
await connector.send(ymq.Message(address=None, payload=b"." * 500_000_000))
msg = await binder.recv()

assert msg.address is not None
self.assertEqual(msg.address.data, b"connector")
self.assertEqual(msg.payload.data, b"." * 500_000_000)

async def test_buffer_interface(self):
msg = TaskCancel.new_msg(TaskID.generate_task_id())
data = serialize(msg)

# verify that capnp can deserialize this data
_ = deserialize(data)

# this creates a copy of the data
copy = ymq.Bytes(data)

# this should deserialize without creating a copy
# because ymq.Bytes uses the buffer protocol
deserialized: TaskCancel = deserialize(copy) # type: ignore
self.assertEqual(deserialized.task_id, msg.task_id)
88 changes: 88 additions & 0 deletions tests/pymod_ymq/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest
from enum import IntEnum
from scaler.io.ymq import ymq
import array


class TestTypes(unittest.TestCase):
def test_exception(self):
# type checkers misidentify this as "unnecessary" due to the type hints file
self.assertTrue(issubclass(ymq.YMQException, Exception)) # type: ignore

exc = ymq.YMQException(ymq.ErrorCode.CoreBug, "oh no")
self.assertEqual(exc.args, (ymq.ErrorCode.CoreBug, "oh no"))
self.assertEqual(exc.code, ymq.ErrorCode.CoreBug)
self.assertEqual(exc.message, "oh no")

def test_error_code(self):
self.assertTrue(issubclass(ymq.ErrorCode, IntEnum)) # type: ignore
self.assertEqual(
ymq.ErrorCode.ConfigurationError.explanation(),
"An error generated by system call that's likely due to mis-configuration",
)

def test_bytes(self):
b = ymq.Bytes(b"data")
self.assertEqual(b.len, len(b))
self.assertEqual(b.len, 4)
self.assertEqual(b.data, b"data")

# would raise an exception if ymq.Bytes didn't support the buffer interface
m = memoryview(b)
self.assertTrue(m.obj is b)
self.assertEqual(m.tobytes(), b"data")

b = ymq.Bytes()
self.assertEqual(b.len, 0)
self.assertTrue(b.data is None)

b = ymq.Bytes(b"")
self.assertEqual(b.len, 0)
self.assertEqual(b.data, b"")

b = ymq.Bytes(array.array("B", [115, 99, 97, 108, 101, 114]))
assert b.len == 6
assert b.data == b"scaler"

def test_message(self):
m = ymq.Message(b"address", b"payload")
assert m.address is not None
self.assertEqual(m.address.data, b"address")
self.assertEqual(m.payload.data, b"payload")

m = ymq.Message(address=None, payload=ymq.Bytes(b"scaler"))
self.assertTrue(m.address is None)
self.assertEqual(m.payload.data, b"scaler")

m = ymq.Message(b"address", payload=b"payload")
assert m.address is not None
self.assertEqual(m.address.data, b"address")
self.assertEqual(m.payload.data, b"payload")

def test_io_context(self):
ctx = ymq.IOContext()
self.assertEqual(ctx.num_threads, 1)

ctx = ymq.IOContext(2)
self.assertEqual(ctx.num_threads, 2)

ctx = ymq.IOContext(num_threads=3)
self.assertEqual(ctx.num_threads, 3)

# TODO: backporting to 3.8 broke this somehow
# it causes a segmentation fault
# re-enable once fixed
@unittest.skip("causes segmentation fault")
def test_io_socket(self):
# check that we can't create io socket instances directly
self.assertRaises(TypeError, lambda: ymq.IOSocket()) # type: ignore

def test_io_socket_type(self):
self.assertTrue(issubclass(ymq.IOSocketType, IntEnum)) # type: ignore

def test_bad_socket_type(self):
ctx = ymq.IOContext()

# TODO: should the core reject this?
socket = ctx.createIOSocket_sync("identity", ymq.IOSocketType.Uninit)
self.assertEqual(socket.socket_type, ymq.IOSocketType.Uninit)
Loading