From c223213c95968024e504f9c4fd990c2b96012de4 Mon Sep 17 00:00:00 2001 From: Paulius Sarka Date: Fri, 1 Sep 2023 22:50:45 +0200 Subject: [PATCH] Add aquire_read_lock et. al. methods to ReaderWriterLock --- CHANGELOG.md | 3 + fasteners/lock.py | 141 +++++++++++++++++++++++++--------- tests/test_lock.py | 186 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 269 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64641c4..bdfa44f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] + - Add `.acquire_read_lock`, `.release_read_lock`, `.acquire_write_lock`, and + `.release_write_lock` methods to the inter thread `ReaderWriterLock` as was + promised in the README. ## [0.18] - Reshuffle the process lock code and properly document it. diff --git a/fasteners/lock.py b/fasteners/lock.py index 6c872d4..fdfbee5 100644 --- a/fasteners/lock.py +++ b/fasteners/lock.py @@ -44,6 +44,7 @@ def __init__(self, threads are not properly identified by threading.current_thread """ self._writer = None + self._writer_entries = 0 self._pending_writers = collections.deque() self._readers = {} self._cond = condition_cls() @@ -99,9 +100,8 @@ def owner(self) -> Optional[str]: return self.READER return None - @contextlib.contextmanager - def read_lock(self): - """Context manager that grants a read lock. + def acquire_read_lock(self): + """Acquire a read lock. Will wait until no active or pending writers. @@ -109,6 +109,18 @@ def read_lock(self): RuntimeError: if a pending writer tries to acquire a read lock. """ me = self._current_thread() + self._acquire_read_lock(me) + + def release_read_lock(self): + """Release a read lock. + + Raises: + RuntimeError: if the current thread does not own a read lock. + """ + me = self._current_thread() + self._release_read_lock(me) + + def _acquire_read_lock(self, me): if me in self._pending_writers: raise RuntimeError("Writer %s can not acquire a read lock" " while waiting for the write lock" @@ -128,23 +140,91 @@ def read_lock(self): break # An active or pending writer; guess we have to wait. self._cond.wait() + + def _release_read_lock(self, me, raise_on_not_owned=True): + # I am no longer a reader, remove *one* occurrence of myself. + # If the current thread acquired two read locks, then it will + # still have to remove that other read lock; this allows for + # basic reentrancy to be possible. + with self._cond: + try: + me_instances = self._readers[me] + if me_instances > 1: + self._readers[me] = me_instances - 1 + else: + self._readers.pop(me) + except KeyError: + if raise_on_not_owned: + raise RuntimeError(f"Thread {me} does not own a read lock") + self._cond.notify_all() + + @contextlib.contextmanager + def read_lock(self): + """Context manager that grants a read lock. + + Will wait until no active or pending writers. + + Raises: + RuntimeError: if a pending writer tries to acquire a read lock. + """ + me = self._current_thread() + self._acquire_read_lock(me) try: yield self finally: - # I am no longer a reader, remove *one* occurrence of myself. - # If the current thread acquired two read locks, then it will - # still have to remove that other read lock; this allows for - # basic reentrancy to be possible. - with self._cond: - try: - me_instances = self._readers[me] - if me_instances > 1: - self._readers[me] = me_instances - 1 - else: - self._readers.pop(me) - except KeyError: - pass - self._cond.notify_all() + self._release_read_lock(me, raise_on_not_owned=False) + + def _acquire_write_lock(self, me): + if self.is_reader(): + raise RuntimeError("Reader %s to writer privilege" + " escalation not allowed" % me) + + with self._cond: + self._pending_writers.append(me) + while True: + # No readers, and no active writer, am I next?? + if len(self._readers) == 0 and self._writer is None: + if self._pending_writers[0] == me: + self._writer = self._pending_writers.popleft() + self._writer_entries = 1 + break + self._cond.wait() + + def _release_write_lock(self, me, raise_on_not_owned=True): + with self._cond: + self._writer = None + self._writer_entries = 0 + self._cond.notify_all() + + def acquire_write_lock(self): + """Acquire a write lock. + + Will wait until no active readers. Blocks readers after acquiring. + + Guaranteed for locks to be processed in fair order (FIFO). + + Raises: + RuntimeError: if an active reader attempts to acquire a lock. + """ + me = self._current_thread() + if self._writer == me: + self._writer_entries += 1 + else: + self._acquire_write_lock(me) + + def release_write_lock(self): + """Release a write lock. + + Raises: + RuntimeError: if the current thread does not own a write lock. + """ + me = self._current_thread() + if self._writer == me: + self._writer_entries -= 1 + if self._writer_entries == 0: + self._release_write_lock(me) + else: + raise RuntimeError(f"Thread {me} does not own a write lock") @contextlib.contextmanager def write_lock(self): @@ -158,29 +238,18 @@ def write_lock(self): RuntimeError: if an active reader attempts to acquire a lock. """ me = self._current_thread() - i_am_writer = self.is_writer(check_pending=False) - if self.is_reader() and not i_am_writer: - raise RuntimeError("Reader %s to writer privilege" - " escalation not allowed" % me) - if i_am_writer: - # Already the writer; this allows for basic reentrancy. - yield self + if self.is_writer(check_pending=False): + self._writer_entries += 1 + try: + yield self + finally: + self._writer_entries -= 1 else: - with self._cond: - self._pending_writers.append(me) - while True: - # No readers, and no active writer, am I next?? - if len(self._readers) == 0 and self._writer is None: - if self._pending_writers[0] == me: - self._writer = self._pending_writers.popleft() - break - self._cond.wait() + self._acquire_write_lock(me) try: yield self finally: - with self._cond: - self._writer = None - self._cond.notify_all() + self._release_write_lock(me) def locked(*args, **kwargs): diff --git a/tests/test_lock.py b/tests/test_lock.py index 1765c10..28df834 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -96,28 +96,40 @@ def _daemon_thread(target): return t -def test_no_double_writers(): +@pytest.mark.parametrize("contextmanager", [True, False]) +def test_no_double_writers(contextmanager): lock = fasteners.ReaderWriterLock() watch = _utils.StopWatch(duration=5) watch.start() dups = collections.deque() active = collections.deque() - def acquire_check(me): + def acquire_check_ctx(me): with lock.write_lock(): if len(active) >= 1: dups.append(me) dups.extend(active) active.append(me) - try: - time.sleep(random.random() / 100) - finally: - active.remove(me) + time.sleep(random.random() / 100) + active.remove(me) + + def acquire_check_plain(me): + lock.acquire_write_lock() + if len(active) >= 1: + dups.append(me) + dups.extend(active) + active.append(me) + time.sleep(random.random() / 100) + active.remove(me) + lock.release_write_lock() def run(): me = threading.current_thread() while not watch.expired(): - acquire_check(me) + if contextmanager: + acquire_check_ctx(me) + else: + acquire_check_plain(me) threads = [] for i in range(0, THREAD_COUNT): @@ -132,14 +144,15 @@ def run(): assert not active -def test_no_concurrent_readers_writers(): +@pytest.mark.parametrize("contextmanager", [True, False]) +def test_no_concurrent_readers_writers(contextmanager): lock = fasteners.ReaderWriterLock() watch = _utils.StopWatch(duration=5) watch.start() dups = collections.deque() active = collections.deque() - def acquire_check(me, reader): + def acquire_check_ctx(me, reader): if reader: lock_func = lock.read_lock else: @@ -153,15 +166,35 @@ def acquire_check(me, reader): dups.append(me) dups.extend(active) active.append(me) - try: - time.sleep(random.random() / 100) - finally: - active.remove(me) + time.sleep(random.random() / 100) + active.remove(me) + + def acquire_check_plain(me, reader): + if reader: + lock_func, unlock_func = lock.acquire_read_lock, lock.release_read_lock + else: + lock_func, unlock_func = lock.acquire_write_lock, lock.release_write_lock + + lock_func() + if not reader: + # There should be no-one else currently active, if there + # is ensure we capture them so that we can later blow-up + # the test. + if len(active) >= 1: + dups.append(me) + dups.extend(active) + active.append(me) + time.sleep(random.random() / 100) + active.remove(me) + unlock_func() def run(): me = threading.current_thread() while not watch.expired(): - acquire_check(me, random.choice([True, False])) + if contextmanager: + acquire_check_ctx(me, random.choice([True, False])) + else: + acquire_check_plain(me, random.choice([True, False])) threads = [] for i in range(0, THREAD_COUNT): @@ -305,7 +338,7 @@ def happy_reader(): assert sum(a == 'r' for a in activated) == 10 -def test_writer_reader_writer(): +def test_writer_reader_writer_ctx(): lock = fasteners.ReaderWriterLock() with lock.write_lock(): assert lock.is_writer() @@ -315,23 +348,48 @@ def test_writer_reader_writer(): assert lock.is_writer() -def test_single_reader_writer(): - results = [] +def test_writer_reader_writer_plain(): + lock = fasteners.ReaderWriterLock() + lock.acquire_write_lock() + assert lock.is_writer() + lock.acquire_read_lock() + assert lock.is_reader() + lock.acquire_write_lock() + assert lock.is_writer() + + +def test_single_reader_writer_ctx(): lock = fasteners.ReaderWriterLock() with lock.read_lock(): assert lock.is_reader() - assert not results with lock.write_lock(): - results.append(1) assert lock.is_writer() with lock.read_lock(): assert lock.is_reader() - assert len(results) == 1 assert not lock.is_reader() assert not lock.is_writer() -def test_reader_to_writer(): +def test_single_reader_writer_plain(): + lock = fasteners.ReaderWriterLock() + + lock.acquire_read_lock() + assert lock.is_reader() + lock.release_read_lock() + + lock.acquire_write_lock() + assert lock.is_writer() + lock.release_write_lock() + + lock.acquire_read_lock() + assert lock.is_reader() + lock.release_read_lock() + + assert not lock.is_reader() + assert not lock.is_writer() + + +def test_reader_to_writer_ctx(): lock = fasteners.ReaderWriterLock() with lock.read_lock(): @@ -345,11 +403,24 @@ def test_reader_to_writer(): assert not lock.is_writer() -def test_writer_to_reader(): +def test_reader_to_writer_plain(): lock = fasteners.ReaderWriterLock() - with lock.write_lock(): + lock.acquire_read_lock() + with pytest.raises(RuntimeError): + lock.acquire_write_lock() + assert lock.is_reader() + assert not lock.is_writer() + lock.release_read_lock() + assert not lock.is_reader() + assert not lock.is_writer() + + +def test_writer_to_reader_ctx(): + lock = fasteners.ReaderWriterLock() + + with lock.write_lock(): with lock.read_lock(): assert lock.is_writer() assert lock.is_reader() @@ -361,7 +432,24 @@ def test_writer_to_reader(): assert not lock.is_reader() -def test_double_writer(): +def test_writer_to_reader_plain(): + lock = fasteners.ReaderWriterLock() + + lock.acquire_write_lock() + lock.acquire_read_lock() + assert lock.is_writer() + assert lock.is_reader() + + lock.release_read_lock() + assert lock.is_writer() + assert not lock.is_reader() + + lock.release_write_lock() + assert not lock.is_writer() + assert not lock.is_reader() + + +def test_double_writer_ctx(): lock = fasteners.ReaderWriterLock() with lock.write_lock(): @@ -377,7 +465,25 @@ def test_double_writer(): assert not lock.is_writer() -def test_double_reader(): +def test_double_writer_plain(): + lock = fasteners.ReaderWriterLock() + + lock.acquire_write_lock() + assert not lock.is_reader() + assert lock.is_writer() + + lock.acquire_write_lock() + assert lock.is_writer() + + lock.release_write_lock() + assert lock.is_writer() + + lock.release_write_lock() + assert not lock.is_reader() + assert not lock.is_writer() + + +def test_double_reader_ctx(): lock = fasteners.ReaderWriterLock() with lock.read_lock(): @@ -393,6 +499,24 @@ def test_double_reader(): assert not lock.is_writer() +def test_double_reader_plain(): + lock = fasteners.ReaderWriterLock() + + lock.acquire_read_lock() + assert lock.is_reader() + assert not lock.is_writer() + + lock.acquire_read_lock() + assert lock.is_reader() + + lock.release_read_lock() + assert lock.is_reader() + + lock.release_read_lock() + assert not lock.is_reader() + assert not lock.is_writer() + + def test_multi_reader_multi_writer(): writer_times, reader_times = _spawn_variation(10, 10) assert len(writer_times) == 10 @@ -456,3 +580,15 @@ def thread2(): while threads: t = threads.pop() t.join() + + +def test_error_when_releasing_write_lock_without_write_lock(): + lock = fasteners.ReaderWriterLock() + with pytest.raises(RuntimeError): + lock.release_write_lock() + + +def test_error_when_releasing_read_lock_without_read_lock(): + lock = fasteners.ReaderWriterLock() + with pytest.raises(RuntimeError): + lock.release_read_lock()