Skip to content

Commit

Permalink
Starts ends contains (#126)
Browse files Browse the repository at this point in the history
* Starts/ends/contains keywords

* Multiple contains

* Fix

* comment

* fix

* Fix

* Fix

* Fix

* Typo
  • Loading branch information
0ssigeno authored Sep 26, 2024
1 parent 7c6e587 commit 008e766
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 18 deletions.
86 changes: 69 additions & 17 deletions atlasq/queryset/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import re
from typing import Any, Dict, List, Tuple, Union

from atlasq.queryset.exceptions import AtlasFieldError, AtlasIndexFieldError
Expand All @@ -10,6 +11,22 @@
logger = logging.getLogger(__name__)


def mergedicts(dict1, dict2):
for k in set(dict1.keys()).union(dict2.keys()):
if k in dict1 and k in dict2:
if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
yield k, dict(mergedicts(dict1[k], dict2[k]))
else:
# If one of the values is not a dict, you can't continue merging it.
# Value from second dict overrides one in first and we move on.
yield k, dict2[k]
# Alternatively, replace this with exception raiser to alert you of value conflicts
elif k in dict1:
yield k, dict1[k]
else:
yield k, dict2[k]


class AtlasTransform:

id_keywords = [
Expand All @@ -35,6 +52,8 @@ class AtlasTransform:
"icontains",
"startswith",
"istartswith",
"iendswith",
"endswith",
"iwholeword",
"wholeword",
"not",
Expand All @@ -51,9 +70,10 @@ class AtlasTransform:
range_keywords = ["gt", "gte", "lt", "lte"]
equals_keywords = []
equals_type_supported = (bool, ObjectId, int, datetime.datetime)
startswith_keywords = ["startswith", "istartswith"]
endswith_keywords = ["endswith", "iendswith"]
contains_keywords = ["contains", "icontains"]
text_keywords = [
"contains",
"icontains",
"iwholeword",
"wholeword",
"exact",
Expand All @@ -64,10 +84,6 @@ class AtlasTransform:
regex_keywords = ["regex", "iregex"]
size_keywords = ["size"]
not_converted = [
"istartswith",
"startswith",
"contains",
"icontains",
"mod",
"match",
]
Expand Down Expand Up @@ -147,9 +163,13 @@ def _single_equals(self, path: str, value: Union[ObjectId, bool]):
}
}

def _contains(self, path: str, value: Any, keyword: str = None):
if not keyword:
return {path: {"$elemMatch": value}}
return {path: {"$elemMatch": {f"${keyword}": value}}}

def _equals(self, path: str, value: Union[List[Union[ObjectId, bool]], ObjectId, bool]) -> Dict:
if isinstance(value, list):

values = value
if not values:
raise AtlasFieldError(f"Text search for equals on {path=} cannot be empty")
Expand All @@ -166,6 +186,16 @@ def _text(self, path: str, value: Any) -> Dict:
"text": {"query": value, "path": path},
}

def _startswith(self, path: str, value: Any) -> Dict:
if not value:
raise AtlasFieldError(f"Text search for {path} cannot be {value}")
return self._regex(path, f"{re.escape(value)}.*")

def _endswith(self, path: str, value: Any) -> Dict:
if not value:
raise AtlasFieldError(f"Text search for {path} cannot be {value}")
return self._regex(path, f".*{re.escape(value)}")

def _size(self, path: str, value: int, operator: str) -> Dict:
if not isinstance(value, int):
raise AtlasFieldError(f"Size search for {path} must be an int")
Expand Down Expand Up @@ -230,17 +260,16 @@ def transform(self) -> Tuple[List[Dict], List[Dict], List[Dict]]:
negative = []

for key, value in self.atlas_query.items():
# if to_go is positive, we add the element in the positive list
# if to_go is negative, we add the element in the negative list
to_go = 1
# if the value is positive, we add the element in the positive list
# if the value is negative, we add the element in the negative list
positive = 1
if isinstance(value, QuerySet):
logger.debug("Casting queryset to list, otherwise the aggregation will fail")
value = list(value)
key_parts = key.split("__")
obj = None
path = ""
for i, keyword in enumerate(key_parts):

if keyword in self.id_keywords:
keyword = "_id"
key_parts[i] = keyword
Expand All @@ -255,17 +284,17 @@ def transform(self) -> Tuple[List[Dict], List[Dict], List[Dict]]:
if keyword in self.not_converted:
raise NotImplementedError(f"Keyword {keyword} not implemented yet")
if keyword in self.negative_keywords:
to_go *= -1
positive *= -1

if keyword in self.size_keywords:
# it must the last keyword, otherwise we do not support it
if i != len(key_parts) - 1:
raise NotImplementedError(f"Keyword {keyword} not implemented yet")
other_aggregations.append(self._size(path, value, "eq" if to_go == 1 else "ne"))
other_aggregations.append(self._size(path, value, "eq" if positive == 1 else "ne"))
break
if keyword in self.exists_keywords:
if value is False:
to_go *= -1
positive *= -1
obj = self._exists(path)
break

Expand All @@ -284,8 +313,31 @@ def transform(self) -> Tuple[List[Dict], List[Dict], List[Dict]]:
if keyword in self.all_keywords:
obj = self._all(path, value)
break
if keyword in self.startswith_keywords:
obj = self._startswith(path, value)
break
if keyword in self.endswith_keywords:
obj = self._endswith(path, value)
break
if keyword in self.contains_keywords:
# this is because we could have contains__gte=3
try:
comparison_keyword = key_parts[i + 1]
except IndexError:
aggregation = self._contains(path, value, "eq")
else:
aggregation = self._contains(path, value, comparison_keyword)
# we are merging together the contains, because in the 100% of cases we want to match the same object
for j, aggr in enumerate(other_aggregations):
if path in aggr:
# if we have another path__contains__keyword, we merge them
other_aggregations[j] = dict(mergedicts(aggr, aggregation))
break
else:
other_aggregations.append(aggregation)
break
if keyword in self.type_keywords:
if to_go == -1:
if positive == -1:
raise NotImplementedError(f"At the moment you can't have a negative `{keyword}` keyword")
other_aggregations.append(self._type(path, value))
else:
Expand All @@ -297,13 +349,13 @@ def transform(self) -> Tuple[List[Dict], List[Dict], List[Dict]]:
if self.atlas_index.ensured:
self._ensure_path_is_indexed(path.split("."))
# we are wrapping the result to an embedded document
converted = self._convert_to_embedded_document(path.split("."), obj, positive=to_go == 1)
converted = self._convert_to_embedded_document(path.split("."), obj, positive=positive == 1)
if obj != converted:
# we have an embedded object
# the mustNot is done inside the embedded document clause
affirmative = self.merge_embedded_documents(converted, affirmative)
else:
if to_go == 1:
if positive == 1:
affirmative.append(converted)
else:
negative.append(converted)
Expand Down
44 changes: 43 additions & 1 deletion tests/queryset/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def test__equals_list_bool(self):
self.assertEqual(res["compound"]["should"][0]["equals"]["path"], "field")
self.assertEqual(res["compound"]["should"][0]["equals"]["value"], True)

def test(self):
def test_equal(self):
q = AtlasQ(f=3)
t = AtlasTransform(q.query, AtlasIndex("test"))
res = t._text("field", "aaa")
Expand All @@ -768,6 +768,48 @@ def test_none(self):
with self.assertRaises(AtlasFieldError):
t._text("field", None)

def test_convert_startswith(self):
q = AtlasQ(f__startswith="test?")
t = AtlasTransform(q.query, AtlasIndex("test"))
res = t._startswith("f", "test?")
self.assertIn("regex", res)
self.assertIn("query", res["regex"])
self.assertIn("test\\?.*", res["regex"]["query"])
self.assertIn("path", res["regex"])
self.assertIn("f", res["regex"]["path"])

def test_convert_endswith(self):
q = AtlasQ(f__endswith="test?")
t = AtlasTransform(q.query, AtlasIndex("test"))
res = t._endswith("f", "test?")
self.assertIn("regex", res)
self.assertIn("query", res["regex"])
self.assertIn(".*test\\?", res["regex"]["query"])
self.assertIn("path", res["regex"])
self.assertIn("f", res["regex"]["path"])

def test_contains(self):
q = AtlasQ(f__contains="test")
positive, negative, aggregations = AtlasTransform(q.query, AtlasIndex("test")).transform()
self.assertEqual(positive, [])
self.assertEqual(negative, [])
self.assertEqual(
{"f": {"$elemMatch": {"$eq": "test"}}},
aggregations[0],
json.dumps(aggregations, indent=4),
)

def test_multiple_contains(self):
q = AtlasQ(f__contains__gte="test1", f__contains__lte="test2")
positive, negative, aggregations = AtlasTransform(q.query, AtlasIndex("test")).transform()
self.assertEqual(positive, [])
self.assertEqual(negative, [])
self.assertEqual(
{"f": {"$elemMatch": {"$gte": "test1", "$lte": "test2"}}},
aggregations[0],
json.dumps(aggregations, indent=4),
)

def test__size_operator_not_supported(self):
q = AtlasQ(f=3)
t = AtlasTransform(q.query, AtlasIndex("test"))
Expand Down

0 comments on commit 008e766

Please sign in to comment.