Skip to content

Commit

Permalink
refactor: Use a TypedDict to annotate state dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed May 8, 2024
1 parent 9d0c08b commit ae0ec18
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from __future__ import annotations

import logging
import sys
import typing as t

from singer_sdk.exceptions import InvalidStreamSortException
from singer_sdk.helpers._typing import to_json_compatible

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

if t.TYPE_CHECKING:
import datetime

Expand All @@ -19,14 +25,25 @@
STARTING_MARKER = "starting_replication_value"

logger = logging.getLogger("singer_sdk")
StreamStateDict: TypeAlias = t.Dict[str, t.Any]


class PartitionsStateDict(t.TypedDict, total=False):
partitions: list[StreamStateDict]


class TapStateDict(t.TypedDict, total=False):
"""State dictionary type."""

bookmarks: dict[str, StreamStateDict | PartitionsStateDict]


def get_state_if_exists(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
state_partition_context: dict | None = None,
state_partition_context: dict[str, t.Any] | None = None,
key: str | None = None,
) -> t.Any | None: # noqa: ANN401
) -> StreamStateDict | None:
"""Return the stream or partition state, creating a new one if it does not exist.
Args:
Expand All @@ -44,34 +61,49 @@ def get_state_if_exists(
ValueError: Raised if state is invalid or cannot be parsed.
"""
if "bookmarks" not in tap_state:
# Not a valid state, e.g. {}
return None

if tap_stream_id not in tap_state["bookmarks"]:
# Stream not present in state, e.g. {"bookmarks": {}}
return None

# At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}}

# stream_state: {"key": "value", "partitions"?: ...}
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state.get(key, None) if key else stream_state
# Either 'value' if key is specified, or {}
return stream_state.get(key, None) if key else stream_state # type: ignore[return-value]

if "partitions" not in stream_state:
return None # No partitions defined

# stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001

matched_partition = _find_in_partitions_list(
stream_state["partitions"],
state_partition_context,
)

if matched_partition is None:
return None # Partition definition not present

return matched_partition.get(key, None) if key else matched_partition


def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
def get_state_partitions_list(
tap_state: TapStateDict,
tap_stream_id: str,
) -> list[StreamStateDict] | None:
"""Return a list of partitions defined in the state, or None if not defined."""
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return]
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None)


def _find_in_partitions_list(
partitions: list[dict],
state_partition_context: dict,
) -> dict | None:
partitions: list[StreamStateDict],
state_partition_context: dict[str, t.Any],
) -> StreamStateDict | None:
found = [
partition_state
for partition_state in partitions
Expand Down

0 comments on commit ae0ec18

Please sign in to comment.