From e9ddc1e2b0d92138f11de12d6a44333a55a7c1e0 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sat, 4 Jan 2025 08:45:38 +0900 Subject: [PATCH] Add `migrate_data` (#1029) --- pyproject.toml | 4 +- requirements.txt | 2 +- requirements/cvxpy.txt | 2 +- requirements/scipy.txt | 2 +- src/tests/test_sqlalchemy.py | 42 ++++++++++++++ src/utilities/__init__.py | 2 +- src/utilities/sqlalchemy.py | 105 +++++++++++++++++++++++++++++++++++ 7 files changed, 153 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 57dda2600..c2ef7f467 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dev = [ "python-dotenv >= 1.0.1, < 1.1", "redis >= 5.2.1, < 5.3", "rich >= 13.8.1, < 13.9", # if 13.9, twine upload fails https://github.com/dycw/python-utilities/actions/runs/11125686648/job/30913966455 - "scipy >= 1.14.1, < 1.15", + "scipy >= 1.15.0, < 1.16", "slack-sdk >= 3.34.0, < 3.35", "sqlalchemy >= 2.0.36, < 2.1", "streamlit >= 1.41.1, < 1.42", @@ -256,7 +256,7 @@ zzz-test-redis = [ "whenever >= 0.6.16, < 0.7", ] zzz-test-rich = ["rich >= 13.8.1, < 13.9"] -zzz-test-scipy = ["scipy >= 1.14.1, < 1.15"] +zzz-test-scipy = ["scipy >= 1.15.0, < 1.16"] zzz-test-sentinel = [] zzz-test-slack-sdk = [ "aiohttp >= 3.11.7, < 3.12", # for slack-sdk diff --git a/requirements.txt b/requirements.txt index bd837c784..3ef8b52a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -363,7 +363,7 @@ rpds-py==0.22.3 # via # jsonschema # referencing -scipy==1.14.1 +scipy==1.15.0 # via # dycw-utilities (pyproject.toml) # clarabel diff --git a/requirements/cvxpy.txt b/requirements/cvxpy.txt index 795660415..eef65d24e 100644 --- a/requirements/cvxpy.txt +++ b/requirements/cvxpy.txt @@ -61,7 +61,7 @@ pytest-xdist==3.6.1 # via dycw-utilities (pyproject.toml) qdldl==0.1.7.post5 # via osqp -scipy==1.14.1 +scipy==1.15.0 # via # clarabel # cvxpy diff --git a/requirements/scipy.txt b/requirements/scipy.txt index 7506e557e..bc072ed55 100644 --- a/requirements/scipy.txt +++ b/requirements/scipy.txt @@ -47,7 +47,7 @@ pytest-rerunfailures==15.0 # via dycw-utilities (pyproject.toml) pytest-xdist==3.6.1 # via dycw-utilities (pyproject.toml) -scipy==1.14.1 +scipy==1.15.0 # via dycw-utilities (pyproject.toml) sortedcontainers==2.4.0 # via hypothesis diff --git a/src/tests/test_sqlalchemy.py b/src/tests/test_sqlalchemy.py index 3ce48f280..787ce20e4 100644 --- a/src/tests/test_sqlalchemy.py +++ b/src/tests/test_sqlalchemy.py @@ -81,6 +81,7 @@ insert_items, is_orm, is_table_or_orm, + migrate_data, selectable_to_string, upsert_items, yield_primary_key_columns, @@ -839,6 +840,47 @@ def test_error_snake_non_unique_error(self, *, id_: int, value: bool) -> None: _ = _map_mapping_to_table(mapping, table, snake=True) +class TestMigrateData: + @FLAKY + @given( + data=data(), + names=sets(_table_names(), min_size=2, max_size=2), + values=lists( + tuples(integers(0, 10), booleans() | none()), + min_size=1, + unique_by=lambda x: x[0], + ), + ) + @settings_with_reduced_examples(phases={Phase.generate}) + async def test_main( + self, *, data: DataObject, names: set[str], values: list[tuple[int, bool]] + ) -> None: + engine1 = await sqlalchemy_engines(data) + name1, name2 = names + table1 = self._make_table(name1) + await insert_items( + engine1, [({"id_": id_, "value": v}, table1) for id_, v in values] + ) + async with engine1.begin() as conn: + result1 = (await conn.execute(select(table1))).all() + assert len(result1) == len(values) + + engine2 = await sqlalchemy_engines(data) + table2 = self._make_table(name2) + await migrate_data(table1, engine1, engine2, table_or_orm_to=table2) + async with engine2.begin() as conn: + result2 = (await conn.execute(select(table2))).all() + assert len(result2) == len(values) + + def _make_table(self, name: str, /) -> Table: + return Table( + name, + MetaData(), + Column("id_", Integer, primary_key=True), + Column("value", Boolean, nullable=True), + ) + + class TestNormalizeInsertItem: @given(case=sampled_from(["tuple", "dict"]), id_=integers(0, 10)) def test_pair_of_tuple_or_str_mapping_and_table( diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index 82b41bec2..83a4a1f83 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.88.10" +__version__ = "0.88.11" diff --git a/src/utilities/sqlalchemy.py b/src/utilities/sqlalchemy.py index d9cf3fac0..e12217b68 100644 --- a/src/utilities/sqlalchemy.py +++ b/src/utilities/sqlalchemy.py @@ -35,6 +35,7 @@ and_, case, insert, + select, text, ) from sqlalchemy.dialects.mssql import dialect as mssql_dialect @@ -101,6 +102,9 @@ CHUNK_SIZE_FRAC = 0.95 +## + + async def check_engine( engine: AsyncEngine, /, @@ -162,6 +166,9 @@ def __str__(self) -> str: return f"{get_repr(self.engine)} must have {self.expected} table(s); got {len(self.rows)}" +## + + def columnwise_max(*columns: Any) -> Any: """Compute the columnwise max of a number of columns.""" return _columnwise_minmax(*columns, op=ge) @@ -199,6 +206,9 @@ def func(x: Any, y: Any, /) -> Any: return reduce(func, columns) +## + + def create_async_engine( drivername: str, /, @@ -232,6 +242,9 @@ def func(x: MaybeIterable[str], /) -> list[str] | str: return _create_async_engine(url, poolclass=poolclass) +## + + async def ensure_tables_created( engine: AsyncEngine, /, @@ -301,6 +314,9 @@ async def ensure_tables_dropped( _ensure_tables_maybe_reraise(error, match) +## + + def get_chunk_size( engine_or_conn: _EngineOrConnectionOrAsync, /, @@ -313,16 +329,25 @@ def get_chunk_size( return max(floor(chunk_size_frac * max_params / scaling), 1) +## + + def get_column_names(table_or_orm: TableOrORMInstOrClass, /) -> list[str]: """Get the column names from a table or ORM instance/class.""" return [col.name for col in get_columns(table_or_orm)] +## + + def get_columns(table_or_orm: TableOrORMInstOrClass, /) -> list[Column[Any]]: """Get the columns from a table or ORM instance/class.""" return list(get_table(table_or_orm).columns) +## + + def get_table(table_or_orm: TableOrORMInstOrClass, /) -> Table: """Get the table from a Table or mapped class.""" if isinstance(table_or_orm, Table): @@ -341,17 +366,26 @@ def __str__(self) -> str: return f"Object {self.obj} must be a Table or mapped class; got {get_class_name(self.obj)!r}" +## + + def get_table_name(table_or_orm: TableOrORMInstOrClass, /) -> str: """Get the table name from a Table or mapped class.""" return get_table(table_or_orm).name +## + + def hash_primary_key_columns(orm: DeclarativeBase, /) -> int: """Compute a hash of the primary key columns.""" values = tuple(getattr(orm, c.name) for c in yield_primary_key_columns(orm)) return hash(values) +## + + _PairOfTupleAndTable: TypeAlias = tuple[tuple[Any, ...], TableOrORMInstOrClass] _PairOfStrMappingAndTable: TypeAlias = tuple[StrMapping, TableOrORMInstOrClass] _PairOfTupleOrStrMappingAndTable: TypeAlias = tuple[ @@ -450,6 +484,9 @@ def __str__(self) -> str: return f"Item must be valid; got {self.item}" +## + + def is_orm(obj: Any, /) -> TypeGuard[ORMInstOrClass]: """Check if an object is an ORM instance/class.""" if isinstance(obj, type): @@ -461,11 +498,39 @@ def is_orm(obj: Any, /) -> TypeGuard[ORMInstOrClass]: return is_orm(type(obj)) +## + + def is_table_or_orm(obj: Any, /) -> TypeGuard[TableOrORMInstOrClass]: """Check if an object is a Table or an ORM instance/class.""" return isinstance(obj, Table) or is_orm(obj) +## + + +async def migrate_data( + table_or_orm_from: TableOrORMInstOrClass, + engine_from: AsyncEngine, + engine_to: AsyncEngine, + /, + *, + table_or_orm_to: TableOrORMInstOrClass | None = None, +) -> None: + """Migrate the contents of a table from one database to another.""" + table_from = get_table(table_or_orm_from) + async with engine_from.begin() as conn: + rows = (await conn.execute(select(table_from))).all() + table_to = table_from if table_or_orm_to is None else get_table(table_or_orm_to) + await ensure_tables_created(engine_to, table_to) + mappings = [dict(r._mapping) for r in rows] # noqa: SLF001 + async with engine_to.begin() as conn: + _ = await conn.execute(insert(table_to).values(mappings)) + + +## + + def _normalize_insert_item( item: _InsertItem, /, *, snake: bool = False ) -> list[_NormalizedItem]: @@ -537,6 +602,9 @@ def _normalize_upsert_item( assert_never(never) +## + + def selectable_to_string( selectable: Selectable[Any], engine_or_conn: _EngineOrConnectionOrAsync, / ) -> str: @@ -547,6 +615,9 @@ def selectable_to_string( return str(com) +## + + class TablenameMixin: """Mix-in for an auto-generated tablename.""" @@ -557,6 +628,9 @@ def __tablename__(cls) -> str: # noqa: N805 return snake_case(get_class_name(cls)) +## + + @dataclass(kw_only=True, slots=True) class Upserter: """Upsert a set of items into a database.""" @@ -633,6 +707,9 @@ async def _run(self, *items: _InsertItem) -> None: await self._post_upsert(items) +## + + _SelectedOrAll: TypeAlias = Literal["selected", "all"] @@ -766,6 +843,9 @@ def __str__(self) -> str: return f"Item must be valid; got {self.item}" +## + + def yield_primary_key_columns( obj: TableOrORMInstOrClass, /, @@ -779,12 +859,18 @@ def yield_primary_key_columns( yield column +## + + def _ensure_tables_maybe_reraise(error: DatabaseError, match: str, /) -> None: """Re-raise the error if it does not match the required statement.""" if not search(match, ensure_str(one(error.args))): raise error # pragma: no cover +## + + def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect: """Get the dialect of a database.""" dialect = engine_or_conn.dialect @@ -804,6 +890,9 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect: raise NotImplementedError(msg) # pragma: no cover +## + + def _get_dialect_max_params( dialect_or_engine_or_conn: Dialect | _EngineOrConnectionOrAsync, / ) -> int: @@ -831,6 +920,9 @@ def _get_dialect_max_params( assert_never(never) +## + + def _is_pair_of_sequence_of_tuple_or_string_mapping_and_table( obj: Any, / ) -> TypeGuard[_PairOfSequenceOfTupleOrStrMappingAndTable]: @@ -869,6 +961,9 @@ def _is_pair_with_predicate_and_table( ) +## + + def _map_mapping_to_table( mapping: StrMapping, table_or_orm: TableOrORMInstOrClass, /, *, snake: bool = False ) -> StrMapping: @@ -941,6 +1036,9 @@ def __str__(self) -> str: return f"Mapping {get_repr(self.mapping)} must be a subset of table columns {get_repr(self.columns)}; found columns {self.first!r}, {self.second!r} and perhaps more to map to {self.key!r} modulo snake casing" +## + + def _orm_inst_to_dict(obj: DeclarativeBase, /) -> StrMapping: """Map an ORM instance to a dictionary.""" cls = type(obj) @@ -960,6 +1058,9 @@ def yield_items() -> Iterator[tuple[str, Any]]: return dict(yield_items()) +## + + @dataclass(kw_only=True, slots=True) class _PrepareInsertOrUpsertItems: mapping: dict[Table, list[StrMapping]] = field(default_factory=dict) @@ -1038,6 +1139,9 @@ def _prepare_insert_or_upsert_items_merge_items( ] + unchanged +## + + def _tuple_to_mapping( values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, / ) -> dict[str, Any]: @@ -1068,6 +1172,7 @@ def _tuple_to_mapping( "insert_items", "is_orm", "is_table_or_orm", + "migrate_data", "selectable_to_string", "upsert_items", "yield_primary_key_columns",