diff --git a/src/dsdk/mssql.py b/src/dsdk/mssql.py index c8c41d6..3b9d4cb 100644 --- a/src/dsdk/mssql.py +++ b/src/dsdk/mssql.py @@ -95,11 +95,20 @@ def connect(self) -> Generator[Any, None, None]: def dry_run( self, - query_parameters, + parameters: Dict[str, Any], exceptions=(DatabaseError, InterfaceError), ): """Dry run.""" - super().dry_run(query_parameters, exceptions) + super().dry_run(parameters, exceptions) + + def dry_run_query( + self, + query: str, + parameters: Dict[str, Any], + exceptions=(DatabaseError, InterfaceError), + ): + """Dry run.""" + super().dry_run_query(query, parameters, exceptions) class Mixin(BaseMixin): diff --git a/src/dsdk/persistor.py b/src/dsdk/persistor.py index c9c349e..a892d27 100644 --- a/src/dsdk/persistor.py +++ b/src/dsdk/persistor.py @@ -168,35 +168,45 @@ def __init__(self, sql: Asset): def dry_run( self, - query_parameters: Dict[str, Any], + parameters: Dict[str, Any], exceptions: Tuple = (), ): """Execute sql found in asse with dry_run parameter set to 1.""" logger.info(self.ON) - query_parameters = query_parameters.copy() - query_parameters["dry_run"] = 1 + parameters = parameters.copy() + parameters["dry_run"] = 1 errors = [] - for path, value in self.sql(): + for path, query in self.sql(): logger.info(self.DRY_RUN, path) - with self.rollback() as cur: - rendered = self.render_without_keys( - cur, - value, - query_parameters, - ) - with NamedTemporaryFile( - "w", delete=False, suffix=".sql" - ) as fout: - fout.write(rendered) - try: - cur.execute(rendered) - except exceptions as e: - logger.warning(self.ERROR, path) - errors.append(e) + error = self.dry_run_query(query, parameters, exceptions) + if error is not None: + errors.append(error) + logger.warning(self.ERROR, path) if bool(errors): raise RuntimeError(self.ERRORS, errors) logger.info(self.END) + def dry_run_query( + self, + query, + parameters, + exceptions: Tuple = (), + ) -> Optional[Exception]: + """Dry run query.""" + with self.rollback() as cur: + rendered = self.render_without_keys( + cur, + query, + parameters, + ) + with NamedTemporaryFile("w", delete=False, suffix=".sql") as fout: + fout.write(rendered) + try: + cur.execute(rendered) + except exceptions as error: + return error + return None + @contextmanager def commit(self) -> Generator[Any, None, None]: """Commit.""" diff --git a/src/dsdk/postgres.py b/src/dsdk/postgres.py index 23172cd..df75202 100644 --- a/src/dsdk/postgres.py +++ b/src/dsdk/postgres.py @@ -143,11 +143,20 @@ def __init__( def dry_run( self, - query_parameters, + parameters: Dict[str, Any], exceptions=(DatabaseError, InterfaceError), ): """Dry run.""" - super().dry_run(query_parameters, exceptions) + super().dry_run(parameters, exceptions) + + def dry_run_query( + self, + query: str, + parameters: Dict[str, Any], + exceptions=(DatabaseError, InterfaceError), + ): + """Dry run query.""" + super().dry_run_query(query, parameters, exceptions) @contextmanager def commit(self) -> Generator[Any, None, None]: