Skip to content

Commit

Permalink
Re-use message writing and formatting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jul 13, 2024
1 parent 168485c commit 8f390f6
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 224 deletions.
199 changes: 198 additions & 1 deletion singer_sdk/_singerlib/encoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from __future__ import annotations

import abc
import enum
import logging
import sys
import typing as t
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone

from singer_sdk._singerlib import exceptions
from singer_sdk._singerlib.messages import Message, SingerMessageType

if sys.version_info < (3, 11):
from backports.datetime_fromisoformat import MonkeyPatch

MonkeyPatch.patch_fromisoformat()

logger = logging.getLogger(__name__)

Expand All @@ -18,6 +26,195 @@
T = t.TypeVar("T", str, bytes)


class SingerMessageType(str, enum.Enum):
"""Singer specification message types."""

RECORD = "RECORD"
SCHEMA = "SCHEMA"
STATE = "STATE"
ACTIVATE_VERSION = "ACTIVATE_VERSION"
BATCH = "BATCH"


def exclude_null_dict(pairs: list[tuple[str, t.Any]]) -> dict[str, t.Any]:
"""Exclude null values from a dictionary.
Args:
pairs: The dictionary key-value pairs.
Returns:
The filtered key-value pairs.
"""
return {key: value for key, value in pairs if value is not None}


@dataclass
class Message:
"""Singer base message."""

type: SingerMessageType = field(init=False)
"""The message type."""

def to_dict(self) -> dict[str, t.Any]:
"""Return a dictionary representation of the message.
Returns:
A dictionary with the defined message fields.
"""
return asdict(self, dict_factory=exclude_null_dict)

@classmethod
def from_dict(
cls: t.Type[Message], # noqa: UP006
data: dict[str, t.Any],
) -> Message:
"""Create an encoding from a dictionary.
Args:
data: The dictionary to create the message from.
Returns:
The created message.
"""
data.pop("type")
return cls(**data)


@dataclass
class RecordMessage(Message):
"""Singer record message."""

stream: str
"""The stream name."""

record: dict[str, t.Any]
"""The record data."""

version: int | None = None
"""The record version."""

time_extracted: datetime | None = None
"""The time the record was extracted."""

@classmethod
def from_dict(cls: type[RecordMessage], data: dict[str, t.Any]) -> RecordMessage:
"""Create a record message from a dictionary.
This overrides the default conversion logic, since it uses unnecessary
deep copying and is very slow.
Args:
data: The dictionary to create the message from.
Returns:
The created message.
"""
time_extracted = data.get("time_extracted")
return cls(
stream=data["stream"],
record=data["record"],
version=data.get("version"),
time_extracted=datetime.fromisoformat(time_extracted)
if time_extracted
else None,
)

def to_dict(self) -> dict[str, t.Any]:
"""Return a dictionary representation of the message.
This overrides the default conversion logic, since it uses unnecessary
deep copying and is very slow.
Returns:
A dictionary with the defined message fields.
"""
result: dict[str, t.Any] = {
"type": "RECORD",
"stream": self.stream,
"record": self.record,
}
if self.version is not None:
result["version"] = self.version

Check warning on line 137 in singer_sdk/_singerlib/encoding/base.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/base.py#L137

Added line #L137 was not covered by tests
if self.time_extracted is not None:
result["time_extracted"] = self.time_extracted
return result

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If the time_extracted is not timezone-aware.
"""
self.type = SingerMessageType.RECORD
if self.time_extracted and not self.time_extracted.tzinfo:
msg = (
"'time_extracted' must be either None or an aware datetime (with a "
"time zone)"
)
raise ValueError(msg)

if self.time_extracted:
self.time_extracted = self.time_extracted.astimezone(timezone.utc)


@dataclass
class SchemaMessage(Message):
"""Singer schema message."""

stream: str
"""The stream name."""

schema: dict[str, t.Any]
"""The schema definition."""

key_properties: t.Sequence[str] | None = None
"""The key properties."""

bookmark_properties: list[str] | None = None
"""The bookmark properties."""

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If bookmark_properties is not a string or list of strings.
"""
self.type = SingerMessageType.SCHEMA

if isinstance(self.bookmark_properties, (str, bytes)):
self.bookmark_properties = [self.bookmark_properties]
if self.bookmark_properties and not isinstance(self.bookmark_properties, list):
msg = "bookmark_properties must be a string or list of strings"
raise ValueError(msg)


@dataclass
class StateMessage(Message):
"""Singer state message."""

value: dict[str, t.Any]
"""The state value."""

def __post_init__(self) -> None:
"""Post-init processing."""
self.type = SingerMessageType.STATE


@dataclass
class ActivateVersionMessage(Message):
"""Singer activate version message."""

stream: str
"""The stream name."""

version: int
"""The version to activate."""

def __post_init__(self) -> None:
"""Post-init processing."""
self.type = SingerMessageType.ACTIVATE_VERSION


class GenericSingerReader(t.Generic[T], metaclass=abc.ABCMeta):
"""Interface for all plugins reading Singer messages as strings or bytes."""

Expand Down
Loading

0 comments on commit 8f390f6

Please sign in to comment.