From c463ed0af1d37acc1d44ad41dbd9016ecb5f4e01 Mon Sep 17 00:00:00 2001 From: Jason Lubken Date: Mon, 10 Jan 2022 18:22:17 -0500 Subject: [PATCH] Add Profile class for profile contextmanager --- src/dsdk/postgres.py | 58 ++++++++++++++++++++++++++++---------------- src/dsdk/profile.py | 41 +++++++++++++++++++++++++++++++ src/dsdk/utils.py | 26 +++++++++++++------- 3 files changed, 95 insertions(+), 30 deletions(-) create mode 100644 src/dsdk/profile.py diff --git a/src/dsdk/postgres.py b/src/dsdk/postgres.py index 24f51a4..23172cd 100644 --- a/src/dsdk/postgres.py +++ b/src/dsdk/postgres.py @@ -149,6 +149,27 @@ def dry_run( """Dry run.""" super().dry_run(query_parameters, exceptions) + @contextmanager + def commit(self) -> Generator[Any, None, None]: + """Commit.""" + with super().commit() as cur: + cur.execute(f"set search_path={self.schema};") + yield cur + + @contextmanager + def connect(self) -> Generator[Any, None, None]: + """Connect.""" + # Replace return type with ContextManager[Any] when mypy is fixed. + # The `with ... as con:` formulation does not close the connection: + # https://www.psycopg.org/docs/usage.html#with-statement + con = self.retry_connect() + logger.info(self.OPEN) + try: + yield con + finally: + con.close() + logger.info(self.CLOSE) + @contextmanager def listen(self, *listens: str) -> Generator[Any, None, None]: """Listen.""" @@ -160,6 +181,7 @@ def listen(self, *listens: str) -> Generator[Any, None, None]: # users to pop the last notify con.notifies = deque(con.notifies) with con.cursor() as cur: + cur.execute(f"set search_path={self.schema};") for each in listens: logger.debug(self.LISTEN, each) cur.execute(each) @@ -168,20 +190,6 @@ def listen(self, *listens: str) -> Generator[Any, None, None]: con.close() logger.info(self.CLOSE) - @contextmanager - def connect(self) -> Generator[Any, None, None]: - """Connect.""" - # Replace return type with ContextManager[Any] when mypy is fixed. - # The `with ... as con:` formulation does not close the connection: - # https://www.psycopg.org/docs/usage.html#with-statement - con = self.retry_connect() - logger.info(self.OPEN) - try: - yield con - finally: - con.close() - logger.info(self.CLOSE) - @contextmanager def open_run(self, parent: Any) -> Generator[Run, None, None]: """Open batch.""" @@ -189,7 +197,6 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: sql = self.sql columns = parent.as_insert_sql() with self.commit() as cur: - cur.execute(f"set search_path={self.schema}") cur.execute(sql.runs.open, columns) for row in cur: ( @@ -217,7 +224,6 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: yield run with self.commit() as cur: - cur.execute(f"set search_path={self.schema}") predictions = run.predictions if predictions is not None: # pylint: disable=unsupported-assignment-operation @@ -244,11 +250,17 @@ def retry_connect(self): dbname=self.database, ) + @contextmanager + def rollback(self) -> Generator[Any, None, None]: + """Rollback.""" + with super().rollback() as cur: + cur.execute(f"set search_path={self.schema};") + yield cur + def scores(self, run_id) -> Series: """Return scores series.""" sql = self.sql with self.rollback() as cur: - cur.execute(f"set search_path={self.schema}") return self.df_from_query( cur, sql.predictions.gold, @@ -289,18 +301,18 @@ def _store_df( out = df.to_dict("records") try: with self.commit() as cur: - cur.execute(f"set search_path={self.schema}") execute_batch( cur, insert, out, ) except DatabaseError as e: + # figure out all rows which failed, + # rolling back any successful insertions + # enumeration is a generator enumeration = enumerate(out) while True: with self.rollback() as cur: - # enumeration is a generator - # it will pick up where it left off for i, row in enumeration: try: cur.execute(insert, row) @@ -308,9 +320,13 @@ def _store_df( # assumes the client encoding is the default utf-8! value = dumps(cur.mogrify(insert, row).decode()) logger.error(self.DATA_TYPE_ERROR, i, value) - break + break # DatabaseError: break for loop else: + # GeneratorExit: enumeration is exhausted + # break while loop break + # DatabaseError: rollback and continue while loop + # enumeration will pick up where it left off raise e diff --git a/src/dsdk/profile.py b/src/dsdk/profile.py new file mode 100644 index 0000000..f113fd3 --- /dev/null +++ b/src/dsdk/profile.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Profile.""" + +from typing import Any, Dict, Optional + +from cfgenvy import yaml_type + + +class Profile: + """Profile.""" + + YAML = "!profile" + + @classmethod + def as_yaml_type(cls, tag: Optional[str] = None): + """As yaml type.""" + yaml_type( + cls, + tag or cls.YAML, + init=cls._yaml_init, + repr=cls._yaml_repr, + ) + + @classmethod + def _yaml_init(cls, loader, node): + """Yaml init.""" + return cls(**loader.construct_mapping(node, deep=True)) + + @classmethod + def _yaml_repr(cls, dumper, self, *, tag: str): + """Yaml repr.""" + return dumper.represent_mapping(tag, self.as_yaml()) + + def __init__(self, on: int, end: Optional[int] = None): + """__init__.""" + self.on = on + self.end = end + + def as_yaml(self) -> Dict[str, Any]: + """As yaml.""" + return {"end": self.end, "on": self.on} diff --git a/src/dsdk/utils.py b/src/dsdk/utils.py index 388c056..c0d8ca5 100644 --- a/src/dsdk/utils.py +++ b/src/dsdk/utils.py @@ -18,6 +18,9 @@ from dateutil import parser, tz +from .profile import Profile + + logger = getLogger(__name__) @@ -108,16 +111,21 @@ def now_utc_datetime() -> datetime: @contextmanager -def profile(key: str) -> Generator[Any, None, None]: +def profile(key: str) -> Generator[Profile, None, None]: """Profile.""" - # Replace return type with ContextManager[Any] when mypy is fixed. - begin = perf_counter_ns() - logger.info('{"key": "%s.begin", "ns": "%s"}', key, begin) - yield - end = perf_counter_ns() - logger.info( - '{"key": "%s.end", "ns": "%s", "elapsed": "%s"}', key, end, end - begin - ) + # Replace return type with ContextManager[Profile] when mypy is fixed. + i = Profile(perf_counter_ns()) + logger.info('{"key": "%s.on", "ns": "%s"}', key, i.on) + try: + yield i + finally: + i.end = perf_counter_ns() + logger.info( + '{"key": "%s.end", "ns": "%s", "elapsed": "%s"}', + key, + i.end, + i.end - i.on, + ) def retry(