diff --git a/src/appose/types.py b/src/appose/types.py index dbd877e..363ea7e 100644 --- a/src/appose/types.py +++ b/src/appose/types.py @@ -30,13 +30,65 @@ import json import re from math import ceil, prod -from multiprocessing import resource_tracker -from multiprocessing.shared_memory import SharedMemory +from multiprocessing import resource_tracker, shared_memory from typing import Any, Dict, Sequence, Union Args = Dict[str, Any] +class SharedMemory(shared_memory.SharedMemory): + """ + An enhanced version of Python's multiprocessing.shared_memory.SharedMemory + class which can be used with a `with` statement. When the program flow + exits the `with` block, this class's `dispose()` method will be invoked, + which might call `close()` or `unlink()` depending on the value of its + `unlink_on_dispose` flag. + """ + + def __init__(self, name: str = None, create: bool = False, size: int = 0): + super().__init__(name=name, create=create, size=size) + self._unlink_on_dispose = create + if _is_worker: + # HACK: Remove this shared memory block from the resource_tracker, + # which wants to clean up shared memory blocks after all known + # references are done using them. + # + # There is one resource_tracker per Python process, and they will + # each try to delete shared memory blocks known to them when they + # are shutting down, even when other processes still need them. + # + # As such, the rule Appose follows is: let the service process + # always handle cleanup of shared memory blocks, regardless of + # which process initially allocated it. + resource_tracker.unregister(self._name, "shared_memory") + + def unlink_on_dispose(self, value: bool) -> None: + """ + Set whether the `unlink()` method should be invoked to destroy + the shared memory block when the `dispose()` method is called. + + Note: dispose() is the method called when exiting a `with` block. + + By default, shared memory objects constructed with `create=True` + will behave this way, whereas shared memory objects constructed + with `create=False` will not. But this method allows to override + the behavior. + """ + self._unlink_on_dispose = value + + def dispose(self) -> None: + if self._unlink_on_dispose: + self.unlink() + else: + self.close() + + def __enter__(self) -> "SharedMemory": + return self + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self.dispose() + + def encode(data: Args) -> str: return json.dumps(data, cls=_ApposeJSONEncoder, separators=(",", ":")) @@ -63,7 +115,9 @@ def __init__(self, dtype: str, shape: Sequence[int], shm: SharedMemory = None): self.dtype = dtype self.shape = shape self.shm = ( - _create_shm(create=True, size=ceil(prod(shape) * _bytes_per_element(dtype))) + SharedMemory( + create=True, size=ceil(prod(shape) * _bytes_per_element(dtype)) + ) if shm is None else shm ) @@ -91,6 +145,12 @@ def ndarray(self): except ModuleNotFoundError: raise ImportError("NumPy is not available.") + def __enter__(self) -> "NDArray": + return self + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self.shm.dispose() + class _ApposeJSONEncoder(json.JSONEncoder): def default(self, obj): @@ -114,7 +174,7 @@ def _appose_object_hook(obj: Dict): atype = obj.get("appose_type") if atype == "shm": # Attach to existing shared memory block. - return _create_shm(name=(obj["name"]), size=(obj["size"])) + return SharedMemory(name=(obj["name"]), size=(obj["size"])) elif atype == "ndarray": return NDArray(obj["dtype"], obj["shape"], obj["shm"]) else: @@ -129,23 +189,6 @@ def _bytes_per_element(dtype: str) -> Union[int, float]: return bits / 8 -def _create_shm(name: str = None, create: bool = False, size: int = 0): - shm = SharedMemory(name=name, create=create, size=size) - if _is_worker: - # HACK: Disable this process's resource_tracker, which wants to clean up - # shared memory blocks after all known references are done using them. - # - # There is one resource_tracker per Python process, and they will each - # try to delete shared memory blocks known to them when they are - # shutting down, even when other processes still need them. - # - # As such, the rule Appose follows is: let the service process always - # do the cleanup of shared memory blocks, regardless of which process - # initially allocated it. - resource_tracker.unregister(shm._name, "shared_memory") - return shm - - _is_worker = False diff --git a/tests/test_shm.py b/tests/test_shm.py index 965e398..68dcf8b 100644 --- a/tests/test_shm.py +++ b/tests/test_shm.py @@ -41,23 +41,20 @@ def test_ndarray(): env = appose.system() with env.python() as service: - # Construct the data. - shm = appose.SharedMemory(create=True, size=2 * 2 * 20 * 25) - shm.buf[0] = 123 - shm.buf[456] = 78 - shm.buf[1999] = 210 - data = appose.NDArray("uint16", [2, 20, 25], shm) + with appose.SharedMemory(create=True, size=2 * 2 * 20 * 25) as shm: + # Construct the data. + shm.buf[0] = 123 + shm.buf[456] = 78 + shm.buf[1999] = 210 + data = appose.NDArray("uint16", [2, 20, 25], shm) - # Run the task. - task = service.task(ndarray_inspect, {"data": data}) - task.wait_for() + # Run the task. + task = service.task(ndarray_inspect, {"data": data}) + task.wait_for() - # Validate the execution result. - assert TaskStatus.COMPLETE == task.status - assert 2 * 20 * 25 * 2 == task.outputs["size"] - assert "uint16" == task.outputs["dtype"] - assert [2, 20, 25] == task.outputs["shape"] - assert 123 + 78 + 210 == task.outputs["sum"] - - # Clean up. - shm.unlink() + # Validate the execution result. + assert TaskStatus.COMPLETE == task.status + assert 2 * 20 * 25 * 2 == task.outputs["size"] + assert "uint16" == task.outputs["dtype"] + assert [2, 20, 25] == task.outputs["shape"] + assert 123 + 78 + 210 == task.outputs["sum"] diff --git a/tests/test_types.py b/tests/test_types.py index 445c749..dc38d23 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -65,40 +65,38 @@ def test_encode(self): "numbers": self.NUMBERS, "words": self.WORDS, } - ndarray = appose.NDArray("float32", [2, 20, 25]) - shm_name = ndarray.shm.name - data["ndArray"] = ndarray - json_str = appose.types.encode(data) - self.assertIsNotNone(json_str) - expected = self.JSON.replace("SHM_NAME", shm_name) - self.assertEqual(expected, json_str) - ndarray.shm.unlink() + with appose.NDArray("float32", [2, 20, 25]) as ndarray: + shm_name = ndarray.shm.name + data["ndArray"] = ndarray + json_str = appose.types.encode(data) + self.assertIsNotNone(json_str) + expected = self.JSON.replace("SHM_NAME", shm_name) + self.assertEqual(expected, json_str) def test_decode(self): - shm = appose.SharedMemory(create=True, size=4000) - shm_name = shm.name - data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name)) - self.assertIsNotNone(data) - self.assertEqual(19, len(data)) - self.assertEqual(123, data["posByte"]) - self.assertEqual(-98, data["negByte"]) - self.assertEqual(9.876543210123456, data["posDouble"]) - self.assertEqual(-1.234567890987654e302, data["negDouble"]) - self.assertEqual(9.876543, data["posFloat"]) - self.assertEqual(-1.2345678, data["negFloat"]) - self.assertEqual(1234567890, data["posInt"]) - self.assertEqual(-987654321, data["negInt"]) - self.assertEqual(12345678987654321, data["posLong"]) - self.assertEqual(-98765432123456789, data["negLong"]) - self.assertEqual(32109, data["posShort"]) - self.assertEqual(-23456, data["negShort"]) - self.assertTrue(data["trueBoolean"]) - self.assertFalse(data["falseBoolean"]) - self.assertEqual("\0", data["nullChar"]) - self.assertEqual(self.STRING, data["aString"]) - self.assertEqual(self.NUMBERS, data["numbers"]) - self.assertEqual(self.WORDS, data["words"]) - ndArray = data["ndArray"] - self.assertEqual("float32", ndArray.dtype) - self.assertEqual([2, 20, 25], ndArray.shape) - shm.unlink() + with appose.SharedMemory(create=True, size=4000) as shm: + shm_name = shm.name + data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name)) + self.assertIsNotNone(data) + self.assertEqual(19, len(data)) + self.assertEqual(123, data["posByte"]) + self.assertEqual(-98, data["negByte"]) + self.assertEqual(9.876543210123456, data["posDouble"]) + self.assertEqual(-1.234567890987654e302, data["negDouble"]) + self.assertEqual(9.876543, data["posFloat"]) + self.assertEqual(-1.2345678, data["negFloat"]) + self.assertEqual(1234567890, data["posInt"]) + self.assertEqual(-987654321, data["negInt"]) + self.assertEqual(12345678987654321, data["posLong"]) + self.assertEqual(-98765432123456789, data["negLong"]) + self.assertEqual(32109, data["posShort"]) + self.assertEqual(-23456, data["negShort"]) + self.assertTrue(data["trueBoolean"]) + self.assertFalse(data["falseBoolean"]) + self.assertEqual("\0", data["nullChar"]) + self.assertEqual(self.STRING, data["aString"]) + self.assertEqual(self.NUMBERS, data["numbers"]) + self.assertEqual(self.WORDS, data["words"]) + ndArray = data["ndArray"] + self.assertEqual("float32", ndArray.dtype) + self.assertEqual([2, 20, 25], ndArray.shape)