Skip to content

Commit 7799453

Browse files
refactor: Use a TypedDict to annotate state dictionaries
1 parent 22d4eae commit 7799453

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

singer_sdk/helpers/_state.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
from __future__ import annotations
44

55
import logging
6+
import sys
67
import typing as t
78

89
from singer_sdk.exceptions import InvalidStreamSortException
910
from singer_sdk.helpers._typing import to_json_compatible
1011

12+
if sys.version_info < (3, 10):
13+
from typing_extensions import TypeAlias
14+
else:
15+
from typing import TypeAlias # noqa: ICN003
16+
1117
if t.TYPE_CHECKING:
1218
import datetime
1319

@@ -21,14 +27,25 @@
2127
STARTING_MARKER = "starting_replication_value"
2228

2329
logger = logging.getLogger("singer_sdk")
30+
StreamStateDict: TypeAlias = t.Dict[str, t.Any]
31+
32+
33+
class PartitionsStateDict(t.TypedDict, total=False):
34+
partitions: list[StreamStateDict]
35+
36+
37+
class TapStateDict(t.TypedDict, total=False):
38+
"""State dictionary type."""
39+
40+
bookmarks: dict[str, StreamStateDict | PartitionsStateDict]
2441

2542

2643
def get_state_if_exists(
27-
tap_state: dict,
44+
tap_state: TapStateDict,
2845
tap_stream_id: str,
29-
state_partition_context: dict | None = None,
46+
state_partition_context: dict[str, t.Any] | None = None,
3047
key: str | None = None,
31-
) -> t.Any | None: # noqa: ANN401
48+
) -> StreamStateDict | None:
3249
"""Return the stream or partition state, creating a new one if it does not exist.
3350
3451
Args:
@@ -46,34 +63,49 @@ def get_state_if_exists(
4663
ValueError: Raised if state is invalid or cannot be parsed.
4764
"""
4865
if "bookmarks" not in tap_state:
66+
# Not a valid state, e.g. {}
4967
return None
68+
5069
if tap_stream_id not in tap_state["bookmarks"]:
70+
# Stream not present in state, e.g. {"bookmarks": {}}
5171
return None
5272

73+
# At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}}
74+
75+
# stream_state: {"key": "value", "partitions"?: ...}
5376
stream_state = tap_state["bookmarks"][tap_stream_id]
5477
if not state_partition_context:
55-
return stream_state.get(key, None) if key else stream_state
78+
# Either 'value' if key is specified, or {}
79+
return stream_state.get(key, None) if key else stream_state # type: ignore[return-value]
80+
5681
if "partitions" not in stream_state:
5782
return None # No partitions defined
5883

84+
# stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001
85+
5986
matched_partition = _find_in_partitions_list(
6087
stream_state["partitions"],
6188
state_partition_context,
6289
)
90+
6391
if matched_partition is None:
6492
return None # Partition definition not present
93+
6594
return matched_partition.get(key, None) if key else matched_partition
6695

6796

68-
def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
97+
def get_state_partitions_list(
98+
tap_state: TapStateDict,
99+
tap_stream_id: str,
100+
) -> list[StreamStateDict] | None:
69101
"""Return a list of partitions defined in the state, or None if not defined."""
70102
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return]
71103

72104

73105
def _find_in_partitions_list(
74-
partitions: list[dict],
106+
partitions: list[StreamStateDict],
75107
state_partition_context: types.Context,
76-
) -> dict | None:
108+
) -> StreamStateDict | None:
77109
found = [
78110
partition_state
79111
for partition_state in partitions
@@ -99,10 +131,10 @@ def _create_in_partitions_list(
99131

100132

101133
def get_writeable_state_dict(
102-
tap_state: dict,
134+
tap_state: TapStateDict,
103135
tap_stream_id: str,
104136
state_partition_context: types.Context | None = None,
105-
) -> dict:
137+
) -> StreamStateDict:
106138
"""Return the stream or partition state, creating a new one if it does not exist.
107139
108140
Args:
@@ -125,13 +157,13 @@ def get_writeable_state_dict(
125157
tap_state["bookmarks"] = {}
126158
if tap_stream_id not in tap_state["bookmarks"]:
127159
tap_state["bookmarks"][tap_stream_id] = {}
128-
stream_state = t.cast(dict, tap_state["bookmarks"][tap_stream_id])
160+
stream_state = tap_state["bookmarks"][tap_stream_id]
129161
if not state_partition_context:
130-
return stream_state
162+
return stream_state # type: ignore[return-value]
131163

132164
if "partitions" not in stream_state:
133165
stream_state["partitions"] = []
134-
stream_state_partitions: list[dict] = stream_state["partitions"]
166+
stream_state_partitions: list[StreamStateDict] = stream_state["partitions"]
135167
if found := _find_in_partitions_list(
136168
stream_state_partitions,
137169
state_partition_context,
@@ -142,7 +174,7 @@ def get_writeable_state_dict(
142174

143175

144176
def write_stream_state(
145-
tap_state: dict,
177+
tap_state: TapStateDict,
146178
tap_stream_id: str,
147179
key: str,
148180
val: t.Any, # noqa: ANN401
@@ -158,12 +190,14 @@ def write_stream_state(
158190
state_dict[key] = val
159191

160192

161-
def reset_state_progress_markers(stream_or_partition_state: dict) -> dict | None:
193+
def reset_state_progress_markers(
194+
stream_or_partition_state: StreamStateDict | PartitionsStateDict,
195+
) -> dict | None:
162196
"""Wipe the state once sync is complete.
163197
164198
For logging purposes, return the wiped 'progress_markers' object if it existed.
165199
"""
166-
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {})
200+
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {}) # type: ignore[misc]
167201
# Remove auto-generated human-readable note:
168202
progress_markers.pop(PROGRESS_MARKER_NOTE, None)
169203
# Return remaining 'progress_markers' if any:

singer_sdk/streams/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
from singer_sdk.helpers import types
5656
from singer_sdk.helpers._compat import Traversable
57+
from singer_sdk.helpers._state import TapStateDict
5758
from singer_sdk.tap_base import Tap
5859

5960
# Replication methods
@@ -147,7 +148,7 @@ def __init__(
147148
self._mask: singer.SelectionMask | None = None
148149
self._schema: dict
149150
self._is_state_flushed: bool = True
150-
self._last_emitted_state: dict | None = None
151+
self._last_emitted_state: TapStateDict | None = None
151152
self._sync_costs: dict[str, int] = {}
152153
self.child_streams: list[Stream] = []
153154
if schema:
@@ -645,7 +646,7 @@ def replication_method(self) -> str:
645646
# State properties:
646647

647648
@property
648-
def tap_state(self) -> dict:
649+
def tap_state(self) -> TapStateDict:
649650
"""Return a writeable state dict for the entire tap.
650651
651652
Note: This dictionary is shared (and writable) across all streams.
@@ -790,7 +791,7 @@ def _write_state_message(self) -> None:
790791
if (not self._is_state_flushed) and (
791792
self.tap_state != self._last_emitted_state
792793
):
793-
self._tap.write_message(singer.StateMessage(value=self.tap_state))
794+
self._tap.write_message(singer.StateMessage(value=self.tap_state)) # type: ignore[arg-type]
794795
self._last_emitted_state = copy.deepcopy(self.tap_state)
795796
self._is_state_flushed = True
796797

singer_sdk/tap_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pathlib import PurePath
3434

3535
from singer_sdk.connectors import SQLConnector
36+
from singer_sdk.helpers._state import TapStateDict
3637
from singer_sdk.mapper import PluginMapper
3738
from singer_sdk.streams import SQLStream, Stream
3839

@@ -92,7 +93,7 @@ def __init__(
9293
# Declare private members
9394
self._streams: dict[str, Stream] | None = None
9495
self._input_catalog: Catalog | None = None
95-
self._state: dict[str, Stream] = {}
96+
self._state: TapStateDict = {}
9697
self._catalog: Catalog | None = None # Tap's working catalog
9798

9899
# Process input catalog
@@ -138,7 +139,7 @@ def streams(self) -> dict[str, Stream]:
138139
return self._streams
139140

140141
@property
141-
def state(self) -> dict:
142+
def state(self) -> TapStateDict: # type: ignore[override]
142143
"""Get tap state.
143144
144145
Returns:
@@ -445,7 +446,7 @@ def sync_all(self) -> None:
445446
"""Sync all streams."""
446447
self._reset_state_progress_markers()
447448
self._set_compatible_replication_methods()
448-
self.write_message(StateMessage(value=self.state))
449+
self.write_message(StateMessage(value=self.state)) # type: ignore[arg-type]
449450

450451
stream: Stream
451452
for stream in self.streams.values():

0 commit comments

Comments
 (0)