Skip to content

Commit 381958a

Browse files
authored
Feature/async checkpoints (#22)
* Fix healthcheck port. * Fix import slipstream without installing extra cache dependency. * Allow generic parameters for AsyncCallable. * Added helper to handle potential awaitable/non-awaitables. * Add test.
1 parent d257597 commit 381958a

File tree

6 files changed

+83
-18
lines changed

6 files changed

+83
-18
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ services:
2626
"kafka-topics",
2727
"--list",
2828
"--bootstrap-server",
29-
"localhost:9091",
29+
"localhost:29091",
3030
]
3131
interval: 5s
3232
timeout: 10s

slipstream/caching.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66
from contextlib import asynccontextmanager
77
from pathlib import Path
88
from types import TracebackType
9-
from typing import (
10-
Any,
11-
TypeVar,
12-
)
13-
14-
from rocksdict import WriteBatch
9+
from typing import Any, TypeVar
1510

1611
from slipstream.interfaces import ICache, Key
1712

@@ -49,6 +44,7 @@
4944
RdictIter,
5045
ReadOptions,
5146
Snapshot,
47+
WriteBatch,
5248
WriteOptions,
5349
)
5450
from rocksdict.rocksdict import (

slipstream/checkpointing.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import logging
44
from collections.abc import AsyncIterable, Callable, Generator
55
from datetime import datetime, timedelta
6-
from typing import (
7-
Any,
8-
)
6+
from typing import Any
97

108
from slipstream.core import Conf, Signal
119
from slipstream.interfaces import ICache
12-
from slipstream.utils import iscoroutinecallable
10+
from slipstream.utils import AsyncCallable, awaitable, iscoroutinecallable
1311

1412
_logger = logging.getLogger(__name__)
1513

@@ -35,12 +33,16 @@ class Dependency:
3533
"""
3634

3735
@property
38-
def downtime_check(self) -> Callable[['Checkpoint', 'Dependency'], Any]:
36+
def downtime_check(
37+
self,
38+
) -> AsyncCallable[['Checkpoint', 'Dependency'], Any]:
3939
"""Is called when downtime is detected."""
4040
return self._downtime_check
4141

4242
@property
43-
def recovery_check(self) -> Callable[['Checkpoint', 'Dependency'], bool]:
43+
def recovery_check(
44+
self,
45+
) -> AsyncCallable[['Checkpoint', 'Dependency'], bool]:
4446
"""Is called when downtime is resolved."""
4547
return self._recovery_check
4648

@@ -49,9 +51,9 @@ def __init__(
4951
name: str,
5052
dependency: AsyncIterable[Any],
5153
downtime_threshold: Any = timedelta(minutes=10),
52-
downtime_check: Callable[['Checkpoint', 'Dependency'], Any]
54+
downtime_check: AsyncCallable[['Checkpoint', 'Dependency'], Any]
5355
| None = None,
54-
recovery_check: Callable[['Checkpoint', 'Dependency'], bool]
56+
recovery_check: AsyncCallable[['Checkpoint', 'Dependency'], bool]
5557
| None = None,
5658
) -> None:
5759
"""Initialize dependency for checkpointing."""
@@ -286,7 +288,7 @@ async def heartbeat(
286288
self._save_checkpoint(dependency, self.state, marker)
287289

288290
if dependency.is_down:
289-
if dependency.recovery_check(self, dependency):
291+
if await awaitable(dependency.recovery_check(self, dependency)):
290292
dependency.is_down = False
291293

292294
if not any(_.is_down for _ in self.dependencies.values()):
@@ -342,7 +344,9 @@ async def check_pulse(
342344

343345
# Trigger on the first dependency that is down and
344346
# pause the dependent stream
345-
if downtime := dependency.downtime_check(self, dependency):
347+
if downtime := await awaitable(
348+
dependency.downtime_check(self, dependency)
349+
):
346350
log_msg = (
347351
f'Downtime of dependency "{dependency.name}" detected'
348352
)

slipstream/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from typing import (
88
Any,
99
ClassVar,
10+
ParamSpec,
1011
TypeAlias,
1112
TypeVar,
1213
)
1314

1415
T = TypeVar('T')
16+
P = ParamSpec('P')
1517

16-
AsyncCallable: TypeAlias = Callable[..., Awaitable[Any]] | Callable[..., Any]
18+
AsyncCallable: TypeAlias = Callable[P, T | Awaitable[T]]
1719
Pipe: TypeAlias = Callable[[AsyncIterable[Any]], AsyncIterable[Any]]
1820

1921

@@ -31,6 +33,11 @@ class Signal(Enum):
3133
STOP = 3
3234

3335

36+
async def awaitable(x: Any) -> Any:
37+
"""Convert into awaitable."""
38+
return await x if isinstance(x, Awaitable) else x
39+
40+
3441
def iscoroutinecallable(o: Any) -> bool:
3542
"""Check whether object is coroutine."""
3643
call = o.__call__ if callable(o) else None # type: ignore[attr-defined]

tests/test_checkpointing.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,58 @@ async def test_custom_callbacks(is_async, checkpoint, mocker):
268268
)
269269

270270

271+
@pytest.mark.parametrize('is_async', [True, False])
272+
@pytest.mark.asyncio
273+
async def test_custom_checks(is_async, mock_cache, mocker):
274+
"""Check custom check functions called."""
275+
if is_async:
276+
downtime_check = mocker.AsyncMock(return_value=timedelta(hours=1))
277+
recovery_check = mocker.AsyncMock(return_value=timedelta(hours=1))
278+
else:
279+
downtime_check = mocker.Mock(return_value=timedelta(hours=1))
280+
recovery_check = mocker.Mock(return_value=timedelta(hours=1))
281+
282+
async def messages():
283+
yield {
284+
'event_timestamp': datetime(2025, 1, 1, 10, tzinfo=UTC),
285+
}
286+
287+
dependency = Dependency(
288+
'dependency',
289+
messages(),
290+
downtime_check=downtime_check,
291+
recovery_check=recovery_check,
292+
)
293+
294+
async def dependent():
295+
yield {
296+
'event_timestamp': datetime(2025, 1, 1, 10, tzinfo=UTC),
297+
}
298+
299+
checkpoint = Checkpoint(
300+
'test', dependent(), [dependency], cache=mock_cache
301+
)
302+
303+
# Trigger downtime
304+
await checkpoint.check_pulse(
305+
datetime(2025, 1, 1, 10, tzinfo=UTC),
306+
state={'offset': 0},
307+
)
308+
assert dependency.is_down is True
309+
await checkpoint.check_pulse(
310+
datetime(2025, 1, 1, 11, tzinfo=UTC),
311+
state={'offset': 1},
312+
)
313+
downtime_check.assert_called()
314+
315+
# Trigger recovery
316+
await checkpoint.heartbeat(
317+
datetime(2025, 1, 1, 11, 1, tzinfo=UTC),
318+
)
319+
recovery_check.assert_called()
320+
assert dependency.is_down is False
321+
322+
271323
def test_repr(checkpoint):
272324
"""Should print representation without crashing."""
273325
assert str(checkpoint)

tests/test_core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ async def test_conf(mocker: MockerFixture):
127127
assert c.__getattr__('group.id') == 'test'
128128
assert c.iterables == {}
129129

130+
# Missing prop
131+
with pytest.raises(
132+
AttributeError, match='object has no attribute "missing_prop"'
133+
):
134+
assert c.missing_prop
135+
130136
# Register iterable
131137
iterable = emoji()
132138
iterable_key = str(id(iterable))

0 commit comments

Comments
 (0)