Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parametrized query support for SELECTs and EXPLAINs #1725

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.1.6"
pypika-tortoise = { path = "vendor/pypika-tortoise", develop = true}
iso8601 = "^1.0.2"
aiosqlite = ">=0.16.0, <0.18.0"
pytz = "*"
Expand Down Expand Up @@ -187,6 +187,7 @@ filterwarnings = [
'ignore:`pk` is deprecated:DeprecationWarning',
'ignore:`index` is deprecated:DeprecationWarning',
]
addopts = "--ignore=vendor"

[tool.coverage.run]
branch = true
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,14 @@ async def test_index_access(self):

async def test_index_badval(self):
with self.assertRaises(ObjectDoesNotExistError) as cm:
await self.cls[100000]
await self.cls[32767]
the_exception = cm.exception
# For compatibility reasons this should be an instance of KeyError
self.assertIsInstance(the_exception, KeyError)
self.assertIs(the_exception.model, self.cls)
self.assertEqual(the_exception.pk_name, "id")
self.assertEqual(the_exception.pk_val, 100000)
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=100000")
self.assertEqual(the_exception.pk_val, 32767)
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767")

async def test_index_badtype(self):
with self.assertRaises(ObjectDoesNotExistError) as cm:
Expand Down
30 changes: 20 additions & 10 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pypika import JoinType, Parameter, Query, Table
from pypika.queries import QueryBuilder
from pypika.terms import ArithmeticExpression, Function
from pypika.terms import ArithmeticExpression, Function, ListParameter

from tortoise.exceptions import OperationalError
from tortoise.expressions import F, RawSQL
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
table = self.model._meta.basetable
self.delete_query = str(
self.model._meta.basequery.where(
table[self.model._meta.db_pk_column] == self.parameter(0)
table[self.model._meta.db_pk_column] == self.insert_parameter(0)
).delete()
)
self.update_cache: Dict[str, str] = {}
Expand All @@ -122,13 +122,17 @@ def __init__(
) = EXECUTOR_CACHE[key]

async def execute_explain(self, query: Query) -> Any:
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql()))
return (await self.db.execute_query(sql))[1]
param = self.parameter()
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql(parameter=param)))
return (await self.db.execute_query(sql, param.get_parameters()))[1]

async def execute_select(
self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None
) -> list:
_, raw_results = await self.db.execute_query(query.get_sql())
param = self.parameter()
_, raw_results = await self.db.execute_query(
query.get_sql(parameter=param), param.get_parameters()
)
instance_list = []
for row in raw_results:
if self.select_related_idx:
Expand Down Expand Up @@ -186,7 +190,7 @@ def _prepare_insert_statement(
query = (
self.db.query_class.into(self.model._meta.basetable)
.columns(*columns)
.insert(*[self.parameter(i) for i in range(len(columns))])
.insert(*[self.insert_parameter(i) for i in range(len(columns))])
)
if ignore_conflicts:
query = query.on_conflict().do_nothing()
Expand All @@ -195,7 +199,10 @@ def _prepare_insert_statement(
async def _process_insert_result(self, instance: "Model", results: Any) -> None:
raise NotImplementedError() # pragma: nocoverage

def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
raise NotImplementedError() # pragma: nocoverage

def parameter(self) -> ListParameter:
raise NotImplementedError() # pragma: nocoverage

async def execute_insert(self, instance: "Model") -> None:
Expand Down Expand Up @@ -264,15 +271,15 @@ def get_update_sql(
field_object = self.model._meta.fields_map[field]
if not field_object.pk:
if field not in arithmetic_or_function.keys():
query = query.set(db_column, self.parameter(count))
query = query.set(db_column, self.insert_parameter(count))
count += 1
else:
value = F.resolver_arithmetic_expression(
self.model, arithmetic_or_function.get(field)
)[0]
query = query.set(db_column, value)

query = query.where(table[self.model._meta.db_pk_column] == self.parameter(count))
query = query.where(table[self.model._meta.db_pk_column] == self.insert_parameter(count))

sql = query.get_sql()
if not arithmetic_or_function:
Expand Down Expand Up @@ -454,7 +461,10 @@ async def _prefetch_m2m_relation(
if having_criterion:
query = query.having(having_criterion)

_, raw_results = await self.db.execute_query(query.get_sql())
param = self.parameter()
_, raw_results = await self.db.execute_query(
query.get_sql(parameter=param), param.get_parameters()
)
relations: List[Tuple[Any, Any]] = []
related_object_list: List["Model"] = []
model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk
Expand Down
9 changes: 6 additions & 3 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pypika import Parameter
from pypika.dialects import PostgreSQLQueryBuilder
from pypika.terms import Term
from pypika.terms import ListParameter, Term

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
Expand Down Expand Up @@ -38,16 +38,19 @@ class BasePostgresExecutor(BaseExecutor):
posix_regex: postgres_posix_regex,
}

def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
return Parameter("$%d" % (pos + 1,))

def parameter(self) -> ListParameter:
return ListParameter(lambda idx: "$%d" % (idx + 1,))

def _prepare_insert_statement(
self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False
) -> PostgreSQLQueryBuilder:
query = (
self.db.query_class.into(self.model._meta.basetable)
.columns(*columns)
.insert(*[self.parameter(i) for i in range(len(columns))])
.insert(*[self.insert_parameter(i) for i in range(len(columns))])
)
if has_generated:
generated_fields = self.model._meta.generated_db_fields
Expand Down
7 changes: 5 additions & 2 deletions tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pypika import Parameter, functions
from pypika.enums import SqlTypes
from pypika.terms import BasicCriterion, Criterion
from pypika.terms import BasicCriterion, Criterion, ListParameter
from pypika.utils import format_quotes

from tortoise import Model
Expand Down Expand Up @@ -117,9 +117,12 @@ class MySQLExecutor(BaseExecutor):
}
EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON"

def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
return Parameter("%s")

def parameter(self) -> ListParameter:
return ListParameter("%s")

async def _process_insert_result(self, instance: Model, results: int) -> None:
pk_field_object = self.model._meta.pk
if (
Expand Down
8 changes: 6 additions & 2 deletions tortoise/backends/odbc/executor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from pypika import Parameter
from pypika import Parameter, QmarkParameter
from pypika.terms import ListParameter

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.fields import BigIntField, IntField, SmallIntField


class ODBCExecutor(BaseExecutor):
def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
return Parameter("?")

def parameter(self) -> ListParameter:
return QmarkParameter()

async def _process_insert_result(self, instance: Model, results: int) -> None:
pk_field_object = self.model._meta.pk
if (
Expand Down
6 changes: 5 additions & 1 deletion tortoise/backends/psycopg/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from pypika import Parameter
from pypika.terms import ListParameter

from tortoise import Model
from tortoise.backends.base_postgres.executor import BasePostgresExecutor
Expand All @@ -24,5 +25,8 @@ async def _process_insert_result(
for key, val in zip(generated_fields, results):
setattr(instance, db_projection[key], val)

def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
return Parameter("%s")

def parameter(self) -> ListParameter:
return ListParameter("%s")
8 changes: 6 additions & 2 deletions tortoise/backends/sqlite/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Optional, Type, Union

import pytz
from pypika import Parameter
from pypika import Parameter, QmarkParameter
from pypika.terms import ListParameter

from tortoise import Model, fields, timezone
from tortoise.backends.base.executor import BaseExecutor
Expand Down Expand Up @@ -83,9 +84,12 @@ class SqliteExecutor(BaseExecutor):
EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN"
DB_NATIVE = {bytes, str, int, float}

def parameter(self, pos: int) -> Parameter:
def insert_parameter(self, pos: int) -> Parameter:
return Parameter("?")

def parameter(self) -> ListParameter:
return QmarkParameter()

async def _process_insert_result(self, instance: Model, results: int) -> None:
pk_field_object = self.model._meta.pk
if (
Expand Down
8 changes: 4 additions & 4 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def resolver_arithmetic_expression(
left_field_object,
) = cls.resolver_arithmetic_expression(model, left)
if left_field_object:
if field_object and type(field_object) != type(left_field_object):
if field_object and type(field_object) is not type(left_field_object):
raise FieldError(
"Cannot use arithmetic expression between different field type"
)
Expand All @@ -82,7 +82,7 @@ def resolver_arithmetic_expression(
right_field_object,
) = cls.resolver_arithmetic_expression(model, right)
if right_field_object:
if field_object and type(field_object) != type(right_field_object):
if field_object and type(field_object) is not type(right_field_object):
raise FieldError(
"Cannot use arithmetic expression between different field type"
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None:
#: Contains the sub-Q's that this Q is made up of
self.children: Tuple[Q, ...] = args
#: Contains the filters applied to this Q
self.filters: Dict[str, FilterInfoDict] = kwargs
self.filters: Dict[str, Any] = kwargs
if join_type not in {self.AND, self.OR}:
raise OperationalError("join_type must be AND or OR")
#: Specifies if this Q does an AND or OR on its children
Expand Down Expand Up @@ -276,7 +276,7 @@ def _process_filter_kwarg(
)
op = param["operator"]
# this is an ugly hack
if op == operator.eq:
if op == operator.eq and not isinstance(encoded_value, Term):
encoded_value = model._meta.db.query_class._builder()._wrapper_cls(encoded_value)
criterion = op(table[param["source_field"]], encoded_value)
return criterion, join
Expand Down
Loading
Loading