diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f88408c8..70abceab9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,5 +61,6 @@ repos: tests/curators/conftest.py| tests/permissions/conftest.py| tests/writelog/conftest.py| + tests/writelog_sqlite/conftest.py| tests/curators/test_curators_examples.py ) diff --git a/lamindb/core/writelog/_constants.py b/lamindb/core/writelog/_constants.py new file mode 100644 index 000000000..8a452ef2b --- /dev/null +++ b/lamindb/core/writelog/_constants.py @@ -0,0 +1 @@ +FOREIGN_KEYS_LIST_COLUMN_NAME = "_lamin_fks" diff --git a/lamindb/core/writelog/_db_metadata_wrapper.py b/lamindb/core/writelog/_db_metadata_wrapper.py index b20034efc..0e22ac021 100644 --- a/lamindb/core/writelog/_db_metadata_wrapper.py +++ b/lamindb/core/writelog/_db_metadata_wrapper.py @@ -6,7 +6,7 @@ from django.db.models import ManyToManyField from typing_extensions import override -from ._types import KeyConstraint, TableUID, UIDColumns +from ._types import Column, ColumnType, KeyConstraint, TableUID, UIDColumns class DatabaseMetadataWrapper(ABC): @@ -20,9 +20,12 @@ class DatabaseMetadataWrapper(ABC): """ @abstractmethod - def get_column_names(self, table: str, cursor: CursorWrapper) -> set[str]: + def get_columns(self, table: str, cursor: CursorWrapper) -> set[Column]: raise NotImplementedError() + def get_column_names(self, table: str, cursor: CursorWrapper) -> set[str]: + return {c.name for c in self.get_columns(table, cursor)} + @abstractmethod def get_tables_with_installed_triggers(self, cursor: CursorWrapper) -> set[str]: raise NotImplementedError() @@ -67,6 +70,25 @@ def get_many_to_many_db_tables(self) -> set[str]: return many_to_many_tables + def _get_columns_by_name( + self, table: str, column_names: list[str], cursor: CursorWrapper + ) -> list[Column]: + columns = self.get_columns(table=table, cursor=cursor) + + column_list: list[Column] = [] + + for column_name in column_names: + column = next((c for c in columns if c.name == column_name), None) + + if column is None: + raise ValueError( + f"Table '{table}' doesn't have a column named '{column_name}'" + ) + + column_list.append(column) + + return column_list + def get_uid_columns(self, table: str, cursor: CursorWrapper) -> UIDColumns: """Get the UID columns for a given table.""" if table == "lamindb_featurevalue": @@ -74,7 +96,19 @@ def get_uid_columns(self, table: str, cursor: CursorWrapper) -> UIDColumns: return [ TableUID( source_table_name=table, - uid_columns=["value", "created_at"], + uid_columns=self._get_columns_by_name( + table, ["value", "created_at"], cursor + ), + key_constraint=None, + ) + ] + elif table == "lamindb_param": + return [ + TableUID( + source_table_name=table, + uid_columns=self._get_columns_by_name( + table, ["name", "dtype", "created_at"], cursor + ), key_constraint=None, ) ] @@ -86,7 +120,7 @@ def get_uid_columns(self, table: str, cursor: CursorWrapper) -> UIDColumns: return [ TableUID( source_table_name=table, - uid_columns=["uid"], + uid_columns=self._get_columns_by_name(table, ["uid"], cursor), key_constraint=None, ) ] @@ -128,6 +162,11 @@ def get_uid_columns(self, table: str, cursor: CursorWrapper) -> UIDColumns: class PostgresDatabaseMetadataWrapper(DatabaseMetadataWrapper): + def __init__(self) -> None: + super().__init__() + + self._columns: dict[str, set[Column]] | None = None + @override def get_table_key_constraints( self, table: str, cursor: CursorWrapper @@ -141,7 +180,11 @@ def get_table_key_constraints( WHEN 'f' THEN 'FOREIGN KEY' END AS constraint_type, a.attname AS source_column, + a.attnum AS source_column_position, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS source_column_type, CASE WHEN tc.contype = 'f' THEN af.attname ELSE NULL END AS target_column, + CASE WHEN tc.contype = 'f' THEN af.attnum ELSE NULL END AS target_column_position, + CASE WHEN tc.contype = 'f' THEN pg_catalog.format_type(af.atttypid, af.atttypmod) ELSE NULL END AS target_column_type, CASE WHEN tc.contype = 'f' THEN tf.relname ELSE NULL END AS target_table FROM pg_constraint tc @@ -171,12 +214,27 @@ def get_table_key_constraints( ( constraint_name, constraint_type, - source_column, - target_column, + source_column_name, + source_column_position, + source_column_type, + target_column_name, + target_column_position, + target_column_type, target_table, ) = k + source_column = Column( + name=source_column_name, + type=self._get_column_type(source_column_type), + ordinal_position=source_column_position, + ) + if constraint_type == "PRIMARY KEY": + if target_table is not None or target_column_name is not None: + raise Exception( + "Expected foreign key's target table/column to be NULL" + ) + if primary_key_constraint is None: primary_key_constraint = KeyConstraint( constraint_name=constraint_name, @@ -187,8 +245,14 @@ def get_table_key_constraints( ) primary_key_constraint.source_columns.append(source_column) - primary_key_constraint.target_columns.append(target_column) + elif constraint_type == "FOREIGN KEY": + target_column = Column( + name=target_column_name, + type=self._get_column_type(target_column_type), + ordinal_position=target_column_position, + ) + if constraint_name not in foreign_key_constraints: foreign_key_constraints[constraint_name] = KeyConstraint( constraint_name=constraint_name, @@ -214,14 +278,64 @@ def get_table_key_constraints( return (primary_key_constraint, list(foreign_key_constraints.values())) + def _get_column_type(self, data_type: str) -> ColumnType: + column_type: ColumnType + + if data_type in ("smallint", "integer", "bigint"): + column_type = ColumnType.INT + elif data_type in ("boolean",): + column_type = ColumnType.BOOL + elif data_type in ("character varying", "text"): + column_type = ColumnType.STR + elif data_type in ("jsonb",): + column_type = ColumnType.JSON + elif data_type in ("date",): + column_type = ColumnType.DATE + elif data_type in ("timestamp with time zone",): + column_type = ColumnType.TIMESTAMPTZ + elif data_type in ("double precision",): + column_type = ColumnType.FLOAT + else: + raise ValueError( + f"Don't know how to canonicalize column type '{data_type}'" + ) + + return column_type + @override - def get_column_names(self, table: str, cursor: CursorWrapper) -> set[str]: - cursor.execute( - "SELECT column_name FROM information_schema.columns WHERE TABLE_NAME = %s ORDER BY ordinal_position", - (table,), - ) + def get_columns(self, table: str, cursor: CursorWrapper) -> set[Column]: + if self._columns is None: + cursor.execute(""" +SELECT + table_name, + column_name, + data_type, + ordinal_position +FROM information_schema.columns +WHERE + table_schema not in ('pg_catalog', 'information_schema') +ORDER BY table_name, ordinal_position +""") + self._columns = {} - return {r[0] for r in cursor.fetchall()} + for row in cursor.fetchall(): + table_name, column_name, data_type, ordinal_position = row + + column_type = self._get_column_type(data_type) + ordinal_position = ordinal_position + + if table_name not in self._columns: + self._columns[table_name] = set() + + self._columns[table_name].add( + Column( + name=column_name, + type=column_type, + ordinal_position=ordinal_position, + ) + ) + + return self._columns[table] @override def get_tables_with_installed_triggers(self, cursor: CursorWrapper) -> set[str]: @@ -246,15 +360,118 @@ class SQLiteDatabaseMetadataWrapper(DatabaseMetadataWrapper): """This is a placeholder until we implement the SQLite side of synchronization.""" @override - def get_column_names(self, table: str, cursor: CursorWrapper) -> set[str]: - raise NotImplementedError() + def get_columns(self, table: str, cursor: CursorWrapper) -> set[Column]: + cursor.execute( + """ +SELECT + name, + "type", + cid AS ordinal_position +FROM pragma_table_info(%s) +""", + [table], + ) + + return { + Column( + name=r[0], + type=self._data_type_to_column_type(r[1]), + ordinal_position=r[2], + ) + for r in cursor.fetchall() + } @override def get_tables_with_installed_triggers(self, cursor: CursorWrapper) -> set[str]: - raise NotImplementedError() + cursor.execute("SELECT tbl_name FROM sqlite_master WHERE type = 'trigger';") + + return {r[0] for r in cursor.fetchall()} @override def get_table_key_constraints( self, table: str, cursor: CursorWrapper ) -> tuple[KeyConstraint, list[KeyConstraint]]: - raise NotImplementedError() + cursor.execute( + """ +SELECT + p.name AS column_name, + p.type AS data_type, + p.cid AS ordinal_position +FROM sqlite_schema m +JOIN pragma_table_info(m.name) p ON p.pk > 0 +WHERE m.type = 'table' + AND m.name = %s +ORDER BY m.name, p.pk; +""", + [table], + ) + + primary_key_columns = cursor.fetchall() + + primary_key_constraint = KeyConstraint( + constraint_name="primary", + constraint_type="PRIMARY KEY", + source_columns=[ + Column( + name=row[0], + type=self._data_type_to_column_type(row[1]), + ordinal_position=row[2], + ) + for row in primary_key_columns + ], + target_columns=[], + target_table=table, + ) + + foreign_key_constraints: dict[int, KeyConstraint] = {} + + cursor.execute( + """ +SELECT + id AS fk_id, + "from" AS from_column, + "table" AS referenced_table, + "to" AS referenced_column +FROM pragma_foreign_key_list(%s) +ORDER BY id, seq; +""", + [table], + ) + + rows = cursor.fetchall() + + for row in rows: + fk_id, from_column, referenced_table, referenced_column = row + + if fk_id not in foreign_key_constraints: + foreign_key_constraints[fk_id] = KeyConstraint( + constraint_name=f"foreign_key_{fk_id}", + constraint_type="FOREIGN KEY", + source_columns=[], + target_columns=[], + target_table=referenced_table, + ) + + foreign_key_constraints[fk_id].source_columns.append(from_column) + foreign_key_constraints[fk_id].target_columns.append(referenced_column) + + return ( + primary_key_constraint, + sorted(foreign_key_constraints.values(), key=lambda c: c.constraint_name), + ) + + def _data_type_to_column_type(self, data_type: str) -> ColumnType: + if data_type.lower() in ("integer", "bigint", "smallint", "smallint unsigned"): + return ColumnType.INT + elif data_type.lower() in ("bool",): + return ColumnType.BOOL + elif data_type.lower().startswith("varchar") or data_type.lower() == "text": + return ColumnType.STR + elif data_type.lower() == "datetime": + return ColumnType.TIMESTAMPTZ + elif data_type.lower() == "date": + return ColumnType.DATE + elif data_type.lower() == "real": + return ColumnType.FLOAT + else: + raise ValueError(f"Unhandled data type '{data_type}'") diff --git a/lamindb/core/writelog/_replayer.py b/lamindb/core/writelog/_replayer.py new file mode 100644 index 000000000..052588a3e --- /dev/null +++ b/lamindb/core/writelog/_replayer.py @@ -0,0 +1,366 @@ +import datetime +from dataclasses import dataclass +from typing import Any + +from django.db import connection +from django.db.backends.utils import CursorWrapper + +from lamindb.core.writelog._constants import FOREIGN_KEYS_LIST_COLUMN_NAME +from lamindb.core.writelog._trigger_installer import WriteLogEventTypes +from lamindb.core.writelog._types import Column, ColumnType +from lamindb.models.writelog import WriteLog, WriteLogTableState + +from ._db_metadata_wrapper import DatabaseMetadataWrapper + + +@dataclass +class WriteLogForeignKey: + table_name: str + foreign_key_columns: list[str] + foreign_uid: dict[str, str] + + +class WriteLogReplayer: + def __init__( + self, + db_metadata: DatabaseMetadataWrapper, + cursor: CursorWrapper, + ): + self.db_metadata = db_metadata + self.cursor = cursor + + def replay(self, write_log_record: WriteLog): + table_name = write_log_record.table.table_name + + # If we're deleting the record, its (resolved) UID is the only thing we need + if write_log_record.event_type == WriteLogEventTypes.DELETE.value: + record_id = self._resolve_uid(table_name, write_log_record.record_uid) + + self.cursor.execute(f""" + DELETE FROM {table_name} + WHERE {" AND ".join(f"{k} = {v}" for k, v in record_id.items())} + LIMIT 1 + """) # noqa: S608 + else: + record_data = self._build_record_data(write_log_record) + + if write_log_record.event_type == WriteLogEventTypes.INSERT.value: + record_column_names = [d[0].name for d in record_data] + record_values = [ + self._cast_column(column=d[0], value=d[1]) for d in record_data + ] + + if table_name in self.db_metadata.get_many_to_many_db_tables(): + record_id = self._resolve_foreign_keys_list( + write_log_record.record_uid + ) + + for column_name, value in record_id.items(): + record_column_names.append(column_name) + record_values.append(str(value)) + else: + uid_list = self.db_metadata.get_uid_columns( + table=table_name, cursor=self.cursor + ) + + if len(uid_list) != 1: + raise ValueError( + f"Table {table_name} is not marked as many-to-many, but has " + "more than one table reference in its UID" + ) + + record_uid = uid_list[0] + + if len(record_uid.uid_columns) != len(write_log_record.record_uid): + raise ValueError( + f"Write log record {write_log_record.uid} expected to " + f"have {len(record_uid.uid_columns)} components to its " + "UID, but only had {len(write_log_record.record_uid)}" + ) + + for uid_column, value in zip( + record_uid.uid_columns, write_log_record.record_uid + ): + record_column_names.append(uid_column.name) + record_values.append(self._cast_column(uid_column, value)) + + self.cursor.execute(f""" + INSERT INTO {table_name} + ({", ".join(record_column_names)}) VALUES ({", ".join(record_values)}) + """) # noqa: S608 + elif write_log_record.event_type == WriteLogEventTypes.UPDATE.value: + record_id = self._resolve_uid(table_name, write_log_record.record_uid) + + set_statements = ", ".join( + f"{column.name} = {self._cast_column(column, value)}" + for (column, value) in record_data + ) + + self.cursor.execute(f""" + UPDATE {table_name} + SET {set_statements} + WHERE {" AND ".join(f"{k} = {v}" for (k, v) in record_id.items())} LIMIT 1 + """) # noqa: S608 + else: + raise ValueError( + f"Unhandled record event type {write_log_record.event_type}" + ) + + def _cast_column(self, column: Column, value: Any) -> str: + """Returns a string representation of the column that will cast it to the appropriate type.""" + if connection.vendor == "postgresql": + return self._cast_column_postgres(column, value) + elif connection.vendor == "sqlite": + return self._cast_column_sqlite(column, value) + else: + raise ValueError(f"Unsupported connection vendor '{connection.vendor}'") + + def _cast_column_postgres(self, column: Column, value: Any) -> str: + if value is None: + return "NULL" + elif column.type == ColumnType.INT: + return str(value) + elif column.type == ColumnType.BOOL: + return "TRUE" if value is True else "FALSE" + elif column.type == ColumnType.STR: + return f"'{value}'" + elif column.type == ColumnType.DATE: + return f"date('{value}')" + elif column.type == ColumnType.FLOAT: + return str(value) + elif column.type == ColumnType.JSON: + return f"to_jsonb('{value}'::text)" + elif column.type == ColumnType.TIMESTAMPTZ: + return f"timestamptz('{value}')" + else: + raise ValueError(f"Unhandled type {column.type}") + + def _cast_column_sqlite(self, column: Column, value: Any) -> str: + if value is None: + return "NULL" + elif column.type == ColumnType.INT: + return str(value) + elif column.type == ColumnType.BOOL: + return "1" if value is True else "0" + elif column.type == ColumnType.STR: + return f"'{value}'" + elif column.type == ColumnType.DATE: + return f"'{value}'" + elif column.type == ColumnType.FLOAT: + return f"CAST('{str(value)}' AS REAL)" + elif column.type == ColumnType.JSON: + return f"'{value}'" + elif column.type == ColumnType.TIMESTAMPTZ: + formatted_datetime = ( + datetime.datetime.fromisoformat(value) + .astimezone(datetime.timezone.utc) + .strftime("%Y-%m-%d %H:%M:%S.%f") + ) + + return f"'{formatted_datetime}'" + else: + raise ValueError(f"Unhandled type {column.type}") + + def _resolve_uid(self, table_name: str, record_uid) -> dict[str, int]: + if table_name in self.db_metadata.get_many_to_many_db_tables(): + resolved_record_uid = self._resolve_foreign_keys_list(record_uid) + else: + table_uid_list = self.db_metadata.get_uid_columns( + table=table_name, cursor=self.cursor + ) + + if ( + len(table_uid_list) != 1 + and table_uid_list[0].source_table_name == table_name + ): + raise ValueError( + f"Expected standard table {table_name}'s UID to refer only to itself" + ) + + uid_columns = table_uid_list[0].uid_columns + + if not ( + isinstance(record_uid, list) and len(record_uid) == len(uid_columns) + ): + raise ValueError( + f"Expected standard table {table_name}'s UID to be a list of length {len(uid_columns)}" + ) + + lookup_where_clause = " AND ".join( + f"{k.name} = {self._cast_column(k, v)}" + for (k, v) in zip(uid_columns, record_uid) + ) + + primary_key, _ = self.db_metadata.get_table_key_constraints( + table=table_name, cursor=self.cursor + ) + + self.cursor.execute(f""" + SELECT {", ".join(c.name for c in primary_key.source_columns)} + FROM {table_name} + WHERE {lookup_where_clause} + LIMIT 1 + """) # noqa: S608 + + table_primary_key_values = self.cursor.fetchone() + + if table_primary_key_values is None: + raise ValueError( + f"Unable to locate a record in {table_name} with UID {record_uid}" + ) + + return dict( + zip( + [c.name for c in primary_key.source_columns], + table_primary_key_values, + ) + ) + + return resolved_record_uid + + def _build_record_data(self, record: WriteLog) -> list[tuple[Column, str]]: + record_data: dict[str, Any] | None = record.record_data + + if record_data is None: + raise ValueError( + f"Expected non-null record data for write log record {record.uid} (type: {record.event_type})" + ) + + # We're outputting this data as tuples so that it's easy to output keys and values + # as two separate lists with the same order in INSERT. + record_data_tuples: list[tuple[Column, str]] = [] + + columns_by_name = { + c.name: c + for c in self.db_metadata.get_columns( + table=record.table.table_name, cursor=self.cursor + ) + } + + for key, value in record_data.items(): + if key == FOREIGN_KEYS_LIST_COLUMN_NAME: + for ( + foreign_key_column, + foreign_key_value, + ) in self._resolve_foreign_keys_list(value).items(): + if foreign_key_column not in columns_by_name: + raise ValueError( + f"Table {record.table.table_name} does not have a column named {foreign_key_column}" + ) + + record_data_tuples.append( + (columns_by_name[foreign_key_column], str(foreign_key_value)) + ) + else: + if key not in columns_by_name: + raise ValueError( + f"Table {record.table.table_name} does not have a column named {key}" + ) + + record_data_tuples.append((columns_by_name[key], value)) + + return record_data_tuples + + def _resolve_foreign_keys_list(self, foreign_keys_list) -> dict[str, int]: + """Resolves a reference to a foreign record by UID into the foreign-keys needed by the source table. + + Write logs refer to records in other tables by their UID. These references are stored as a list of 3-tuples, + where each 3-tuple is: + + [ + ID of the destination table in WriteLogTableState, + list of table fields containing the foreign key, + record/value pairs defining the foreign record's UID + ] + + This method resolves these three-tuples into a dict that maps column names to record IDs. + + This assumes that all tables are keyed with integers for simplicity. + """ + if not isinstance(foreign_keys_list, list) or any( + not (isinstance(x, list) and len(x) == 3) for x in foreign_keys_list + ): + raise ValueError("Expected a foreign keys list to be a list of 3-tuples") + + resolved_foreign_keys: dict[str, int] = {} + + for foreign_table_uid in foreign_keys_list: + resolved_foreign_key = self._foreign_key_from_json( + json_obj=foreign_table_uid + ) + resolved_foreign_keys.update( + self._get_foreign_key_values(resolved_foreign_key) + ) + + return resolved_foreign_keys + + def _foreign_key_from_json(self, json_obj: list) -> "WriteLogForeignKey": + try: + table_id, foreign_key_columns, foreign_uid = json_obj + + if not isinstance(table_id, int): + raise ValueError( + f"Expected the first element of a foreign key's JSON representation to be a table state ID (JSON: {json_obj})" + ) + + if not isinstance(foreign_key_columns, list): + raise ValueError( + f"Expected the second element of a foreign key's JSON representation to be a list of column names (JSON: {json_obj})" + ) + + if not isinstance(foreign_uid, dict): + raise ValueError( + f"Expected the third element of a foreign key's JSON representation to be a dict mapping columns to values (JSON: {json_obj})" + ) + + table_name = WriteLogTableState.objects.get(id=table_id).table_name + + return WriteLogForeignKey( + table_name=table_name, + foreign_key_columns=foreign_key_columns, + foreign_uid=foreign_uid, + ) + except ValueError as e: + raise ValueError( + f"Malformed write log foreign key '{json_obj}' encountered" + ) from e + + def _get_foreign_key_values( + self, foreign_key: WriteLogForeignKey + ) -> dict[str, int]: + primary_key, _ = self.db_metadata.get_table_key_constraints( + foreign_key.table_name, cursor=self.cursor + ) + + if len(foreign_key.foreign_key_columns) != len(primary_key.source_columns): + raise ValueError( + f"Expected number of primary key columns for {primary_key.target_table} " + f"({len(primary_key.source_columns)}) and number of foreign key columns " + f"for {foreign_key.table_name} ({len(foreign_key.foreign_key_columns)}) to match" + ) + + uid_constraints = " AND ".join( + f"{k} = '{v}'" for (k, v) in foreign_key.foreign_uid.items() + ) + + self.cursor.execute( + f"SELECT {','.join(c.name for c in primary_key.source_columns)} " # noqa: S608 + f"FROM {foreign_key.table_name} " + f"WHERE {uid_constraints}" + ) + + rows = self.cursor.fetchall() + + if len(rows) == 0: + raise ValueError( + f"Record matching foreign key constraint {foreign_key} not found" + ) + + if len(rows) > 1: + raise ValueError( + f"Found more than one record matching foreign key constraint {foreign_key}" + ) + + row = rows[0] + + return dict(zip(foreign_key.foreign_key_columns, row)) diff --git a/lamindb/core/writelog/_trigger_installer.py b/lamindb/core/writelog/_trigger_installer.py index 8ab6d1388..391b4fe92 100644 --- a/lamindb/core/writelog/_trigger_installer.py +++ b/lamindb/core/writelog/_trigger_installer.py @@ -3,19 +3,23 @@ from abc import ABC, abstractmethod from typing import Any -from django.db import models, transaction +from django.db import transaction from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import CursorWrapper from lamin_utils import logger from typing_extensions import override +from lamindb.core.writelog._constants import FOREIGN_KEYS_LIST_COLUMN_NAME from lamindb.core.writelog._graph_utils import find_cycle, topological_sort +from lamindb.core.writelog._utils import ( + update_migration_state, + update_write_log_table_state, +) from lamindb.models.writelog import ( DEFAULT_BRANCH_CODE, DEFAULT_CREATED_BY_UID, DEFAULT_RUN_UID, WriteLogLock, - WriteLogMigrationState, WriteLogTableState, ) @@ -32,20 +36,6 @@ class WriteLogEventTypes(enum.Enum): DELETE = 2 -class DjangoMigration(models.Model): - """This model class allows us to access the migrations table using normal Django syntax.""" - - app = models.CharField(max_length=255) - name = models.CharField(max_length=255) - applied = models.DateTimeField() - - class Meta: - managed = False # Tell Django not to manage this table - db_table = "django_migrations" # Specify the actual table name - - -FOREIGN_KEYS_LIST_COLUMN_NAME = "_lamin_fks" - RESERVED_COLUMNS: tuple[str] = (FOREIGN_KEYS_LIST_COLUMN_NAME,) # Certain tables must be excluded from write log triggers to avoid @@ -91,63 +81,16 @@ def __init__( def install_triggers(self, table: str, cursor: CursorWrapper): raise NotImplementedError() - def _update_write_log_table_state(self, tables: set[str]): - existing_tables = set( - WriteLogTableState.objects.filter(table_name__in=tables).values_list( - "table_name", flat=True - ) - ) - - table_states_to_add = [ - WriteLogTableState(table_name=t, backfilled=False) - for t in tables - if t not in existing_tables - ] - - WriteLogTableState.objects.bulk_create(table_states_to_add) - - def _update_migration_state(self, cursor: CursorWrapper): - app_migrations = {} - - for app in DjangoMigration.objects.values_list("app", flat=True).distinct(): - migrations = DjangoMigration.objects.filter(app=app) - max_migration_id = 0 - - for migration in migrations: - # Extract the number from the migration name - match = re.match(r"^([0-9]+)_", migration.name) # type: ignore - if match: - migration_id = int(match.group(1)) - max_migration_id = max(max_migration_id, migration_id) - - app_migrations[app] = max_migration_id - - current_state = [ - {"migration_id": mig_id, "app": app} - for app, mig_id in sorted(app_migrations.items()) - ] - - try: - latest_state = WriteLogMigrationState.objects.order_by("-id").first() - latest_state_json = ( - latest_state.migration_state_id if latest_state else None - ) - except WriteLogMigrationState.DoesNotExist: - latest_state_json = None - - if current_state and current_state != latest_state_json: - WriteLogMigrationState.objects.create(migration_state_id=current_state) - def update_write_log_triggers(self, update_all: bool = False): with transaction.atomic(): # Ensure that the write log lock exists WriteLogLock.load() tables = self.db_metadata.get_db_tables() - self._update_write_log_table_state(tables) + update_write_log_table_state(tables) cursor = self.connection.cursor() - self._update_migration_state(cursor) + update_migration_state() tables_with_installed_triggers = ( self.db_metadata.get_tables_with_installed_triggers(cursor) @@ -296,7 +239,9 @@ def backfill_self_referential_table( primary_key_constraint: KeyConstraint, foreign_key_constraints: list[KeyConstraint], ): - primary_key_columns_set: set[str] = set(primary_key_constraint.source_columns) + primary_key_columns_set: set[str] = { + c.name for c in primary_key_constraint.source_columns + } self_referential_constraints = [ fk for fk in foreign_key_constraints if fk.target_table == table ] @@ -304,7 +249,9 @@ def backfill_self_referential_table( self_reference_columns_set: set[str] = set() for foreign_key_constraint in self_referential_constraints: - self_reference_columns_set.add(*(foreign_key_constraint.source_columns)) + self_reference_columns_set.add( + *(c.name for c in foreign_key_constraint.source_columns) + ) # We need to specify these columns in a fixed order so that we can figure out # which elements in the output row correspond to each column after the lookup query @@ -335,20 +282,20 @@ def backfill_self_referential_table( for foreign_key_constraint in self_referential_constraints: # If all source columns are null, skip this constraint. if not any( - row_dict[source_col] is not None + row_dict[source_col.name] is not None for source_col in foreign_key_constraint.source_columns ): continue - referenced_pk = {} + referenced_pk: dict[str, Any] = {} # We need to map the source columns in the constraint to their corresponding target columns. for i, source_column in enumerate( foreign_key_constraint.source_columns ): - referenced_pk[foreign_key_constraint.target_columns[i]] = row_dict[ - source_column - ] + referenced_pk[foreign_key_constraint.target_columns[i].name] = ( + row_dict[source_column.name] + ) hashed_referenced_pk = self._hash_pk(referenced_pk) @@ -480,7 +427,7 @@ def _build_record_uid_inner(self, is_delete: bool) -> str: return self._build_jsonb_array( [ - f"{table_name_in_trigger}.{column}" + f"{table_name_in_trigger}.{column.name}" for column in table_uid.uid_columns ] ) @@ -512,7 +459,7 @@ def _build_record_uid_inner(self, is_delete: bool) -> str: table_id_var, self._build_jsonb_array( [ - f"'{c}'" + f"'{c.name}'" for c in foreign_key_constraint.source_columns ] ), @@ -540,21 +487,23 @@ def _build_record_data(self) -> str: # We don't need to store the table's primary key columns in its data object, # since the object will be identified by its UID columns. - non_key_columns = table_columns.difference(set(primary_key.source_columns)) + non_key_columns = table_columns.difference( + {c.name for c in primary_key.source_columns} + ) # We also don't need to store any foreign-key columns in its data object, # since the references those columns encode will be captured by reference to # their UID columns. for foreign_key_constraint in foreign_key_constraints: non_key_columns.difference_update( - set(foreign_key_constraint.source_columns) + {c.name for c in foreign_key_constraint.source_columns} ) # Since we're recording the record's UID, we don't need to store the UID columns in the # data object as well. if len(uid_columns_list) == 1: table_uid = uid_columns_list[0] - non_key_columns.difference_update(table_uid.uid_columns) + non_key_columns.difference_update(c.name for c in table_uid.uid_columns) record_data: dict[str | int, Any] = {} @@ -571,7 +520,8 @@ def _build_record_data(self) -> str: # Don't record foreign-keys to space, since we store space_uid separately if not ( foreign_key_constraint.target_table == "lamindb_space" - and foreign_key_constraint.source_columns == ["space_id"] + and [c.name for c in foreign_key_constraint.source_columns] + == ["space_id"] ) ] @@ -626,7 +576,7 @@ def add_foreign_key_source_columns_variable( [ table_id, self._build_jsonb_array( - [f"'{c}'" for c in foreign_key_constraint.source_columns] + [f"'{c.name}'" for c in foreign_key_constraint.source_columns] ), uid_lookup_variable, ] @@ -667,7 +617,7 @@ def add_foreign_key_uid_lookup_variable( table_uid = uid_column_list[0] where_clause = " AND ".join( - f"{target_col} = {source_record}.{source_col}" + f"{target_col.name} = {source_record}.{source_col.name}" for source_col, target_col in zip( foreign_key_constraint.source_columns, foreign_key_constraint.target_columns, @@ -686,11 +636,11 @@ def add_foreign_key_uid_lookup_variable( f""" coalesce( ( - SELECT {self._build_jsonb_object({c: c for c in table_uid.uid_columns})} + SELECT {self._build_jsonb_object({c.name: c.name for c in table_uid.uid_columns})} FROM {foreign_key_constraint.target_table} WHERE {where_clause} ), - {self._build_jsonb_object(dict.fromkeys(table_uid.uid_columns, "NULL"))} + {self._build_jsonb_object(dict.fromkeys([c.name for c in table_uid.uid_columns], "NULL"))} ) """, # noqa: S608 ) diff --git a/lamindb/core/writelog/_types.py b/lamindb/core/writelog/_types.py index fab7c7f30..d162ef0ff 100644 --- a/lamindb/core/writelog/_types.py +++ b/lamindb/core/writelog/_types.py @@ -1,13 +1,23 @@ from dataclasses import dataclass +from enum import Enum from typing import Literal, Optional -@dataclass -class ManyToManyRelationship: - first_column: str - first_table: str - second_column: str - second_table: str +class ColumnType(Enum): + INT = 0 + BOOL = 1 + STR = 2 + DATE = 3 + FLOAT = 4 + JSON = 5 + TIMESTAMPTZ = 6 + + +@dataclass(frozen=True) +class Column: + name: str + type: ColumnType + ordinal_position: int @dataclass @@ -18,8 +28,8 @@ class KeyConstraint: constraint_type: Literal["PRIMARY KEY", "FOREIGN KEY"] # These need to be a list to account for composite primary keys - source_columns: list[str] - target_columns: list[str] + source_columns: list[Column] + target_columns: list[Column] target_table: str @@ -27,7 +37,7 @@ class KeyConstraint: @dataclass class TableUID: source_table_name: str - uid_columns: list[str] + uid_columns: list[Column] key_constraint: Optional[KeyConstraint] diff --git a/lamindb/core/writelog/_utils.py b/lamindb/core/writelog/_utils.py new file mode 100644 index 000000000..aabd0e5b9 --- /dev/null +++ b/lamindb/core/writelog/_utils.py @@ -0,0 +1,64 @@ +import re + +from django.db import models + +from lamindb.models.writelog import WriteLogMigrationState, WriteLogTableState + + +class DjangoMigration(models.Model): + """This model class allows us to access the migrations table using normal Django syntax.""" + + app = models.CharField(max_length=255) + name = models.CharField(max_length=255) + applied = models.DateTimeField() + + class Meta: + managed = False # Tell Django not to manage this table + db_table = "django_migrations" # Specify the actual table name + + +def update_write_log_table_state(tables: set[str]): + existing_tables = set( + WriteLogTableState.objects.filter(table_name__in=tables).values_list( + "table_name", flat=True + ) + ) + + table_states_to_add = [ + WriteLogTableState(table_name=t, backfilled=False) + for t in tables + if t not in existing_tables + ] + + WriteLogTableState.objects.bulk_create(table_states_to_add) + + +def update_migration_state(): + app_migrations = {} + + for app in DjangoMigration.objects.values_list("app", flat=True).distinct(): + migrations = DjangoMigration.objects.filter(app=app) + max_migration_id = 0 + + for migration in migrations: + # Extract the number from the migration name + match = re.match(r"^([0-9]+)_", migration.name) # type: ignore + if match: + migration_id = int(match.group(1)) + max_migration_id = max(max_migration_id, migration_id) + + app_migrations[app] = max_migration_id + + current_state = [ + {"migration_id": mig_id, "app": app} + for app, mig_id in sorted(app_migrations.items()) + ] + + try: + latest_state = WriteLogMigrationState.objects.order_by("-id").first() + latest_state_json = latest_state.migration_state_id if latest_state else None + except WriteLogMigrationState.DoesNotExist: + latest_state_json = None + + if current_state and current_state != latest_state_json: + WriteLogMigrationState.objects.create(migration_state_id=current_state) diff --git a/lamindb/migrations/0102_alter_writelog_record_uid.py b/lamindb/migrations/0102_alter_writelog_record_uid.py new file mode 100644 index 000000000..07db4cc27 --- /dev/null +++ b/lamindb/migrations/0102_alter_writelog_record_uid.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2 on 2025-05-27 08:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("lamindb", "0101_alter_artifact_hash_alter_feature_name_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="writelog", + name="record_uid", + field=models.JSONField(default=1), + preserve_default=False, + ), + ] diff --git a/lamindb/models/writelog.py b/lamindb/models/writelog.py index 65015f982..f3515bb3e 100644 --- a/lamindb/models/writelog.py +++ b/lamindb/models/writelog.py @@ -52,8 +52,7 @@ class WriteLog(models.Model): created_by_uid = models.CharField(max_length=8, default=DEFAULT_CREATED_BY_UID) branch_code = models.IntegerField(default=DEFAULT_BRANCH_CODE) run_uid = models.CharField(max_length=16, default=DEFAULT_RUN_UID) - # Many-to-many tables don't have row UIDs, so this needs to be nullable. - record_uid = models.JSONField(null=True) + record_uid = models.JSONField() record_data = models.JSONField(null=True) event_type = models.PositiveSmallIntegerField() created_at = models.DateTimeField() diff --git a/noxfile.py b/noxfile.py index d1c44263c..628fb6389 100644 --- a/noxfile.py +++ b/noxfile.py @@ -256,6 +256,7 @@ def test(session, group): run(session, f"pytest {coverage_args} ./tests/storage --durations=50") elif group == "unit-writelog": run(session, f"pytest {coverage_args} ./tests/writelog --durations=50") + run(session, f"pytest {coverage_args} ./tests/writelog_sqlite --durations=50") elif group == "tutorial": run(session, "lamin logout") run( diff --git a/tests/writelog/test_postgres_trigger_installer.py b/tests/writelog/test_postgres_trigger_installer.py index 61f17f65d..f92c90363 100644 --- a/tests/writelog/test_postgres_trigger_installer.py +++ b/tests/writelog/test_postgres_trigger_installer.py @@ -7,15 +7,15 @@ from django.db import connection as django_connection_proxy from django.db import transaction from django.db.backends.utils import CursorWrapper +from lamindb.core.writelog._constants import FOREIGN_KEYS_LIST_COLUMN_NAME from lamindb.core.writelog._db_metadata_wrapper import ( PostgresDatabaseMetadataWrapper, ) from lamindb.core.writelog._trigger_installer import ( - FOREIGN_KEYS_LIST_COLUMN_NAME, PostgresWriteLogRecordingTriggerInstaller, WriteLogEventTypes, ) -from lamindb.core.writelog._types import TableUID, UIDColumns +from lamindb.core.writelog._types import Column, ColumnType, TableUID, UIDColumns from lamindb.models.artifact import Artifact from lamindb.models.run import Run from lamindb.models.sqlrecord import Space @@ -864,7 +864,10 @@ def test_triggers_with_compound_table_uid(compound_uid_table, compound_uid_child compound_uid_table: [ TableUID( source_table_name=compound_uid_table, - uid_columns=["uid_1", "uid_2"], + uid_columns=[ + Column(name="uid_1", type=ColumnType.STR, ordinal_position=1), + Column(name="uid_2", type=ColumnType.STR, ordinal_position=3), + ], key_constraint=None, ) ] @@ -974,7 +977,10 @@ def test_triggers_many_to_many_to_compound_uid_with_self_links( compound_uid_table: [ TableUID( source_table_name=compound_uid_table, - uid_columns=["uid_1", "uid_2"], + uid_columns=[ + Column(name="uid_1", type=ColumnType.STR, ordinal_position=1), + Column(name="uid_2", type=ColumnType.STR, ordinal_position=3), + ], key_constraint=None, ) ] diff --git a/tests/writelog_sqlite/conftest.py b/tests/writelog_sqlite/conftest.py new file mode 100644 index 000000000..bbb048180 --- /dev/null +++ b/tests/writelog_sqlite/conftest.py @@ -0,0 +1,34 @@ +import shutil +from time import perf_counter + +import lamindb_setup as ln_setup +import pytest + +AUTO_CONNECT = ln_setup.settings.auto_connect +ln_setup.settings.auto_connect = False + +import lamindb as ln + + +def pytest_sessionstart(): + t_execute_start = perf_counter() + + ln_setup._TESTING = True + ln.setup.init( + storage="./default_storage_writelog_core", + modules="bionty", + name="lamindb-writelog-tests-core", + db=None, + ) + ln.setup.settings.auto_connect = True + ln.settings.creation.artifact_silence_missing_run_warning = True + + total_time_elapsed = perf_counter() - t_execute_start + print(f"Time to setup the instance: {total_time_elapsed:.3f}s") + + +def pytest_sessionfinish(session: pytest.Session): + shutil.rmtree("./default_storage_writelog_core") + + ln.setup.delete("lamindb-writelog-tests-core", force=True) + ln.setup.settings.auto_connect = AUTO_CONNECT diff --git a/tests/writelog_sqlite/test_replayer.py b/tests/writelog_sqlite/test_replayer.py new file mode 100644 index 000000000..9040b76e3 --- /dev/null +++ b/tests/writelog_sqlite/test_replayer.py @@ -0,0 +1,392 @@ +import datetime +from typing import TYPE_CHECKING, cast + +import pytest +from django.db import connection +from django.db import connection as django_connection_proxy +from django.db.backends.utils import CursorWrapper +from lamindb.core.writelog._constants import FOREIGN_KEYS_LIST_COLUMN_NAME +from lamindb.core.writelog._db_metadata_wrapper import SQLiteDatabaseMetadataWrapper +from lamindb.core.writelog._replayer import WriteLogReplayer +from lamindb.core.writelog._trigger_installer import WriteLogEventTypes +from lamindb.core.writelog._types import UIDColumns +from lamindb.core.writelog._utils import ( + update_migration_state, + update_write_log_table_state, +) +from lamindb.models.writelog import ( + WriteLog, + WriteLogLock, + WriteLogMigrationState, + WriteLogTableState, +) +from typing_extensions import override + +if TYPE_CHECKING: + from django.db.backends.base.base import BaseDatabaseWrapper + +django_connection = cast("BaseDatabaseWrapper", django_connection_proxy) + + +class FakeMetadataWrapper(SQLiteDatabaseMetadataWrapper): + """A fake DB metadata wrapper that allows us to control which database tables the installer will see and target.""" + + def __init__(self): + super().__init__() + self._tables_with_triggers = set() + self._db_tables = set() + self._many_to_many_tables = set() + self._uid_columns: dict[str, UIDColumns] = {} + + @override + def get_tables_with_installed_triggers(self, cursor: CursorWrapper) -> set[str]: + return self._tables_with_triggers + + def set_tables_with_installed_triggers(self, tables: set[str]): + self._tables_with_triggers = tables + + @override + def get_db_tables(self) -> set[str]: + return self._db_tables + + def set_db_tables(self, tables: set[str]): + self._db_tables = tables + + @override + def get_many_to_many_db_tables(self) -> set[str]: + return self._many_to_many_tables + + def set_many_to_many_db_tables(self, tables: set[str]): + self._many_to_many_tables = tables + + @override + def get_uid_columns(self, table: str, cursor: CursorWrapper) -> UIDColumns: + if table in self._uid_columns: + return self._uid_columns[table] + else: + return super().get_uid_columns(table, cursor) + + def set_uid_columns(self, table: str, uid_columns: UIDColumns): + self._uid_columns[table] = uid_columns + + +@pytest.fixture(scope="function") +def write_log_state(): + WriteLog.objects.all().delete() + WriteLogTableState.objects.all().delete() + WriteLogMigrationState.objects.all().delete() + + yield + + WriteLog.objects.all().delete() + WriteLogTableState.objects.all().delete() + WriteLogMigrationState.objects.all().delete() + + +def test_connection_is_sqlite(): + assert connection.vendor == "sqlite" + + +@pytest.fixture(scope="function") +def simple_table(write_log_state): + cursor = django_connection.cursor() + + cursor.execute(""" +CREATE TABLE simple_table ( + id integer NOT NULL PRIMARY KEY AUTOINCREMENT, + uid varchar(20) NOT NULL UNIQUE, + bool_col bool, + text_col TEXT, + timestamp_col datetime, + date_col date, + float_col REAL +) +""") + + yield "simple_table" + + cursor.execute("DROP TABLE IF EXISTS simple_table") + + +@pytest.fixture(scope="function") +def write_log_lock(): + write_log_lock = WriteLogLock.load() + + assert write_log_lock is not None + + write_log_lock.lock() + + yield write_log_lock + + write_log_lock.unlock() + + +def test_replayer_happy_path(simple_table, write_log_lock, write_log_state): + db_metadata = FakeMetadataWrapper() + db_metadata.set_db_tables({simple_table}) + + cursor = django_connection.cursor() + + update_write_log_table_state({simple_table}) + update_migration_state() + + current_migration_state = WriteLogMigrationState.objects.order_by("-id").first() + + simple_table_state = WriteLogTableState.objects.get(table_name=simple_table) + + replayer = WriteLogReplayer(db_metadata=db_metadata, cursor=cursor) + + input_write_log = [ + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist001", + record_uid=["SimpleRecord1"], + record_data={ + "bool_col": True, + "text_col": "hello world", + "timestamp_col": "2025-05-23T02:03:16.913425+00:00", + "date_col": "2025-05-23", + "float_col": 8.675309, + FOREIGN_KEYS_LIST_COLUMN_NAME: [], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime(2025, 5, 23, 12, 34, 56), + ), + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist002", + record_uid=["SimpleRecord2"], + record_data={ + "bool_col": False, + "text_col": "Hallo, Welt!", + "timestamp_col": "2025-05-23T02:03:20.392310+00:00", + "date_col": "2025-05-23", + "float_col": 1.8007777777, + FOREIGN_KEYS_LIST_COLUMN_NAME: [], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime(2025, 5, 23, 12, 41, 00), + ), + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist003", + record_uid=["SimpleRecord1"], + record_data={ + "bool_col": False, + "text_col": "hello world", + "timestamp_col": "2025-05-23T02:03:16.913425+00:00", + "date_col": "2025-05-24", + "float_col": 8.675309, + FOREIGN_KEYS_LIST_COLUMN_NAME: [], + }, + event_type=WriteLogEventTypes.UPDATE.value, + created_at=datetime.datetime(2025, 5, 23, 12, 55, 00), + ), + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist004", + record_uid=["SimpleRecord2"], + record_data=None, + event_type=WriteLogEventTypes.DELETE.value, + created_at=datetime.datetime(2025, 5, 23, 12, 56, 00), + ), + ] + + for write_log_entry in input_write_log: + replayer.replay(write_log_entry) + write_log_entry.save() + + write_log = WriteLog.objects.all().order_by("id") + + assert len(write_log) == 4 + assert list(write_log) == input_write_log + + cursor.execute(f"SELECT uid, bool_col FROM {simple_table} ORDER BY uid ASC") # noqa: S608 + + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][0] == "SimpleRecord1" + assert rows[0][1] is False + + +@pytest.fixture(scope="function") +def many_to_many_table(simple_table, write_log_state): + cursor = django_connection.cursor() + + cursor.execute(f""" +CREATE TABLE many_to_many_table ( + id integer NOT NULL PRIMARY KEY AUTOINCREMENT, + simple_a_id integer, + simple_b_id integer, + FOREIGN KEY (simple_a_id) REFERENCES {simple_table}(id), + FOREIGN KEY (simple_b_id) REFERENCES {simple_table}(id) +) +""") + + yield "many_to_many_table" + + cursor.execute("DROP TABLE IF EXISTS many_to_many_table") + + +def test_replayer_many_to_many( + simple_table, many_to_many_table, write_log_lock, write_log_state +): + db_metadata = FakeMetadataWrapper() + db_metadata.set_db_tables({simple_table, many_to_many_table}) + db_metadata.set_many_to_many_db_tables({many_to_many_table}) + + cursor = django_connection.cursor() + + update_write_log_table_state({simple_table, many_to_many_table}) + update_migration_state() + + current_migration_state = WriteLogMigrationState.objects.order_by("-id").first() + + simple_table_state = WriteLogTableState.objects.get(table_name=simple_table) + many_to_many_table_state = WriteLogTableState.objects.get( + table_name=many_to_many_table + ) + + replayer = WriteLogReplayer(db_metadata=db_metadata, cursor=cursor) + + input_write_log = [ + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist001", + record_uid=["SimpleRecord1"], + record_data={ + "bool_col": True, + "text_col": "hello world", + "timestamp_col": "2025-05-23T02:03:16.913425+00:00", + "date_col": "2025-05-23", + "float_col": 8.675309, + FOREIGN_KEYS_LIST_COLUMN_NAME: [], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 34, 56, tzinfo=datetime.timezone.utc + ), + ), + WriteLog( + migration_state=current_migration_state, + table=simple_table_state, + uid="Hist002", + record_uid=["SimpleRecord2"], + record_data={ + "bool_col": False, + "text_col": "Hallo, Welt!", + "timestamp_col": "2025-05-23T02:03:20.392310+00:00", + "date_col": "2025-05-23", + "float_col": 1.8007777777, + FOREIGN_KEYS_LIST_COLUMN_NAME: [], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 41, 00, tzinfo=datetime.timezone.utc + ), + ), + WriteLog( + migration_state=current_migration_state, + table=many_to_many_table_state, + uid="ManyToMany1", + record_uid=[ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord1"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord2"}], + ], + record_data={ + FOREIGN_KEYS_LIST_COLUMN_NAME: [ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord1"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord2"}], + ], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 55, 00, tzinfo=datetime.timezone.utc + ), + ), + WriteLog( + migration_state=current_migration_state, + table=many_to_many_table_state, + uid="ManyToMany2", + record_uid=[ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord2"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord2"}], + ], + record_data={ + FOREIGN_KEYS_LIST_COLUMN_NAME: [ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord2"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord2"}], + ], + }, + event_type=WriteLogEventTypes.INSERT.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 56, 00, tzinfo=datetime.timezone.utc + ), + ), + WriteLog( + migration_state=current_migration_state, + table=many_to_many_table_state, + uid="ManyToMany3", + record_uid=[ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord2"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord2"}], + ], + record_data={ + FOREIGN_KEYS_LIST_COLUMN_NAME: [ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord2"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord1"}], + ], + }, + event_type=WriteLogEventTypes.UPDATE.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 56, 30, tzinfo=datetime.timezone.utc + ), + ), + WriteLog( + migration_state=current_migration_state, + table=many_to_many_table_state, + uid="ManyToMany4", + record_uid=[ + [simple_table_state.id, ["simple_a_id"], {"uid": "SimpleRecord2"}], + [simple_table_state.id, ["simple_b_id"], {"uid": "SimpleRecord1"}], + ], + record_data=None, + event_type=WriteLogEventTypes.DELETE.value, + created_at=datetime.datetime( + 2025, 5, 23, 12, 56, 45, tzinfo=datetime.timezone.utc + ), + ), + ] + + for write_log_entry in input_write_log: + replayer.replay(write_log_entry) + write_log_entry.save() + + write_log = WriteLog.objects.all().order_by("id") + + assert len(write_log) == 6 + assert list(write_log) == input_write_log + + cursor.execute(f"SELECT id, uid FROM {simple_table}") # noqa: S608 + + rows = cursor.fetchall() + + assert [r[1] for r in rows] == ["SimpleRecord1", "SimpleRecord2"] + + simple_uid_to_id = {r[1]: r[0] for r in rows} + + cursor.execute(f"SELECT id, simple_a_id, simple_b_id FROM {many_to_many_table}") # noqa: S608 + + rows = cursor.fetchall() + + assert len(rows) == 1 + + assert rows[0][1:] == ( + simple_uid_to_id["SimpleRecord1"], + simple_uid_to_id["SimpleRecord2"], + )