diff --git a/reladiff/__init__.py b/reladiff/__init__.py index 7028133..f5106b9 100644 --- a/reladiff/__init__.py +++ b/reladiff/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, Iterable, Optional, Union +from typing import Tuple, Iterable, Optional, Union from sqeleton.abcs import DbTime, DbPath @@ -15,7 +15,7 @@ def connect_to_table( db_info: Union[str, dict], table_name: Union[DbPath, str], - key_columns: Union[str, Sequence[str]] = ("id",), + key_columns: Union[Iterable[str], str] = ("id",), thread_count: Optional[int] = 1, **kwargs, ) -> TableSegment: @@ -46,7 +46,7 @@ def diff_tables( table2: TableSegment, *, # Name of the key column, which uniquely identifies each row (usually id) - key_columns: Sequence[str] = None, + key_columns: Optional[Union[Iterable[str], str]] = None, # Name of updated column, which signals that rows changed (usually updated_at or last_update) update_column: str = None, # Extra columns to compare @@ -89,7 +89,7 @@ def diff_tables( """Finds the diff between table1 and table2. Parameters: - key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) + key_columns (Union[Iterable[str], str], optional): Name of the key column, which uniquely identifies each row (usually id) update_column (str, optional): Name of updated column, which signals that rows changed. Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare @@ -145,6 +145,8 @@ def diff_tables( """ if isinstance(key_columns, str): key_columns = (key_columns,) + else: + key_columns = tuple(key_columns or ()) tables = [table1, table2] override_attrs = {