Skip to content

Commit

Permalink
Added "WHERE SELF" shortcut in SQL methods
Browse files Browse the repository at this point in the history
  • Loading branch information
maximebf committed May 21, 2024
1 parent af623bd commit 51b9d0f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
6 changes: 3 additions & 3 deletions sqlorm/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def primary_key(self):

def primary_key_condition(self, pk, table_alias=None, prefix=None) -> SQL:
"""Returns the SQL condition to query a single row from this mapped table matching the primary key"""
if isinstance(pk, self.object_class):
pk = self.get_primary_key(pk)
cols = self.primary_key
if isinstance(cols, list):
if not isinstance(pk, (list, tuple)):
Expand Down Expand Up @@ -282,9 +284,7 @@ def update(self, obj, **dehydrate_kwargs):
values = self.dehydrate(obj, with_primary_key=False, **dehydrate_kwargs)
if not values:
return
return SQL.update(self.table, values).where(
self.primary_key_condition(self.get_primary_key(obj))
)
return SQL.update(self.table, values).where(self.primary_key_condition(obj))

def delete(self, obj):
pk = self.get_primary_key(obj)
Expand Down
50 changes: 27 additions & 23 deletions sqlorm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __new__(cls, name, bases, dct):
cls.process_meta_inheritance(model_class)
return cls.post_process_model_class(model_class)

@staticmethod
def pre_process_model_class_dict(name, bases, dct):
@classmethod
def pre_process_model_class_dict(cls, name, bases, dct):
model_registry = {}
for base in bases:
if issubclass(base, BaseModel):
Expand Down Expand Up @@ -66,30 +66,34 @@ def pre_process_model_class_dict(name, bases, dct):
attr = attr.__wrapped__
if callable(attr):
if is_sqlfunc(attr):
doc = inspect.getdoc(attr)
if doc.upper().startswith("SELECT WHERE"):
doc = doc[7:]
if doc.upper().startswith("WHERE"):
if wrapper is classmethod:
attr.__doc__ = "{cls.select_from()} " + doc
elif not wrapper:
attr.__doc__ = "{self.select_from()} " + doc
if doc.upper().startswith("INSERT INTO ("):
doc = f"INSERT INTO {dct['table']} {doc[12:]}"
if doc.upper().startswith("UPDATE SET"):
doc = f"UPDATE {dct['table']} {doc[7:]}"
if doc.upper().startswith("DELETE WHERE"):
doc = f"DELETE FROM {dct['table']} {doc[7:]}"
if not getattr(attr, "query_decorator", None) and ".select_from(" in doc:
# because the statement does not start with SELECT, it would default to execute when using .select_from()
attr = fetchall(attr)
# the model registry is passed as template locals to sql func methods
# so model classes are available in the evaluation scope of SQLTemplate
method = sqlfunc(attr, is_method=True, template_locals=model_registry)
dct[attr_name] = wrapper(method) if wrapper else method
dct[attr_name] = cls.make_sqlfunc_from_method(attr, wrapper, model_registry)

dct["__mapper__"] = mapped_attrs
return dct

@staticmethod
def make_sqlfunc_from_method(func, decorator, model_registry):
doc = inspect.getdoc(func)
accessor = "cls" if decorator is classmethod else "self"
if doc.upper().startswith("SELECT WHERE"):
doc = doc[7:]
if doc.upper().startswith("WHERE"):
func.__doc__ = "{%s.select_from()} %s" % (accessor, doc)
if doc.upper().startswith("INSERT INTO ("):
doc = "INSERT INTO {%s.table} %s" % (accessor, doc[12:])
if doc.upper().startswith("UPDATE SET"):
doc = "UPDATE {%s.table} %s" % (accessor, doc[7:])
if doc.upper().startswith("DELETE WHERE"):
doc = "DELETE FROM {%s.table} %s" % (accessor, doc[7:])
if "WHERE SELF" in doc.upper():
doc = doc.replace("WHERE SELF", "WHERE {self.__mapper__.primary_key_condition(self)}")
if not getattr(func, "query_decorator", None) and ".select_from(" in doc:
# because the statement does not start with SELECT, it would default to execute when using .select_from()
func = fetchall(func)
# the model registry is passed as template locals to sql func methods
# so model classes are available in the evaluation scope of SQLTemplate
method = sqlfunc(func, is_method=True, template_locals=model_registry)
return decorator(method) if decorator else method

@staticmethod
def post_process_model_class(cls):
Expand Down

0 comments on commit 51b9d0f

Please sign in to comment.