-
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
49d2035
commit 4a3ee96
Showing
1 changed file
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from __future__ import annotations | ||
|
||
import ctypes | ||
import mmap | ||
import os | ||
import re | ||
import secrets | ||
from collections.abc import Callable, Iterator, KeysView | ||
from urllib.parse import quote, unquote | ||
|
||
from zict.common import ZictBase | ||
|
||
|
||
class SharedMemory(ZictBase[str, bytes]): | ||
"""Mutable Mapping interface to a directory | ||
Keys must be strings, values must be buffers. | ||
Keys are cached in memory and you shouldn't share the directory with other File | ||
objects. However, see :meth:`link` for inter-process comunication. | ||
Parameters | ||
---------- | ||
directory: str | ||
Directory to write to. If it already exists, existing files will be imported as | ||
mapping elements. If it doesn't exists, it will be created. | ||
memmap: bool (optional) | ||
If True, use `mmap` for reading. Defaults to False. | ||
Examples | ||
-------- | ||
>>> z = File('myfile') # doctest: +SKIP | ||
>>> z['x'] = b'123' # doctest: +SKIP | ||
>>> z['x'] # doctest: +SKIP | ||
b'123' | ||
Also supports writing lists of bytes objects | ||
>>> z['y'] = [b'123', b'4567'] # doctest: +SKIP | ||
>>> z['y'] # doctest: +SKIP | ||
b'1234567' | ||
Or anything that can be used with file.write, like a memoryview | ||
>>> z['data'] = np.ones(5).data # doctest: +SKIP | ||
""" | ||
|
||
_data: dict[str, tuple[int, str]] # {key: (fd, unique token)} | ||
_memfd_create: Callable[[bytes, int], int] | ||
|
||
def __init__(self): | ||
self._data = {} | ||
libc = ctypes.CDLL("libc.so.6") | ||
self._memfd_create = libc.memfd_create | ||
|
||
@staticmethod | ||
def _safe_key(key: str) -> tuple[str, str]: | ||
"""Escape key so as to be usable on all filesystems. | ||
Append to the filenames a unique suffix that changes every time this method is | ||
called. This prevents race conditions when another thread/process opens the | ||
files for read (see :meth:`link` below), as it guarantees that a file is either | ||
complete and coherent or it does not exist. | ||
""" | ||
token = secrets.token_bytes(4).hex() | ||
return quote(key, safe="") + "#" + token, token | ||
|
||
@staticmethod | ||
def _unsafe_key(key: str, expect_token: str) -> str: | ||
"""Undo the escaping done by _safe_key()""" | ||
key, actual_token = key.split("#") # raises ValueError on malformation | ||
if actual_token != expect_token: | ||
raise ValueError("token mismatch") | ||
return unquote(key) | ||
|
||
def __str__(self) -> str: | ||
return f"<SharedMemory: {len(self)} elements>" | ||
|
||
__repr__ = __str__ | ||
|
||
def __setitem__( | ||
self, | ||
key: str, | ||
value: bytes | ||
| bytearray | ||
| memoryview | ||
| list[bytes | bytearray | memoryview] | ||
| tuple[bytes | bytearray | memoryview, ...], | ||
) -> None: | ||
try: | ||
del self[key] | ||
except KeyError: | ||
pass | ||
|
||
safe_key, token = self._safe_key(key) | ||
fd = self._memfd_create(safe_key.encode("ascii"), 0) | ||
if fd == -1: | ||
raise OSError("Call to memfd_create failed") # pragma: nocover | ||
|
||
fh = os.fdopen(fd, "wb", closefd=False) | ||
if isinstance(value, (tuple, list)): | ||
fh.writelines(value) | ||
else: | ||
fh.write(value) | ||
fh.flush() | ||
|
||
self._data[key] = fd, token | ||
|
||
def __getitem__(self, key: str) -> memoryview: | ||
fd, _ = self._data[key] | ||
return memoryview(mmap.mmap(fd, 0)) | ||
|
||
def __delitem__(self, key: str) -> None: | ||
fd, _ = self._data.pop(key) | ||
os.close(fd) | ||
|
||
def __del__(self) -> None: | ||
for fd, _ in self._data.values(): | ||
os.close(fd) | ||
|
||
def __contains__(self, key: object) -> bool: | ||
return key in self._data | ||
|
||
def keys(self) -> KeysView[str]: | ||
return self._data.keys() | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(self._data) | ||
|
||
def __len__(self) -> int: | ||
return len(self._data) | ||
|
||
def shm_export(self, key: str) -> tuple[int, int, str]: | ||
"""TODO | ||
Returns | ||
------- | ||
- pid of the current process | ||
- fd of the shared memory | ||
- unique token (used to prevent race conditions) | ||
""" | ||
fd, token = self._data[key] | ||
return os.getpid(), fd, token | ||
|
||
def shm_import(self, pid: int, fd: int, token: str) -> str: | ||
"""TODO | ||
Hardlink an external file into self.directory. | ||
The file must be on the same filesystem as self.directory. This is an atomic | ||
operation which allows for data transfer between multiple File instances (or | ||
from an external data creator to a File instance) running on different | ||
processes, and is particularly useful in conjunction with memory mapping. | ||
Examples | ||
-------- | ||
In process 1: | ||
.. code-block:: python | ||
z1 = File("/dev/shm/z1", memmap=True) | ||
z1["x"] = b"Hello world!" | ||
send_to_proc2("x", z1.get_path("x")) | ||
In process 2: | ||
.. code-block:: python | ||
z2 = File("/dev/shm/z2", memmap=True) | ||
key, path = receive_from_proc1() | ||
try: | ||
z2.link(key, path) | ||
except FileNotFoundError: | ||
# Handle race condition: key was deleted from z1 | ||
Now z1["x"] and z2["x"] share the same memory. Updating the bytearray contents | ||
on one (``z1["x"][:] = ...``) will immediately be reflected onto the other. | ||
Setting a new value on either (``z1["x"] = ...``) will decouple them. | ||
There are now two files on disk, ``/dev/shm/z1/x#0`` and ``/dev/shm/z2/x#0``, | ||
which share the same inode. The memory is released when both z1 and z2 delete | ||
the key. | ||
.. note:: | ||
Filenames change every time you set a new value for a key; this prevents a | ||
race condition when z1 is in the process of replacing ``x`` with an entirely | ||
new value while z2 acquires it. | ||
You may also use link() to create aliases to its own data. | ||
This reads x back into memory and then writes a deep copy of it into y:: | ||
>>> z["y"] = z["x"] # doctest: +SKIP | ||
This creates a second, shallow reference to x and is the same as writing | ||
``z["z"] = z["x"]`` on a regular in-memory dict:: | ||
>>> z.link("z", z.get_path("x")) # doctest: +SKIP | ||
""" | ||
fail = FileNotFoundError("Peer process no longer holds the key") | ||
try: | ||
fd = os.open(f"/proc/{pid}/{fd}", os.O_RDWR) | ||
except OSError: | ||
raise fail | ||
|
||
target = os.readlink(f"/proc/{os.getpid()}/{fd}") | ||
m = re.match(r"/memfd:(.+) \(deleted\)$", target) | ||
if not m: | ||
os.close(fd) | ||
raise fail | ||
|
||
try: | ||
key = self._unsafe_key(m.group(1), token) | ||
except ValueError: | ||
os.close(fd) | ||
raise fail | ||
|
||
self._data[key] = fd, token | ||
return key |