Skip to content

Commit

Permalink
refactor: Override superclass to_jsonschema in `PostgresSQLToJSONSc…
Browse files Browse the repository at this point in the history
…hema`
  • Loading branch information
edgarrmondragon committed Nov 28, 2024
1 parent 7e7d194 commit 1b53d23
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions tap_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import json
import select
import typing as t
from functools import cached_property
from types import MappingProxyType
from typing import TYPE_CHECKING, Any

import psycopg2
import singer_sdk.helpers._typing
Expand All @@ -26,11 +24,10 @@
from singer_sdk.streams.core import REPLICATION_INCREMENTAL
from sqlalchemy.dialects import postgresql

if TYPE_CHECKING:
if t.TYPE_CHECKING:
from collections.abc import Iterable, Mapping

from singer_sdk.helpers.types import Context
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector

Expand All @@ -44,7 +41,12 @@ def __init__(self, dates_as_string: bool, json_as_object: bool, *args, **kwargs)
self.dates_as_string = dates_as_string
self.json_as_object = json_as_object

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@functools.singledispatchmethod
def to_jsonschema(self, column_type: t.Any) -> dict:
"""Customize the JSON Schema for Postgres types."""
return super().to_jsonschema(column_type)

@to_jsonschema.register
def array_to_jsonschema(self, column_type: postgresql.ARRAY) -> dict:
"""Override the default mapping for NUMERIC columns.
Expand All @@ -55,32 +57,29 @@ def array_to_jsonschema(self, column_type: postgresql.ARRAY) -> dict:
"items": self.to_jsonschema(column_type.item_type),
}

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def json_to_jsonschema(self, column_type: postgresql.JSON) -> dict:
"""Override the default mapping for JSON and JSONB columns."""
if self.json_as_object:
return {"type": ["object", "null"]}
return {"type": ["string", "number", "integer", "array", "object", "boolean"]}

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def datetime_to_jsonschema(self, column_type: sqlalchemy.types.DateTime) -> dict:
"""Override the default mapping for DATETIME columns."""
if self.dates_as_string:
return {"type": ["string", "null"]}
return super().datetime_to_jsonschema(column_type)

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def date_to_jsonschema(self, column_type: sqlalchemy.types.Date) -> dict:
"""Override the default mapping for DATE columns."""
if self.dates_as_string:
return {"type": ["string", "null"]}
return super().date_to_jsonschema(column_type)


def patched_conform(
elem: Any,
property_schema: dict,
) -> Any:
def patched_conform(elem: t.Any, property_schema: dict) -> t.Any:
"""Overrides Singer SDK type conformance.
Most logic here is from singer_sdk.helpers._typing._conform_primitive_property, as
Expand Down Expand Up @@ -272,11 +271,11 @@ class PostgresLogBasedStream(SQLStream):
replication_key = "_sdc_lsn"

@property
def config(self) -> Mapping[str, Any]:
def config(self) -> Mapping[str, t.Any]:
"""Return a read-only config dictionary."""
return MappingProxyType(self._config)

@cached_property
@functools.cached_property
def schema(self) -> dict:
"""Override schema for log-based replication adding _sdc columns."""
schema_dict = t.cast(dict, self._singer_catalog_entry.schema.to_dict())
Expand All @@ -293,7 +292,7 @@ def schema(self) -> dict:

def _increment_stream_state(
self,
latest_record: dict[str, Any],
latest_record: dict[str, t.Any],
*,
context: Context | None = None,
) -> None:
Expand Down Expand Up @@ -326,7 +325,7 @@ def _increment_stream_state(
check_sorted=self.check_sorted,
)

def get_records(self, context: Context | None) -> Iterable[dict[str, Any]]:
def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]:
"""Return a generator of row-type dictionary objects."""
status_interval = 5.0 # if no records in 5 seconds the tap can exit
start_lsn = self.get_starting_replication_key_value(context=context)
Expand Down

0 comments on commit 1b53d23

Please sign in to comment.