|
28 | 28 | ###
|
29 | 29 |
|
30 | 30 | import json
|
31 |
| -from typing import Any, Dict |
| 31 | +import re |
| 32 | +from math import ceil, prod |
| 33 | +from multiprocessing import resource_tracker, shared_memory |
| 34 | +from typing import Any, Dict, Sequence, Union |
32 | 35 |
|
33 | 36 | Args = Dict[str, Any]
|
34 | 37 |
|
35 | 38 |
|
| 39 | +class SharedMemory(shared_memory.SharedMemory): |
| 40 | + """ |
| 41 | + An enhanced version of Python's multiprocessing.shared_memory.SharedMemory |
| 42 | + class which can be used with a `with` statement. When the program flow |
| 43 | + exits the `with` block, this class's `dispose()` method will be invoked, |
| 44 | + which might call `close()` or `unlink()` depending on the value of its |
| 45 | + `unlink_on_dispose` flag. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__(self, name: str = None, create: bool = False, size: int = 0): |
| 49 | + super().__init__(name=name, create=create, size=size) |
| 50 | + self._unlink_on_dispose = create |
| 51 | + if _is_worker: |
| 52 | + # HACK: Remove this shared memory block from the resource_tracker, |
| 53 | + # which wants to clean up shared memory blocks after all known |
| 54 | + # references are done using them. |
| 55 | + # |
| 56 | + # There is one resource_tracker per Python process, and they will |
| 57 | + # each try to delete shared memory blocks known to them when they |
| 58 | + # are shutting down, even when other processes still need them. |
| 59 | + # |
| 60 | + # As such, the rule Appose follows is: let the service process |
| 61 | + # always handle cleanup of shared memory blocks, regardless of |
| 62 | + # which process initially allocated it. |
| 63 | + resource_tracker.unregister(self._name, "shared_memory") |
| 64 | + |
| 65 | + def unlink_on_dispose(self, value: bool) -> None: |
| 66 | + """ |
| 67 | + Set whether the `unlink()` method should be invoked to destroy |
| 68 | + the shared memory block when the `dispose()` method is called. |
| 69 | +
|
| 70 | + Note: dispose() is the method called when exiting a `with` block. |
| 71 | +
|
| 72 | + By default, shared memory objects constructed with `create=True` |
| 73 | + will behave this way, whereas shared memory objects constructed |
| 74 | + with `create=False` will not. But this method allows to override |
| 75 | + the behavior. |
| 76 | + """ |
| 77 | + self._unlink_on_dispose = value |
| 78 | + |
| 79 | + def dispose(self) -> None: |
| 80 | + if self._unlink_on_dispose: |
| 81 | + self.unlink() |
| 82 | + else: |
| 83 | + self.close() |
| 84 | + |
| 85 | + def __enter__(self) -> "SharedMemory": |
| 86 | + return self |
| 87 | + |
| 88 | + def __exit__(self, exc_type, exc_value, exc_tb) -> None: |
| 89 | + self.dispose() |
| 90 | + |
| 91 | + |
36 | 92 | def encode(data: Args) -> str:
|
37 |
| - return json.dumps(data) |
| 93 | + return json.dumps(data, cls=_ApposeJSONEncoder, separators=(",", ":")) |
38 | 94 |
|
39 | 95 |
|
40 | 96 | def decode(the_json: str) -> Args:
|
41 |
| - return json.loads(the_json) |
| 97 | + return json.loads(the_json, object_hook=_appose_object_hook) |
| 98 | + |
| 99 | + |
| 100 | +class NDArray: |
| 101 | + """ |
| 102 | + Data structure for a multi-dimensional array. |
| 103 | + The array contains elements of a data type, arranged in |
| 104 | + a particular shape, and flattened into SharedMemory. |
| 105 | + """ |
| 106 | + |
| 107 | + def __init__(self, dtype: str, shape: Sequence[int], shm: SharedMemory = None): |
| 108 | + """ |
| 109 | + Create an NDArray. |
| 110 | + :param dtype: The type of the data elements; e.g. int8, uint8, float32, float64. |
| 111 | + :param shape: The dimensional extents; e.g. a stack of 7 image planes |
| 112 | + with resolution 512x512 would have shape [7, 512, 512]. |
| 113 | + :param shm: The SharedMemory containing the array data, or None to create it. |
| 114 | + """ |
| 115 | + self.dtype = dtype |
| 116 | + self.shape = shape |
| 117 | + self.shm = ( |
| 118 | + SharedMemory( |
| 119 | + create=True, size=ceil(prod(shape) * _bytes_per_element(dtype)) |
| 120 | + ) |
| 121 | + if shm is None |
| 122 | + else shm |
| 123 | + ) |
| 124 | + |
| 125 | + def __str__(self): |
| 126 | + return ( |
| 127 | + f"NDArray(" |
| 128 | + f"dtype='{self.dtype}', " |
| 129 | + f"shape={self.shape}, " |
| 130 | + f"shm='{self.shm.name}' ({self.shm.size}))" |
| 131 | + ) |
| 132 | + |
| 133 | + def ndarray(self): |
| 134 | + """ |
| 135 | + Create a NumPy ndarray object for working with the array data. |
| 136 | + No array data is copied; the NumPy array wraps the same SharedMemory. |
| 137 | + Requires the numpy package to be installed. |
| 138 | + """ |
| 139 | + try: |
| 140 | + import numpy |
| 141 | + |
| 142 | + return numpy.ndarray( |
| 143 | + prod(self.shape), dtype=self.dtype, buffer=self.shm.buf |
| 144 | + ).reshape(self.shape) |
| 145 | + except ModuleNotFoundError: |
| 146 | + raise ImportError("NumPy is not available.") |
| 147 | + |
| 148 | + def __enter__(self) -> "NDArray": |
| 149 | + return self |
| 150 | + |
| 151 | + def __exit__(self, exc_type, exc_value, exc_tb) -> None: |
| 152 | + self.shm.dispose() |
| 153 | + |
| 154 | + |
| 155 | +class _ApposeJSONEncoder(json.JSONEncoder): |
| 156 | + def default(self, obj): |
| 157 | + if isinstance(obj, SharedMemory): |
| 158 | + return { |
| 159 | + "appose_type": "shm", |
| 160 | + "name": obj.name, |
| 161 | + "size": obj.size, |
| 162 | + } |
| 163 | + if isinstance(obj, NDArray): |
| 164 | + return { |
| 165 | + "appose_type": "ndarray", |
| 166 | + "dtype": obj.dtype, |
| 167 | + "shape": obj.shape, |
| 168 | + "shm": obj.shm, |
| 169 | + } |
| 170 | + return super().default(obj) |
| 171 | + |
| 172 | + |
| 173 | +def _appose_object_hook(obj: Dict): |
| 174 | + atype = obj.get("appose_type") |
| 175 | + if atype == "shm": |
| 176 | + # Attach to existing shared memory block. |
| 177 | + return SharedMemory(name=(obj["name"]), size=(obj["size"])) |
| 178 | + elif atype == "ndarray": |
| 179 | + return NDArray(obj["dtype"], obj["shape"], obj["shm"]) |
| 180 | + else: |
| 181 | + return obj |
| 182 | + |
| 183 | + |
| 184 | +def _bytes_per_element(dtype: str) -> Union[int, float]: |
| 185 | + try: |
| 186 | + bits = int(re.sub("[^0-9]", "", dtype)) |
| 187 | + except ValueError: |
| 188 | + raise ValueError(f"Invalid dtype: {dtype}") |
| 189 | + return bits / 8 |
| 190 | + |
| 191 | + |
| 192 | +_is_worker = False |
| 193 | + |
| 194 | + |
| 195 | +def _set_worker(value: bool) -> None: |
| 196 | + global _is_worker |
| 197 | + _is_worker = value |
0 commit comments