Skip to content

Commit

Permalink
Fixed bug with inheritance and various other issues
Browse files Browse the repository at this point in the history
  • Loading branch information
maximebf committed Jun 4, 2024
1 parent 0dad52d commit 60fc72e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 46 deletions.
80 changes: 44 additions & 36 deletions sqlorm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .engine import Engine, ensure_transaction, _signals, _signal_rv
from .sqlfunc import is_sqlfunc, sqlfunc, fetchall, fetchone, execute, update
from .resultset import ResultSet, CompositeResultSet
from .types import SQLType
from .types import SQLType, Integer
from .mapper import (
Mapper,
MappedColumnMixin,
Expand All @@ -17,23 +17,27 @@

class ModelMetaclass(abc.ABCMeta):
def __new__(cls, name, bases, dct):
if not bases or abc.ABC in bases:
if len(bases) == 1 and bases[0] is abc.ABC: # BaseModel
return super().__new__(cls, name, bases, dct)
dct = cls.pre_process_model_class_dict(name, bases, dct)

model_registry = cls.find_model_registry(bases)
mapped_attrs = cls.process_mapped_attributes(dct)
cls.process_sql_methods(dct, model_registry)
model_class = super().__new__(cls, name, bases, dct)
cls.process_meta_inheritance(model_class)
return cls.post_process_model_class(model_class)
if abc.ABC not in bases:
cls.create_mapper(model_class, mapped_attrs)
model_class.__model_registry__.register(model_class)
return model_class

@classmethod
def pre_process_model_class_dict(cls, name, bases, dct):
model_registry = {}
def find_model_registry(bases):
for base in bases:
if issubclass(base, BaseModel):
model_registry = base.__model_registry__
break

dct["table"] = SQL.Id(dct.get("__table__", dct.get("table", name.lower())))
if hasattr(base, "__model_registry__"):
return base.__model_registry__
return ModelRegistry()

@staticmethod
def process_mapped_attributes(dct):
mapped_attrs = {}
for name, annotation in dct.get("__annotations__", {}).items():
primary_key = False
Expand All @@ -45,11 +49,11 @@ def pre_process_model_class_dict(cls, name, bases, dct):
dct[name] = mapped_attrs[name] = Column(name, annotation, primary_key=primary_key)
elif isinstance(dct[name], Column):
mapped_attrs[name] = dct[name]
dct[name].type = SQLType.from_pytype(annotation)
if dct[name].type is None:
dct[name].type = SQLType.from_pytype(annotation)
elif isinstance(dct[name], Relationship):
# add now to keep the declaration order
mapped_attrs[name] = dct[name]

for attr_name, attr in dct.items():
if isinstance(attr, Column) and not attr.name:
# in the case of models, we allow column object to be initialized without names
Expand All @@ -58,27 +62,28 @@ def pre_process_model_class_dict(cls, name, bases, dct):
if isinstance(attr, (Column, Relationship)) and attr_name not in mapped_attrs:
# not annotated attributes
mapped_attrs[attr_name] = attr
continue

return mapped_attrs

@classmethod
def process_sql_methods(cls, dct, model_registry=None):
for attr_name, attr in dct.items():
wrapper = type(attr) if isinstance(attr, (staticmethod, classmethod)) else False
if wrapper:
# the only way to replace the wrapped function for a class/static method is before the class initialization.
attr = attr.__wrapped__
if callable(attr):
if is_sqlfunc(attr):
dct[attr_name] = cls.make_sqlfunc_from_method(attr, wrapper, model_registry)

dct["__mapper__"] = mapped_attrs
return dct
if callable(attr) and is_sqlfunc(attr):
# the model registry is passed as template locals to sql func methods
# so model classes are available in the evaluation scope of SQLTemplate
dct[attr_name] = cls.make_sqlfunc_from_method(attr, wrapper, model_registry)

@staticmethod
def make_sqlfunc_from_method(func, decorator, model_registry):
def make_sqlfunc_from_method(func, decorator, template_locals=None):
doc = inspect.getdoc(func)
accessor = "cls" if decorator is classmethod else "self"
if doc.upper().startswith("SELECT WHERE"):
doc = doc[7:]
if doc.upper().startswith("WHERE"):
func.__doc__ = "{%s.select_from()} %s" % (accessor, doc)
doc = "{%s.select_from()} %s" % (accessor, doc)
if doc.upper().startswith("INSERT INTO ("):
doc = "INSERT INTO {%s.table} %s" % (accessor, doc[12:])
if doc.upper().startswith("UPDATE SET"):
Expand All @@ -87,21 +92,26 @@ def make_sqlfunc_from_method(func, decorator, model_registry):
doc = "DELETE FROM {%s.table} %s" % (accessor, doc[7:])
if "WHERE SELF" in doc.upper():
doc = doc.replace("WHERE SELF", "WHERE {self.__mapper__.primary_key_condition(self)}")
func.__doc__ = doc
if not getattr(func, "query_decorator", None) and ".select_from(" in doc:
# because the statement does not start with SELECT, it would default to execute when using .select_from()
func = fetchall(func)
# the model registry is passed as template locals to sql func methods
# so model classes are available in the evaluation scope of SQLTemplate
method = sqlfunc(func, is_method=True, template_locals=model_registry)
method = sqlfunc(func, is_method=True, template_locals=template_locals)
return decorator(method) if decorator else method

@staticmethod
def post_process_model_class(cls):
mapped_attrs = cls.__mapper__
def create_mapper(cls, mapped_attrs=None):
cls.table = SQL.Id(getattr(cls, "__table__", getattr(cls, "table", cls.__name__.lower())))
cls.__mapper__ = ModelMapper(
cls, cls.table.name, allow_unknown_columns=cls.Meta.allow_unknown_columns
)
cls.__mapper__.map(mapped_attrs)

for attr_name in dir(cls):
if isinstance(getattr(cls, attr_name), (Column, Relationship)) and attr_name not in mapped_attrs:
cls.__mapper__.map(attr_name, getattr(cls, attr_name))
if mapped_attrs:
cls.__mapper__.map(mapped_attrs)

cls.c = cls.__mapper__.columns # handy shortcut

auto_primary_key = cls.Meta.auto_primary_key
Expand All @@ -110,14 +120,11 @@ def post_process_model_class(cls):
# we force the usage of SELECT * as we auto add a primary key without any other mapped columns
# without doing this, only the primary key would be selected
cls.__mapper__.force_select_wildcard = True
cls.__mapper__.map(auto_primary_key, Column(auto_primary_key, primary_key=True))

cls.__model_registry__.register(cls)
return cls
cls.__mapper__.map(auto_primary_key, Column(auto_primary_key, type=cls.Meta.auto_primary_key_type, primary_key=True))

@staticmethod
def process_meta_inheritance(cls):
if getattr(cls.Meta, "__inherit__", True):
if hasattr(cls, "Meta") and getattr(cls.Meta, "__inherit__", True):
bases_meta = ModelMetaclass.aggregate_bases_meta_attrs(cls)
for key, value in bases_meta.items():
if not hasattr(cls.Meta, key):
Expand All @@ -130,7 +137,7 @@ def process_meta_inheritance(cls):
def aggregate_bases_meta_attrs(cls):
meta = {}
for base in cls.__bases__:
if issubclass(base, BaseModel):
if hasattr(base, "Meta"):
if getattr(base.Meta, "__inherit__", True):
meta.update(ModelMetaclass.aggregate_bases_meta_attrs(base))
meta.update(
Expand Down Expand Up @@ -331,6 +338,7 @@ class Meta:
auto_primary_key: t.Optional[str] = (
"id" # auto generate a primary key with this name if no primary key are declared
)
auto_primary_key_type: SQLType = Integer
allow_unknown_columns: bool = True # hydrate() will set attributes for unknown columns

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion sqlorm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_all(model_registry=None, engine=None, check_missing=False, logger=Non
missing = False
with ensure_transaction(engine) as tx:
try:
tx.execute(model.find_one())
model.find_one()
except Exception:
missing = True
if missing:
Expand Down
7 changes: 5 additions & 2 deletions sqlorm/sql_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ def __init__(self, template, code):
self.template = template
self.code = code

def _render(self, params):
def eval(self):
return eval(self.code, self.template.eval_globals, self.template.locals)

def _render(self, params):
return SQL(self.eval())._render(params)


class ParametrizedEvalBlock(EvalBlock):
def _render(self, params):
return params.add(super()._render(params))
return params.add(self.eval())


class SQLTemplateError(Exception):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_pool():
assert conn in e.pool

with e as tx:
tx.session.connect()
assert not e.pool
assert len(e.active_conns) == 2
assert tx.session.conn is conn
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_session():
e = Engine.from_uri("sqlite://:memory:")
with e.session() as sess:
assert isinstance(sess, Session)
assert sess.conn
assert not sess.conn
assert sess.engine is e
assert not sess.virtual_tx

Expand All @@ -104,4 +105,5 @@ def test_session():
assert tx.session is sess
assert not tx.virtual

sess.connect()
assert sess.conn
32 changes: 26 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sqlorm import Model, SQL, Relationship, is_dirty, PrimaryKey
from sqlorm import Model, SQL, Relationship, is_dirty, PrimaryKey, Column
from sqlorm.mapper import Mapper
from models import *
import abc


def test_model_registry():
Expand Down Expand Up @@ -68,6 +69,25 @@ def test_mapper():
assert User.tasks in User.__mapper__.relationships


def test_inheritance():
class A(Model, abc.ABC):
col1: str
col2 = Column(type=int)
col3 = Column(type=bool)

assert not hasattr(A, "__mapper__")
assert isinstance(A.col1, Column)
assert A.col3.type.sql_type == "boolean"

class B(A):
col3: str
col4 = Column(type=int)

assert B.__mapper__
assert B.__mapper__.columns.names == ["col1", "col2", "col3", "col4", "id"]
assert B.col3.type.sql_type == "text"


def test_find_all(engine):
listener_called = False

Expand Down Expand Up @@ -194,10 +214,10 @@ def test_update(cls):
def test_delete(cls):
"DELETE WHERE col1 = 'foo'"

assert TestModel.find_all.sql(TestModel) == "SELECT test.id , test.col1 FROM test WHERE col1 = 'foo'"
assert TestModel.test_insert.sql(TestModel) == "INSERT INTO test (col1) VALUES ('bar')"
assert TestModel.test_insert.sql(TestModel) == "UPDATE test SET col1 = 'bar'"
assert TestModel.test_insert.sql(TestModel) == "DELETE FROM test WHERE col1 = 'foo'"
assert str(TestModel.find_all.sql(TestModel)) == "SELECT test.id , test.col1 FROM test WHERE col1 = 'foo'"
assert str(TestModel.test_insert.sql(TestModel)) == "INSERT INTO test (col1) VALUES ('bar')"
assert str(TestModel.test_update.sql(TestModel)) == "UPDATE test SET col1 = 'bar'"
assert str(TestModel.test_delete.sql(TestModel)) == "DELETE FROM test WHERE col1 = 'foo'"


def test_dirty_tracking(engine):
Expand Down Expand Up @@ -351,4 +371,4 @@ def on_after_delete(sender, obj):
assert listener_called == 2

user = User.get(4)
assert not user
assert not user

0 comments on commit 60fc72e

Please sign in to comment.