Skip to content

Commit d97f71b

Browse files
authored
Relay pagination optimizations (#777)
1 parent 3cf0e74 commit d97f71b

File tree

7 files changed

+332
-45
lines changed

7 files changed

+332
-45
lines changed

strawberry_django/optimizer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from graphql.execution.collect_fields import collect_sub_fields
3838
from graphql.language.ast import OperationType
3939
from graphql.type.definition import GraphQLResolveInfo, get_named_type
40-
from strawberry import relay
40+
from strawberry import UNSET, relay
4141
from strawberry.extensions import SchemaExtension
4242
from strawberry.relay.utils import SliceMetadata
4343
from strawberry.schema.schema import Schema
@@ -82,10 +82,13 @@
8282
from collections.abc import Generator
8383

8484
from django.contrib.contenttypes.fields import GenericRelation
85+
from strawberry.relay import Edge
8586
from strawberry.types.execution import ExecutionContext
8687
from strawberry.types.field import StrawberryField
8788
from strawberry.utils.await_maybe import AwaitableOrValue
8889

90+
from strawberry_django.pagination import OffsetPaginationInfo
91+
8992

9093
__all__ = [
9194
"DjangoOptimizerExtension",
@@ -582,19 +585,33 @@ def _optimize_prefetch_queryset(
582585
connection_type is relay.ListConnection
583586
or connection_type is DjangoListConnection
584587
):
588+
field_def_ = connection_type_def.get_field("edges")
589+
assert field_def_
590+
field_ = field_def_.resolve_type(type_definition=connection_type_def)
591+
field_ = unwrap_type(field_)
592+
edge_class = cast("Edge", field_)
593+
585594
slice_metadata = SliceMetadata.from_arguments(
586595
Info(_raw_info=info, _field=field),
587596
first=field_kwargs.get("first"),
588597
last=field_kwargs.get("last"),
589598
before=field_kwargs.get("before"),
590599
after=field_kwargs.get("after"),
600+
max_results=connection_extension.max_results,
601+
prefix=edge_class.CURSOR_PREFIX,
591602
)
603+
mark_reversed = slice_metadata.expected is None
592604
qs = apply_window_pagination(
593605
qs,
594606
related_field_id=related_field_id,
595607
offset=slice_metadata.start,
596-
limit=slice_metadata.end - slice_metadata.start,
608+
limit=(
609+
field_kwargs.get("last", UNSET)
610+
if mark_reversed
611+
else slice_metadata.end - slice_metadata.start
612+
),
597613
max_results=connection_extension.max_results,
614+
reverse=mark_reversed,
598615
)
599616
elif connection_type is DjangoCursorConnection:
600617
qs, _ = apply_cursor_pagination(
@@ -611,7 +628,7 @@ def _optimize_prefetch_queryset(
611628
mark_optimized = False
612629

613630
if isinstance(field.type, type) and issubclass(field.type, OffsetPaginated):
614-
pagination = field_kwargs.get("pagination")
631+
pagination: OffsetPaginationInfo | None = field_kwargs.get("pagination")
615632
qs = apply_window_pagination(
616633
qs,
617634
related_field_id=related_field_id,

strawberry_django/pagination.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def apply_window_pagination(
172172
offset: int = 0,
173173
limit: Optional[int] = UNSET,
174174
max_results: Optional[int] = None,
175+
reverse: bool = False,
175176
) -> _QS:
176177
"""Apply pagination using window functions.
177178
@@ -186,8 +187,17 @@ def apply_window_pagination(
186187
related_field_id: The related field id to apply pagination to.
187188
offset: The offset to start the pagination from.
188189
limit: The limit of items to return.
190+
reverse: The need to reverse queryset ordering for backwards relay pagination
189191
190192
"""
193+
if limit is UNSET:
194+
settings = strawberry_django_settings()
195+
limit = (
196+
max_results
197+
if max_results is not None
198+
else settings["PAGINATION_DEFAULT_LIMIT"]
199+
)
200+
191201
order_by = [
192202
expr
193203
for expr, _ in queryset.query.get_compiler(
@@ -210,13 +220,23 @@ def apply_window_pagination(
210220
if offset:
211221
queryset = queryset.filter(_strawberry_row_number__gt=offset)
212222

213-
if limit is UNSET:
214-
settings = strawberry_django_settings()
215-
limit = (
216-
max_results
217-
if max_results is not None
218-
else settings["PAGINATION_DEFAULT_LIMIT"]
223+
if reverse:
224+
order_by_reverse = [
225+
expr
226+
for expr, _ in queryset.reverse()
227+
.query.get_compiler(
228+
using=queryset._db or DEFAULT_DB_ALIAS # type: ignore
229+
)
230+
.get_order_by()
231+
]
232+
queryset = queryset.annotate(
233+
_strawberry_row_number_reversed=_PaginationWindow(
234+
RowNumber(),
235+
partition_by=related_field_id,
236+
order_by=order_by_reverse,
237+
),
219238
)
239+
return queryset.filter(_strawberry_row_number_reversed__lte=limit)
220240

221241
# Limit == -1 means no limit. sys.maxsize is set by relay when paginating
222242
# from the end to as a way to mimic a "not limit" as well

strawberry_django/relay/list_connection.py

Lines changed: 147 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,50 @@
99
from strawberry.relay.types import NodeIterableType
1010
from strawberry.types import get_object_definition
1111
from strawberry.types.base import StrawberryContainer
12+
from strawberry.types.nodes import InlineFragment, Selection
1213
from strawberry.utils.await_maybe import AwaitableOrValue
14+
from strawberry.utils.inspect import in_async_context
1315
from typing_extensions import Self, deprecated
1416

1517
from strawberry_django.pagination import get_total_count
18+
from strawberry_django.queryset import get_queryset_config
1619
from strawberry_django.resolvers import django_resolver
20+
from strawberry_django.utils.typing import unwrap_type
21+
22+
23+
def _should_optimize_total_count(info: Info) -> bool:
24+
"""Check if the user requested to resolve the `totalCount` field of a connection.
25+
26+
Taken and adjusted from strawberry.relay.utils
27+
"""
28+
resolve_for_field_names = {"totalCount"}
29+
30+
def _check_selection(selection: Selection) -> bool:
31+
"""Recursively inspect the selection to check if the user requested to resolve the `edges` field.
32+
33+
Args:
34+
selection (Selection): The selection to check.
35+
36+
Returns:
37+
bool: True if the user requested to resolve the `edges` field of a connection, False otherwise.
38+
39+
"""
40+
if (
41+
not isinstance(selection, InlineFragment)
42+
and selection.name in resolve_for_field_names
43+
):
44+
return True
45+
if selection.selections:
46+
return any(
47+
_check_selection(selection) for selection in selection.selections
48+
)
49+
return False
50+
51+
for selection_field in info.selected_fields:
52+
for selection in selection_field.selections:
53+
if _check_selection(selection):
54+
return True
55+
return False
1756

1857

1958
@strawberry.type(name="Connection", description="A connection to a list of items.")
@@ -25,6 +64,11 @@ class DjangoListConnection(relay.ListConnection[relay.NodeType]):
2564
def total_count(self) -> Optional[int]:
2665
assert self.nodes is not None
2766

67+
try:
68+
return self.edges[0].node._strawberry_total_count # type: ignore
69+
except (IndexError, AttributeError):
70+
pass
71+
2872
if isinstance(self.nodes, models.QuerySet):
2973
return get_total_count(self.nodes)
3074

@@ -42,32 +86,52 @@ def resolve_connection(
4286
last: Optional[int] = None,
4387
**kwargs: Any,
4488
) -> AwaitableOrValue[Self]:
45-
from strawberry_django.optimizer import is_optimized_by_prefetching
89+
if isinstance(nodes, models.QuerySet) and (
90+
queryset_config := get_queryset_config(nodes)
91+
):
92+
if queryset_config.optimized_by_prefetching:
93+
try:
94+
conn = cls.resolve_optimized_connection_by_prefetch(
95+
nodes,
96+
info=info,
97+
before=before,
98+
after=after,
99+
first=first,
100+
last=last,
101+
**kwargs,
102+
)
103+
except AttributeError:
104+
warnings.warn(
105+
(
106+
"Pagination annotations not found, falling back to QuerySet resolution. "
107+
"This might cause N+1 issues..."
108+
),
109+
RuntimeWarning,
110+
stacklevel=2,
111+
)
112+
else:
113+
conn = cast("Self", conn)
114+
conn.nodes = nodes
115+
return conn
46116

47-
if isinstance(nodes, models.QuerySet) and is_optimized_by_prefetching(nodes):
48-
try:
49-
conn = cls.resolve_connection_from_cache(
50-
nodes,
51-
info=info,
52-
before=before,
53-
after=after,
54-
first=first,
55-
last=last,
56-
**kwargs,
57-
)
58-
except AttributeError:
59-
warnings.warn(
60-
(
61-
"Pagination annotations not found, falling back to QuerySet resolution. "
62-
"This might cause N+1 issues..."
63-
),
64-
RuntimeWarning,
65-
stacklevel=2,
66-
)
67-
else:
68-
conn = cast("Self", conn)
69-
conn.nodes = nodes
70-
return conn
117+
if queryset_config.optimized:
118+
if (last or 0) > 0 and before is None:
119+
return cls.resolve_optimized_last_connection(
120+
nodes,
121+
info=info,
122+
before=before,
123+
after=after,
124+
first=first,
125+
last=last,
126+
**kwargs,
127+
)
128+
129+
if _should_optimize_total_count(info):
130+
nodes = nodes.annotate(
131+
_strawberry_total_count=models.Window(
132+
expression=models.Count(1), partition_by=None
133+
)
134+
)
71135

72136
conn = super().resolve_connection(
73137
nodes,
@@ -93,7 +157,7 @@ async def wrapper():
93157
return conn
94158

95159
@classmethod
96-
def resolve_connection_from_cache(
160+
def resolve_optimized_connection_by_prefetch(
97161
cls,
98162
nodes: NodeIterableType[relay.NodeType],
99163
*,
@@ -147,6 +211,63 @@ def resolve_connection_from_cache(
147211
),
148212
)
149213

214+
@classmethod
215+
def resolve_optimized_last_connection(
216+
cls,
217+
nodes: NodeIterableType[relay.NodeType],
218+
*,
219+
info: Info,
220+
before: Optional[str] = None,
221+
after: Optional[str] = None,
222+
first: Optional[int] = None,
223+
last: Optional[int] = None,
224+
**kwargs: Any,
225+
) -> AwaitableOrValue[Self]:
226+
"""Resolve the connection being paginated only via `last`.
227+
228+
In order to prevent fetching the entire table, QuerySet is first counted & the
229+
amount is used instead of `before=None`.
230+
"""
231+
assert isinstance(nodes, models.QuerySet)
232+
233+
type_def = get_object_definition(cls)
234+
assert type_def
235+
field_def = type_def.get_field("edges")
236+
assert field_def
237+
field = field_def.resolve_type(type_definition=type_def)
238+
field = unwrap_type(field)
239+
edge_class = cast("relay.Edge[relay.NodeType]", field)
240+
241+
if in_async_context():
242+
243+
async def wrapper():
244+
total_count = await nodes.acount()
245+
before = relay.to_base64(edge_class.CURSOR_PREFIX, total_count)
246+
conn = cls.resolve_connection(
247+
nodes,
248+
info=info,
249+
before=before,
250+
after=after,
251+
first=first,
252+
last=last,
253+
**kwargs,
254+
)
255+
return await conn if inspect.isawaitable(conn) else conn
256+
257+
return wrapper()
258+
259+
total_count = nodes.count()
260+
before = relay.to_base64(edge_class.CURSOR_PREFIX, total_count)
261+
return cls.resolve_connection(
262+
nodes,
263+
info=info,
264+
before=before,
265+
after=after,
266+
first=first,
267+
last=last,
268+
**kwargs,
269+
)
270+
150271

151272
if TYPE_CHECKING:
152273

tests/polymorphism_custom/test_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_polymorphic_interface_connection():
186186
}
187187
"""
188188

189-
with assert_num_queries(2):
189+
with assert_num_queries(1):
190190
result = schema.execute_sync(query)
191191
assert not result.errors
192192
assert result.data == {

0 commit comments

Comments
 (0)