Skip to content

Commit

Permalink
Extract dry_run function for individual queries
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubken committed Jan 11, 2022
1 parent c463ed0 commit 47cf570
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 23 deletions.
13 changes: 11 additions & 2 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,20 @@ def connect(self) -> Generator[Any, None, None]:

def dry_run(
self,
query_parameters,
parameters: Dict[str, Any],
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, exceptions)
super().dry_run(parameters, exceptions)

def dry_run_query(
self,
query: str,
parameters: Dict[str, Any],
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run_query(query, parameters, exceptions)


class Mixin(BaseMixin):
Expand Down
48 changes: 29 additions & 19 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,35 +168,45 @@ def __init__(self, sql: Asset):

def dry_run(
self,
query_parameters: Dict[str, Any],
parameters: Dict[str, Any],
exceptions: Tuple = (),
):
"""Execute sql found in asse with dry_run parameter set to 1."""
logger.info(self.ON)
query_parameters = query_parameters.copy()
query_parameters["dry_run"] = 1
parameters = parameters.copy()
parameters["dry_run"] = 1
errors = []
for path, value in self.sql():
for path, query in self.sql():
logger.info(self.DRY_RUN, path)
with self.rollback() as cur:
rendered = self.render_without_keys(
cur,
value,
query_parameters,
)
with NamedTemporaryFile(
"w", delete=False, suffix=".sql"
) as fout:
fout.write(rendered)
try:
cur.execute(rendered)
except exceptions as e:
logger.warning(self.ERROR, path)
errors.append(e)
error = self.dry_run_query(query, parameters, exceptions)
if error is not None:
errors.append(error)
logger.warning(self.ERROR, path)
if bool(errors):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

def dry_run_query(
self,
query,
parameters,
exceptions: Tuple = (),
) -> Optional[Exception]:
"""Dry run query."""
with self.rollback() as cur:
rendered = self.render_without_keys(
cur,
query,
parameters,
)
with NamedTemporaryFile("w", delete=False, suffix=".sql") as fout:
fout.write(rendered)
try:
cur.execute(rendered)
except exceptions as error:
return error
return None

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
Expand Down
13 changes: 11 additions & 2 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,20 @@ def __init__(

def dry_run(
self,
query_parameters,
parameters: Dict[str, Any],
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, exceptions)
super().dry_run(parameters, exceptions)

def dry_run_query(
self,
query: str,
parameters: Dict[str, Any],
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run query."""
super().dry_run_query(query, parameters, exceptions)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
Expand Down

0 comments on commit 47cf570

Please sign in to comment.