diff --git a/.gitignore b/.gitignore index 218ba63..eb93365 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ nosetests.xml .idea .idea/ +.venv/ +.envrc diff --git a/dominate/dom_tag.py b/dominate/dom_tag.py index 532d916..44ca8ba 100644 --- a/dominate/dom_tag.py +++ b/dominate/dom_tag.py @@ -23,6 +23,9 @@ from collections import defaultdict, namedtuple from functools import wraps import threading +from asyncio import get_event_loop +from uuid import uuid4 +from contextvars import ContextVar try: # Python 3 @@ -37,19 +40,44 @@ basestring = str unicode = str - try: import greenlet except ImportError: greenlet = None +# We want dominate to work in async contexts - however, the problem is +# when we bind a tag using "with", we set what is essentially a global variable. +# If we are processing multiple documents at the same time, one context +# can "overwrite" the "bound tag" of another - this can cause documents to +# sort of bleed into one another... + +# The solution is to use a ContextVar - which provides async context local storage. +# We use this to store a unique ID for each async context. We then use thie ID to +# form the key (in _get_thread_context) that is used to index the _with_context defaultdict. +# The presense of this key ensures that each async context has its own stack and doesn't conflict. +async_context_id = ContextVar('async_context_id', default = None) + +def _get_async_context_id(): + if async_context_id.get() is None: + async_context_id.set(uuid4().hex) + return async_context_id.get() def _get_thread_context(): context = [threading.current_thread()] + # Tag extra content information with a name to make sure + # a greenlet.getcurrent() == 1 doesn't get confused with a + # a _get_thread_context() == 1. if greenlet: - context.append(greenlet.getcurrent()) - return hash(tuple(context)) - + context.append(("greenlet", greenlet.getcurrent())) + + try: + if get_event_loop().is_running(): + # Only add this extra information if we are actually in a running event loop + context.append(("async", _get_async_context_id())) + # A runtime error is raised if there is no async loop... + except RuntimeError: + pass + return tuple(context) class dom_tag(object): is_single = False # Tag does not require matching end tag (ex.
) diff --git a/tests/test_dom_tag_async.py b/tests/test_dom_tag_async.py new file mode 100644 index 0000000..648f9c8 --- /dev/null +++ b/tests/test_dom_tag_async.py @@ -0,0 +1,75 @@ +from asyncio import gather, run, Semaphore +from dominate.dom_tag import async_context_id +from textwrap import dedent + +from dominate import tags + +# To simulate sleep without making the tests take a hella long time to complete +# lets use a pair of semaphores to explicitly control when our coroutines run. +# The order of execution will be marked as comments below: +def test_async_bleed(): + async def tag_routine_1(sem_1, sem_2): + root = tags.div(id = 1) # [1] + with root: # [2] + sem_2.release() # [3] + await sem_1.acquire() # [4] + tags.div(id = 2) # [11] + return str(root) # [12] + + async def tag_routine_2(sem_1, sem_2): + await sem_2.acquire() # [5] + root = tags.div(id = 3) # [6] + with root: # [7] + tags.div(id = 4) # [8] + sem_1.release() # [9] + return str(root) # [10] + + async def merge(): + sem_1 = Semaphore(0) + sem_2 = Semaphore(0) + return await gather( + tag_routine_1(sem_1, sem_2), + tag_routine_2(sem_1, sem_2) + ) + + # Set this test up for failure - pre-set the context to a non-None value. + # As it is already set, _get_async_context_id will not set it to a new, unique value + # and thus we won't be able to differentiate between the two contexts. This essentially simulates + # the behavior before our async fix was implemented (the bleed): + async_context_id.set(0) + tag_1, tag_2 = run(merge()) + + # This looks wrong - but its what we would expect if we don't + # properly handle async... + assert tag_1 == dedent("""\ +
+
+
+
+
+
+ """).strip() + + assert tag_2 == dedent("""\ +
+
+
+ """).strip() + + # Okay, now lets do it right - lets clear the context. Now when each async function + # calls _get_async_context_id, it will get a unique value and we can differentiate. + async_context_id.set(None) + tag_1, tag_2 = run(merge()) + + # Ah, much better... + assert tag_1 == dedent("""\ +
+
+
+ """).strip() + + assert tag_2 == dedent("""\ +
+
+
+ """).strip()