diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 95023f1..5a8c5f2 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -29,7 +29,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore -from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select +from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select, or_, case from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert from sqlalchemy.engine import Connection, Engine from sqlalchemy.ext.asyncio import ( @@ -1117,10 +1117,29 @@ def _handle_field_filter( queried_field = self.EmbeddingStore.cmetadata[field].astext - if operator in {"$in"}: - return queried_field.in_([str(val) for val in filter_value]) - elif operator in {"$nin"}: - return ~queried_field.in_([str(val) for val in filter_value]) + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float, bool)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + is_array = func.jsonb_typeof(self.EmbeddingStore.cmetadata[field]) == 'array' + + # For array fields, Use the @> operator to check if the JSONB field contains any of the values + array_check = or_(*[ + self.EmbeddingStore.cmetadata[field].op("@>")(cast(val, JSONB)) + for val in filter_value + ]) + + # For non-array fields, use in_ + non_array_check = queried_field.in_([str(val) for val in filter_value]) + + result = case( + (is_array, array_check), + else_=non_array_check + ) + + return result if operator == "$in" else ~result elif operator in {"$like"}: return queried_field.like(filter_value) elif operator in {"$ilike"}: diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..96045e5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,7 @@ +import sys +import asyncio + +# Only preform check if your code will run on non-windows environments. +if sys.platform == 'win32': + # Set the policy to prevent "Event loop is closed" error on Windows - https://github.com/encode/httpx/issues/914 + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) \ No newline at end of file diff --git a/tests/unit_tests/fixtures/filtering_test_cases.py b/tests/unit_tests/fixtures/filtering_test_cases.py index 181e8ba..411ae44 100644 --- a/tests/unit_tests/fixtures/filtering_test_cases.py +++ b/tests/unit_tests/fixtures/filtering_test_cases.py @@ -241,6 +241,18 @@ {"id": {"$in": [1, 2]}}, [1, 2], ), + ( + {"tags": {"$in": ["c", "d"]}}, + [2, 3], + ), + ( + {"tags": {"$in": ["b", "d"]}}, + [1, 2, 3], + ), + ( + {"location": {"$in": [1, 2.0]}}, + [1, 2], + ), # Test nin ( {"name": {"$nin": ["adam", "bob"]}}, @@ -251,6 +263,22 @@ {"id": {"$nin": [1, 2]}}, [3], ), + ( + {"tags": {"$nin": ["c", "d"]}}, + [1], + ), + ( + {"tags": {"$nin": ["d"]}}, + [1, 2], + ), + ( + {"tags": {"$nin": ["e", "f"]}}, + [1, 2, 3], + ), + ( + {"location": {"$nin": [1.0, 2]}}, + [3], + ), ] TYPE_5_FILTERING_TEST_CASES = [