From d41460de2e6dbf507cc553c51ecb473eb1537eee Mon Sep 17 00:00:00 2001 From: Johnson Kwok Date: Tue, 12 Nov 2024 12:11:02 -0800 Subject: [PATCH] [jk] Refactor code for compatibility with SQLAlchemy v1 --- pyiceberg/catalog/sql.py | 367 ++++++++++++++++++++------------------- 1 file changed, 193 insertions(+), 174 deletions(-) diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 6a4318253f..78ae3b5f83 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -32,15 +32,10 @@ select, union, update, + Column, ) from sqlalchemy.exc import IntegrityError, NoResultFound, OperationalError, ProgrammingError -from sqlalchemy.orm import ( - DeclarativeBase, - Mapped, - MappedAsDataclass, - Session, - mapped_column, -) +from sqlalchemy.orm import declarative_base, Session from pyiceberg.catalog import ( METADATA_LOCATION, @@ -78,30 +73,29 @@ DEFAULT_POOL_PRE_PING_VALUE = "false" DEFAULT_INIT_CATALOG_TABLES = "true" - -class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase): - pass +# Use the traditional ORM base class +Base = declarative_base() -class IcebergTables(SqlCatalogBaseTable): +class IcebergTables(Base): __tablename__ = "iceberg_tables" - catalog_name: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - table_namespace: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - table_name: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - metadata_location: Mapped[Optional[str]] = mapped_column(String(1000), nullable=True) - previous_metadata_location: Mapped[Optional[str]] = mapped_column(String(1000), nullable=True) + catalog_name = Column(String(255), nullable=False, primary_key=True) + table_namespace = Column(String(255), nullable=False, primary_key=True) + table_name = Column(String(255), nullable=False, primary_key=True) + metadata_location = Column(String(1000), nullable=True) + previous_metadata_location = Column(String(1000), nullable=True) -class IcebergNamespaceProperties(SqlCatalogBaseTable): +class IcebergNamespaceProperties(Base): __tablename__ = "iceberg_namespace_properties" # Catalog minimum Namespace Properties NAMESPACE_MINIMAL_PROPERTIES = {"exists": "true"} - catalog_name: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - namespace: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - property_key: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - property_value: Mapped[str] = mapped_column(String(1000), nullable=False) + catalog_name = Column(String(255), nullable=False, primary_key=True) + namespace = Column(String(255), nullable=False, primary_key=True) + property_key = Column(String(255), nullable=False, primary_key=True) + property_value = Column(String(1000), nullable=False) class SqlCatalog(MetastoreCatalog): @@ -137,18 +131,16 @@ def _ensure_tables_exist(self) -> None: stmt = select(1).select_from(table) try: session.scalar(stmt) - except ( - OperationalError, - ProgrammingError, - ): # sqlalchemy returns OperationalError in case of sqlite and ProgrammingError with postgres. + except (OperationalError, ProgrammingError): + # Handle missing tables (SQLite vs Postgres differences) self.create_tables() return def create_tables(self) -> None: - SqlCatalogBaseTable.metadata.create_all(self.engine) + Base.metadata.create_all(self.engine) def destroy_tables(self) -> None: - SqlCatalogBaseTable.metadata.drop_all(self.engine) + Base.metadata.drop_all(self.engine) def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table: # Check for expected properties. @@ -194,15 +186,15 @@ def create_table( Table: the created table instance. Raises: - AlreadyExistsError: If a table with the name already exists. - ValueError: If the identifier is invalid, or no path is given to store metadata. - + TableAlreadyExistsError: If a table with the name already exists. + NoSuchNamespaceError: If the identifier's namespace is invalid. """ - schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + schema = self._convert_schema_if_needed(schema) # type: ignore identifier_nocatalog = self._identifier_to_tuple_without_catalog(identifier) namespace_identifier = Catalog.namespace_from(identifier_nocatalog) table_name = Catalog.table_name_from(identifier_nocatalog) + if not self._namespace_exists(namespace_identifier): raise NoSuchNamespaceError(f"Namespace does not exist: {namespace_identifier}") @@ -210,25 +202,32 @@ def create_table( location = self._resolve_table_location(location, namespace, table_name) metadata_location = self._get_metadata_location(location=location) metadata = new_table_metadata( - location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties + location=location, + schema=schema, + partition_spec=partition_spec, + sort_order=sort_order, + properties=properties, ) io = load_file_io(properties=self.properties, location=metadata_location) self._write_metadata(metadata, io, metadata_location) - with Session(self.engine) as session: - try: - session.add( - IcebergTables( - catalog_name=self.name, - table_namespace=namespace, - table_name=table_name, - metadata_location=metadata_location, - previous_metadata_location=None, - ) + session = Session(self.engine) + try: + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=namespace, + table_name=table_name, + metadata_location=metadata_location, + previous_metadata_location=None, ) - session.commit() - except IntegrityError as e: - raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e + ) + session.commit() + except IntegrityError as e: + session.rollback() + raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e + finally: + session.close() return self.load_table(identifier=identifier) @@ -250,23 +249,27 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: namespace_tuple = Catalog.namespace_from(identifier_tuple) namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(identifier_tuple) + if not self._namespace_exists(namespace): raise NoSuchNamespaceError(f"Namespace does not exist: {namespace}") - with Session(self.engine) as session: - try: - session.add( - IcebergTables( - catalog_name=self.name, - table_namespace=namespace, - table_name=table_name, - metadata_location=metadata_location, - previous_metadata_location=None, - ) + session = Session(self.engine) + try: + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=namespace, + table_name=table_name, + metadata_location=metadata_location, + previous_metadata_location=None, ) - session.commit() - except IntegrityError as e: - raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e + ) + session.commit() + except IntegrityError as e: + session.rollback() + raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e + finally: + session.close() return self.load_table(identifier=identifier) @@ -289,16 +292,20 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: namespace_tuple = Catalog.namespace_from(identifier_tuple) namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(identifier_tuple) - with Session(self.engine) as session: + + session = Session(self.engine) + try: stmt = select(IcebergTables).where( IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == namespace, IcebergTables.table_name == table_name, ) result = session.scalar(stmt) - if result: - return self._convert_orm_to_iceberg(result) - raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") + if result: + return self._convert_orm_to_iceberg(result) + raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") + finally: + session.close() def drop_table(self, identifier: Union[str, Identifier]) -> None: """Drop a table. @@ -313,33 +320,24 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: namespace_tuple = Catalog.namespace_from(identifier_tuple) namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(identifier_tuple) - with Session(self.engine) as session: - if self.engine.dialect.supports_sane_rowcount: - res = session.execute( - delete(IcebergTables).where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == namespace, - IcebergTables.table_name == table_name, - ) + + session = Session(self.engine) + try: + res = session.execute( + delete(IcebergTables).where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == namespace, + IcebergTables.table_name == table_name, ) - if res.rowcount < 1: - raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") - else: - try: - tbl = ( - session.query(IcebergTables) - .with_for_update(of=IcebergTables) - .filter( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == namespace, - IcebergTables.table_name == table_name, - ) - .one() - ) - session.delete(tbl) - except NoResultFound as e: - raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") from e + ) + if res.rowcount < 1: + raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") session.commit() + except NoResultFound: + session.rollback() + raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") + finally: + session.close() def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: """Rename a fully classified table name. @@ -353,7 +351,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U Raises: NoSuchTableError: If a table with the name does not exist. - TableAlreadyExistsError: If a table with the new name already exist. + TableAlreadyExistsError: If a table with the new name already exists. NoSuchNamespaceError: If the target namespace does not exist. """ from_identifier_tuple = self._identifier_to_tuple_without_catalog(from_identifier) @@ -364,42 +362,48 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U to_namespace_tuple = Catalog.namespace_from(to_identifier_tuple) to_namespace = Catalog.namespace_to_string(to_namespace_tuple) to_table_name = Catalog.table_name_from(to_identifier_tuple) + if not self._namespace_exists(to_namespace): raise NoSuchNamespaceError(f"Namespace does not exist: {to_namespace}") - with Session(self.engine) as session: - try: - if self.engine.dialect.supports_sane_rowcount: - stmt = ( - update(IcebergTables) - .where( + + session = Session(self.engine) + try: + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == from_namespace, + IcebergTables.table_name == from_table_name, + ) + .values(table_namespace=to_namespace, table_name=to_table_name) + ) + result = session.execute(stmt) + if result.rowcount < 1: + raise NoSuchTableError(f"Table does not exist: {from_table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == from_namespace, IcebergTables.table_name == from_table_name, ) - .values(table_namespace=to_namespace, table_name=to_table_name) + .one() ) - result = session.execute(stmt) - if result.rowcount < 1: - raise NoSuchTableError(f"Table does not exist: {from_table_name}") - else: - try: - tbl = ( - session.query(IcebergTables) - .with_for_update(of=IcebergTables) - .filter( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == from_namespace, - IcebergTables.table_name == from_table_name, - ) - .one() - ) - tbl.table_namespace = to_namespace - tbl.table_name = to_table_name - except NoResultFound as e: - raise NoSuchTableError(f"Table does not exist: {from_table_name}") from e - session.commit() - except IntegrityError as e: - raise TableAlreadyExistsError(f"Table {to_namespace}.{to_table_name} already exists") from e + tbl.table_namespace = to_namespace + tbl.table_name = to_table_name + except NoResultFound: + raise NoSuchTableError(f"Table does not exist: {from_table_name}") + session.commit() + except IntegrityError: + session.rollback() + raise TableAlreadyExistsError(f"Table {to_namespace}.{to_table_name} already exists") + finally: + session.close() + return self.load_table(to_identifier) def commit_table( @@ -431,18 +435,24 @@ def commit_table( current_table = None updated_staged_table = self._update_and_stage_table(current_table, table.identifier, requirements, updates) + if current_table and updated_staged_table.metadata == current_table.metadata: # no changes, do nothing - return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location) + return CommitTableResponse( + metadata=current_table.metadata, + metadata_location=current_table.metadata_location + ) + self._write_metadata( metadata=updated_staged_table.metadata, io=updated_staged_table.io, metadata_path=updated_staged_table.metadata_location, ) - with Session(self.engine) as session: + session = Session(self.engine) + try: if current_table: - # table exists, update it + # Table exists, update it if self.engine.dialect.supports_sane_rowcount: stmt = ( update(IcebergTables) @@ -459,7 +469,9 @@ def commit_table( ) result = session.execute(stmt) if result.rowcount < 1: - raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") + raise CommitFailedException( + f"Table has been updated by another process: {namespace}.{table_name}" + ) else: try: tbl = ( @@ -475,27 +487,31 @@ def commit_table( ) tbl.metadata_location = updated_staged_table.metadata_location tbl.previous_metadata_location = current_table.metadata_location - except NoResultFound as e: - raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e - session.commit() - else: - # table does not exist, create it - try: - session.add( - IcebergTables( - catalog_name=self.name, - table_namespace=namespace, - table_name=table_name, - metadata_location=updated_staged_table.metadata_location, - previous_metadata_location=None, + except NoResultFound: + raise CommitFailedException( + f"Table has been updated by another process: {namespace}.{table_name}" ) + else: + # Table does not exist, create it + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=namespace, + table_name=table_name, + metadata_location=updated_staged_table.metadata_location, + previous_metadata_location=None, ) - session.commit() - except IntegrityError as e: - raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e + ) + session.commit() + except IntegrityError: + session.rollback() + raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") + finally: + session.close() return CommitTableResponse( - metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location + metadata=updated_staged_table.metadata, + metadata_location=updated_staged_table.metadata_location ) def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: @@ -504,11 +520,13 @@ def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: with Session(self.engine) as session: stmt = ( select(IcebergTables) - .where(IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == namespace) + .where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == namespace + ) .limit(1) ) - result = session.execute(stmt).all() - if result: + if session.execute(stmt).first(): return True stmt = ( select(IcebergNamespaceProperties) @@ -518,10 +536,7 @@ def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: ) .limit(1) ) - result = session.execute(stmt).all() - if result: - return True - return False + return bool(session.execute(stmt).first()) def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. @@ -536,15 +551,14 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper if self._namespace_exists(namespace): raise NamespaceAlreadyExistsError(f"Namespace {namespace} already exists") - if not properties: - properties = IcebergNamespaceProperties.NAMESPACE_MINIMAL_PROPERTIES - create_properties = properties if properties else IcebergNamespaceProperties.NAMESPACE_MINIMAL_PROPERTIES + create_properties = properties or IcebergNamespaceProperties.NAMESPACE_MINIMAL_PROPERTIES + namespace_str = Catalog.namespace_to_string(namespace, NoSuchNamespaceError) with Session(self.engine) as session: for key, value in create_properties.items(): session.add( IcebergNamespaceProperties( catalog_name=self.name, - namespace=Catalog.namespace_to_string(namespace, NoSuchNamespaceError), + namespace=namespace_str, property_key=key, property_value=value, ) @@ -592,11 +606,17 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: if namespace and not self._namespace_exists(namespace): raise NoSuchNamespaceError(f"Namespace does not exist: {namespace}") - namespace = Catalog.namespace_to_string(namespace) - stmt = select(IcebergTables).where(IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == namespace) + namespace_str = Catalog.namespace_to_string(namespace) + stmt = select(IcebergTables).where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == namespace_str + ) with Session(self.engine) as session: result = session.scalars(stmt) - return [(Catalog.identifier_to_tuple(table.table_namespace) + (table.table_name,)) for table in result] + return [ + Catalog.identifier_to_tuple(table.table_namespace) + (table.table_name,) + for table in result + ] def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: """List namespaces from the given namespace. If not given, list top-level namespaces from the catalog. @@ -613,18 +633,19 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi if namespace and not self._namespace_exists(namespace): raise NoSuchNamespaceError(f"Namespace does not exist: {namespace}") - table_stmt = select(IcebergTables.table_namespace).where(IcebergTables.catalog_name == self.name) - namespace_stmt = select(IcebergNamespaceProperties.namespace).where(IcebergNamespaceProperties.catalog_name == self.name) - if namespace: - namespace_str = Catalog.namespace_to_string(namespace, NoSuchNamespaceError) - table_stmt = table_stmt.where(IcebergTables.table_namespace.like(namespace_str)) - namespace_stmt = namespace_stmt.where(IcebergNamespaceProperties.namespace.like(namespace_str)) - stmt = union( - table_stmt, - namespace_stmt, + namespace_str = Catalog.namespace_to_string(namespace, NoSuchNamespaceError) if namespace else '%' + table_stmt = select(IcebergTables.table_namespace).where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace.like(namespace_str) ) + namespace_stmt = select(IcebergNamespaceProperties.namespace).where( + IcebergNamespaceProperties.catalog_name == self.name, + IcebergNamespaceProperties.namespace.like(namespace_str) + ) + + stmt = union(table_stmt, namespace_stmt) with Session(self.engine) as session: - return [Catalog.identifier_to_tuple(namespace_col) for namespace_col in session.execute(stmt).scalars()] + return [Catalog.identifier_to_tuple(namespace_col) for namespace_col in session.scalars(stmt)] def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: """Get properties for a namespace. @@ -640,10 +661,11 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper """ namespace_str = Catalog.namespace_to_string(namespace) if not self._namespace_exists(namespace): - raise NoSuchNamespaceError(f"Namespace {namespace_str} does not exists") + raise NoSuchNamespaceError(f"Namespace {namespace_str} does not exist") stmt = select(IcebergNamespaceProperties).where( - IcebergNamespaceProperties.catalog_name == self.name, IcebergNamespaceProperties.namespace == namespace_str + IcebergNamespaceProperties.catalog_name == self.name, + IcebergNamespaceProperties.namespace == namespace_str ) with Session(self.engine) as session: result = session.scalars(stmt) @@ -665,7 +687,7 @@ def update_namespace_properties( """ namespace_str = Catalog.namespace_to_string(namespace) if not self._namespace_exists(namespace): - raise NoSuchNamespaceError(f"Namespace {namespace_str} does not exists") + raise NoSuchNamespaceError(f"Namespace {namespace_str} does not exist") current_properties = self.load_namespace_properties(namespace=namespace) properties_update_summary = self._get_updated_props_and_update_summary( @@ -682,26 +704,23 @@ def update_namespace_properties( session.execute(delete_stmt) if updates: - # SQLAlchemy does not (yet) support engine agnostic UPSERT - # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements - # This is not a problem since it runs in a single transaction delete_stmt = delete(IcebergNamespaceProperties).where( IcebergNamespaceProperties.catalog_name == self.name, IcebergNamespaceProperties.namespace == namespace_str, - IcebergNamespaceProperties.property_key.in_(set(updates.keys())), + IcebergNamespaceProperties.property_key.in_(updates.keys()), ) session.execute(delete_stmt) + insert_stmt_values = [ { IcebergNamespaceProperties.catalog_name: self.name, IcebergNamespaceProperties.namespace: namespace_str, - IcebergNamespaceProperties.property_key: property_key, - IcebergNamespaceProperties.property_value: property_value, + IcebergNamespaceProperties.property_key: key, + IcebergNamespaceProperties.property_value: value, } - for property_key, property_value in updates.items() + for key, value in updates.items() ] - insert_stmt = insert(IcebergNamespaceProperties).values(insert_stmt_values) - session.execute(insert_stmt) + session.execute(insert(IcebergNamespaceProperties).values(insert_stmt_values)) session.commit() return properties_update_summary