From 736742cea88a409a093e4d930528779cf1338043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= Date: Thu, 21 Dec 2023 20:54:13 -0600 Subject: [PATCH] refactor: Use a `TypedDict` to annotate state dictionaries --- singer_sdk/helpers/_state.py | 66 +++++++++++++++++++++++++++--------- singer_sdk/streams/core.py | 7 ++-- singer_sdk/tap_base.py | 7 ++-- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/singer_sdk/helpers/_state.py b/singer_sdk/helpers/_state.py index ed3d345eb0..40de9909d4 100644 --- a/singer_sdk/helpers/_state.py +++ b/singer_sdk/helpers/_state.py @@ -3,11 +3,17 @@ from __future__ import annotations import logging +import sys import typing as t from singer_sdk.exceptions import InvalidStreamSortException from singer_sdk.helpers._typing import to_json_compatible +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias # noqa: ICN003 + if t.TYPE_CHECKING: import datetime @@ -19,14 +25,25 @@ STARTING_MARKER = "starting_replication_value" logger = logging.getLogger("singer_sdk") +StreamStateDict: TypeAlias = t.Dict[str, t.Any] + + +class PartitionsStateDict(t.TypedDict, total=False): + partitions: list[StreamStateDict] + + +class TapStateDict(t.TypedDict, total=False): + """State dictionary type.""" + + bookmarks: dict[str, StreamStateDict | PartitionsStateDict] def get_state_if_exists( - tap_state: dict, + tap_state: TapStateDict, tap_stream_id: str, - state_partition_context: dict | None = None, + state_partition_context: dict[str, t.Any] | None = None, key: str | None = None, -) -> t.Any | None: # noqa: ANN401 +) -> StreamStateDict | None: """Return the stream or partition state, creating a new one if it does not exist. Args: @@ -44,34 +61,49 @@ def get_state_if_exists( ValueError: Raised if state is invalid or cannot be parsed. """ if "bookmarks" not in tap_state: + # Not a valid state, e.g. {} return None + if tap_stream_id not in tap_state["bookmarks"]: + # Stream not present in state, e.g. {"bookmarks": {}} return None + # At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}} + + # stream_state: {"key": "value", "partitions"?: ...} stream_state = tap_state["bookmarks"][tap_stream_id] if not state_partition_context: - return stream_state.get(key, None) if key else stream_state + # Either 'value' if key is specified, or {} + return stream_state.get(key, None) if key else stream_state # type: ignore[return-value] + if "partitions" not in stream_state: return None # No partitions defined + # stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001 + matched_partition = _find_in_partitions_list( stream_state["partitions"], state_partition_context, ) + if matched_partition is None: return None # Partition definition not present + return matched_partition.get(key, None) if key else matched_partition -def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None: +def get_state_partitions_list( + tap_state: TapStateDict, + tap_stream_id: str, +) -> list[StreamStateDict] | None: """Return a list of partitions defined in the state, or None if not defined.""" return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return] def _find_in_partitions_list( - partitions: list[dict], - state_partition_context: dict, -) -> dict | None: + partitions: list[StreamStateDict], + state_partition_context: dict[str, t.Any], +) -> StreamStateDict | None: found = [ partition_state for partition_state in partitions @@ -97,10 +129,10 @@ def _create_in_partitions_list( def get_writeable_state_dict( - tap_state: dict, + tap_state: TapStateDict, tap_stream_id: str, state_partition_context: dict | None = None, -) -> dict: +) -> StreamStateDict: """Return the stream or partition state, creating a new one if it does not exist. Args: @@ -123,13 +155,13 @@ def get_writeable_state_dict( tap_state["bookmarks"] = {} if tap_stream_id not in tap_state["bookmarks"]: tap_state["bookmarks"][tap_stream_id] = {} - stream_state = t.cast(dict, tap_state["bookmarks"][tap_stream_id]) + stream_state = tap_state["bookmarks"][tap_stream_id] if not state_partition_context: - return stream_state + return stream_state # type: ignore[return-value] if "partitions" not in stream_state: stream_state["partitions"] = [] - stream_state_partitions: list[dict] = stream_state["partitions"] + stream_state_partitions: list[StreamStateDict] = stream_state["partitions"] if found := _find_in_partitions_list( stream_state_partitions, state_partition_context, @@ -140,7 +172,7 @@ def get_writeable_state_dict( def write_stream_state( - tap_state: dict, + tap_state: TapStateDict, tap_stream_id: str, key: str, val: t.Any, # noqa: ANN401 @@ -156,12 +188,14 @@ def write_stream_state( state_dict[key] = val -def reset_state_progress_markers(stream_or_partition_state: dict) -> dict | None: +def reset_state_progress_markers( + stream_or_partition_state: StreamStateDict | PartitionsStateDict, +) -> dict | None: """Wipe the state once sync is complete. For logging purposes, return the wiped 'progress_markers' object if it existed. """ - progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {}) + progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {}) # type: ignore[misc] # Remove auto-generated human-readable note: progress_markers.pop(PROGRESS_MARKER_NOTE, None) # Return remaining 'progress_markers' if any: diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 886466c5d2..b98e7074b3 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -60,6 +60,7 @@ import logging from singer_sdk.helpers._compat import Traversable + from singer_sdk.helpers._state import TapStateDict from singer_sdk.tap_base import Tap # Replication methods @@ -149,7 +150,7 @@ def __init__( self._mask: singer.SelectionMask | None = None self._schema: dict self._is_state_flushed: bool = True - self._last_emitted_state: dict | None = None + self._last_emitted_state: TapStateDict | None = None self._sync_costs: dict[str, int] = {} self.child_streams: list[Stream] = [] if schema: @@ -638,7 +639,7 @@ def replication_method(self) -> str: # State properties: @property - def tap_state(self) -> dict: + def tap_state(self) -> TapStateDict: """Return a writeable state dict for the entire tap. Note: This dictionary is shared (and writable) across all streams. @@ -783,7 +784,7 @@ def _write_state_message(self) -> None: if (not self._is_state_flushed) and ( self.tap_state != self._last_emitted_state ): - self._tap.write_message(singer.StateMessage(value=self.tap_state)) + self._tap.write_message(singer.StateMessage(value=self.tap_state)) # type: ignore[arg-type] self._last_emitted_state = copy.deepcopy(self.tap_state) self._is_state_flushed = True diff --git a/singer_sdk/tap_base.py b/singer_sdk/tap_base.py index d8fb75a8ff..f49243b657 100644 --- a/singer_sdk/tap_base.py +++ b/singer_sdk/tap_base.py @@ -34,6 +34,7 @@ from pathlib import PurePath from singer_sdk.connectors import SQLConnector + from singer_sdk.helpers._state import TapStateDict from singer_sdk.mapper import PluginMapper from singer_sdk.streams import SQLStream, Stream @@ -93,7 +94,7 @@ def __init__( # Declare private members self._streams: dict[str, Stream] | None = None self._input_catalog: Catalog | None = None - self._state: dict[str, Stream] = {} + self._state: TapStateDict = {} self._catalog: Catalog | None = None # Tap's working catalog # Process input catalog @@ -139,7 +140,7 @@ def streams(self) -> dict[str, Stream]: return self._streams @property - def state(self) -> dict: + def state(self) -> TapStateDict: # type: ignore[override] """Get tap state. Returns: @@ -446,7 +447,7 @@ def sync_all(self) -> None: """Sync all streams.""" self._reset_state_progress_markers() self._set_compatible_replication_methods() - self.write_message(StateMessage(value=self.state)) + self.write_message(StateMessage(value=self.state)) # type: ignore[arg-type] stream: Stream for stream in self.streams.values():