Skip to content

Commit

Permalink
Shared memory IPC
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 30, 2022
1 parent 4240a57 commit 08420ae
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 45 deletions.
150 changes: 110 additions & 40 deletions zict/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,18 @@
import mmap
import os
import pathlib
from collections.abc import Iterator
from collections.abc import Iterator, KeysView
from urllib.parse import quote, unquote

from zict.common import ZictBase


def _safe_key(key: str) -> str:
"""
Escape key so as to be usable on all filesystems.
"""
# Even directory separators are unsafe.
return quote(key, safe="")


def _unsafe_key(key: str) -> str:
"""
Undo the escaping done by _safe_key().
"""
return unquote(key)


class File(ZictBase[str, bytes]):
"""Mutable Mapping interface to a directory
Keys must be strings, values must be buffers
Note this shouldn't be used for interprocess persistence, as keys
are cached in memory.
Keys must be strings, values must be buffers.
Keys are cached in memory and you shouldn't share the directory with other File
objects. However, see :meth:`link` for inter-process comunication.
Parameters
----------
Expand Down Expand Up @@ -60,27 +44,48 @@ class File(ZictBase[str, bytes]):

directory: str
memmap: bool
_keys: set[str]
filenames: dict[str, str]
_inc: int

def __init__(self, directory: str | pathlib.Path, memmap: bool = False):
self.directory = str(directory)
self.memmap = memmap
self._keys = set()
self.filenames = {}
self._inc = 0

if not os.path.exists(self.directory):
os.makedirs(self.directory, exist_ok=True)
else:
for n in os.listdir(self.directory):
self._keys.add(_unsafe_key(n))
for fn in os.listdir(self.directory):
self.filenames[self._unsafe_key(fn)] = fn
self._inc += 1

def _safe_key(self, key: str) -> str:
"""Escape key so as to be usable on all filesystems.
Append to the filenames a unique suffix that changes every time this method is
called. This prevents race conditions when another thread/process opens the
files for read (see :meth:`link` below), as it guarantees that a file is either
complete and coherent or it does not exist.
"""
# `#` is escaped by quote and is supported by most file systems
key = quote(key, safe="") + f"#{self._inc}"
self._inc += 1
return key

@staticmethod
def _unsafe_key(key: str) -> str:
"""Undo the escaping done by _safe_key()"""
key = key.split("#")[0]
return unquote(key)

def __str__(self) -> str:
return f"<File: {self.directory}, {len(self)} elements>"

__repr__ = __str__

def __getitem__(self, key: str) -> bytearray | memoryview:
if key not in self._keys:
raise KeyError(key)
fn = os.path.join(self.directory, _safe_key(key))
fn = os.path.join(self.directory, self.filenames[key])

# distributed.protocol.numpy.deserialize_numpy_ndarray makes sure that, if the
# numpy array was writeable before serialization, remains writeable afterwards.
Expand Down Expand Up @@ -108,29 +113,94 @@ def __setitem__(
| list[bytes | bytearray | memoryview]
| tuple[bytes | bytearray | memoryview, ...],
) -> None:
fn = os.path.join(self.directory, _safe_key(key))
with open(fn, "wb") as fh:
try:
del self[key]
except KeyError:
pass

fn = self._safe_key(key)
with open(os.path.join(self.directory, fn), "wb") as fh:
if isinstance(value, (tuple, list)):
fh.writelines(value)
else:
fh.write(value)
self._keys.add(key)
self.filenames[key] = fn

def __contains__(self, key: object) -> bool:
return key in self._keys
return key in self.filenames

# FIXME dictionary views https://github.com/dask/zict/issues/61
def keys(self) -> set[str]: # type: ignore
return self._keys
def keys(self) -> KeysView[str]:
return self.filenames.keys()

def __iter__(self) -> Iterator[str]:
return iter(self._keys)
return iter(self.filenames)

def __delitem__(self, key: str) -> None:
if key not in self._keys:
raise KeyError(key)
os.remove(os.path.join(self.directory, _safe_key(key)))
self._keys.remove(key)
fn = self.filenames.pop(key)
os.remove(os.path.join(self.directory, fn))

def __len__(self) -> int:
return len(self._keys)
return len(self.filenames)

def get_path(self, key: str) -> str:
return os.path.join(self.directory, self.filenames[key])

def link(self, key: str, path: str) -> None:
"""Hardlink an external file into self.directory.
The file must be on the same filesystem as self.directory. This is an atomic
operation which allows for data transfer between multiple File instances (or
from an external data creator to a File instance) running on different
processes, and is particularly useful in conjunction with memory mapping.
Examples
--------
In process 1:
.. code-block:: python
z1 = File("/dev/shm/z1", memmap=True)
z1["x"] = b"Hello world!"
send_to_proc2("x", z1.get_path("x"))
In process 2:
.. code-block:: python
z2 = File("/dev/shm/z2", memmap=True)
key, path = receive_from_proc1()
try:
z2.link(key, path)
except FileNotFoundError:
# Handle race condition: key was deleted from z1
Now z1["x"] and z2["x"] share the same memory. Updating the bytearray contents
on one (``z1["x"][:] = ...``) will immediately be reflected onto the other.
Setting a new value on either (``z1["x"] = ...``) will decouple them.
There are now two files on disk, ``/dev/shm/z1/x#0`` and ``/dev/shm/z2/x#0``,
which share the same inode. The memory is released when both z1 and z2 delete
the key.
.. note::
Filenames change every time you set a new value for a key; this prevents a
race condition when z1 is in the process of replacing ``x`` with an entirely
new value while z2 acquires it.
You may also use link() to create aliases to its own data.
This reads x back into memory and then writes a deep copy of it into y::
>>> z["y"] = z["x"] # doctest: +SKIP
This creates a second, shallow reference to x and is the same as writing
``z["z"] = z["x"]`` on a regular in-memory dict::
>>> z.link("z", z.get_path("x")) # doctest: +SKIP
"""
try:
del self[key]
except KeyError:
pass

fn = self._safe_key(key)
os.link(path, os.path.join(self.directory, fn))
self.filenames[key] = fn
50 changes: 45 additions & 5 deletions zict/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_implementation(tmpdir, dirtype):
assert not z

z["x"] = b"123"
assert os.listdir(tmpdir) == ["x"]
with open(tmpdir / "x", "rb") as f:
assert os.listdir(tmpdir) == ["x#0"]
with open(tmpdir / "x#0", "rb") as f:
assert f.read() == b"123"

assert "x" in z
Expand All @@ -38,7 +38,7 @@ def test_memmap_implementation(tmpdir):
mv = memoryview(b"123")
assert "x" not in z
z["x"] = mv
assert os.listdir(tmpdir) == ["x"]
assert os.listdir(tmpdir) == ["x#0"]
assert "x" in z
mv2 = z["x"]
assert mv2 == b"123"
Expand All @@ -62,17 +62,23 @@ def test_contextmanager(tmpdir):
with File(tmpdir) as z:
z["x"] = b"123"

with open(tmpdir / "x", "rb") as fh:
with open(tmpdir / "x#0", "rb") as fh:
assert fh.read() == b"123"


def test_delitem(tmpdir):
z = File(tmpdir)

z["x"] = b"123"
assert os.listdir(tmpdir) == ["x"]
assert os.listdir(tmpdir) == ["x#0"]
del z["x"]
assert os.listdir(tmpdir) == []
# File name is never repeated
z["x"] = b"123"
assert os.listdir(tmpdir) == ["x#1"]
# __setitem__ deletes the previous file
z["x"] = b"123"
assert os.listdir(tmpdir) == ["x#2"]


def test_missing_key(tmpdir):
Expand Down Expand Up @@ -116,3 +122,37 @@ def test_write_list_of_bytes(tmpdir):

z["x"] = [b"123", b"4567"]
assert z["x"] == b"1234567"


@pytest.mark.parametrize("memmap", [False, True])
def test_link(tmpdir, memmap):
z1 = File(tmpdir / "a", memmap=memmap)
z2 = File(tmpdir / "b", memmap=memmap)
z1["x"] = b"123"

z1.link("y", z1.get_path("x"))
z2.link("x", z1.get_path("x"))
assert z1["x"] == b"123"
assert z1["y"] == b"123"
assert z2["x"] == b"123"
assert sorted(os.listdir(tmpdir / "a")) == ["x#0", "y#1"]
assert os.listdir(tmpdir / "b") == ["x#0"]

if not memmap:
return

z1["x"].cast("c")[0] = b"4"
assert z1["y"] == b"423"
assert z2["x"] == b"423"

z1["x"] = b"567"
assert z1["x"] == b"567"
assert z1["y"] == b"423"
assert z2["x"] == b"423"
assert sorted(os.listdir(tmpdir / "a")) == ["x#2", "y#1"]
assert os.listdir(tmpdir / "b") == ["x#0"]

del z1["y"]
assert z2["x"] == b"423"
assert os.listdir(tmpdir / "a") == ["x#2"]
assert os.listdir(tmpdir / "b") == ["x#0"]

0 comments on commit 08420ae

Please sign in to comment.