From fc0f03bb06b29c019b4dd7d5a4eea2221f6758db Mon Sep 17 00:00:00 2001 From: M Aswin Kishore <60577077+mak626@users.noreply.github.com> Date: Thu, 21 Dec 2023 19:37:51 +0530 Subject: [PATCH] fix: pagination errors hasNextPage didn't become false when using after and first together Fixed reverse Querying using last and before, in compliance to graphql relay spec https://relay.dev/graphql/connections.htm#sec-Backward-pagination-arguments https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo --- graphene_mongo/fields.py | 114 ++++++++++----------------------- graphene_mongo/fields_async.py | 52 ++++++--------- graphene_mongo/utils.py | 16 ++--- 3 files changed, 61 insertions(+), 121 deletions(-) diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index f7a448f..5ba8317 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -39,6 +39,7 @@ find_skip_and_limit, get_model_reference_fields, get_query_fields, + has_page_info, ) PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -276,7 +277,7 @@ def fields(self): return self._type._meta.fields def get_queryset( - self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args + self, model, info, required_fields=None, skip=None, limit=None, **args ) -> QuerySet: if required_fields is None: required_fields = list() @@ -325,49 +326,22 @@ def get_queryset( else: args.update(queryset_or_filters) if limit is not None: - if reversed: - if self.order_by: - order_by = self.order_by + ",-pk" - else: - order_by = "-pk" - return ( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(order_by) - .skip(skip if skip else 0) - .limit(limit) - ) - else: - return ( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(self.order_by) - .skip(skip if skip else 0) - .limit(limit) - ) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip(skip if skip else 0) + .limit(limit) + ) elif skip is not None: - if reversed: - if self.order_by: - order_by = self.order_by + ",-pk" - else: - order_by = "-pk" - return ( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(order_by) - .skip(skip) - ) - else: - return ( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(self.order_by) - .skip(skip) - ) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip(skip) + ) return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by) def default_resolver(self, _root, info, required_fields=None, resolved=None, **args): @@ -401,7 +375,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a skip = 0 count = 0 limit = None - reverse = False first = args.pop("first", None) after = args.pop("after", None) if after: @@ -410,6 +383,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a before = args.pop("before", None) if before: before = cursor_to_offset(before) + requires_page_info = has_page_info(info) has_next_page = False if resolved is not None: @@ -417,7 +391,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a if isinstance(items, QuerySet): try: - if last is not None and after is not None: + if last is not None: count = items.count(with_limit_and_skip=False) else: count = None @@ -426,29 +400,24 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a else: count = len(items) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if isinstance(items, QuerySet): if limit: - _base_query: QuerySet = ( - items.order_by("-pk").skip(skip) if reverse else items.skip(skip) - ) + _base_query: QuerySet = items.skip(skip) items = _base_query.limit(limit) - has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0 + has_next_page = len(_base_query.skip(skip + limit).only("id").limit(1)) != 0 elif skip: items = items.skip(skip) else: if limit: - if reverse: - _base_query = items[::-1] - items = _base_query[skip : skip + limit] - has_next_page = (skip + limit) < len(_base_query) - else: - _base_query = items - items = items[skip : skip + limit] - has_next_page = (skip + limit) < len(_base_query) + _base_query = items + items = items[skip : skip + limit] + has_next_page = ( + (skip + limit) < len(_base_query) if requires_page_info else False + ) elif skip: items = items[skip:] iterables = list(items) @@ -503,11 +472,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a else: count = self.model.objects(args_copy).count() if count != 0: - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, after=after, last=last, before=before, count=count ) iterables = self.get_queryset( - self.model, info, required_fields, skip, limit, reverse, **args + self.model, info, required_fields, skip, limit, **args ) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): @@ -519,14 +488,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a elif "pk__in" in args and args["pk__in"]: count = len(args["pk__in"]) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if limit: - if reverse: - args["pk__in"] = args["pk__in"][::-1][skip : skip + limit] - else: - args["pk__in"] = args["pk__in"][skip : skip + limit] + args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) @@ -542,18 +508,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a field_name = to_snake_case(info.field_name) items = getattr(_root, field_name, []) count = len(items) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if limit: - if reverse: - _base_query = items[::-1] - items = _base_query[skip : skip + limit] - has_next_page = (skip + limit) < len(_base_query) - else: - _base_query = items - items = items[skip : skip + limit] - has_next_page = (skip + limit) < len(_base_query) + _base_query = items + items = items[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False elif skip: items = items[skip:] iterables = items @@ -567,11 +528,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a ) has_previous_page = True if skip else False - if reverse: - iterables = list(iterables) - iterables.reverse() - skip = limit - connection = connection_from_iterables( edges=iterables, start_offset=skip, diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index dcfe8cf..bbbd96a 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -92,7 +92,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non skip = 0 count = 0 limit = None - reverse = False first = args.pop("first", None) after = args.pop("after", None) if after: @@ -109,7 +108,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if isinstance(items, QuerySet): try: - if last is not None and after is not None: + if last is not None: count = await sync_to_async(items.count)(with_limit_and_skip=False) else: count = None @@ -118,20 +117,21 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non else: count = len(items) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if isinstance(items, QuerySet): if limit: - _base_query: QuerySet = ( - await sync_to_async(items.order_by("-pk").skip)(skip) - if reverse - else await sync_to_async(items.skip)(skip) - ) + _base_query: QuerySet = await sync_to_async(items.skip)(skip) items = await sync_to_async(_base_query.limit)(limit) has_next_page = ( - (await sync_to_async(len)(_base_query.skip(limit).only("id").limit(1)) != 0) + ( + await sync_to_async(len)( + _base_query.skip(skip + limit).only("id").limit(1) + ) + != 0 + ) if requires_page_info else False ) @@ -139,12 +139,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non items = await sync_to_async(items.skip)(skip) else: if limit: - if reverse: - _base_query = items[::-1] - items = _base_query[skip : skip + limit] - else: - _base_query = items - items = items[skip : skip + limit] + _base_query = items + items = items[skip : skip + limit] has_next_page = ( (skip + limit) < len(_base_query) if requires_page_info else False ) @@ -195,11 +191,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non else: count = await sync_to_async(self.model.objects(args_copy).count)() if count != 0: - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, after=after, last=last, before=before, count=count ) iterables = self.get_queryset( - self.model, info, required_fields, skip, limit, reverse, **args + self.model, info, required_fields, skip, limit, **args ) iterables = await sync_to_async(list)(iterables) list_length = len(iterables) @@ -212,14 +208,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non elif "pk__in" in args and args["pk__in"]: count = len(args["pk__in"]) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if limit: - if reverse: - args["pk__in"] = args["pk__in"][::-1][skip : skip + limit] - else: - args["pk__in"] = args["pk__in"][skip : skip + limit] + args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) @@ -236,16 +229,12 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non field_name = to_snake_case(info.field_name) items = getattr(_root, field_name, []) count = len(items) - skip, limit, reverse = find_skip_and_limit( + skip, limit = find_skip_and_limit( first=first, last=last, after=after, before=before, count=count ) if limit: - if reverse: - _base_query = items[::-1] - items = _base_query[skip : skip + limit] - else: - _base_query = items - items = items[skip : skip + limit] + _base_query = items + items = items[skip : skip + limit] has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False elif skip: items = items[skip:] @@ -261,11 +250,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non ) has_previous_page = True if requires_page_info and skip else False - if reverse: - iterables = await sync_to_async(list)(iterables) - iterables.reverse() - skip = limit - connection = connection_from_iterables( edges=iterables, start_offset=skip, diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index ad18933..87699b6 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -259,9 +259,12 @@ def ast_to_dict(node, include_loc=False): def find_skip_and_limit(first, last, after, before, count=None): - reverse = False skip = 0 limit = None + + if last is not None and count is None: + raise ValueError("Count Missing") + if first is not None and after is not None: skip = after + 1 limit = first @@ -274,29 +277,26 @@ def find_skip_and_limit(first, last, after, before, count=None): skip = 0 limit = first elif last is not None and before is not None: - reverse = False if last >= before: limit = before else: limit = last skip = before - last elif last is not None and after is not None: - if not count: - raise ValueError("Count Missing") - reverse = True + skip = after + 1 if last + after < count: limit = last else: limit = count - after - 1 elif last is not None: - skip = 0 + skip = count - last limit = last - reverse = True elif after is not None: skip = after + 1 elif before is not None: limit = before - return skip, limit, reverse + + return skip, limit def connection_from_iterables(