diff --git a/tap_postgres/client.py b/tap_postgres/client.py index a6ae564c..42c3ef9f 100644 --- a/tap_postgres/client.py +++ b/tap_postgres/client.py @@ -20,6 +20,8 @@ import sqlalchemy.types from psycopg2 import extras from singer_sdk import SQLConnector, SQLStream +from singer_sdk import typing as th +from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema from singer_sdk.connectors.sql import SQLToJSONSchema from singer_sdk.helpers._state import increment_state from singer_sdk.helpers._typing import TypeConformanceLevel @@ -32,6 +34,12 @@ from singer_sdk.helpers.types import Context from sqlalchemy.dialects import postgresql from sqlalchemy.engine import Engine + from sqlalchemy.engine.interfaces import ( # type: ignore[attr-defined] + ReflectedColumn, + ReflectedIndex, + ReflectedPrimaryKeyConstraint, + TableKey, + ) from sqlalchemy.engine.reflection import Inspector @@ -183,6 +191,135 @@ def get_schema_names(self, engine: Engine, inspected: Inspector) -> list[str]: return self.config["filter_schemas"] return super().get_schema_names(engine, inspected) + # Uses information_schema for speed. + def discover_catalog_entry_optimized( # noqa: PLR0913 + self, + engine: Engine, + inspected: Inspector, + schema_name: str, + table_name: str, + is_view: bool, + table_data: dict[TableKey, list[ReflectedColumn]], + pk_data: dict[TableKey, ReflectedPrimaryKeyConstraint], + index_data: dict[TableKey, list[ReflectedIndex]], + ) -> CatalogEntry: + """Create `CatalogEntry` object for the given table or a view. + + Args: + engine: SQLAlchemy engine + inspected: SQLAlchemy inspector instance for engine + schema_name: Schema name to inspect + table_name: Name of the table or a view + is_view: Flag whether this object is a view, returned by `get_object_names` + table_data: Cached inspector data for the relevant tables + pk_data: Cached inspector data for the relevant primary keys + index_data: Cached inspector data for the relevant indexes + + Returns: + `CatalogEntry` object for the given table or a view + """ + # Initialize unique stream name + unique_stream_id = f"{schema_name}-{table_name}" + table_key = (schema_name, table_name) + + # Detect key properties + possible_primary_keys: list[list[str]] = [] + pk_def = pk_data.get(table_key, {}) + if pk_def and "constrained_columns" in pk_def: + possible_primary_keys.append(pk_def["constrained_columns"]) + + # An element of the columns list is ``None`` if it's an expression and is + # returned in the ``expressions`` list of the reflected index. + possible_primary_keys.extend( + index_def["column_names"] + for index_def in index_data.get(table_key, []) + if index_def.get("unique", False) + ) + + key_properties = next(iter(possible_primary_keys), None) + + # Initialize columns list + table_schema = th.PropertiesList() + for column_def in table_data.get(table_key, []): + column_name = column_def["name"] + is_nullable = column_def.get("nullable", False) + jsonschema_type: dict = self.to_jsonschema_type(column_def["type"]) + table_schema.append( + th.Property( + name=column_name, + wrapped=th.CustomType(jsonschema_type), + nullable=is_nullable, + required=column_name in key_properties if key_properties else False, + ), + ) + schema = table_schema.to_dict() + + # Initialize available replication methods + addl_replication_methods: list[str] = [""] # By default an empty list. + # Notes regarding replication methods: + # - 'INCREMENTAL' replication must be enabled by the user by specifying + # a replication_key value. + # - 'LOG_BASED' replication must be enabled by the developer, according + # to source-specific implementation capabilities. + replication_method = next(reversed(["FULL_TABLE", *addl_replication_methods])) + + # Create the catalog entry object + return CatalogEntry( + tap_stream_id=unique_stream_id, + stream=unique_stream_id, + table=table_name, + key_properties=key_properties, + schema=Schema.from_dict(schema), + is_view=is_view, + replication_method=replication_method, + metadata=MetadataMapping.get_standard_metadata( + schema_name=schema_name, + schema=schema, + replication_method=replication_method, + key_properties=key_properties, + valid_replication_keys=None, # Must be defined by user + ), + database=None, # Expects single-database context + row_count=None, + stream_alias=None, + replication_key=None, # Must be defined by user + ) + + def discover_catalog_entries(self) -> list[dict]: + """Return a list of catalog entries from discovery. + + Returns: + The discovered catalog entries as a list. + """ + result: list[dict] = [] + engine = self._engine + inspected = sa.inspect(engine) + for schema_name in self.get_schema_names(engine, inspected): + # Use get_multi_* data here instead of pulling per-table + table_data = inspected.get_multi_columns(schema=schema_name) + pk_data = inspected.get_multi_pk_constraint(schema=schema_name) + index_data = inspected.get_multi_indexes(schema=schema_name) + + # Iterate through each table and view + for table_name, is_view in self.get_object_names( + engine, + inspected, + schema_name, + ): + catalog_entry = self.discover_catalog_entry_optimized( + engine, + inspected, + schema_name, + table_name, + is_view, + table_data, + pk_data, + index_data, + ) + result.append(catalog_entry.to_dict()) + + return result + class PostgresStream(SQLStream): """Stream class for Postgres streams."""