Skip to content

Commit fbfa648

Browse files
feat: add initial check function
1 parent a58652c commit fbfa648

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/recordlinker/database/__init__.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import typing
1010

1111
from sqlalchemy import create_engine
12-
from sqlalchemy import orm
12+
from sqlalchemy import orm, inspect, MetaData
13+
from sqlalchemy.exc import SQLAlchemyError
1314

1415
from recordlinker import models
1516
from recordlinker.config import settings
1617

17-
18-
def create_sessionmaker(init_tables: bool = True) -> orm.sessionmaker:
18+
#TODO: figure out wher this variable `verify_tables` is gonna come from
19+
def create_sessionmaker(init_tables: bool = True, verify_tables: bool = True) -> orm.sessionmaker:
1920
"""
2021
Create a new sessionmaker for the database connection.
2122
"""
@@ -25,10 +26,50 @@ def create_sessionmaker(init_tables: bool = True) -> orm.sessionmaker:
2526
if settings.connection_pool_max_overflow is not None:
2627
kwargs["max_overflow"] = settings.connection_pool_max_overflow
2728
engine = create_engine(settings.db_uri, **kwargs)
29+
2830
if init_tables:
2931
models.Base.metadata.create_all(engine)
32+
if verify_tables:
33+
verify_tables_match_orm(engine)
34+
3035
return orm.sessionmaker(bind=engine)
3136

37+
def verify_tables_match_orm(engine):
38+
"""
39+
Verify that database tables match ORM definitions.
40+
"""
41+
inspector = inspect(engine)
42+
orm_metadata = models.Base.metadata # Use ORM schema
43+
44+
for table_name, orm_table in orm_metadata.tables.items():
45+
# Check if the table exists in the database
46+
if not inspector.has_table(table_name):
47+
raise SQLAlchemyError(
48+
f"Table '{table_name}' is missing in the database."
49+
)
50+
51+
# Retrieve database columns
52+
db_columns = inspector.get_columns(table_name)
53+
db_column_details = {col['name']: col for col in db_columns}
54+
55+
# Compare ORM and database columns
56+
for orm_column in orm_table.columns:
57+
column_name = orm_column.name
58+
59+
if column_name not in db_column_details:
60+
raise SQLAlchemyError(
61+
f"Column '{column_name}' is missing in the database for table '{table_name}'."
62+
)
63+
64+
db_col_type = db_column_details[column_name]['type']
65+
orm_col_type = orm_column.type
66+
67+
if type(db_col_type) != type(orm_col_type):
68+
raise SQLAlchemyError(
69+
f"Type mismatch for column '{column_name}' in table '{table_name}': "
70+
f"DB type is {db_col_type}, ORM type is {orm_col_type}."
71+
)
72+
3273

3374
def get_session() -> typing.Iterator[orm.Session]:
3475
"""

test.db

44 KB
Binary file not shown.

0 commit comments

Comments
 (0)