55from .engine import Engine , ensure_transaction , _signals , _signal_rv
66from .sqlfunc import is_sqlfunc , sqlfunc , fetchall , fetchone , execute , update
77from .resultset import ResultSet , CompositeResultSet
8- from .types import SQLType
8+ from .types import SQLType , Integer
99from .mapper import (
1010 Mapper ,
1111 MappedColumnMixin ,
1717
1818class ModelMetaclass (abc .ABCMeta ):
1919 def __new__ (cls , name , bases , dct ):
20- if not bases or abc .ABC in bases :
20+ if len ( bases ) == 1 and bases [ 0 ] is abc .ABC : # BaseModel
2121 return super ().__new__ (cls , name , bases , dct )
22- dct = cls .pre_process_model_class_dict (name , bases , dct )
22+
23+ model_registry = cls .find_model_registry (bases )
24+ mapped_attrs = cls .process_mapped_attributes (dct )
25+ cls .process_sql_methods (dct , model_registry )
2326 model_class = super ().__new__ (cls , name , bases , dct )
2427 cls .process_meta_inheritance (model_class )
25- return cls .post_process_model_class (model_class )
28+ if abc .ABC not in bases :
29+ cls .create_mapper (model_class , mapped_attrs )
30+ model_class .__model_registry__ .register (model_class )
31+ return model_class
2632
27- @classmethod
28- def pre_process_model_class_dict (cls , name , bases , dct ):
29- model_registry = {}
33+ def find_model_registry (bases ):
3034 for base in bases :
31- if issubclass (base , BaseModel ):
32- model_registry = base .__model_registry__
33- break
34-
35- dct ["table" ] = SQL .Id (dct .get ("__table__" , dct .get ("table" , name .lower ())))
35+ if hasattr (base , "__model_registry__" ):
36+ return base .__model_registry__
37+ return ModelRegistry ()
3638
39+ @staticmethod
40+ def process_mapped_attributes (dct ):
3741 mapped_attrs = {}
3842 for name , annotation in dct .get ("__annotations__" , {}).items ():
3943 primary_key = False
@@ -45,11 +49,11 @@ def pre_process_model_class_dict(cls, name, bases, dct):
4549 dct [name ] = mapped_attrs [name ] = Column (name , annotation , primary_key = primary_key )
4650 elif isinstance (dct [name ], Column ):
4751 mapped_attrs [name ] = dct [name ]
48- dct [name ].type = SQLType .from_pytype (annotation )
52+ if dct [name ].type is None :
53+ dct [name ].type = SQLType .from_pytype (annotation )
4954 elif isinstance (dct [name ], Relationship ):
5055 # add now to keep the declaration order
5156 mapped_attrs [name ] = dct [name ]
52-
5357 for attr_name , attr in dct .items ():
5458 if isinstance (attr , Column ) and not attr .name :
5559 # in the case of models, we allow column object to be initialized without names
@@ -58,27 +62,28 @@ def pre_process_model_class_dict(cls, name, bases, dct):
5862 if isinstance (attr , (Column , Relationship )) and attr_name not in mapped_attrs :
5963 # not annotated attributes
6064 mapped_attrs [attr_name ] = attr
61- continue
62-
65+ return mapped_attrs
66+
67+ @classmethod
68+ def process_sql_methods (cls , dct , model_registry = None ):
69+ for attr_name , attr in dct .items ():
6370 wrapper = type (attr ) if isinstance (attr , (staticmethod , classmethod )) else False
6471 if wrapper :
6572 # the only way to replace the wrapped function for a class/static method is before the class initialization.
6673 attr = attr .__wrapped__
67- if callable (attr ):
68- if is_sqlfunc (attr ):
69- dct [attr_name ] = cls .make_sqlfunc_from_method (attr , wrapper , model_registry )
70-
71- dct ["__mapper__" ] = mapped_attrs
72- return dct
74+ if callable (attr ) and is_sqlfunc (attr ):
75+ # the model registry is passed as template locals to sql func methods
76+ # so model classes are available in the evaluation scope of SQLTemplate
77+ dct [attr_name ] = cls .make_sqlfunc_from_method (attr , wrapper , model_registry )
7378
7479 @staticmethod
75- def make_sqlfunc_from_method (func , decorator , model_registry ):
80+ def make_sqlfunc_from_method (func , decorator , template_locals = None ):
7681 doc = inspect .getdoc (func )
7782 accessor = "cls" if decorator is classmethod else "self"
7883 if doc .upper ().startswith ("SELECT WHERE" ):
7984 doc = doc [7 :]
8085 if doc .upper ().startswith ("WHERE" ):
81- func . __doc__ = "{%s.select_from()} %s" % (accessor , doc )
86+ doc = "{%s.select_from()} %s" % (accessor , doc )
8287 if doc .upper ().startswith ("INSERT INTO (" ):
8388 doc = "INSERT INTO {%s.table} %s" % (accessor , doc [12 :])
8489 if doc .upper ().startswith ("UPDATE SET" ):
@@ -87,21 +92,26 @@ def make_sqlfunc_from_method(func, decorator, model_registry):
8792 doc = "DELETE FROM {%s.table} %s" % (accessor , doc [7 :])
8893 if "WHERE SELF" in doc .upper ():
8994 doc = doc .replace ("WHERE SELF" , "WHERE {self.__mapper__.primary_key_condition(self)}" )
95+ func .__doc__ = doc
9096 if not getattr (func , "query_decorator" , None ) and ".select_from(" in doc :
9197 # because the statement does not start with SELECT, it would default to execute when using .select_from()
9298 func = fetchall (func )
93- # the model registry is passed as template locals to sql func methods
94- # so model classes are available in the evaluation scope of SQLTemplate
95- method = sqlfunc (func , is_method = True , template_locals = model_registry )
99+ method = sqlfunc (func , is_method = True , template_locals = template_locals )
96100 return decorator (method ) if decorator else method
97101
98102 @staticmethod
99- def post_process_model_class (cls ):
100- mapped_attrs = cls . __mapper__
103+ def create_mapper (cls , mapped_attrs = None ):
104+ cls . table = SQL . Id ( getattr ( cls , "__table__" , getattr ( cls , "table" , cls . __name__ . lower ())))
101105 cls .__mapper__ = ModelMapper (
102106 cls , cls .table .name , allow_unknown_columns = cls .Meta .allow_unknown_columns
103107 )
104- cls .__mapper__ .map (mapped_attrs )
108+
109+ for attr_name in dir (cls ):
110+ if isinstance (getattr (cls , attr_name ), (Column , Relationship )) and attr_name not in mapped_attrs :
111+ cls .__mapper__ .map (attr_name , getattr (cls , attr_name ))
112+ if mapped_attrs :
113+ cls .__mapper__ .map (mapped_attrs )
114+
105115 cls .c = cls .__mapper__ .columns # handy shortcut
106116
107117 auto_primary_key = cls .Meta .auto_primary_key
@@ -110,14 +120,11 @@ def post_process_model_class(cls):
110120 # we force the usage of SELECT * as we auto add a primary key without any other mapped columns
111121 # without doing this, only the primary key would be selected
112122 cls .__mapper__ .force_select_wildcard = True
113- cls .__mapper__ .map (auto_primary_key , Column (auto_primary_key , primary_key = True ))
114-
115- cls .__model_registry__ .register (cls )
116- return cls
123+ cls .__mapper__ .map (auto_primary_key , Column (auto_primary_key , type = cls .Meta .auto_primary_key_type , primary_key = True ))
117124
118125 @staticmethod
119126 def process_meta_inheritance (cls ):
120- if getattr (cls .Meta , "__inherit__" , True ):
127+ if hasattr ( cls , "Meta" ) and getattr (cls .Meta , "__inherit__" , True ):
121128 bases_meta = ModelMetaclass .aggregate_bases_meta_attrs (cls )
122129 for key , value in bases_meta .items ():
123130 if not hasattr (cls .Meta , key ):
@@ -130,7 +137,7 @@ def process_meta_inheritance(cls):
130137 def aggregate_bases_meta_attrs (cls ):
131138 meta = {}
132139 for base in cls .__bases__ :
133- if issubclass (base , BaseModel ):
140+ if hasattr (base , "Meta" ):
134141 if getattr (base .Meta , "__inherit__" , True ):
135142 meta .update (ModelMetaclass .aggregate_bases_meta_attrs (base ))
136143 meta .update (
@@ -331,6 +338,7 @@ class Meta:
331338 auto_primary_key : t .Optional [str ] = (
332339 "id" # auto generate a primary key with this name if no primary key are declared
333340 )
341+ auto_primary_key_type : SQLType = Integer
334342 allow_unknown_columns : bool = True # hydrate() will set attributes for unknown columns
335343
336344 @classmethod
0 commit comments