diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 540de023e6..5ade1bd776 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -13,6 +13,7 @@ from functools import lru_cache import sqlalchemy as sa +from sqlalchemy.engine import reflection from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema @@ -960,6 +961,75 @@ def discover_catalog_entry( replication_key=None, # Must be defined by user ) + def _discover_catalog_entry_from_inspected( + self, + *, + table_name: str, + schema_name: str | None, + columns: list[reflection.ReflectedColumn], + primary_key: reflection.ReflectedPrimaryKeyConstraint, + indices: list[reflection.ReflectedIndex], + is_view: bool = False, + ) -> CatalogEntry: + unique_stream_id = f"{schema_name}-{table_name}" if schema_name else table_name + + # Detect key properties + possible_primary_keys: list[list[str]] = [] + if "constrained_columns" in primary_key: # type: ignore[redundant-expr] + possible_primary_keys.append(primary_key["constrained_columns"]) + + # An element of the columns list is ``None`` if it's an expression and is + # returned in the ``expressions`` list of the reflected index. + possible_primary_keys.extend( + index_def["column_names"] # type: ignore[misc] + for index_def in indices + if index_def.get("unique", False) + ) + + key_properties = next(iter(possible_primary_keys), None) + + # Initialize columns list + table_schema = th.PropertiesList() + for column_def in columns: + column_name = column_def["name"] + is_nullable = column_def.get("nullable", False) + jsonschema_type: dict = self.to_jsonschema_type(column_def["type"]) + table_schema.append( + th.Property( + name=column_name, + wrapped=th.CustomType(jsonschema_type), + nullable=is_nullable, + required=column_name in key_properties if key_properties else False, + ), + ) + schema = table_schema.to_dict() + + # Initialize available replication methods + addl_replication_methods: list[str] = [] # By default an empty list. + replication_method = next(reversed(["FULL_TABLE", *addl_replication_methods])) + + # Create the catalog entry object + return CatalogEntry( + tap_stream_id=unique_stream_id, + stream=unique_stream_id, + table=table_name, + key_properties=key_properties, + schema=Schema.from_dict(schema), + is_view=is_view, + replication_method=replication_method, + metadata=MetadataMapping.get_standard_metadata( + schema_name=schema_name, + schema=schema, + replication_method=replication_method, + key_properties=key_properties, + valid_replication_keys=None, # Must be defined by user + ), + database=None, # Expects single-database context + row_count=None, + stream_alias=None, + replication_key=None, # Must be defined by user + ) + def discover_catalog_entries(self) -> list[dict]: """Return a list of catalog entries from discovery. @@ -969,21 +1039,30 @@ def discover_catalog_entries(self) -> list[dict]: result: list[dict] = [] engine = self._engine inspected = sa.inspect(engine) + object_kinds = ( + (reflection.ObjectKind.TABLE, False), + (reflection.ObjectKind.ANY_VIEW, True), + ) for schema_name in self.get_schema_names(engine, inspected): - # Iterate through each table and view - for table_name, is_view in self.get_object_names( - engine, - inspected, - schema_name, - ): - catalog_entry = self.discover_catalog_entry( - engine, - inspected, - schema_name, - table_name, - is_view, + for object_kind, is_view in object_kinds: + columns = inspected.get_multi_columns( + schema=schema_name, + kind=object_kind, + ) + pk = inspected.get_multi_pk_constraint(schema=schema_name) + indices = inspected.get_multi_indexes(schema=schema_name) + + result.extend( + self._discover_catalog_entry_from_inspected( + table_name=_table, + schema_name=_schema, + columns=columns[_schema, _table], + primary_key=pk[_schema, _table], + indices=indices[_schema, _table], + is_view=is_view, + ).to_dict() + for _schema, _table in columns ) - result.append(catalog_entry.to_dict()) return result