Skip to content

Commit

Permalink
Add check
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubken committed Jun 2, 2020
1 parent f15da56 commit c4c4fc2
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@

from configargparse import ArgParser as ArgumentParser

from .service import Service
from .service import Service, Task
from .utils import get_logger

logger = get_logger(__name__, INFO)

try:
# Since not everyone will use mssql
from sqlalchemy import create_engine
from sqlalchemy.exc import DatabaseError, InterfaceError
except ImportError:
create_engine = None
DatabaseError = InterfaceError = Exception


if TYPE_CHECKING:
Expand Down Expand Up @@ -67,3 +69,70 @@ def open_mssql(self) -> Generator:
with self._mssql.connect() as con:
yield con
logger.info('"action": "connect"')


class TablePriviledgeCheck(Task): # pylint: disable=too-few-public-methods
"""Table priviledge check."""

CONNECT = """
select 1 as n
"""

EXTANT = """
select 1 as n where exists (select 1 as n from {table})
"""

KEY = "table_priviledge_check"

ON = "".join(("{", f'"key": "{KEY}.on"', "}"))

END = "".join(("{", f'"key": "{KEY}.end"', "}"))

COLUMN_PRIVILEDGE = "".join(
(
"{",
", ".join(
(f'"key": "{KEY}.column_priviledge_warning"', '"value": "%s"')
),
"}",
)
)

FAILED = "".join(
("{", ", ".join((f'"key": "{KEY}.failed"', '"value": "%s"')), "}")
)

FAILURES = "".join(
("{", ", ".join((f'"key": "{KEY}.failures"', '"value": "%s"')), "}")
)

def __init__(self, tables):
"""__init__."""
self.tables = tables

def __call__(self, batch, service):
"""__call__."""
logger.info(self.ON)
with service.open_mssql() as con:
# force lazy connection open.
cur = con.execute(self.CONNECT)
for _ in cur.fetchall():
pass
failures = []
for table in self.tables:
sql = self.EXTANT.format(table=table)
try:
cur = con.execute(sql)
for _ in cur.fetchall():
pass
except (DatabaseError, InterfaceError) as error:
number, *_ = error.orig.args
# column privileges are a non-standard breaking "feature"
if number == 230:
logger.info(self.COLUMN_PRIVILEDGE, table)
continue
logger.warning(self.FAILED, table)
failures.append(table)
if bool(failures):
raise RuntimeError(self.FAILURES, failures)
logger.info(self.END)

0 comments on commit c4c4fc2

Please sign in to comment.