From 5d98422e006fe700a62c59babec716e58dc0bbfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Fri, 12 Jul 2024 20:38:07 -0600 Subject: [PATCH] Re-use message writing and formatting logic --- singer_sdk/_singerlib/encoding/base.py | 202 +++++++++++++++++++- singer_sdk/_singerlib/messages.py | 249 +++---------------------- 2 files changed, 227 insertions(+), 224 deletions(-) diff --git a/singer_sdk/_singerlib/encoding/base.py b/singer_sdk/_singerlib/encoding/base.py index 91d8a1c08e..cafca412cf 100644 --- a/singer_sdk/_singerlib/encoding/base.py +++ b/singer_sdk/_singerlib/encoding/base.py @@ -3,12 +3,23 @@ from __future__ import annotations import abc +import enum import logging +import sys import typing as t from collections import Counter, defaultdict +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone from singer_sdk._singerlib import exceptions -from singer_sdk._singerlib.messages import Message, SingerMessageType + +if sys.version_info < (3, 11): + from backports.datetime_fromisoformat import MonkeyPatch + + MonkeyPatch.patch_fromisoformat() + +if t.TYPE_CHECKING: + from singer_sdk._singerlib.messages import Message logger = logging.getLogger(__name__) @@ -18,6 +29,195 @@ T = t.TypeVar("T", str, bytes) +class SingerMessageType(str, enum.Enum): + """Singer specification message types.""" + + RECORD = "RECORD" + SCHEMA = "SCHEMA" + STATE = "STATE" + ACTIVATE_VERSION = "ACTIVATE_VERSION" + BATCH = "BATCH" + + +def exclude_null_dict(pairs: list[tuple[str, t.Any]]) -> dict[str, t.Any]: + """Exclude null values from a dictionary. + + Args: + pairs: The dictionary key-value pairs. + + Returns: + The filtered key-value pairs. + """ + return {key: value for key, value in pairs if value is not None} + + +@dataclass +class Message: + """Singer base message.""" + + type: SingerMessageType = field(init=False) + """The message type.""" + + def to_dict(self) -> dict[str, t.Any]: + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self, dict_factory=exclude_null_dict) + + @classmethod + def from_dict( + cls: t.Type[Message], # noqa: UP006 + data: dict[str, t.Any], + ) -> Message: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + data.pop("type") + return cls(**data) + + +@dataclass +class RecordMessage(Message): + """Singer record message.""" + + stream: str + """The stream name.""" + + record: dict[str, t.Any] + """The record data.""" + + version: int | None = None + """The record version.""" + + time_extracted: datetime | None = None + """The time the record was extracted.""" + + @classmethod + def from_dict(cls: type[RecordMessage], data: dict[str, t.Any]) -> RecordMessage: + """Create a record message from a dictionary. + + This overrides the default conversion logic, since it uses unnecessary + deep copying and is very slow. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + time_extracted = data.get("time_extracted") + return cls( + stream=data["stream"], + record=data["record"], + version=data.get("version"), + time_extracted=datetime.fromisoformat(time_extracted) + if time_extracted + else None, + ) + + def to_dict(self) -> dict[str, t.Any]: + """Return a dictionary representation of the message. + + This overrides the default conversion logic, since it uses unnecessary + deep copying and is very slow. + + Returns: + A dictionary with the defined message fields. + """ + result: dict[str, t.Any] = { + "type": "RECORD", + "stream": self.stream, + "record": self.record, + } + if self.version is not None: + result["version"] = self.version + if self.time_extracted is not None: + result["time_extracted"] = self.time_extracted + return result + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If the time_extracted is not timezone-aware. + """ + self.type = SingerMessageType.RECORD + if self.time_extracted and not self.time_extracted.tzinfo: + msg = ( + "'time_extracted' must be either None or an aware datetime (with a " + "time zone)" + ) + raise ValueError(msg) + + if self.time_extracted: + self.time_extracted = self.time_extracted.astimezone(timezone.utc) + + +@dataclass +class SchemaMessage(Message): + """Singer schema message.""" + + stream: str + """The stream name.""" + + schema: dict[str, t.Any] + """The schema definition.""" + + key_properties: t.Sequence[str] | None = None + """The key properties.""" + + bookmark_properties: list[str] | None = None + """The bookmark properties.""" + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If bookmark_properties is not a string or list of strings. + """ + self.type = SingerMessageType.SCHEMA + + if isinstance(self.bookmark_properties, (str, bytes)): + self.bookmark_properties = [self.bookmark_properties] + if self.bookmark_properties and not isinstance(self.bookmark_properties, list): + msg = "bookmark_properties must be a string or list of strings" + raise ValueError(msg) + + +@dataclass +class StateMessage(Message): + """Singer state message.""" + + value: dict[str, t.Any] + """The state value.""" + + def __post_init__(self) -> None: + """Post-init processing.""" + self.type = SingerMessageType.STATE + + +@dataclass +class ActivateVersionMessage(Message): + """Singer activate version message.""" + + stream: str + """The stream name.""" + + version: int + """The version to activate.""" + + def __post_init__(self) -> None: + """Post-init processing.""" + self.type = SingerMessageType.ACTIVATE_VERSION + + class GenericSingerReader(t.Generic[T], metaclass=abc.ABCMeta): """Interface for all plugins reading Singer messages as strings or bytes.""" diff --git a/singer_sdk/_singerlib/messages.py b/singer_sdk/_singerlib/messages.py index ae8977c9e1..ae5572c4c5 100644 --- a/singer_sdk/_singerlib/messages.py +++ b/singer_sdk/_singerlib/messages.py @@ -2,226 +2,29 @@ from __future__ import annotations -import enum -import sys -import typing as t -from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone - -from singer_sdk._singerlib.json import serialize_json - -if sys.version_info < (3, 11): - from backports.datetime_fromisoformat import MonkeyPatch - - MonkeyPatch.patch_fromisoformat() - - -class SingerMessageType(str, enum.Enum): - """Singer specification message types.""" - - RECORD = "RECORD" - SCHEMA = "SCHEMA" - STATE = "STATE" - ACTIVATE_VERSION = "ACTIVATE_VERSION" - BATCH = "BATCH" - - -def exclude_null_dict(pairs: list[tuple[str, t.Any]]) -> dict[str, t.Any]: - """Exclude null values from a dictionary. - - Args: - pairs: The dictionary key-value pairs. - - Returns: - The filtered key-value pairs. - """ - return {key: value for key, value in pairs if value is not None} - - -@dataclass -class Message: - """Singer base message.""" - - type: SingerMessageType = field(init=False) - """The message type.""" - - def to_dict(self) -> dict[str, t.Any]: - """Return a dictionary representation of the message. - - Returns: - A dictionary with the defined message fields. - """ - return asdict(self, dict_factory=exclude_null_dict) - - @classmethod - def from_dict( - cls: t.Type[Message], # noqa: UP006 - data: dict[str, t.Any], - ) -> Message: - """Create an encoding from a dictionary. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - data.pop("type") - return cls(**data) - - -@dataclass -class RecordMessage(Message): - """Singer record message.""" - - stream: str - """The stream name.""" - - record: dict[str, t.Any] - """The record data.""" - - version: int | None = None - """The record version.""" - - time_extracted: datetime | None = None - """The time the record was extracted.""" - - @classmethod - def from_dict(cls: type[RecordMessage], data: dict[str, t.Any]) -> RecordMessage: - """Create a record message from a dictionary. - - This overrides the default conversion logic, since it uses unnecessary - deep copying and is very slow. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - time_extracted = data.get("time_extracted") - return cls( - stream=data["stream"], - record=data["record"], - version=data.get("version"), - time_extracted=datetime.fromisoformat(time_extracted) - if time_extracted - else None, - ) - - def to_dict(self) -> dict[str, t.Any]: - """Return a dictionary representation of the message. - - This overrides the default conversion logic, since it uses unnecessary - deep copying and is very slow. - - Returns: - A dictionary with the defined message fields. - """ - result: dict[str, t.Any] = { - "type": "RECORD", - "stream": self.stream, - "record": self.record, - } - if self.version is not None: - result["version"] = self.version - if self.time_extracted is not None: - result["time_extracted"] = self.time_extracted - return result - - def __post_init__(self) -> None: - """Post-init processing. - - Raises: - ValueError: If the time_extracted is not timezone-aware. - """ - self.type = SingerMessageType.RECORD - if self.time_extracted and not self.time_extracted.tzinfo: - msg = ( - "'time_extracted' must be either None or an aware datetime (with a " - "time zone)" - ) - raise ValueError(msg) - - if self.time_extracted: - self.time_extracted = self.time_extracted.astimezone(timezone.utc) - - -@dataclass -class SchemaMessage(Message): - """Singer schema message.""" - - stream: str - """The stream name.""" - - schema: dict[str, t.Any] - """The schema definition.""" - - key_properties: t.Sequence[str] | None = None - """The key properties.""" - - bookmark_properties: list[str] | None = None - """The bookmark properties.""" - - def __post_init__(self) -> None: - """Post-init processing. - - Raises: - ValueError: If bookmark_properties is not a string or list of strings. - """ - self.type = SingerMessageType.SCHEMA - - if isinstance(self.bookmark_properties, (str, bytes)): - self.bookmark_properties = [self.bookmark_properties] - if self.bookmark_properties and not isinstance(self.bookmark_properties, list): - msg = "bookmark_properties must be a string or list of strings" - raise ValueError(msg) - - -@dataclass -class StateMessage(Message): - """Singer state message.""" - - value: dict[str, t.Any] - """The state value.""" - - def __post_init__(self) -> None: - """Post-init processing.""" - self.type = SingerMessageType.STATE - - -@dataclass -class ActivateVersionMessage(Message): - """Singer activate version message.""" - - stream: str - """The stream name.""" - - version: int - """The version to activate.""" - - def __post_init__(self) -> None: - """Post-init processing.""" - self.type = SingerMessageType.ACTIVATE_VERSION - - -def format_message(message: Message) -> str: - """Format a message as a JSON string. - - Args: - message: The message to format. - - Returns: - The formatted message. - """ - return serialize_json(message.to_dict()) - - -def write_message(message: Message) -> None: - """Write a message to stdout. - - Args: - message: The message to write. - """ - sys.stdout.write(format_message(message) + "\n") - sys.stdout.flush() +from singer_sdk._singerlib.encoding import SingerWriter +from singer_sdk._singerlib.encoding.base import ( + ActivateVersionMessage, + Message, + RecordMessage, + SchemaMessage, + SingerMessageType, + StateMessage, + exclude_null_dict, +) + +__all__ = [ + "ActivateVersionMessage", + "Message", + "RecordMessage", + "SchemaMessage", + "SingerMessageType", + "StateMessage", + "exclude_null_dict", + "format_message", + "write_message", +] + +WRITER = SingerWriter() +format_message = WRITER.format_message +write_message = WRITER.write_message