Skip to content

Implementation of automatic batching for async #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ The complete documentation for GQL can be found at
* Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage)
* Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html)
* Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html)
* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html)
* [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line
* [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically

96 changes: 96 additions & 0 deletions docs/advanced/batching_requests.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
.. _batching_requests:

Batching requests
=================

If you need to send multiple GraphQL queries to a backend,
and if the backend supports batch requests,
then you might want to send those requests in a batch instead of
making multiple execution requests.

.. warning::
- Some backends do not support batch requests
- File uploads and subscriptions are not supported with batch requests

Batching requests manually
^^^^^^^^^^^^^^^^^^^^^^^^^^

To execute a batch of requests manually:

- First Make a list of :class:`GraphQLRequest <gql.GraphQLRequest>` objects, containing:
* your GraphQL query
* Optional variable_values
* Optional operation_name

.. code-block:: python
request1 = GraphQLRequest("""
query getContinents {
continents {
code
name
}
}
"""
)
request2 = GraphQLRequest("""
query getContinentName ($code: ID!) {
continent (code: $code) {
name
}
}
""",
variable_values={
"code": "AF",
},
)
requests = [request1, request2]
- Then use one of the `execute_batch` methods, either on Client,
or in a sync or async session

**Sync**:

.. code-block:: python
transport = RequestsHTTPTransport(url=url)
# Or transport = HTTPXTransport(url=url)
with Client(transport=transport) as session:
results = session.execute_batch(requests)
result1 = results[0]
result2 = results[1]
**Async**:

.. code-block:: python
transport = AIOHTTPTransport(url=url)
# Or transport = HTTPXAsyncTransport(url=url)
async with Client(transport=transport) as session:
results = await session.execute_batch(requests)
result1 = results[0]
result2 = results[1]
.. note::
If any request in the batch returns an error, then a TransportQueryError will be raised
with the first error found.

Automatic Batching of requests
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If your code execute multiple requests independently in a short time
(either from different threads in sync code, or from different asyncio tasks in async code),
then you can use gql automatic batching of request functionality.

You define a :code:`batching_interval` in your :class:`Client <gql.Client>`
and each time a new execution request is received through an `execute` method,
we will wait that interval (in seconds) for other requests to arrive
before sending all the requests received in that interval in a single batch.
1 change: 1 addition & 0 deletions docs/advanced/index.rst
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ Advanced

async_advanced_usage
async_permanent_session
batching_requests
logging
error_handling
local_schema
177 changes: 160 additions & 17 deletions gql/client.py
Original file line number Diff line number Diff line change
@@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):

if reconnecting:
self.session = ReconnectingAsyncClientSession(client=self, **kwargs)
await self.session.start_connecting_task()
else:
try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e
self.session = AsyncClientSession(client=self)

await self.session.connect()

# Get schema from transport if needed
try:
if self.fetch_schema_from_transport and not self.schema:
@@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
# we don't know what type of exception is thrown here because it
# depends on the underlying transport; we just make sure that the
# transport is closed and re-raise the exception
await self.transport.close()
await self.session.close()
raise

return self.session

async def close_async(self):
"""Close the async transport and stop the optional reconnecting task."""

if isinstance(self.session, ReconnectingAsyncClientSession):
await self.session.stop_connecting_task()

await self.transport.close()
await self.session.close()

async def __aenter__(self):
return await self.connect_async()
@@ -1564,12 +1557,17 @@ async def _execute(
):
request = request.serialize_variable_values(self.client.schema)

# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)
# Check if batching is enabled
if self.client.batching_enabled:
future_result = await self._execute_future(request)
result = await future_result
else:
# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)

# Unserialize the result if requested
if self.client.schema:
@@ -1828,6 +1826,134 @@ async def execute_batch(

return cast(List[Dict[str, Any]], [result.data for result in results])

async def _batch_loop(self) -> None:
"""Main loop of the task used to wait for requests
to execute them in a batch"""

stop_loop = False

while not stop_loop:
# First wait for a first request in from the batch queue
requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = []

# Wait for the first request
request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = (
await self.batch_queue.get()
)

if request_and_future is None:
# None is our sentinel value to stop the loop
break

requests_and_futures.append(request_and_future)

# Then wait the requested batch interval except if we already
# have the maximum number of requests in the queue
if self.batch_queue.qsize() < self.client.batch_max - 1:
# Wait for the batch interval
await asyncio.sleep(self.client.batch_interval)

# Then get the requests which had been made during that wait interval
for _ in range(self.client.batch_max - 1):
try:
# Use get_nowait since we don't want to wait here
request_and_future = self.batch_queue.get_nowait()

if request_and_future is None:
# Sentinel value - stop after processing current batch
stop_loop = True
break

requests_and_futures.append(request_and_future)

except asyncio.QueueEmpty:
# No more requests in queue, that's fine
break

# Extract requests and futures
requests = [request for request, _ in requests_and_futures]
futures = [future for _, future in requests_and_futures]

# Execute the batch
try:
results: List[ExecutionResult] = await self._execute_batch(
requests,
serialize_variables=False, # already done
parse_result=False, # will be done later
validate_document=False, # already validated
)

# Set the result for each future
for result, future in zip(results, futures):
if not future.cancelled():
future.set_result(result)

except Exception as exc:
# If batch execution fails, propagate the error to all futures
for future in futures:
if not future.cancelled():
future.set_exception(exc)

# Signal that the task has stopped
self._batch_task_stopped_event.set()

async def _execute_future(
self,
request: GraphQLRequest,
) -> asyncio.Future:
"""If batching is enabled, this method will put a request in the batching queue
instead of executing it directly so that the requests could be put in a batch.
"""

assert hasattr(self, "batch_queue"), "Batching is not enabled"
assert not self._batch_task_stop_requested, "Batching task has been stopped"

future: asyncio.Future = asyncio.Future()
await self.batch_queue.put((request, future))

return future

async def _batch_init(self):
"""Initialize the batch task loop if batching is enabled."""
if self.client.batching_enabled:
self.batch_queue: asyncio.Queue = asyncio.Queue()
self._batch_task_stop_requested = False
self._batch_task_stopped_event = asyncio.Event()
self._batch_task = asyncio.create_task(self._batch_loop())

async def _batch_cleanup(self):
"""Cleanup the batching task if batching is enabled."""
if hasattr(self, "_batch_task_stopped_event"):
# Send a None in the queue to indicate that the batching task must stop
# after having processed the remaining requests in the queue
self._batch_task_stop_requested = True
await self.batch_queue.put(None)

# Wait for the task to process remaining requests and stop
await self._batch_task_stopped_event.wait()

async def connect(self):
"""Connect the transport and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e

async def close(self):
"""Close the transport and cleanup the batching task if batching is enabled.
Will wait until all the remaining requests in the batch processing queue
have been executed.
"""
await self._batch_cleanup()

await self.transport.close()

async def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.
@@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
self._connect_task.cancel()
self._connect_task = None

async def connect(self):
"""Start the connect task and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

await self.start_connecting_task()

async def close(self):
"""Stop the connect task and cleanup the batching task
if batching is enabled."""
await self._batch_cleanup()

await self.stop_connecting_task()

await self.transport.close()

async def _execute_once(
self,
request: GraphQLRequest,
39 changes: 27 additions & 12 deletions gql/graphql_request.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from graphql import DocumentNode, GraphQLSchema, print_ast

from .gql import gql
from .utilities import serialize_variable_values


@dataclass(frozen=True)
class GraphQLRequest:
"""GraphQL Request to be executed."""

document: DocumentNode
"""GraphQL query as AST Node object."""
def __init__(
self,
document: Union[DocumentNode, str],
*,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
):
"""
Initialize a GraphQL request.
variable_values: Optional[Dict[str, Any]] = None
"""Dictionary of input parameters (Default: None)."""
Args:
document: GraphQL query as AST Node object or as a string.
If string, it will be converted to DocumentNode using gql().
variable_values: Dictionary of input parameters (Default: None).
operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""
if isinstance(document, str):
self.document = gql(document)
else:
self.document = document

operation_name: Optional[str] = None
"""
Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""
self.variable_values = variable_values
self.operation_name = operation_name

def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
assert self.variable_values
@@ -48,3 +60,6 @@ def payload(self) -> Dict[str, Any]:
payload["variables"] = self.variable_values

return payload

def __str__(self):
return str(self.payload)
48 changes: 30 additions & 18 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
@@ -274,22 +274,35 @@ def _prepare_file_uploads(

return post_args

async def raise_response_error(
self,
@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
resp: aiohttp.ClientResponse,
reason: str,
) -> None:
# We raise a TransportServerError if status code is 400 or higher
# We raise a TransportProtocolError in the other cases

# If the status is >400,
# then we need to raise a TransportServerError
try:
# Raise ClientResponseError if response status is 400 or higher
resp.raise_for_status()
except ClientResponseError as e:
raise TransportServerError(str(e), e.status) from e

@classmethod
async def _raise_response_error(
cls,
resp: aiohttp.ClientResponse,
reason: str,
) -> None:
# We raise a TransportServerError if status code is 400 or higher
# We raise a TransportProtocolError in the other cases

cls._raise_transport_server_error_if_status_more_than_400(resp)

result_text = await resp.text()
self._raise_invalid_result(result_text, reason)
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)

async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:

@@ -304,10 +317,10 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:
log.debug("<<< %s", result_text)

except Exception:
await self.raise_response_error(response, "Not a JSON answer")
await self._raise_response_error(response, "Not a JSON answer")

if result is None:
await self.raise_response_error(response, "Not a JSON answer")
await self._raise_response_error(response, "Not a JSON answer")

return result

@@ -318,7 +331,7 @@ async def _prepare_result(
result = await self._get_json_result(response)

if "errors" not in result and "data" not in result:
await self.raise_response_error(
await self._raise_response_error(
response, 'No "data" or "errors" keys in answer'
)

@@ -336,14 +349,13 @@ async def _prepare_batch_result(

answers = await self._get_json_result(response)

return get_batch_execution_result_list(reqs, answers)

def _raise_invalid_result(self, result_text: str, reason: str) -> None:
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

async def execute(
self,
29 changes: 22 additions & 7 deletions gql/transport/httpx.py
Original file line number Diff line number Diff line change
@@ -195,18 +195,33 @@ def _prepare_batch_result(

answers = self._get_json_result(response)

return get_batch_execution_result_list(reqs, answers)

def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

try:
# Raise a HTTPError if response status is 400 or higher
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
response: httpx.Response,
) -> None:
# If the status is >400,
# then we need to raise a TransportServerError
try:
# Raise a HTTPStatusError if response status is 400 or higher
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportServerError(str(e), e.response.status_code) from e

@classmethod
def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

cls._raise_transport_server_error_if_status_more_than_400(response)

raise TransportProtocolError(
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
)
70 changes: 35 additions & 35 deletions gql/transport/requests.py
Original file line number Diff line number Diff line change
@@ -258,24 +258,6 @@ def execute( # type: ignore

self.response_headers = response.headers

def raise_response_error(resp: requests.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

try:
# Raise a HTTPError if response status is 400 or higher
resp.raise_for_status()
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else None
raise TransportServerError(str(e), status_code) from e

result_text = resp.text
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)

try:
if self.json_deserialize == json.loads:
result = response.json()
@@ -286,17 +268,42 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn:
log.debug("<<< %s", response.text)

except Exception:
raise_response_error(response, "Not a JSON answer")
self._raise_response_error(response, "Not a JSON answer")

if "errors" not in result and "data" not in result:
raise_response_error(response, 'No "data" or "errors" keys in answer')
self._raise_response_error(response, 'No "data" or "errors" keys in answer')

return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)

@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
response: requests.Response,
) -> None:
# If the status is >400,
# then we need to raise a TransportServerError
try:
# Raise a HTTPError if response status is 400 or higher
response.raise_for_status()
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else None
raise TransportServerError(str(e), status_code) from e

@classmethod
def _raise_response_error(cls, resp: requests.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

cls._raise_transport_server_error_if_status_more_than_400(resp)

result_text = resp.text
raise TransportProtocolError(
f"Server did not return a GraphQL result: " f"{reason}: " f"{result_text}"
)

def execute_batch(
self,
reqs: List[GraphQLRequest],
@@ -330,30 +337,23 @@ def execute_batch(

answers = self._extract_response(response)

return get_batch_execution_result_list(reqs, answers)

def _raise_invalid_result(self, result_text: str, reason: str) -> None:
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

def _extract_response(self, response: requests.Response) -> Any:
try:
response.raise_for_status()
result = response.json()

if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", response.text)

except requests.HTTPError as e:
raise TransportServerError(
str(e), e.response.status_code if e.response is not None else None
) from e

except Exception:
self._raise_invalid_result(str(response.text), "Not a JSON answer")
self._raise_response_error(response, "Not a JSON answer")

return result

190 changes: 190 additions & 0 deletions tests/test_aiohttp_batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Mapping

import pytest
@@ -7,6 +8,7 @@
TransportClosed,
TransportProtocolError,
TransportQueryError,
TransportServerError,
)

# Marking all tests in this file with the aiohttp marker
@@ -29,6 +31,21 @@
'{"code":"SA","name":"South America"}]}}]'
)

query1_server_answer_twice_list = (
"["
'{"data":{"continents":['
'{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},'
'{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},'
'{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},'
'{"code":"SA","name":"South America"}]}},'
'{"data":{"continents":['
'{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},'
'{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},'
'{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},'
'{"code":"SA","name":"South America"}]}}'
"]"
)


@pytest.mark.asyncio
async def test_aiohttp_batch_query(aiohttp_server):
@@ -72,6 +89,179 @@ async def handler(request):
assert transport.response_headers["dummy"] == "test1234"


@pytest.mark.asyncio
async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(
text=query1_server_answer_list,
content_type="application/json",
headers={"dummy": "test1234"},
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(
transport=transport,
batch_interval=0.01, # 10ms batch interval
) as session:

query = gql(query1_str)

result = await session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"

# Checking response headers are saved in the transport
assert hasattr(transport, "response_headers")
assert isinstance(transport.response_headers, Mapping)
assert transport.response_headers["dummy"] == "test1234"


@pytest.mark.asyncio
async def test_aiohttp_batch_auto_two_requests(aiohttp_server):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(
text=query1_server_answer_twice_list,
content_type="application/json",
headers={"dummy": "test1234"},
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")
transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(
transport=transport,
batch_interval=0.01,
) as session:

async def test_coroutine():
query = gql(query1_str)

# Execute query asynchronously
result = await session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"

# Create two concurrent tasks that will be batched together
tasks = []
for _ in range(2):
task = asyncio.create_task(test_coroutine())
tasks.append(task)

# Wait for all tasks to complete
await asyncio.gather(*tasks)


@pytest.mark.asyncio
async def test_aiohttp_batch_auto_two_requests_close_session_directly(aiohttp_server):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(
text=query1_server_answer_twice_list,
content_type="application/json",
headers={"dummy": "test1234"},
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")
transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(
transport=transport,
batch_interval=0.1,
) as session:

async def test_coroutine():
query = gql(query1_str)

# Execute query asynchronously
result = await session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"

# Create two concurrent tasks that will be batched together
tasks = []
for _ in range(2):
task = asyncio.create_task(test_coroutine())
tasks.append(task)

await asyncio.sleep(0.01)

# Wait for all tasks to complete
await asyncio.gather(*tasks)


@pytest.mark.asyncio
async def test_aiohttp_batch_error_code_401(aiohttp_server):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
# Will generate http error code 401
return web.Response(
text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}',
content_type="application/json",
status=401,
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(
transport=transport,
batch_interval=0.01, # 10ms batch interval
) as session:

query = gql(query1_str)

with pytest.raises(TransportServerError) as exc_info:
await session.execute(query)

assert "401, message='Unauthorized'" in str(exc_info.value)


@pytest.mark.asyncio
async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test):
from aiohttp import web
12 changes: 11 additions & 1 deletion tests/test_graphql_request.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@

from gql import GraphQLRequest, gql

from .conftest import MS
from .conftest import MS, strip_braces_spaces

# Marking all tests in this file with the aiohttp marker
pytestmark = pytest.mark.aiohttp
@@ -200,3 +200,13 @@ def test_serialize_variables_using_money_example():
req = req.serialize_variable_values(schema)

assert req.variable_values == {"money": {"amount": 10, "currency": "DM"}}


def test_graphql_request_using_string_instead_of_document():
request = GraphQLRequest("{balance}")

expected_payload = "{'query': '{\\n balance\\n}'}"

print(request)

assert str(request) == strip_braces_spaces(expected_payload)