Skip to content

Commit

Permalink
WIP: shared memory without tmpfs
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 7, 2023
1 parent 49d2035 commit 4a3ee96
Showing 1 changed file with 216 additions and 0 deletions.
216 changes: 216 additions & 0 deletions zict/shared_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from __future__ import annotations

import ctypes
import mmap
import os
import re
import secrets
from collections.abc import Callable, Iterator, KeysView
from urllib.parse import quote, unquote

from zict.common import ZictBase


class SharedMemory(ZictBase[str, bytes]):
"""Mutable Mapping interface to a directory
Keys must be strings, values must be buffers.
Keys are cached in memory and you shouldn't share the directory with other File
objects. However, see :meth:`link` for inter-process comunication.
Parameters
----------
directory: str
Directory to write to. If it already exists, existing files will be imported as
mapping elements. If it doesn't exists, it will be created.
memmap: bool (optional)
If True, use `mmap` for reading. Defaults to False.
Examples
--------
>>> z = File('myfile') # doctest: +SKIP
>>> z['x'] = b'123' # doctest: +SKIP
>>> z['x'] # doctest: +SKIP
b'123'
Also supports writing lists of bytes objects
>>> z['y'] = [b'123', b'4567'] # doctest: +SKIP
>>> z['y'] # doctest: +SKIP
b'1234567'
Or anything that can be used with file.write, like a memoryview
>>> z['data'] = np.ones(5).data # doctest: +SKIP
"""

_data: dict[str, tuple[int, str]] # {key: (fd, unique token)}
_memfd_create: Callable[[bytes, int], int]

def __init__(self):
self._data = {}
libc = ctypes.CDLL("libc.so.6")
self._memfd_create = libc.memfd_create

@staticmethod
def _safe_key(key: str) -> tuple[str, str]:
"""Escape key so as to be usable on all filesystems.
Append to the filenames a unique suffix that changes every time this method is
called. This prevents race conditions when another thread/process opens the
files for read (see :meth:`link` below), as it guarantees that a file is either
complete and coherent or it does not exist.
"""
token = secrets.token_bytes(4).hex()
return quote(key, safe="") + "#" + token, token

@staticmethod
def _unsafe_key(key: str, expect_token: str) -> str:
"""Undo the escaping done by _safe_key()"""
key, actual_token = key.split("#") # raises ValueError on malformation
if actual_token != expect_token:
raise ValueError("token mismatch")
return unquote(key)

def __str__(self) -> str:
return f"<SharedMemory: {len(self)} elements>"

__repr__ = __str__

def __setitem__(
self,
key: str,
value: bytes
| bytearray
| memoryview
| list[bytes | bytearray | memoryview]
| tuple[bytes | bytearray | memoryview, ...],
) -> None:
try:
del self[key]
except KeyError:
pass

safe_key, token = self._safe_key(key)
fd = self._memfd_create(safe_key.encode("ascii"), 0)
if fd == -1:
raise OSError("Call to memfd_create failed") # pragma: nocover

fh = os.fdopen(fd, "wb", closefd=False)
if isinstance(value, (tuple, list)):
fh.writelines(value)
else:
fh.write(value)
fh.flush()

self._data[key] = fd, token

def __getitem__(self, key: str) -> memoryview:
fd, _ = self._data[key]
return memoryview(mmap.mmap(fd, 0))

def __delitem__(self, key: str) -> None:
fd, _ = self._data.pop(key)
os.close(fd)

def __del__(self) -> None:
for fd, _ in self._data.values():
os.close(fd)

def __contains__(self, key: object) -> bool:
return key in self._data

def keys(self) -> KeysView[str]:
return self._data.keys()

def __iter__(self) -> Iterator[str]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

def shm_export(self, key: str) -> tuple[int, int, str]:
"""TODO
Returns
-------
- pid of the current process
- fd of the shared memory
- unique token (used to prevent race conditions)
"""
fd, token = self._data[key]
return os.getpid(), fd, token

def shm_import(self, pid: int, fd: int, token: str) -> str:
"""TODO
Hardlink an external file into self.directory.
The file must be on the same filesystem as self.directory. This is an atomic
operation which allows for data transfer between multiple File instances (or
from an external data creator to a File instance) running on different
processes, and is particularly useful in conjunction with memory mapping.
Examples
--------
In process 1:
.. code-block:: python
z1 = File("/dev/shm/z1", memmap=True)
z1["x"] = b"Hello world!"
send_to_proc2("x", z1.get_path("x"))
In process 2:
.. code-block:: python
z2 = File("/dev/shm/z2", memmap=True)
key, path = receive_from_proc1()
try:
z2.link(key, path)
except FileNotFoundError:
# Handle race condition: key was deleted from z1
Now z1["x"] and z2["x"] share the same memory. Updating the bytearray contents
on one (``z1["x"][:] = ...``) will immediately be reflected onto the other.
Setting a new value on either (``z1["x"] = ...``) will decouple them.
There are now two files on disk, ``/dev/shm/z1/x#0`` and ``/dev/shm/z2/x#0``,
which share the same inode. The memory is released when both z1 and z2 delete
the key.
.. note::
Filenames change every time you set a new value for a key; this prevents a
race condition when z1 is in the process of replacing ``x`` with an entirely
new value while z2 acquires it.
You may also use link() to create aliases to its own data.
This reads x back into memory and then writes a deep copy of it into y::
>>> z["y"] = z["x"] # doctest: +SKIP
This creates a second, shallow reference to x and is the same as writing
``z["z"] = z["x"]`` on a regular in-memory dict::
>>> z.link("z", z.get_path("x")) # doctest: +SKIP
"""
fail = FileNotFoundError("Peer process no longer holds the key")
try:
fd = os.open(f"/proc/{pid}/{fd}", os.O_RDWR)
except OSError:
raise fail

target = os.readlink(f"/proc/{os.getpid()}/{fd}")
m = re.match(r"/memfd:(.+) \(deleted\)$", target)
if not m:
os.close(fd)
raise fail

try:
key = self._unsafe_key(m.group(1), token)
except ValueError:
os.close(fd)
raise fail

self._data[key] = fd, token
return key

0 comments on commit 4a3ee96

Please sign in to comment.