9
9
import typing
10
10
11
11
from sqlalchemy import create_engine
12
- from sqlalchemy import orm
12
+ from sqlalchemy import orm , inspect , MetaData
13
+ from sqlalchemy .exc import SQLAlchemyError
13
14
14
15
from recordlinker import models
15
16
from recordlinker .config import settings
16
17
17
-
18
- def create_sessionmaker (init_tables : bool = True ) -> orm .sessionmaker :
18
+ #TODO: figure out wher this variable `verify_tables` is gonna come from
19
+ def create_sessionmaker (init_tables : bool = True , verify_tables : bool = True ) -> orm .sessionmaker :
19
20
"""
20
21
Create a new sessionmaker for the database connection.
21
22
"""
@@ -25,10 +26,50 @@ def create_sessionmaker(init_tables: bool = True) -> orm.sessionmaker:
25
26
if settings .connection_pool_max_overflow is not None :
26
27
kwargs ["max_overflow" ] = settings .connection_pool_max_overflow
27
28
engine = create_engine (settings .db_uri , ** kwargs )
29
+
28
30
if init_tables :
29
31
models .Base .metadata .create_all (engine )
32
+ if verify_tables :
33
+ verify_tables_match_orm (engine )
34
+
30
35
return orm .sessionmaker (bind = engine )
31
36
37
+ def verify_tables_match_orm (engine ):
38
+ """
39
+ Verify that database tables match ORM definitions.
40
+ """
41
+ inspector = inspect (engine )
42
+ orm_metadata = models .Base .metadata # Use ORM schema
43
+
44
+ for table_name , orm_table in orm_metadata .tables .items ():
45
+ # Check if the table exists in the database
46
+ if not inspector .has_table (table_name ):
47
+ raise SQLAlchemyError (
48
+ f"Table '{ table_name } ' is missing in the database."
49
+ )
50
+
51
+ # Retrieve database columns
52
+ db_columns = inspector .get_columns (table_name )
53
+ db_column_details = {col ['name' ]: col for col in db_columns }
54
+
55
+ # Compare ORM and database columns
56
+ for orm_column in orm_table .columns :
57
+ column_name = orm_column .name
58
+
59
+ if column_name not in db_column_details :
60
+ raise SQLAlchemyError (
61
+ f"Column '{ column_name } ' is missing in the database for table '{ table_name } '."
62
+ )
63
+
64
+ db_col_type = db_column_details [column_name ]['type' ]
65
+ orm_col_type = orm_column .type
66
+
67
+ if type (db_col_type ) != type (orm_col_type ):
68
+ raise SQLAlchemyError (
69
+ f"Type mismatch for column '{ column_name } ' in table '{ table_name } ': "
70
+ f"DB type is { db_col_type } , ORM type is { orm_col_type } ."
71
+ )
72
+
32
73
33
74
def get_session () -> typing .Iterator [orm .Session ]:
34
75
"""
0 commit comments