Skip to content

Support for jsonb array comparision for $in $nin operator #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select
from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select, or_, case
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.ext.asyncio import (
Expand Down Expand Up @@ -1117,10 +1117,29 @@ def _handle_field_filter(

queried_field = self.EmbeddingStore.cmetadata[field].astext

if operator in {"$in"}:
return queried_field.in_([str(val) for val in filter_value])
elif operator in {"$nin"}:
return ~queried_field.in_([str(val) for val in filter_value])
if operator in {"$in", "$nin"}:
for val in filter_value:
if not isinstance(val, (str, int, float, bool)):
raise NotImplementedError(
f"Unsupported type: {type(val)} for value: {val}"
)
is_array = func.jsonb_typeof(self.EmbeddingStore.cmetadata[field]) == 'array'

# For array fields, Use the @> operator to check if the JSONB field contains any of the values
array_check = or_(*[
self.EmbeddingStore.cmetadata[field].op("@>")(cast(val, JSONB))
for val in filter_value
])

# For non-array fields, use in_
non_array_check = queried_field.in_([str(val) for val in filter_value])

result = case(
(is_array, array_check),
else_=non_array_check
)

return result if operator == "$in" else ~result
elif operator in {"$like"}:
return queried_field.like(filter_value)
elif operator in {"$ilike"}:
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import sys
import asyncio

# Only preform check if your code will run on non-windows environments.
if sys.platform == 'win32':
# Set the policy to prevent "Event loop is closed" error on Windows - https://github.com/encode/httpx/issues/914
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
28 changes: 28 additions & 0 deletions tests/unit_tests/fixtures/filtering_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@
{"id": {"$in": [1, 2]}},
[1, 2],
),
(
{"tags": {"$in": ["c", "d"]}},
[2, 3],
),
(
{"tags": {"$in": ["b", "d"]}},
[1, 2, 3],
),
(
{"location": {"$in": [1, 2.0]}},
[1, 2],
),
# Test nin
(
{"name": {"$nin": ["adam", "bob"]}},
Expand All @@ -251,6 +263,22 @@
{"id": {"$nin": [1, 2]}},
[3],
),
(
{"tags": {"$nin": ["c", "d"]}},
[1],
),
(
{"tags": {"$nin": ["d"]}},
[1, 2],
),
(
{"tags": {"$nin": ["e", "f"]}},
[1, 2, 3],
),
(
{"location": {"$nin": [1.0, 2]}},
[3],
),
]

TYPE_5_FILTERING_TEST_CASES = [
Expand Down