Skip to content

Commit

Permalink
Merge pull request #63 from pennsignals/dry_run
Browse files Browse the repository at this point in the history
Dry run
  • Loading branch information
cjbayesian authored Dec 22, 2021
2 parents 87ca271 + 98cae71 commit f4fafcc
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 146 deletions.
Empty file added assets/mssql/.gitignore
Empty file.
3 changes: 0 additions & 3 deletions assets/mssql/extant.sql

This file was deleted.

3 changes: 0 additions & 3 deletions assets/postgres/extant.sql

This file was deleted.

1 change: 0 additions & 1 deletion local/notifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ postgres: !postgres
sql: !asset
ext: .sql
path: ./assets/postgres
tables: []
username: ${POSTGRES_USERNAME}
username: ${EPIC_USERNAME}
user_id: ${EPIC_USER_ID}
Expand Down
1 change: 0 additions & 1 deletion local/verifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ postgres: !postgres
sql: !asset
ext: .sql
path: ./assets/postgres
tables: []
username: ${POSTGRES_USERNAME}
username: ${EPIC_USERNAME}
user_id: ${EPIC_USER_ID}
Expand Down
4 changes: 0 additions & 4 deletions src/dsdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from .interval import Interval
from .model import Mixin as ModelMixin
from .model import Model
from .mssql import CheckTablePrivileges as CheckMssqlTablePrivileges
from .mssql import Mixin as MssqlMixin
from .mssql import Persistor as Mssql
from .postgres import CheckTablePrivileges as CheckPostgresTablePrivileges
from .postgres import Mixin as PostgresMixin
from .postgres import Persistor as Postgres
from .postgres import PredictionMixin as PostgresPredictionMixin
Expand All @@ -34,8 +32,6 @@
"ModelMixin",
"MssqlMixin",
"Mssql",
"CheckMssqlTablePrivileges",
"CheckPostgresTablePrivileges",
"PostgresPredictionMixin",
"PostgresMixin",
"Postgres",
Expand Down
50 changes: 10 additions & 40 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Any, Dict, Generator

from .persistor import Persistor as BasePersistor
from .service import Service, Task
from .service import Service
from .utils import StubError

logger = getLogger(__name__)
Expand Down Expand Up @@ -43,8 +43,6 @@ class Messages: # pylint: disable=too-few-public-methods
COLUMN_PRIVILEGE = dumps({"key": f"{KEY}.warn", "value": "%s"})
COMMIT = dumps({"key": f"{KEY}.commit"})
END = dumps({"key": f"{KEY}.end"})
ERROR = dumps({"key": f"{KEY}.table.error", "table": "%s"})
ERRORS = dumps({"key": f"{KEY}.tables.error", "tables": "%s"})
EXTANT = dumps({"key": f"{KEY}.sql.extant", "value": "%s"})
ON = dumps({"key": f"{KEY}.on"})
OPEN = dumps({"key": f"{KEY}.open"})
Expand All @@ -66,33 +64,6 @@ def mogrify(
"""Safely mogrify parameters into query or fragment."""
return _mssql.substitute_params(query, parameters)

def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""check."""
logger.info(self.ON)
errors = []
for table in self.tables:
try:
statement = self.extant(table)
logger.info(self.EXTANT, table)
logger.debug(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
n, *_ = row
assert n == 1
continue
# pylint: disable=catching-non-exception
except exceptions as error:
number, *_ = error.args # args are not wrapped
# column privileges are a standards-breaking mssql mis-feature
if number == 230:
logger.info(self.COLUMN_PRIVILEGE, table)
continue
logger.warning(self.ERROR, table)
errors.append(table)
if bool(errors):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
Expand All @@ -112,6 +83,15 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

def dry_run(
self,
query_parameters,
skip=(),
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, skip, exceptions)


class Mixin(BaseMixin):
"""Mixin."""
Expand All @@ -133,13 +113,3 @@ def as_yaml(self) -> Dict[str, Any]:
"mssql": self.mssql,
**super().as_yaml(),
}


class CheckTablePrivileges(Task): # pylint: disable=too-few-public-methods
"""CheckTablePrivileges."""

def __call__(self, batch, service):
"""__call__."""
mssql = service.mssql
with mssql.rollback() as cur:
mssql.check(cur)
84 changes: 59 additions & 25 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from json import dumps
from logging import getLogger
from re import compile as re_compile
from string import Formatter
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Generator, Optional, Sequence, Tuple
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple

from cfgenvy import yaml_type
from pandas import DataFrame, concat
Expand All @@ -30,9 +31,8 @@ class AbstractPersistor:
CLOSE = dumps({"key": f"{KEY}.close"})
COMMIT = dumps({"key": f"{KEY}.commit"})
END = dumps({"key": f"{KEY}.end"})
ERROR = dumps({"key": f"{KEY}.table.error", "table": "%s"})
ERRORS = dumps({"key": f"{KEY}.tables.error", "tables": "%s"})
EXTANT = dumps({"key": f"{KEY}.sql.extant", "value": "%s"})
ERROR = dumps({"key": f"{KEY}.table.error", "query": "%s"})
ERRORS = dumps({"key": f"{KEY}.dry_run.error", "query": "%s"})
ON = dumps({"key": f"{KEY}.on"})
OPEN = dumps({"key": f"{KEY}.open"})
ROLLBACK = dumps({"key": f"{KEY}.rollback"})
Expand Down Expand Up @@ -130,6 +130,15 @@ def df_from_query_by_keys(
df.columns = columns
return df

@classmethod
def render_without_keys(cls, cur, query, parameters):
"""Render query without keys."""
if parameters is None:
parameters = {}
formatter = Formatter()
query = "".join((each[0] for each, in formatter.parse(query)))
return cls.mogrify(cur, query, parameters).decode("utf-8")

@classmethod
def mogrify(
cls,
Expand All @@ -152,28 +161,55 @@ def union_all(
union = cls.mogrify(cur, union, parameters).decode("utf-8")
return union

def __init__(self, sql: Asset, tables: Tuple[str, ...]):
def __init__(self, sql: Asset):
"""__init__."""
self.sql = sql
self.tables = tables

def check(self, cur, exceptions):
"""Check."""
def on_dry_run(
self,
sql: Asset,
query_parameters: Dict[str, Any],
skip: Tuple,
exceptions: Tuple,
):
"""On dry run."""
errors: List[Exception] = []
for key, value in vars(sql).items():
if value.__class__ == Asset:
errors += self.on_dry_run(
value, query_parameters, skip, exceptions
)
continue
if key in skip:
continue
with self.rollback() as cur:
rendered = self.render_without_keys(
cur,
value,
query_parameters,
)
with NamedTemporaryFile(
"w", delete=False, suffix=".sql"
) as fout:
fout.write(rendered)
try:
cur.execute(rendered)
except exceptions as e:
logger.warning(self.ERROR, key)
errors.append(e)
return errors

def dry_run(
self,
query_parameters: Dict[str, Any],
skip: Tuple = (),
exceptions: Tuple = (),
):
"""Execute sql found in asse with dry_run parameter set to 1."""
logger.info(self.ON)
errors = []
for table in self.tables:
try:
statement = self.extant(table)
logger.info(self.EXTANT, table)
logger.debug(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
n, *_ = row
assert n == 1
continue
except exceptions:
logger.warning(self.ERROR, table)
errors.append(table)
query_parameters = query_parameters.copy()
query_parameters["dry_run"] = 1
errors = self.on_dry_run(self.sql, query_parameters, skip, exceptions)
if bool(errors):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)
Expand Down Expand Up @@ -259,7 +295,6 @@ def __init__( # pylint: disable=too-many-arguments
password: str,
port: int,
sql: Asset,
tables: Tuple[str, ...],
username: str,
):
"""__init__."""
Expand All @@ -268,7 +303,7 @@ def __init__( # pylint: disable=too-many-arguments
self.password = password
self.port = port
self.username = username
super().__init__(sql, tables)
super().__init__(sql)

def as_yaml(self) -> Dict[str, Any]:
"""As yaml."""
Expand All @@ -278,7 +313,6 @@ def as_yaml(self) -> Dict[str, Any]:
"password": self.password,
"port": self.port,
"sql": self.sql,
"tables": self.tables,
"username": self.username,
}

Expand Down
23 changes: 9 additions & 14 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .interval import Interval
from .persistor import Persistor as BasePersistor
from .service import Delegate, Service, Task
from .service import Delegate, Service
from .utils import StubError, retry

logger = getLogger(__name__)
Expand Down Expand Up @@ -131,9 +131,14 @@ def mogrify(
"""Safely mogrify parameters into query or fragment."""
return cur.mogrify(query, parameters)

def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""Check."""
super().check(cur, exceptions)
def dry_run(
self,
query_parameters,
skip=("schema",),
exceptions=(DatabaseError, InterfaceError),
):
"""Dry run."""
super().dry_run(query_parameters, skip, exceptions)

@contextmanager
def listen(self, *listens: str) -> Generator[Any, None, None]:
Expand Down Expand Up @@ -356,13 +361,3 @@ def open_batch(self) -> Generator[Run, None, None]:
with super().open_batch() as parent:
with self.postgres.open_run(parent) as run:
yield run


class CheckTablePrivileges(Task): # pylint: disable=too-few-public-methods
"""CheckTablePrivileges."""

def __call__(self, batch, service):
"""__call__."""
postgres = service.postgres
with postgres.rollback() as cur:
postgres.check(cur)
3 changes: 2 additions & 1 deletion src/dsdk/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from cfgenvy import Parser, yaml_type
from numpy import allclose
from pandas import DataFrame
from pkg_resources import DistributionNotFound, get_distribution

Expand Down Expand Up @@ -386,7 +387,7 @@ def on_validate_gold(self) -> Batch:

n_tests += 1
try:
assert (scores == test).all()
assert allclose(scores, test)
logger.info(self.MATCH, "pass")
n_passes += 1
except AssertionError:
Expand Down
1 change: 1 addition & 0 deletions test/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model.pkl
Binary file removed test/model.pkl
Binary file not shown.
22 changes: 0 additions & 22 deletions test/test_dsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def __init__(self, **kwargs):
sql: !asset
ext: .sql
path: ./assets/mssql
tables:
- a
- b
- c
username: mssql
model: !model ./test/model.pkl
postgres: !postgres
Expand All @@ -95,12 +91,6 @@ def __init__(self, **kwargs):
sql: !asset
ext: .sql
path: ./assets/postgres
tables:
- ichi
- ni
- san
- shi
- go
username: postgres
""".strip()

Expand All @@ -123,10 +113,6 @@ def __init__(self, **kwargs):
sql: !asset
ext: .sql
path: ./assets/mssql
tables:
- a
- b
- c
username: mssql
postgres: !postgres
database: test
Expand All @@ -136,12 +122,6 @@ def __init__(self, **kwargs):
sql: !asset
ext: .sql
path: ./assets/postgres
tables:
- ichi
- ni
- san
- shi
- go
username: postgres
time_zone: null
""".strip()
Expand All @@ -161,7 +141,6 @@ def build(
port=1433,
database="test",
sql=Asset.build(path="./assets/mssql", ext=".sql"),
tables=("a", "b", "c"),
)
postgres = Postgres(
username="postgres",
Expand All @@ -170,7 +149,6 @@ def build(
port=5432,
database="test",
sql=Asset.build(path="./assets/postgres", ext=".sql"),
tables=("ichi", "ni", "san", "shi", "go"),
)
return (
cls,
Expand Down
Loading

0 comments on commit f4fafcc

Please sign in to comment.