From 64f4332db64223316cbd808b4ee4859b35ffb559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Mon, 2 Dec 2024 17:27:09 -0600 Subject: [PATCH] perf(taps): Improved discovery performance for SQL taps --- singer_sdk/connectors/sql.py | 104 ++++++++++++++++++++++++++++++----- 1 file changed, 91 insertions(+), 13 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 540de023e6..22eed2810f 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -13,6 +13,7 @@ from functools import lru_cache import sqlalchemy as sa +from sqlalchemy.engine import reflection from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema @@ -960,6 +961,73 @@ def discover_catalog_entry( replication_key=None, # Must be defined by user ) + def _discover_catalog_entry_from_inspected( + self, + *, + table_name: str, + schema_name: str | None, + columns: list[reflection.ReflectedColumn], + primary_key: reflection.ReflectedPrimaryKeyConstraint | None, + unique_constraints: list[reflection.ReflectedUniqueConstraint], + is_view: bool = False, + ) -> CatalogEntry: + unique_stream_id = f"{schema_name}-{table_name}" if schema_name else table_name + + # Detect key properties + possible_primary_keys: list[list[str]] = [] + if primary_key and "constrained_columns" in primary_key: + possible_primary_keys.append(primary_key["constrained_columns"]) + + # Check UNIQUE constraints + possible_primary_keys.extend( + unique_constraint["column_names"] + for unique_constraint in unique_constraints + ) + + key_properties = next(iter(possible_primary_keys), None) + + # Initialize columns list + table_schema = th.PropertiesList() + for column_def in columns: + column_name = column_def["name"] + is_nullable = column_def.get("nullable", False) + jsonschema_type: dict = self.to_jsonschema_type(column_def["type"]) + table_schema.append( + th.Property( + name=column_name, + wrapped=th.CustomType(jsonschema_type), + nullable=is_nullable, + required=column_name in key_properties if key_properties else False, + ), + ) + schema = table_schema.to_dict() + + # Initialize available replication methods + addl_replication_methods: list[str] = [] # By default an empty list. + replication_method = next(reversed(["FULL_TABLE", *addl_replication_methods])) + + # Create the catalog entry object + return CatalogEntry( + tap_stream_id=unique_stream_id, + stream=unique_stream_id, + table=table_name, + key_properties=key_properties, + schema=Schema.from_dict(schema), + is_view=is_view, + replication_method=replication_method, + metadata=MetadataMapping.get_standard_metadata( + schema_name=schema_name, + schema=schema, + replication_method=replication_method, + key_properties=key_properties, + valid_replication_keys=None, # Must be defined by user + ), + database=None, # Expects single-database context + row_count=None, + stream_alias=None, + replication_key=None, # Must be defined by user + ) + def discover_catalog_entries(self) -> list[dict]: """Return a list of catalog entries from discovery. @@ -969,21 +1037,31 @@ def discover_catalog_entries(self) -> list[dict]: result: list[dict] = [] engine = self._engine inspected = sa.inspect(engine) + object_kinds = ( + (reflection.ObjectKind.TABLE, False), + (reflection.ObjectKind.ANY_VIEW, True), + ) for schema_name in self.get_schema_names(engine, inspected): - # Iterate through each table and view - for table_name, is_view in self.get_object_names( - engine, - inspected, - schema_name, - ): - catalog_entry = self.discover_catalog_entry( - engine, - inspected, - schema_name, - table_name, - is_view, + columns = inspected.get_multi_columns(schema=schema_name) + pk = inspected.get_multi_pk_constraint(schema=schema_name) + unique = inspected.get_multi_unique_constraints(schema=schema_name) + for object_kind, is_view in object_kinds: + columns = inspected.get_multi_columns( + schema=schema_name, + kind=object_kind, + ) + + result.extend( + self._discover_catalog_entry_from_inspected( + table_name=_table, + schema_name=_schema, + columns=columns[_schema, _table], + primary_key=pk.get((_schema, _table)), + unique_constraints=unique.get((_schema, _table), []), + is_view=is_view, + ).to_dict() + for _schema, _table in columns ) - result.append(catalog_entry.to_dict()) return result