Skip to content

Commit

Permalink
refactor: Use a TypedDict to annotate state dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Aug 9, 2024
1 parent 22d4eae commit 7799453
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 21 deletions.
64 changes: 49 additions & 15 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from __future__ import annotations

import logging
import sys
import typing as t

from singer_sdk.exceptions import InvalidStreamSortException
from singer_sdk.helpers._typing import to_json_compatible

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

if t.TYPE_CHECKING:
import datetime

Expand All @@ -21,14 +27,25 @@
STARTING_MARKER = "starting_replication_value"

logger = logging.getLogger("singer_sdk")
StreamStateDict: TypeAlias = t.Dict[str, t.Any]


class PartitionsStateDict(t.TypedDict, total=False):
partitions: list[StreamStateDict]


class TapStateDict(t.TypedDict, total=False):
"""State dictionary type."""

bookmarks: dict[str, StreamStateDict | PartitionsStateDict]


def get_state_if_exists(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
state_partition_context: dict | None = None,
state_partition_context: dict[str, t.Any] | None = None,
key: str | None = None,
) -> t.Any | None: # noqa: ANN401
) -> StreamStateDict | None:
"""Return the stream or partition state, creating a new one if it does not exist.
Args:
Expand All @@ -46,34 +63,49 @@ def get_state_if_exists(
ValueError: Raised if state is invalid or cannot be parsed.
"""
if "bookmarks" not in tap_state:
# Not a valid state, e.g. {}
return None

if tap_stream_id not in tap_state["bookmarks"]:
# Stream not present in state, e.g. {"bookmarks": {}}
return None

# At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}}

# stream_state: {"key": "value", "partitions"?: ...}
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state.get(key, None) if key else stream_state
# Either 'value' if key is specified, or {}
return stream_state.get(key, None) if key else stream_state # type: ignore[return-value]

if "partitions" not in stream_state:
return None # No partitions defined

# stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001

matched_partition = _find_in_partitions_list(
stream_state["partitions"],
state_partition_context,
)

if matched_partition is None:
return None # Partition definition not present

return matched_partition.get(key, None) if key else matched_partition


def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
def get_state_partitions_list(
tap_state: TapStateDict,
tap_stream_id: str,
) -> list[StreamStateDict] | None:
"""Return a list of partitions defined in the state, or None if not defined."""
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return]


def _find_in_partitions_list(
partitions: list[dict],
partitions: list[StreamStateDict],
state_partition_context: types.Context,
) -> dict | None:
) -> StreamStateDict | None:
found = [
partition_state
for partition_state in partitions
Expand All @@ -99,10 +131,10 @@ def _create_in_partitions_list(


def get_writeable_state_dict(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
state_partition_context: types.Context | None = None,
) -> dict:
) -> StreamStateDict:
"""Return the stream or partition state, creating a new one if it does not exist.
Args:
Expand All @@ -125,13 +157,13 @@ def get_writeable_state_dict(
tap_state["bookmarks"] = {}
if tap_stream_id not in tap_state["bookmarks"]:
tap_state["bookmarks"][tap_stream_id] = {}
stream_state = t.cast(dict, tap_state["bookmarks"][tap_stream_id])
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state
return stream_state # type: ignore[return-value]

if "partitions" not in stream_state:
stream_state["partitions"] = []
stream_state_partitions: list[dict] = stream_state["partitions"]
stream_state_partitions: list[StreamStateDict] = stream_state["partitions"]
if found := _find_in_partitions_list(
stream_state_partitions,
state_partition_context,
Expand All @@ -142,7 +174,7 @@ def get_writeable_state_dict(


def write_stream_state(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
key: str,
val: t.Any, # noqa: ANN401
Expand All @@ -158,12 +190,14 @@ def write_stream_state(
state_dict[key] = val


def reset_state_progress_markers(stream_or_partition_state: dict) -> dict | None:
def reset_state_progress_markers(
stream_or_partition_state: StreamStateDict | PartitionsStateDict,
) -> dict | None:
"""Wipe the state once sync is complete.
For logging purposes, return the wiped 'progress_markers' object if it existed.
"""
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {})
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {}) # type: ignore[misc]
# Remove auto-generated human-readable note:
progress_markers.pop(PROGRESS_MARKER_NOTE, None)
# Return remaining 'progress_markers' if any:
Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from singer_sdk.helpers import types
from singer_sdk.helpers._compat import Traversable
from singer_sdk.helpers._state import TapStateDict
from singer_sdk.tap_base import Tap

# Replication methods
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
self._mask: singer.SelectionMask | None = None
self._schema: dict
self._is_state_flushed: bool = True
self._last_emitted_state: dict | None = None
self._last_emitted_state: TapStateDict | None = None
self._sync_costs: dict[str, int] = {}
self.child_streams: list[Stream] = []
if schema:
Expand Down Expand Up @@ -645,7 +646,7 @@ def replication_method(self) -> str:
# State properties:

@property
def tap_state(self) -> dict:
def tap_state(self) -> TapStateDict:
"""Return a writeable state dict for the entire tap.
Note: This dictionary is shared (and writable) across all streams.
Expand Down Expand Up @@ -790,7 +791,7 @@ def _write_state_message(self) -> None:
if (not self._is_state_flushed) and (
self.tap_state != self._last_emitted_state
):
self._tap.write_message(singer.StateMessage(value=self.tap_state))
self._tap.write_message(singer.StateMessage(value=self.tap_state)) # type: ignore[arg-type]
self._last_emitted_state = copy.deepcopy(self.tap_state)
self._is_state_flushed = True

Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pathlib import PurePath

from singer_sdk.connectors import SQLConnector
from singer_sdk.helpers._state import TapStateDict
from singer_sdk.mapper import PluginMapper
from singer_sdk.streams import SQLStream, Stream

Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
# Declare private members
self._streams: dict[str, Stream] | None = None
self._input_catalog: Catalog | None = None
self._state: dict[str, Stream] = {}
self._state: TapStateDict = {}
self._catalog: Catalog | None = None # Tap's working catalog

# Process input catalog
Expand Down Expand Up @@ -138,7 +139,7 @@ def streams(self) -> dict[str, Stream]:
return self._streams

@property
def state(self) -> dict:
def state(self) -> TapStateDict: # type: ignore[override]
"""Get tap state.
Returns:
Expand Down Expand Up @@ -445,7 +446,7 @@ def sync_all(self) -> None:
"""Sync all streams."""
self._reset_state_progress_markers()
self._set_compatible_replication_methods()
self.write_message(StateMessage(value=self.state))
self.write_message(StateMessage(value=self.state)) # type: ignore[arg-type]

stream: Stream
for stream in self.streams.values():
Expand Down

0 comments on commit 7799453

Please sign in to comment.