Skip to content

Commit 0b3a36b

Browse files
authored
Add AsyncConcatenateIterator (#67)
* Add AsyncConcatenateIterator * Add Async test runners * Fix recognition of AsyncIterator
1 parent 8b98e46 commit 0b3a36b

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

python/cog/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import coglet
22
from coglet.api import (
3+
AsyncConcatenateIterator,
34
BaseModel,
45
BasePredictor,
56
CancelationException,
@@ -14,6 +15,7 @@
1415
__version__ = coglet.__version__
1516

1617
__all__ = [
18+
'AsyncConcatenateIterator',
1719
'BaseModel',
1820
'BasePredictor',
1921
'CancelationException',

python/coglet/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pathlib
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
4-
from typing import Any, Iterator, List, Optional, Type, TypeVar, Union
4+
from typing import Any, AsyncIterator, Iterator, List, Optional, Type, TypeVar, Union
55

66
########################################
77
# Custom encoding
@@ -73,6 +73,11 @@ class ConcatenateIterator(Iterator[_T_co]):
7373
def __next__(self) -> _T_co: ...
7474

7575

76+
class AsyncConcatenateIterator(AsyncIterator[_T_co]):
77+
@abstractmethod
78+
async def __anext__(self) -> _T_co: ...
79+
80+
7681
########################################
7782
# Input, Output
7883
########################################

python/coglet/inspector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing
55
import warnings
66
from types import ModuleType
7-
from typing import Any, Callable, Dict, Iterator, Optional, Type
7+
from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, Type
88

99
from coglet import adt, api
1010

@@ -139,13 +139,13 @@ def _output_adt(tpe: type) -> adt.Output:
139139
origin = typing.get_origin(tpe)
140140
kind = None
141141
ft = None
142-
if origin is typing.get_origin(Iterator):
142+
if origin in {typing.get_origin(Iterator), typing.get_origin(AsyncIterator)}:
143143
kind = adt.Kind.ITERATOR
144144
t_args = typing.get_args(tpe)
145145
assert len(t_args) == 1, 'iterator type must have one type argument'
146146
ft = adt.FieldType.from_type(t_args[0])
147147
assert ft.repetition is adt.Repetition.REQUIRED
148-
elif origin is api.ConcatenateIterator:
148+
elif origin in {api.ConcatenateIterator, api.AsyncConcatenateIterator}:
149149
kind = adt.Kind.CONCAT_ITERATOR
150150
t_args = typing.get_args(tpe)
151151
assert len(t_args) == 1, 'iterator type must have one type argument'
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import asyncio
2+
3+
from cog import AsyncConcatenateIterator, BasePredictor
4+
5+
6+
class Predictor(BasePredictor):
7+
test_inputs = {'i': 3, 's': 'foo'}
8+
9+
async def predict(self, i: int, s: str) -> AsyncConcatenateIterator[str]:
10+
await asyncio.sleep(0.1)
11+
print('starting prediction')
12+
if i > 0:
13+
await asyncio.sleep(0.6)
14+
for x in range(i):
15+
print(f'prediction in progress {x + 1}/{i}')
16+
await asyncio.sleep(0.6)
17+
yield f'*{s}-{x}*'
18+
await asyncio.sleep(0.6)
19+
print('completed prediction')

python/tests/runners/async_iterator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import asyncio
2-
from typing import Iterator
2+
from typing import AsyncIterator
33

44
from cog import BasePredictor
55

66

77
class Predictor(BasePredictor):
88
test_inputs = {'i': 3, 's': 'foo'}
99

10-
async def predict(self, i: int, s: str) -> Iterator[str]:
10+
async def predict(self, i: int, s: str) -> AsyncIterator[str]:
1111
await asyncio.sleep(0.1)
1212
print('starting prediction')
1313
if i > 0:

0 commit comments

Comments
 (0)