Skip to content

Commit

Permalink
Add Profile class for profile contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubken committed Jan 10, 2022
1 parent ffec384 commit c463ed0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 30 deletions.
58 changes: 37 additions & 21 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def dry_run(
"""Dry run."""
super().dry_run(query_parameters, exceptions)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with super().commit() as cur:
cur.execute(f"set search_path={self.schema};")
yield cur

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
# The `with ... as con:` formulation does not close the connection:
# https://www.psycopg.org/docs/usage.html#with-statement
con = self.retry_connect()
logger.info(self.OPEN)
try:
yield con
finally:
con.close()
logger.info(self.CLOSE)

@contextmanager
def listen(self, *listens: str) -> Generator[Any, None, None]:
"""Listen."""
Expand All @@ -160,6 +181,7 @@ def listen(self, *listens: str) -> Generator[Any, None, None]:
# users to pop the last notify
con.notifies = deque(con.notifies)
with con.cursor() as cur:
cur.execute(f"set search_path={self.schema};")
for each in listens:
logger.debug(self.LISTEN, each)
cur.execute(each)
Expand All @@ -168,28 +190,13 @@ def listen(self, *listens: str) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
# The `with ... as con:` formulation does not close the connection:
# https://www.psycopg.org/docs/usage.html#with-statement
con = self.retry_connect()
logger.info(self.OPEN)
try:
yield con
finally:
con.close()
logger.info(self.CLOSE)

@contextmanager
def open_run(self, parent: Any) -> Generator[Run, None, None]:
"""Open batch."""
# Replace return type with ContextManager[Run] when mypy is fixed.
sql = self.sql
columns = parent.as_insert_sql()
with self.commit() as cur:
cur.execute(f"set search_path={self.schema}")
cur.execute(sql.runs.open, columns)
for row in cur:
(
Expand Down Expand Up @@ -217,7 +224,6 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]:
yield run

with self.commit() as cur:
cur.execute(f"set search_path={self.schema}")
predictions = run.predictions
if predictions is not None:
# pylint: disable=unsupported-assignment-operation
Expand All @@ -244,11 +250,17 @@ def retry_connect(self):
dbname=self.database,
)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with super().rollback() as cur:
cur.execute(f"set search_path={self.schema};")
yield cur

def scores(self, run_id) -> Series:
"""Return scores series."""
sql = self.sql
with self.rollback() as cur:
cur.execute(f"set search_path={self.schema}")
return self.df_from_query(
cur,
sql.predictions.gold,
Expand Down Expand Up @@ -289,28 +301,32 @@ def _store_df(
out = df.to_dict("records")
try:
with self.commit() as cur:
cur.execute(f"set search_path={self.schema}")
execute_batch(
cur,
insert,
out,
)
except DatabaseError as e:
# figure out all rows which failed,
# rolling back any successful insertions
# enumeration is a generator
enumeration = enumerate(out)
while True:
with self.rollback() as cur:
# enumeration is a generator
# it will pick up where it left off
for i, row in enumeration:
try:
cur.execute(insert, row)
except DatabaseError:
# assumes the client encoding is the default utf-8!
value = dumps(cur.mogrify(insert, row).decode())
logger.error(self.DATA_TYPE_ERROR, i, value)
break
break # DatabaseError: break for loop
else:
# GeneratorExit: enumeration is exhausted
# break while loop
break
# DatabaseError: rollback and continue while loop
# enumeration will pick up where it left off
raise e


Expand Down
41 changes: 41 additions & 0 deletions src/dsdk/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
"""Profile."""

from typing import Any, Dict, Optional

from cfgenvy import yaml_type


class Profile:
"""Profile."""

YAML = "!profile"

@classmethod
def as_yaml_type(cls, tag: Optional[str] = None):
"""As yaml type."""
yaml_type(
cls,
tag or cls.YAML,
init=cls._yaml_init,
repr=cls._yaml_repr,
)

@classmethod
def _yaml_init(cls, loader, node):
"""Yaml init."""
return cls(**loader.construct_mapping(node, deep=True))

@classmethod
def _yaml_repr(cls, dumper, self, *, tag: str):
"""Yaml repr."""
return dumper.represent_mapping(tag, self.as_yaml())

def __init__(self, on: int, end: Optional[int] = None):
"""__init__."""
self.on = on
self.end = end

def as_yaml(self) -> Dict[str, Any]:
"""As yaml."""
return {"end": self.end, "on": self.on}
26 changes: 17 additions & 9 deletions src/dsdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from dateutil import parser, tz

from .profile import Profile


logger = getLogger(__name__)


Expand Down Expand Up @@ -108,16 +111,21 @@ def now_utc_datetime() -> datetime:


@contextmanager
def profile(key: str) -> Generator[Any, None, None]:
def profile(key: str) -> Generator[Profile, None, None]:
"""Profile."""
# Replace return type with ContextManager[Any] when mypy is fixed.
begin = perf_counter_ns()
logger.info('{"key": "%s.begin", "ns": "%s"}', key, begin)
yield
end = perf_counter_ns()
logger.info(
'{"key": "%s.end", "ns": "%s", "elapsed": "%s"}', key, end, end - begin
)
# Replace return type with ContextManager[Profile] when mypy is fixed.
i = Profile(perf_counter_ns())
logger.info('{"key": "%s.on", "ns": "%s"}', key, i.on)
try:
yield i
finally:
i.end = perf_counter_ns()
logger.info(
'{"key": "%s.end", "ns": "%s", "elapsed": "%s"}',
key,
i.end,
i.end - i.on,
)


def retry(
Expand Down

0 comments on commit c463ed0

Please sign in to comment.