9
9
from strawberry .relay .types import NodeIterableType
10
10
from strawberry .types import get_object_definition
11
11
from strawberry .types .base import StrawberryContainer
12
+ from strawberry .types .nodes import InlineFragment , Selection
12
13
from strawberry .utils .await_maybe import AwaitableOrValue
14
+ from strawberry .utils .inspect import in_async_context
13
15
from typing_extensions import Self , deprecated
14
16
15
17
from strawberry_django .pagination import get_total_count
18
+ from strawberry_django .queryset import get_queryset_config
16
19
from strawberry_django .resolvers import django_resolver
20
+ from strawberry_django .utils .typing import unwrap_type
21
+
22
+
23
+ def _should_optimize_total_count (info : Info ) -> bool :
24
+ """Check if the user requested to resolve the `totalCount` field of a connection.
25
+
26
+ Taken and adjusted from strawberry.relay.utils
27
+ """
28
+ resolve_for_field_names = {"totalCount" }
29
+
30
+ def _check_selection (selection : Selection ) -> bool :
31
+ """Recursively inspect the selection to check if the user requested to resolve the `edges` field.
32
+
33
+ Args:
34
+ selection (Selection): The selection to check.
35
+
36
+ Returns:
37
+ bool: True if the user requested to resolve the `edges` field of a connection, False otherwise.
38
+
39
+ """
40
+ if (
41
+ not isinstance (selection , InlineFragment )
42
+ and selection .name in resolve_for_field_names
43
+ ):
44
+ return True
45
+ if selection .selections :
46
+ return any (
47
+ _check_selection (selection ) for selection in selection .selections
48
+ )
49
+ return False
50
+
51
+ for selection_field in info .selected_fields :
52
+ for selection in selection_field .selections :
53
+ if _check_selection (selection ):
54
+ return True
55
+ return False
17
56
18
57
19
58
@strawberry .type (name = "Connection" , description = "A connection to a list of items." )
@@ -25,6 +64,11 @@ class DjangoListConnection(relay.ListConnection[relay.NodeType]):
25
64
def total_count (self ) -> Optional [int ]:
26
65
assert self .nodes is not None
27
66
67
+ try :
68
+ return self .edges [0 ].node ._strawberry_total_count # type: ignore
69
+ except (IndexError , AttributeError ):
70
+ pass
71
+
28
72
if isinstance (self .nodes , models .QuerySet ):
29
73
return get_total_count (self .nodes )
30
74
@@ -42,32 +86,52 @@ def resolve_connection(
42
86
last : Optional [int ] = None ,
43
87
** kwargs : Any ,
44
88
) -> AwaitableOrValue [Self ]:
45
- from strawberry_django .optimizer import is_optimized_by_prefetching
89
+ if isinstance (nodes , models .QuerySet ) and (
90
+ queryset_config := get_queryset_config (nodes )
91
+ ):
92
+ if queryset_config .optimized_by_prefetching :
93
+ try :
94
+ conn = cls .resolve_optimized_connection_by_prefetch (
95
+ nodes ,
96
+ info = info ,
97
+ before = before ,
98
+ after = after ,
99
+ first = first ,
100
+ last = last ,
101
+ ** kwargs ,
102
+ )
103
+ except AttributeError :
104
+ warnings .warn (
105
+ (
106
+ "Pagination annotations not found, falling back to QuerySet resolution. "
107
+ "This might cause N+1 issues..."
108
+ ),
109
+ RuntimeWarning ,
110
+ stacklevel = 2 ,
111
+ )
112
+ else :
113
+ conn = cast ("Self" , conn )
114
+ conn .nodes = nodes
115
+ return conn
46
116
47
- if isinstance (nodes , models .QuerySet ) and is_optimized_by_prefetching (nodes ):
48
- try :
49
- conn = cls .resolve_connection_from_cache (
50
- nodes ,
51
- info = info ,
52
- before = before ,
53
- after = after ,
54
- first = first ,
55
- last = last ,
56
- ** kwargs ,
57
- )
58
- except AttributeError :
59
- warnings .warn (
60
- (
61
- "Pagination annotations not found, falling back to QuerySet resolution. "
62
- "This might cause N+1 issues..."
63
- ),
64
- RuntimeWarning ,
65
- stacklevel = 2 ,
66
- )
67
- else :
68
- conn = cast ("Self" , conn )
69
- conn .nodes = nodes
70
- return conn
117
+ if queryset_config .optimized :
118
+ if (last or 0 ) > 0 and before is None :
119
+ return cls .resolve_optimized_last_connection (
120
+ nodes ,
121
+ info = info ,
122
+ before = before ,
123
+ after = after ,
124
+ first = first ,
125
+ last = last ,
126
+ ** kwargs ,
127
+ )
128
+
129
+ if _should_optimize_total_count (info ):
130
+ nodes = nodes .annotate (
131
+ _strawberry_total_count = models .Window (
132
+ expression = models .Count (1 ), partition_by = None
133
+ )
134
+ )
71
135
72
136
conn = super ().resolve_connection (
73
137
nodes ,
@@ -93,7 +157,7 @@ async def wrapper():
93
157
return conn
94
158
95
159
@classmethod
96
- def resolve_connection_from_cache (
160
+ def resolve_optimized_connection_by_prefetch (
97
161
cls ,
98
162
nodes : NodeIterableType [relay .NodeType ],
99
163
* ,
@@ -147,6 +211,63 @@ def resolve_connection_from_cache(
147
211
),
148
212
)
149
213
214
+ @classmethod
215
+ def resolve_optimized_last_connection (
216
+ cls ,
217
+ nodes : NodeIterableType [relay .NodeType ],
218
+ * ,
219
+ info : Info ,
220
+ before : Optional [str ] = None ,
221
+ after : Optional [str ] = None ,
222
+ first : Optional [int ] = None ,
223
+ last : Optional [int ] = None ,
224
+ ** kwargs : Any ,
225
+ ) -> AwaitableOrValue [Self ]:
226
+ """Resolve the connection being paginated only via `last`.
227
+
228
+ In order to prevent fetching the entire table, QuerySet is first counted & the
229
+ amount is used instead of `before=None`.
230
+ """
231
+ assert isinstance (nodes , models .QuerySet )
232
+
233
+ type_def = get_object_definition (cls )
234
+ assert type_def
235
+ field_def = type_def .get_field ("edges" )
236
+ assert field_def
237
+ field = field_def .resolve_type (type_definition = type_def )
238
+ field = unwrap_type (field )
239
+ edge_class = cast ("relay.Edge[relay.NodeType]" , field )
240
+
241
+ if in_async_context ():
242
+
243
+ async def wrapper ():
244
+ total_count = await nodes .acount ()
245
+ before = relay .to_base64 (edge_class .CURSOR_PREFIX , total_count )
246
+ conn = cls .resolve_connection (
247
+ nodes ,
248
+ info = info ,
249
+ before = before ,
250
+ after = after ,
251
+ first = first ,
252
+ last = last ,
253
+ ** kwargs ,
254
+ )
255
+ return await conn if inspect .isawaitable (conn ) else conn
256
+
257
+ return wrapper ()
258
+
259
+ total_count = nodes .count ()
260
+ before = relay .to_base64 (edge_class .CURSOR_PREFIX , total_count )
261
+ return cls .resolve_connection (
262
+ nodes ,
263
+ info = info ,
264
+ before = before ,
265
+ after = after ,
266
+ first = first ,
267
+ last = last ,
268
+ ** kwargs ,
269
+ )
270
+
150
271
151
272
if TYPE_CHECKING :
152
273
0 commit comments