Skip to content

Commit 30d48c3

Browse files
authored
🐛 Fix: rest paginator iteration error (#212)
1 parent 7ccbfbf commit 30d48c3

File tree

2 files changed

+85
-57
lines changed

2 files changed

+85
-57
lines changed

githubkit/rest/paginator.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Callable,
77
Generic,
88
Optional,
9+
TypedDict,
910
TypeVar,
1011
Union,
1112
cast,
@@ -16,6 +17,7 @@
1617
import httpx
1718

1819
from githubkit.response import Response
20+
from githubkit.typing import HeaderTypes
1921
from githubkit.utils import is_async
2022

2123
if TYPE_CHECKING:
@@ -35,6 +37,12 @@
3537
NEXT_LINK_PATTERN = r'<([^<>]+)>;\s*rel="next"'
3638

3739

40+
class PaginatorState(TypedDict):
41+
next_link: Optional[httpx.URL]
42+
request_method: str
43+
response_model: Any
44+
45+
3846
# https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api
3947
# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts
4048
class Paginator(Generic[RT]):
@@ -76,33 +84,34 @@ def __init__(
7684

7785
self.map_func = map_func
7886

79-
self._initialized: bool = False
80-
self._request_method: Optional[str] = None
81-
self._response_model: Optional[Any] = None
82-
self._next_link: Optional[httpx.URL] = None
87+
self._state: Optional[PaginatorState] = None
8388

8489
self._index: int = 0
8590
self._cached_data: list[RT] = []
8691

8792
@property
8893
def finalized(self) -> bool:
8994
"""Whether the paginator is finalized or not."""
90-
return self._initialized and self._next_link is None
95+
return (self._state["next_link"] is None) if self._state is not None else False
96+
97+
@property
98+
def _headers(self) -> Optional[HeaderTypes]:
99+
return self.kwargs.get("headers") # type: ignore
91100

92101
def reset(self) -> None:
93102
"""Reset the paginator to the initial state."""
94103

95-
self._initialized = False
96-
self._next_link = None
104+
self._state = None
97105
self._index = 0
98106
self._cached_data = []
99107

100108
def __next__(self) -> RT:
101109
while self._index >= len(self._cached_data):
102-
self._get_next_page()
103110
if self.finalized:
104111
raise StopIteration
105112

113+
self._get_next_page()
114+
106115
current = self._cached_data[self._index]
107116
self._index += 1
108117
return current
@@ -114,10 +123,11 @@ def __iter__(self: Self) -> Self:
114123

115124
async def __anext__(self) -> RT:
116125
while self._index >= len(self._cached_data):
117-
await self._aget_next_page()
118126
if self.finalized:
119127
raise StopAsyncIteration
120128

129+
await self._aget_next_page()
130+
121131
current = self._cached_data[self._index]
122132
self._index += 1
123133
return current
@@ -151,64 +161,56 @@ def _fill_cache_data(self, data: list[RT]) -> None:
151161
self._index = 0
152162

153163
def _get_next_page(self) -> None:
154-
if not self._initialized:
164+
if self._state is None:
155165
# First request
156-
response = cast(
157-
Response[Any],
158-
self.request(*self.args, **self.kwargs),
159-
)
160-
self._initialized = True
161-
self._request_method = response.raw_request.method
166+
response = cast(Response[Any], self.request(*self.args, **self.kwargs))
162167
else:
163-
# Next request
164-
if self._next_link is None:
165-
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
166-
if self._request_method is None:
167-
raise RuntimeError("Request method is not set, this should not happen.")
168-
if self._response_model is None:
169-
raise RuntimeError("Response model is not set, this should not happen.")
170-
171168
# we request the next page with the same method and response model
169+
if self._state["next_link"] is None:
170+
raise RuntimeError("No next page to request")
171+
172172
response = cast(
173173
Response[Any],
174174
self.rest._github.request(
175-
self._request_method,
176-
self._next_link,
177-
headers=self.kwargs.get("headers"), # type: ignore
178-
response_model=self._response_model, # type: ignore
175+
self._state["request_method"],
176+
self._state["next_link"],
177+
headers=self._headers, # type: ignore
178+
response_model=self._state["response_model"], # type: ignore
179179
),
180180
)
181181

182-
self._next_link = self._find_next_link(response)
182+
self._state = PaginatorState(
183+
next_link=self._find_next_link(response),
184+
request_method=response.raw_request.method,
185+
response_model=response._data_model,
186+
)
183187
self._fill_cache_data(self._apply_map_func(response))
184188

185189
async def _aget_next_page(self) -> None:
186-
if not self._initialized:
190+
if self._state is None:
187191
# First request
188192
response = cast(
189193
Response[Any],
190194
await self.request(*self.args, **self.kwargs), # type: ignore
191195
)
192-
self._initialized = True
193-
self._request_method = response.raw_request.method
194196
else:
195-
# Next request
196-
if self._next_link is None:
197-
raise RuntimeError("Paginator is finalized, no more pages to fetch.")
198-
if self._request_method is None:
199-
raise RuntimeError("Request method is not set, this should not happen.")
200-
if self._response_model is None:
201-
raise RuntimeError("Response model is not set, this should not happen.")
197+
# we request the next page with the same method and response model
198+
if self._state["next_link"] is None:
199+
raise RuntimeError("No next page to request")
202200

203201
response = cast(
204202
Response[Any],
205-
await self.rest._github.request(
206-
self._request_method,
207-
self._next_link,
208-
headers=self.kwargs.get("headers"), # type: ignore
209-
response_model=self._response_model, # type: ignore
203+
await self.rest._github.arequest(
204+
self._state["request_method"],
205+
self._state["next_link"],
206+
headers=self._headers, # type: ignore
207+
response_model=self._state["response_model"], # type: ignore
210208
),
211209
)
212210

213-
self._next_link = self._find_next_link(response)
211+
self._state = PaginatorState(
212+
next_link=self._find_next_link(response),
213+
request_method=response.raw_request.method,
214+
response_model=response._data_model,
215+
)
214216
self._fill_cache_data(self._apply_map_func(response))

tests/test_rest/test_call.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44

55
from githubkit import GitHub
66
from githubkit.versions import LATEST_VERSION
7-
from githubkit.versions.latest.models import FullRepository
7+
from githubkit.versions.latest.models import FullRepository, Issue
88

99
OWNER = "yanyongyu"
1010
REPO = "githubkit"
11+
ISSUE_COUNT_QUERY = """
12+
query($owner: String!, $repo: String!) {
13+
repository(owner: $owner, name: $repo) {
14+
issues {
15+
totalCount
16+
}
17+
}
18+
}
19+
"""
1120

1221

1322
def test_call(g: GitHub):
@@ -56,34 +65,51 @@ async def test_async_call_with_raw_body(g: GitHub):
5665

5766
def test_paginate(g: GitHub):
5867
paginator = g.rest.paginate(
59-
g.rest.issues.list_for_repo, owner=OWNER, repo=REPO, per_page=50
68+
g.rest.issues.list_for_repo, owner=OWNER, repo=REPO, state="all", per_page=50
6069
)
61-
for _ in paginator:
62-
...
70+
count = 0
71+
for issue in paginator:
72+
assert isinstance(issue, Issue)
73+
if not issue.pull_request:
74+
count += 1
75+
76+
result = g.graphql.request(ISSUE_COUNT_QUERY, {"owner": OWNER, "repo": REPO})
77+
assert result["repository"]["issues"]["totalCount"] == count
6378

6479

6580
def test_paginate_with_partial(g: GitHub):
6681
paginator = g.rest.paginate(
67-
partial(g.rest.issues.list_for_repo, OWNER, REPO), per_page=50
82+
partial(g.rest.issues.list_for_repo, OWNER, REPO), state="all", per_page=50
6883
)
69-
for _ in paginator:
70-
...
84+
for issue in paginator:
85+
assert isinstance(issue, Issue)
7186

7287

7388
@pytest.mark.anyio
7489
async def test_async_paginate(g: GitHub):
7590
paginator = g.rest.paginate(
76-
g.rest.issues.async_list_for_repo, owner=OWNER, repo=REPO, per_page=50
91+
g.rest.issues.async_list_for_repo,
92+
owner=OWNER,
93+
repo=REPO,
94+
state="all",
95+
per_page=50,
7796
)
78-
async for _ in paginator:
79-
...
97+
count = 0
98+
async for issue in paginator:
99+
assert isinstance(issue, Issue)
100+
if not issue.pull_request:
101+
count += 1
102+
103+
result = g.graphql.request(ISSUE_COUNT_QUERY, {"owner": OWNER, "repo": REPO})
104+
assert result["repository"]["issues"]["totalCount"] == count
80105

81106

82107
@pytest.mark.anyio
83108
async def test_async_paginate_with_partial(g: GitHub):
84109
paginator = g.rest.paginate(
85110
partial(g.rest.issues.async_list_for_repo, OWNER, REPO),
111+
state="all",
86112
per_page=50,
87113
)
88-
async for _ in paginator:
89-
...
114+
async for issue in paginator:
115+
assert isinstance(issue, Issue)

0 commit comments

Comments
 (0)