From 7a2172194e43a0eef46e24f020c7e695666a4aa8 Mon Sep 17 00:00:00 2001 From: srimon12 Date: Sat, 30 May 2026 20:10:55 +0530 Subject: [PATCH 1/3] Refactor QQL Executor and Utility Functions - Moved various utility functions from `executor.py` to a new `utils.py` file for better organization and maintainability. - Simplified the `_resolve_topology` method by utilizing the new `collection_topology_kwargs` function. - Updated the `_build_dense_point_vector` and `_extract_point_id_and_payload` methods to use the new utility functions. - Enhanced error handling in the `_ensure_collection` method to catch specific exceptions. - Improved the handling of hybrid queries and prefetching by introducing `build_hybrid_prefetches` and `build_dense_query` functions. - Added tests for the `AsyncConnection` class to ensure proper initialization, query execution, and lifecycle management. - Updated existing tests to cover new functionality and ensure backward compatibility. --- README.md | 18 +- benchmark/qql_transport_benchmark.py | 165 ++++ docs/programmatic.md | 58 +- src/qql/__init__.py | 8 + src/qql/async_connection.py | 138 +++ src/qql/async_executor.py | 1201 ++++++++++++++++++++++++++ src/qql/connection.py | 10 +- src/qql/executor.py | 367 ++------ src/qql/utils.py | 453 ++++++++++ tests/test_async_connection.py | 301 +++++++ tests/test_connection.py | 18 +- 11 files changed, 2424 insertions(+), 313 deletions(-) create mode 100644 benchmark/qql_transport_benchmark.py create mode 100644 src/qql/async_connection.py create mode 100644 src/qql/async_executor.py create mode 100644 src/qql/utils.py create mode 100644 tests/test_async_connection.py diff --git a/README.md b/README.md index 5823e27..8bc0a02 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![PyPI version](https://img.shields.io/pypi/v/qql-cli?color=blue&label=PyPI)](https://pypi.org/project/qql-cli/) [![Python 3.12+](https://img.shields.io/pypi/pyversions/qql-cli)](https://pypi.org/project/qql-cli/) [![MIT License](https://img.shields.io/badge/license-MIT-green)](LICENSE) -[![Tests](https://img.shields.io/badge/tests-549%20passing-brightgreen)](tests/) +[![Tests](https://img.shields.io/badge/tests-633%20passing-brightgreen)](tests/) -Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, and collection dump/restore. +Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, collection dump/restore, async execution, and programmatic gRPC transport. ``` qql> INSERT INTO COLLECTION notes VALUES {'text': 'Qdrant is a vector database', 'author': 'alice', 'year': 2024} @@ -50,7 +50,7 @@ Your query string When you run `INSERT`, the `text` field is automatically converted into a dense vector using [Fastembed](https://github.com/qdrant/fastembed). In **hybrid mode** (`USING HYBRID`), a sparse BM25 vector is also generated alongside the dense vector, and searches use Qdrant's Reciprocal Rank Fusion (RRF) by default to merge the results of both retrieval methods. You can switch hybrid search to DBSF with `FUSION 'dbsf'`. -QQL also exposes a **programmatic API** for use inside Python applications — no CLI required: +QQL also exposes a **programmatic API** for use inside Python applications — no CLI required. Use `Connection` for sync code and `AsyncConnection` for async apps: ```python from qql import Connection @@ -62,6 +62,14 @@ with Connection("http://localhost:6333") as conn: print(hit["score"], hit["payload"]) ``` +```python +from qql import AsyncConnection + +async with AsyncConnection("http://localhost:6333", prefer_grpc=True) as conn: + result = await conn.run_query("SHOW COLLECTIONS") + print(result.data) +``` + --- ## Installation @@ -104,7 +112,7 @@ Full documentation lives in the [`docs/`](docs/) folder and at **[pavanjava.gith | [WHERE Filters](docs/filters.md) | Full SQL-style filter operators | | [Collections & Quantization](docs/collections.md) | SHOW, CREATE, DROP, QUANTIZE (scalar/turbo/binary/product), CREATE INDEX, UPDATE VECTOR, UPDATE PAYLOAD | | [Scripts: EXECUTE / DUMP](docs/scripts.md) | Script files, collection backup/restore | -| [Programmatic Usage](docs/programmatic.md) | Use QQL as a Python library via `Connection` or `run_query()` | +| [Programmatic Usage](docs/programmatic.md) | Sync/async Python APIs and gRPC | | [Reference: Models / Config / Errors](docs/reference.md) | Embedding models, config file, error reference | --- @@ -188,7 +196,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected: **549 tests passing**. +Expected: **633 tests passing**. --- diff --git a/benchmark/qql_transport_benchmark.py b/benchmark/qql_transport_benchmark.py new file mode 100644 index 0000000..c5ab7f0 --- /dev/null +++ b/benchmark/qql_transport_benchmark.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +import os +import time +from dataclasses import dataclass + +from qql import AsyncConnection, Connection + + +URL = os.getenv("QQL_BENCH_URL", "http://localhost:6333") +SECRET = os.getenv("QQL_BENCH_SECRET") or None +ITERATIONS = int(os.getenv("QQL_BENCH_ITERATIONS", "50")) +WARMUP = int(os.getenv("QQL_BENCH_WARMUP", "5")) +CONCURRENCY = int(os.getenv("QQL_BENCH_CONCURRENCY", "10")) + +DOCS = [ + "Qdrant stores vectors and payloads for semantic search workloads", + "FastEmbed generates local dense embeddings for short text queries", + "gRPC can reduce transport overhead for high volume vector database calls", + "REST remains simple and reliable for operational database workflows", + "Async clients help Python applications keep network requests in flight", + "Local embedding models can dominate latency before the database is called", + "Hybrid search combines dense vectors with sparse lexical retrieval", + "Payload filters narrow search results by metadata fields and values", + "Collection topology determines named dense and sparse vector behavior", + "Benchmark results should separate setup cost from measured query latency", +] +QUERY_TEXT = "local embedding vector database transport benchmark" + + +@dataclass(frozen=True) +class Result: + mode: str + total_ms: float + avg_ms: float + qps: float + + +def search_query(collection: str) -> str: + return f"SEARCH {collection} SIMILAR TO '{QUERY_TEXT}' LIMIT 5" + + +def insert_query(collection: str, idx: int, text: str) -> str: + return ( + f"INSERT INTO COLLECTION {collection} " + f"VALUES {{'id': {idx}, 'text': '{text}', 'kind': 'bench'}}" + ) + + +def ignore_drop(conn: Connection, collection: str) -> None: + try: + conn.run_query(f"DROP COLLECTION {collection}") + except Exception: + pass + + +async def ignore_drop_async(conn: AsyncConnection, collection: str) -> None: + try: + await conn.run_query(f"DROP COLLECTION {collection}") + except Exception: + pass + + +def setup_sync(collection: str, *, prefer_grpc: bool) -> None: + with Connection(URL, secret=SECRET, prefer_grpc=prefer_grpc) as conn: + ignore_drop(conn, collection) + for idx, text in enumerate(DOCS, start=1): + conn.run_query(insert_query(collection, idx, text)) + + +async def setup_async(collection: str, *, prefer_grpc: bool) -> None: + async with AsyncConnection(URL, secret=SECRET, prefer_grpc=prefer_grpc) as conn: + await ignore_drop_async(conn, collection) + for idx, text in enumerate(DOCS, start=1): + await conn.run_query(insert_query(collection, idx, text)) + + +def bench_sync(mode: str, collection: str, *, prefer_grpc: bool) -> Result: + setup_sync(collection, prefer_grpc=prefer_grpc) + query = search_query(collection) + with Connection(URL, secret=SECRET, prefer_grpc=prefer_grpc) as conn: + for _ in range(WARMUP): + conn.run_query(query) + start = time.perf_counter() + for _ in range(ITERATIONS): + conn.run_query(query) + total_ms = (time.perf_counter() - start) * 1000 + return Result(mode, total_ms, total_ms / ITERATIONS, ITERATIONS / (total_ms / 1000)) + + +async def bench_async(mode: str, collection: str, *, prefer_grpc: bool) -> Result: + await setup_async(collection, prefer_grpc=prefer_grpc) + query = search_query(collection) + async with AsyncConnection(URL, secret=SECRET, prefer_grpc=prefer_grpc) as conn: + for _ in range(WARMUP): + await conn.run_query(query) + start = time.perf_counter() + for _ in range(ITERATIONS): + await conn.run_query(query) + total_ms = (time.perf_counter() - start) * 1000 + return Result(mode, total_ms, total_ms / ITERATIONS, ITERATIONS / (total_ms / 1000)) + + +async def bench_async_concurrent( + mode: str, + collection: str, + *, + prefer_grpc: bool, +) -> Result: + query = search_query(collection) + async with AsyncConnection(URL, secret=SECRET, prefer_grpc=prefer_grpc) as conn: + for _ in range(WARMUP): + await conn.run_query(query) + sem = asyncio.Semaphore(CONCURRENCY) + + async def one() -> None: + async with sem: + await conn.run_query(query) + + start = time.perf_counter() + await asyncio.gather(*(one() for _ in range(ITERATIONS))) + total_ms = (time.perf_counter() - start) * 1000 + return Result(mode, total_ms, total_ms / ITERATIONS, ITERATIONS / (total_ms / 1000)) + + +def print_table(title: str, results: list[Result]) -> None: + print(f"\n### {title}\n") + print("| Mode | Total ms | Avg ms/op | Ops/sec |") + print("|---|---:|---:|---:|") + for r in results: + print(f"| {r.mode} | {r.total_ms:,.2f} | {r.avg_ms:,.2f} | {r.qps:,.2f} |") + + +async def main() -> None: + print("QQL SEARCH benchmark") + print(f"URL: {URL}") + print(f"Workload: {ITERATIONS} measured SEARCH queries, {WARMUP} warmup") + print("Embedding: local FastEmbed dense model, warmed before timing") + + latency = [ + bench_sync("sync REST", "qql_bench_sync_rest", prefer_grpc=False), + await bench_async("async REST", "qql_bench_async_rest", prefer_grpc=False), + bench_sync("sync gRPC", "qql_bench_sync_grpc", prefer_grpc=True), + await bench_async("async gRPC", "qql_bench_async_grpc", prefer_grpc=True), + ] + print_table("Single-flight latency", latency) + + concurrent = [ + await bench_async_concurrent( + f"async REST x{CONCURRENCY}", + "qql_bench_async_rest", + prefer_grpc=False, + ), + await bench_async_concurrent( + f"async gRPC x{CONCURRENCY}", + "qql_bench_async_grpc", + prefer_grpc=True, + ), + ] + print_table("Async concurrent throughput", concurrent) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/programmatic.md b/docs/programmatic.md index 007c9d6..ae70e1a 100644 --- a/docs/programmatic.md +++ b/docs/programmatic.md @@ -11,10 +11,10 @@ QQL can be used as a Python library without the CLI. ## `Connection` — Primary API -`Connection` is the recommended way to use QQL programmatically. It opens a -single connection to Qdrant once and reuses it for every `run_query()` call — -more efficient than the legacy `run_query()` function, which creates a new -client on every invocation. +`Connection` is the recommended sync API for using QQL programmatically. It +opens a single connection to Qdrant once and reuses it for every `run_query()` +call — more efficient than the legacy `run_query()` function, which creates a +new client on every invocation. ### Basic usage @@ -180,9 +180,55 @@ with Connection("http://localhost:6333") as conn: | `secret` | `str \| None` | `None` | API key; `None` for unauthenticated | | `default_model` | `str \| None` | `None` → `sentence-transformers/all-MiniLM-L6-v2` | Dense embedding model used when no `USING MODEL` clause is given | | `verify` | `bool \| str` | `True` | TLS verification setting; use `False` to skip verification or a CA bundle path for internal/self-signed certificates | +| `prefer_grpc` | `bool` | `False` | Use Qdrant's gRPC transport when available | +| `grpc_port` | `int` | `6334` | Qdrant gRPC port used when `prefer_grpc=True` | | `default_dense_vector_name` | `str` | `"dense"` | Dense vector name used when QQL creates a collection and no explicit `USING VECTOR` name is given | | `default_sparse_vector_name` | `str` | `"sparse"` | Sparse vector name used when QQL creates a hybrid collection and no explicit sparse vector name is given | +### gRPC transport + +Pass `prefer_grpc=True` when your Qdrant deployment exposes the gRPC port: + +```python +from qql import Connection + +with Connection("http://localhost:6333", prefer_grpc=True) as conn: + result = conn.run_query("SHOW COLLECTIONS") + print(result.data) +``` + +Use `grpc_port` when the deployment uses a non-default gRPC port. + +--- + +## `AsyncConnection` — Async API + +`AsyncConnection` mirrors the sync `Connection` API and uses Qdrant's +`AsyncQdrantClient` under the hood. + +```python +from qql import AsyncConnection + +async with AsyncConnection("http://localhost:6333") as conn: + await conn.run_query( + "INSERT INTO COLLECTION notes VALUES {'text': 'hello async world'}" + ) + result = await conn.run_query("SEARCH notes SIMILAR TO 'async' LIMIT 5") + for hit in result.data: + print(hit["score"], hit["payload"]) +``` + +Async connections support the same `url`, `secret`, `default_model`, `verify`, +`prefer_grpc`, and `grpc_port` parameters: + +```python +async with AsyncConnection( + "http://localhost:6333", + prefer_grpc=True, +) as conn: + result = await conn.run_query("SHOW COLLECTIONS") +``` + ### Power-user: `executor` property For low-level access to the pipeline, use `conn.executor` directly: @@ -225,8 +271,8 @@ for hit in result.data: print(hit["score"], hit["payload"]) ``` -`run_query()` accepts the same `url`, `secret`, `default_model`, and `verify` -parameters as `Connection.__init__()`. +`run_query()` accepts the same `url`, `secret`, `default_model`, `verify`, +`prefer_grpc`, and `grpc_port` parameters as `Connection.__init__()`. --- diff --git a/src/qql/__init__.py b/src/qql/__init__.py index da43725..9305306 100644 --- a/src/qql/__init__.py +++ b/src/qql/__init__.py @@ -12,6 +12,8 @@ QQLConfig, load_config, ) +from .async_connection import AsyncConnection +from .async_executor import AsyncExecutor from .connection import Connection from .exceptions import QQLError, QQLRuntimeError, QQLSyntaxError from .executor import ExecutionResult, Executor @@ -20,6 +22,8 @@ __all__ = [ "__version__", + "AsyncConnection", + "AsyncExecutor", "Connection", "DEFAULT_DENSE_VECTOR_NAME", "DEFAULT_MODEL", @@ -43,6 +47,8 @@ def run_query( secret: str | None = None, default_model: str | None = None, verify: bool | str = True, + prefer_grpc: bool = False, + grpc_port: int = 6334, ) -> ExecutionResult: """One-shot convenience function kept for backward compatibility. @@ -61,5 +67,7 @@ def run_query( secret=secret, default_model=default_model, verify=verify, + prefer_grpc=prefer_grpc, + grpc_port=grpc_port, ) as conn: return conn.run_query(query) diff --git a/src/qql/async_connection.py b/src/qql/async_connection.py new file mode 100644 index 0000000..fb78bc3 --- /dev/null +++ b/src/qql/async_connection.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Any +from qdrant_client import AsyncQdrantClient + +from .config import DEFAULT_MODEL, QQLConfig +from .async_executor import AsyncExecutor +from .executor import ExecutionResult +from .lexer import Lexer +from .parser import Parser + + +class AsyncConnection: + """Stateful asynchronous connection to a Qdrant instance. + + Creates a single ``AsyncQdrantClient`` and ``AsyncExecutor`` once and reuses + them for every :meth:`run_query` call — more efficient than the standalone + one-shot helpers, which create a fresh client on every invocation. + + **Basic usage**:: + + conn = AsyncConnection("http://localhost:6333", secret="my-key") + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello world'}" + ) + result = await conn.run_query("SEARCH docs SIMILAR TO 'hello' LIMIT 5") + await conn.close() + + **Context manager (preferred)**:: + + async with AsyncConnection("http://localhost:6333") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + print(result.data) + + **Qdrant Cloud**:: + + async with AsyncConnection("https://.qdrant.io", secret="") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + + **Custom embedding model**:: + + async with AsyncConnection( + "http://localhost:6333", + default_model="BAAI/bge-base-en-v1.5", + ) as conn: + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello'}" + ) + """ + + def __init__( + self, + url: str = "http://localhost:6333", + secret: str | None = None, + default_model: str | None = None, + verify: bool | str = True, + prefer_grpc: bool = False, + grpc_port: int = 6334, + ) -> None: + """Create an asynchronous connection to a Qdrant instance. + + Args: + url: Base URL of the Qdrant instance (default: ``http://localhost:6333``). + secret: API key for authenticated instances; ``None`` for unauthenticated. + default_model: Dense embedding model used when no ``USING MODEL`` clause + is specified. Defaults to ``sentence-transformers/all-MiniLM-L6-v2``. + verify: SSL certificate verification. Set to ``False`` to skip + verification for self-signed/internal certificates, or pass + a path to a custom CA bundle (default: ``True``). + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). + """ + self._config = QQLConfig( + url=url, + secret=secret, + default_model=default_model or DEFAULT_MODEL, + verify=verify, + ) + client_kwargs = {"url": url, "api_key": secret, "verify": verify} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = AsyncQdrantClient(**client_kwargs) + self._executor = AsyncExecutor(self._client, self._config) + + # ── Public API ──────────────────────────────────────────────────────── + + async def run_query(self, query: str) -> ExecutionResult: + """Parse and execute a single QQL statement asynchronously. + + Args: + query: A QQL query string, e.g. ``"SEARCH docs SIMILAR TO 'hello' LIMIT 5"``. + + Returns: + An :class:`~qql.ExecutionResult` with ``success``, ``message``, and ``data`` fields. + + Raises: + QQLSyntaxError: The query string could not be parsed. + QQLRuntimeError: The query parsed correctly but Qdrant rejected it. + """ + tokens = Lexer().tokenize(query) + node = Parser(tokens).parse() + return await self._executor.execute(node) + + async def close(self) -> None: + """Close the underlying Qdrant asynchronous connection pool.""" + await self._client.close() + + # ── Context manager ─────────────────────────────────────────────────── + + async def __aenter__(self) -> AsyncConnection: + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + await self.close() + + # ── Power-user properties ───────────────────────────────────────────── + + @property + def config(self) -> QQLConfig: + """The :class:`~qql.QQLConfig` in use (url, secret, default_model).""" + return self._config + + @property + def executor(self) -> AsyncExecutor: + """Direct access to the :class:`~qql.AsyncExecutor` for low-level use. + + Example — run multiple statements sharing a pre-built AST node:: + + from qql.lexer import Lexer + from qql.parser import Parser + + conn = AsyncConnection("http://localhost:6333") + tokens = Lexer().tokenize("SHOW COLLECTIONS") + node = Parser(tokens).parse() + result = await conn.executor.execute(node) + """ + return self._executor diff --git a/src/qql/async_executor.py b/src/qql/async_executor.py new file mode 100644 index 0000000..282cfe3 --- /dev/null +++ b/src/qql/async_executor.py @@ -0,0 +1,1201 @@ +from __future__ import annotations + +import time +import asyncio +from typing import Any + +from qdrant_client import AsyncQdrantClient +from qdrant_client.http.exceptions import UnexpectedResponse +from qdrant_client.models import ( + Distance, + Filter, + FusionQuery, + LookupLocation, + Modifier, + PointStruct, + PointVectors, + RecommendInput, + RecommendQuery, + SearchParams, + SparseVector, + SparseVectorParams, + VectorParams, + PayloadSchemaType, +) + +from .ast_nodes import ( + ASTNode, + AlterCollectionStmt, + CreateCollectionStmt, + CreateIndexStmt, + DeleteStmt, + DropCollectionStmt, + InsertBulkStmt, + InsertStmt, + RecommendStmt, + SelectStmt, + ScrollStmt, + SearchStmt, + ShowCollectionStmt, + ShowCollectionsStmt, + UpdateVectorStmt, + UpdatePayloadStmt, +) +from .config import QQLConfig +from .embedder import Embedder, SparseEmbedder +from .exceptions import QQLRuntimeError +from .executor import Executor, ExecutionResult, CollectionTopology +from .utils import ( + build_dense_point_vector, + build_dense_query, + build_hybrid_prefetches, + exclude_ids_from_filter, + extract_point_id_and_payload, + has_mmr, + is_grpc_not_found_error, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, +) + +_RERANK_FETCH_MULTIPLIER = 4 +_HYBRID_PREFETCH_MULTIPLIER = 4 +_COLLECTION_VISIBILITY_TIMEOUT_SECONDS = 5.0 +_COLLECTION_VISIBILITY_POLL_SECONDS = 0.05 + + +class AsyncExecutor(Executor): + """Asynchronous QQL execution engine for ``AsyncQdrantClient``. + + The async executor mirrors :class:`~qql.Executor` at the statement boundary: + every AST node supported by the sync executor has an async execution path + here. Pure parsing, validation, vector-shaping, filter-building, and result + formatting helpers live in ``qql.utils`` or are inherited from + :class:`~qql.Executor`; only Qdrant client calls and collection-creation + coordination are implemented with ``async``/``await`` in this module. + """ + + def __init__(self, client: AsyncQdrantClient, config: QQLConfig) -> None: + super().__init__(client=client, config=config) # type: ignore[arg-type] + self._client: AsyncQdrantClient = client + self._creation_lock = asyncio.Lock() + + async def execute(self, node: ASTNode) -> ExecutionResult: + if isinstance(node, InsertBulkStmt): + return await self._execute_insert_bulk(node) + if isinstance(node, InsertStmt): + return await self._execute_insert(node) + if isinstance(node, CreateCollectionStmt): + return await self._execute_create(node) + if isinstance(node, AlterCollectionStmt): + return await self._execute_alter_collection(node) + if isinstance(node, CreateIndexStmt): + return await self._execute_create_index(node) + if isinstance(node, DropCollectionStmt): + return await self._execute_drop(node) + if isinstance(node, ShowCollectionsStmt): + return await self._execute_show(node) + if isinstance(node, ShowCollectionStmt): + return await self._execute_show_collection(node) + if isinstance(node, ScrollStmt): + return await self._execute_scroll(node) + if isinstance(node, SelectStmt): + return await self._execute_select(node) + if isinstance(node, SearchStmt): + return await self._execute_search(node) + if isinstance(node, RecommendStmt): + return await self._execute_recommend(node) + if isinstance(node, DeleteStmt): + return await self._execute_delete(node) + if isinstance(node, UpdateVectorStmt): + return await self._execute_update_vector(node) + if isinstance(node, UpdatePayloadStmt): + return await self._execute_update_payload(node) + raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") + + # ── Topology & Helper methods ───────────────────────────────────────── + + async def _fetch_collection_info(self, name: str): + try: + return await self._client.get_collection(name) + except UnexpectedResponse as e: + if e.status_code == 404: + return None + raise QQLRuntimeError( + f"Qdrant error fetching collection '{name}': {e}" + ) from e + except ValueError as e: + if f"Collection {name} not found" in str(e): + return None + raise + except Exception as e: + if is_grpc_not_found_error(e): + return None + raise + + async def _resolve_topology(self, name: str) -> CollectionTopology: + info = await self._fetch_collection_info(name) + if info is None: + return CollectionTopology(exists=False, is_named_dense=False) + return self._topology_from_collection_info(info) + + async def _ensure_collection( + self, + name: str, + vector_size: int, + topology: CollectionTopology, + explicit_vector: str | None, + ) -> CollectionTopology: + if topology.exists: + sizes = topology.dense_size_map() + if topology.is_named_dense: + vector_name = topology.dense_using(explicit_vector) + expected_size = sizes.get(vector_name) # type: ignore[arg-type] + if expected_size is not None and expected_size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' vector " + f"'{vector_name}' expects {expected_size} dims, but " + f"model produces {vector_size} dims. Specify a compatible " + "model with USING MODEL ''." + ) + elif topology.has_unnamed_dense: + expected_size = sizes.get("") + if expected_size is not None and expected_size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' expects " + f"{expected_size} dims, but model produces {vector_size} dims. " + f"Specify a compatible model with USING MODEL ''." + ) + else: + raise QQLRuntimeError("Collection has no dense vector") + return topology + + vector_name = explicit_vector or self._default_dense_vector_name() + created_topology = CollectionTopology( + exists=True, + is_named_dense=True, + dense_names=(vector_name,), + dense_sizes=((vector_name, vector_size),), + ) + async with self._creation_lock: + current_topology = await self._resolve_topology(name) + if current_topology.exists: + return await self._ensure_collection( + name, vector_size, current_topology, explicit_vector + ) + + await self._create_collection_and_wait( + collection_name=name, + vectors_config={ + vector_name: VectorParams( + size=vector_size, distance=Distance.COSINE + ) + }, + ) + return created_topology + + async def _create_collection_and_wait(self, **kwargs: Any) -> None: + collection_name = kwargs["collection_name"] + await self._client.create_collection(**kwargs) + + deadline = time.monotonic() + _COLLECTION_VISIBILITY_TIMEOUT_SECONDS + while time.monotonic() < deadline: + if await self._client.collection_exists(collection_name): + return + await asyncio.sleep(_COLLECTION_VISIBILITY_POLL_SECONDS) + + raise QQLRuntimeError( + f"Collection '{collection_name}' was created but did not become visible in time" + ) + + async def _build_hybrid_vectors( + self, + query_text: str, + dense_model: str, + sparse_model_name: str, + ) -> tuple[list[float], SparseVector]: + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(query_text) + sparse_obj = sparse_embedder.query_embed(query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + return dense_vector, sparse_vector + + # ── Statement executors ─────────────────────────────────────────────── + + async def _execute_insert(self, node: InsertStmt) -> ExecutionResult: + if "text" not in node.values: + raise QQLRuntimeError("INSERT requires a 'text' field in VALUES") + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(node.values["text"]) + sparse_obj = sparse_embedder.embed(node.values["text"]) + + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + else: + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams( + size=len(dense_vector), distance=Distance.COSINE + ) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + point_id, payload = extract_point_id_and_payload(node.values) + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[ + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}] (hybrid)", + data={"id": point_id, "collection": node.collection}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.values["text"]) + + topology = await self._ensure_collection( + node.collection, len(vector), topology, node.dense_vector + ) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + + point_id, payload = extract_point_id_and_payload(node.values) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[PointStruct(id=point_id, vector=point_vector, payload=payload)], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}]", + data={"id": point_id, "collection": node.collection}, + ) + + async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: + if not node.values_list: + raise QQLRuntimeError("INSERT BULK VALUES list is empty") + for i, vals in enumerate(node.values_list): + if "text" not in vals: + raise QQLRuntimeError( + f"INSERT BULK: item at index {i} is missing required 'text' field" + ) + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + + dense_vectors = [ + dense_embedder.embed(vals["text"]) for vals in node.values_list + ] + sparse_objs = [sparse_embedder.embed(vals["text"]) for vals in node.values_list] + + first_dense_vector = dense_vectors[0] if dense_vectors else None + if not topology.exists: + assert first_dense_vector is not None + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams(size=len(first_dense_vector), distance=Distance.COSINE) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + points: list[PointStruct] = [] + for idx, vals in enumerate(node.values_list): + point_id, payload = extract_point_id_and_payload(vals) + dense_vector = dense_vectors[idx] + sparse_obj = sparse_objs[idx] + sparse_vector = SparseVector( + indices=sparse_obj["indices"], values=sparse_obj["values"] + ) + points.append( + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + + vectors = [embedder.embed(vals["text"]) for vals in node.values_list] + + first_vector = vectors[0] if vectors else None + points = [] + for idx, vals in enumerate(node.values_list): + vector = vectors[idx] + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + points.append( + PointStruct(id=point_id, vector=point_vector, payload=payload) + ) + + assert first_vector is not None + topology = await self._ensure_collection( + node.collection, len(first_vector), topology, node.dense_vector + ) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, + ) + + async def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: + if await self._client.collection_exists(node.collection): + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' already exists", + ) + + dense_model_name = node.model or self._config.default_model + + quant_config = ( + self._build_quantization_config(node.quantization) + if node.quantization is not None + else None + ) + quant_label = ( + f", {node.quantization.type.value} quantization" + if node.quantization is not None + else "" + ) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + params_config = self._build_collection_params_create_kwargs(node.config) + config_label = self._describe_collection_config(node.config) + vector_on_disk = ( + node.config.vectors.on_disk + if node.config is not None and node.config.vectors is not None + else None + ) + + if node.hybrid: + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + create_kwargs: dict[str, Any] = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + "sparse_vectors_config": { + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' created " + f"(hybrid: {dims}-dim dense + BM25 sparse, cosine distance{quant_label}{config_label})" + ), + ) + + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + create_kwargs = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' created ({dims}-dimensional vectors, cosine distance{quant_label}{config_label})", + ) + + async def _execute_alter_collection(self, node: AlterCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + update_kwargs: dict[str, Any] = {"collection_name": node.collection} + vectors_config = self._build_vectors_config_diff(topology, node.config) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + collection_params = self._build_collection_params_diff(node.config) + quantization_config = self._build_alter_quantization_config(node.quantization) + + if vectors_config is not None: + update_kwargs["vectors_config"] = vectors_config + if hnsw_config is not None: + update_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + update_kwargs["optimizers_config"] = optimizers_config + if collection_params is not None: + update_kwargs["collection_params"] = collection_params + if quantization_config is not None: + update_kwargs["quantization_config"] = quantization_config + + try: + await self._client.update_collection(**update_kwargs) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during ALTER COLLECTION: {e}") from e + + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' altered" + f"{self._describe_collection_config(node.config)}" + f"{self._describe_quantization_update(node.quantization)}" + ), + ) + + async def _execute_create_index(self, node: CreateIndexStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + schema_map = { + "keyword": PayloadSchemaType.KEYWORD, + "integer": PayloadSchemaType.INTEGER, + "float": PayloadSchemaType.FLOAT, + "bool": PayloadSchemaType.BOOL, + "text": PayloadSchemaType.TEXT, + "geo": PayloadSchemaType.GEO, + "datetime": PayloadSchemaType.DATETIME, + "uuid": PayloadSchemaType.UUID, + } + try: + schema_map[node.schema] + except KeyError as e: + raise QQLRuntimeError( + "Unknown index type '" + f"{node.schema}'. Expected one of: keyword, integer, float, bool, text, geo, datetime, uuid" + ) from e + field_schema = self._build_payload_index_schema(node) + + try: + await self._client.create_payload_index( + collection_name=node.collection, + field_name=node.field_name, + field_schema=field_schema, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during CREATE INDEX: {e}") from e + + option_label = f" with options {node.options}" if node.options else "" + return ExecutionResult( + success=True, + message=( + f"Created index on '{node.collection}.{node.field_name}' as '{node.schema}'{option_label}" + ), + ) + + async def _execute_drop(self, node: DropCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + await self._client.delete_collection(node.collection) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' dropped", + ) + + async def _execute_show(self, node: ShowCollectionsStmt) -> ExecutionResult: + response = await self._client.get_collections() + names = [c.name for c in response.collections] + return ExecutionResult( + success=True, + message=f"{len(names)} collection(s) found", + data=names, + ) + + async def _execute_show_collection(self, node: ShowCollectionStmt) -> ExecutionResult: + info = await self._fetch_collection_info(node.collection) + if info is None: + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + config = info.config + params = config.params + + vectors = params.vectors # type: ignore[union-attr] + sparse_vector_params = params.sparse_vectors or {} + if isinstance(vectors, dict): + vector_details = {} + for vname, vconfig in vectors.items(): + vector_details[vname] = { + "size": vconfig.size, + "distance": str(vconfig.distance) if vconfig.distance else None, + "on_disk": vconfig.on_disk, + } + elif vectors is None: + raise QQLRuntimeError( + f"Collection '{node.collection}' has no vector configuration" + ) + else: + vector_details = { + "": { + "size": vectors.size, + "distance": str(vectors.distance) if vectors.distance else None, + "on_disk": vectors.on_disk, + } + } + topology = "hybrid" if sparse_vector_params else "dense" + + sparse_vectors = {} + if sparse_vector_params: + for sname, sconfig in sparse_vector_params.items(): + sparse_vectors[sname] = { + "modifier": str(sconfig.modifier) if sconfig.modifier else None, + } + + quant_config = config.quantization_config + quantization = None + if quant_config is not None: + qtype = type(quant_config).__name__ + if hasattr(quant_config, "scalar"): + quantization = "scalar" + elif hasattr(quant_config, "binary"): + quantization = "binary" + elif hasattr(quant_config, "product"): + quantization = "product" + elif hasattr(quant_config, "turbo"): + quantization = "turbo" + else: + quantization = qtype + + hnsw = { + "m": config.hnsw_config.m, + "ef_construct": config.hnsw_config.ef_construct, + } + if config.hnsw_config.full_scan_threshold is not None: + hnsw["full_scan_threshold"] = config.hnsw_config.full_scan_threshold + if config.hnsw_config.max_indexing_threads is not None: + hnsw["max_indexing_threads"] = config.hnsw_config.max_indexing_threads + if config.hnsw_config.on_disk is not None: + hnsw["on_disk"] = config.hnsw_config.on_disk + if config.hnsw_config.payload_m is not None: + hnsw["payload_m"] = config.hnsw_config.payload_m + if config.hnsw_config.inline_storage is not None: + hnsw["inline_storage"] = config.hnsw_config.inline_storage + + payload_indexes = {} + for field_name, idx_info in (info.payload_schema or {}).items(): + payload_indexes[field_name] = self._serialize_payload_index_info(idx_info) + + sharding = { + "shard_number": params.shard_number, + "replication_factor": params.replication_factor, + "write_consistency_factor": params.write_consistency_factor, + "read_fan_out_factor": params.read_fan_out_factor, + "read_fan_out_delay_ms": params.read_fan_out_delay_ms, + "on_disk_payload": params.on_disk_payload, + } + + data = { + "name": node.collection, + "status": str(info.status), + "points_count": info.points_count, + "indexed_vectors_count": info.indexed_vectors_count, + "segments_count": info.segments_count, + "topology": topology, + "vectors": vector_details, + "sparse_vectors": sparse_vectors or None, + "quantization": quantization, + "hnsw_config": hnsw, + "payload_schema": payload_indexes or None, + "sharding": sharding, + } + + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' diagnostics", + data=data, + ) + + async def _execute_scroll(self, node: ScrollStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + scroll_filter: Filter | None = None + if node.query_filter is not None: + scroll_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + try: + records, next_offset = await self._client.scroll( + collection_name=node.collection, + scroll_filter=scroll_filter, + limit=node.limit, + offset=node.after, + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SCROLL: {e}") from e + + points = [ + {"id": str(rec.id), "payload": rec.payload or {}} + for rec in records + ] + return ExecutionResult( + success=True, + message=f"Scrolled {len(points)} point(s) from '{node.collection}'", + data={"points": points, "next_offset": next_offset}, + ) + + async def _execute_select(self, node: SelectStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + records = await self._client.retrieve( + collection_name=node.collection, + ids=[node.point_id], + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SELECT: {e}") from e + + if not records: + return ExecutionResult( + success=True, + message=f"Point '{node.point_id}' not found in '{node.collection}'", + ) + + record = records[0] + return ExecutionResult( + success=True, + message=f"Retrieved point '{node.point_id}' from '{node.collection}'", + data={"id": str(record.id), "payload": record.payload or {}}, + ) + + async def _execute_search(self, node: SearchStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + validate_search_mmr_usage(node) + + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if node.group_by is not None: + return await self._execute_search_groups( + node, qdrant_filter, search_params, topology + ) + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + prefetch=build_hybrid_prefetches( + topology, + node, + dense_vector, + sparse_vector, + search_params, + _HYBRID_PREFETCH_MULTIPLIER, + ), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid)", + data=results, + ) + + if node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse)", + data=results, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + + try: + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points( + collection_name=node.collection, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s)", + data=results, + ) + + async def _execute_search_groups( + self, + node: SearchStmt, + qdrant_filter: Filter | None, + search_params: SearchParams | None, + topology: CollectionTopology, + ) -> ExecutionResult: + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + prefetch=build_hybrid_prefetches( + topology, + node, + dense_vector, + sparse_vector, + search_params, + _HYBRID_PREFETCH_MULTIPLIER, + ), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "hybrid, grouped" + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_obj = SparseEmbedder(sparse_model_name).query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "sparse, grouped" + else: + model_name = node.model or self._config.default_model + vector = Embedder(model_name).embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "grouped" + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during GROUP BY SEARCH: {e}") from e + + groups = [ + { + "group_id": str(g.id), + "hits": [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in g.hits + ], + } + for g in response.groups + ] + return ExecutionResult( + success=True, + message=f"Found {len(groups)} group(s) by '{node.group_by}' ({label})", + data=groups, + ) + + async def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + + search_params = self._build_search_params(node.with_clause) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during RECOMMEND: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + return ExecutionResult( + success=True, + message=f"Found {len(results)} recommendation(s)", + data=results, + ) + + async def _execute_delete(self, node: DeleteStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + if node.query_filter is not None: + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ), + ) + return ExecutionResult( + success=True, + message=f"Deleted points from '{node.collection}' by filter", + ) + + from qdrant_client.models import PointIdsList + + if node.point_id is None: + raise QQLRuntimeError("DELETE requires either a point id or a filter") + + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=PointIdsList(points=[node.point_id]), + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during DELETE: {e}") from e + + return ExecutionResult( + success=True, + message=f"Deleted point '{node.point_id}' from '{node.collection}'", + ) + + async def _execute_update_vector(self, node: UpdateVectorStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + vector_name = topology.dense_payload_name(node.vector_name) + vector_struct: Any = ( + {vector_name: list(node.vector)} if vector_name else list(node.vector) + ) + try: + await self._client.update_vectors( + collection_name=node.collection, + points=[PointVectors(id=node.point_id, vector=vector_struct)], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE VECTOR: {e}") from e + return ExecutionResult( + success=True, + message=f"Updated vector for point [{node.point_id}] in '{node.collection}'", + data=[], + ) + + async def _execute_update_payload(self, node: UpdatePayloadStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + try: + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=qdrant_filter, + wait=True, + ) + return ExecutionResult( + success=True, + message=f"Payload updated in '{node.collection}' (filter-based)", + data=[], + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=[node.point_id], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE PAYLOAD: {e}") from e + return ExecutionResult( + success=True, + message=f"Payload updated for point [{node.point_id}] in '{node.collection}'", + data=[], + ) diff --git a/src/qql/connection.py b/src/qql/connection.py index fe72824..64e86c8 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -52,6 +52,8 @@ def __init__( secret: str | None = None, default_model: str | None = None, verify: bool | str = True, + prefer_grpc: bool = False, + grpc_port: int = 6334, ) -> None: """Create a connection to a Qdrant instance. @@ -64,6 +66,8 @@ def __init__( verify: SSL certificate verification. Set to ``False`` to skip verification for self-signed/internal certificates, or pass a path to a custom CA bundle (default: ``True``). + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). """ from qdrant_client import QdrantClient @@ -73,7 +77,11 @@ def __init__( default_model=default_model or DEFAULT_MODEL, verify=verify, ) - self._client = QdrantClient(url=url, api_key=secret, verify=verify) + client_kwargs = {"url": url, "api_key": secret, "verify": verify} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = QdrantClient(**client_kwargs) self._executor = Executor(self._client, self._config) # ── Public API ──────────────────────────────────────────────────────── diff --git a/src/qql/executor.py b/src/qql/executor.py index 975947d..f767b1a 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -import uuid from dataclasses import dataclass from typing import Any @@ -15,40 +14,23 @@ CompressionRatio, Distance, Disabled, - FieldCondition, Filter, - Fusion, FusionQuery, - HasIdCondition, HnswConfigDiff, - IsEmptyCondition, - IsNullCondition, KeywordIndexParams, KeywordIndexType, Language, LookupLocation, - MatchAny, - MatchExcept, - MatchPhrase, - MatchText, - MatchTextAny, - MatchValue, - Mmr, Modifier, - NearestQuery, OptimizersConfigDiff, - PayloadField, PayloadSchemaType, PointStruct, PointVectors, - Prefetch, ProductQuantization, ProductQuantizationConfig, QuantizationSearchParams, - Range, RecommendInput, RecommendQuery, - RecommendStrategy, ScalarQuantization, ScalarQuantizationConfig, ScalarType, @@ -72,28 +54,14 @@ from .ast_nodes import ( ASTNode, AlterCollectionStmt, - AndExpr, - BetweenExpr, CollectionConfig, - CompareExpr, CreateCollectionStmt, CreateIndexStmt, DeleteStmt, DropCollectionStmt, FilterExpr, - InExpr, InsertBulkStmt, InsertStmt, - IsEmptyExpr, - IsNotEmptyExpr, - IsNotNullExpr, - IsNullExpr, - MatchAnyExpr, - MatchPhraseExpr, - MatchTextExpr, - NotExpr, - NotInExpr, - OrExpr, QuantizationUpdate, QuantizationConfig, QuantizationType, @@ -110,6 +78,21 @@ from .config import QQLConfig from .embedder import CrossEncoderEmbedder, Embedder, SparseEmbedder from .exceptions import QQLRuntimeError +from .utils import ( + build_dense_point_vector, + build_dense_query, + build_hybrid_prefetches, + build_qdrant_filter, + collection_topology_kwargs, + exclude_ids_from_filter, + extract_point_id_and_payload, + has_mmr, + is_grpc_not_found_error, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, + wrap_as_filter, +) _RERANK_FETCH_MULTIPLIER = 4 _HYBRID_PREFETCH_MULTIPLIER = 4 @@ -255,6 +238,14 @@ def _fetch_collection_info(self, name: str): raise QQLRuntimeError( f"Qdrant error fetching collection '{name}': {e}" ) from e + except ValueError as e: + if f"Collection {name} not found" in str(e): + return None + raise + except Exception as e: + if is_grpc_not_found_error(e): + return None + raise def _topology_from_collection_info(self, info: Any) -> CollectionTopology: """Parse a CollectionInfo object into a :class:`CollectionTopology`. @@ -265,40 +256,7 @@ def _topology_from_collection_info(self, info: Any) -> CollectionTopology: params = info.config.params vectors = params.vectors # type: ignore[union-attr] sparse_vectors = params.sparse_vectors or {} - - if isinstance(vectors, dict): - dense_names = tuple(vectors.keys()) - dense_sizes: tuple[tuple[str, int], ...] = tuple( - (k, v.size) - for k, v in vectors.items() - if getattr(v, "size", None) is not None - ) - has_unnamed_dense = False - is_named_dense = True - elif vectors is None: - dense_names = () - dense_sizes = () - has_unnamed_dense = False - is_named_dense = False - else: - # Single unnamed dense vector - dense_names = () - unnamed_size = getattr(vectors, "size", None) - dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () - has_unnamed_dense = True - is_named_dense = False - - sparse_names = ( - tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () - ) - return CollectionTopology( - exists=True, - is_named_dense=is_named_dense, - has_unnamed_dense=has_unnamed_dense, - dense_names=dense_names, - sparse_names=sparse_names, - dense_sizes=dense_sizes, - ) + return CollectionTopology(**collection_topology_kwargs(vectors, sparse_vectors)) def _resolve_topology(self, name: str) -> CollectionTopology: """Return the topology for *name* using exactly one Qdrant API call. @@ -362,7 +320,7 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: }, ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( collection_name=node.collection, @@ -392,9 +350,14 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: self._ensure_collection( node.collection, len(vector), topology, node.dense_vector ) - point_vector = self._build_dense_point_vector(topology, vector, node.dense_vector) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( @@ -443,7 +406,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: first_dense_vector: list[float] | None = None points: list[PointStruct] = [] for vals in node.values_list: - point_id, payload = self._extract_point_id_and_payload(vals) + point_id, payload = extract_point_id_and_payload(vals) dense_vector = dense_embedder.embed(vals["text"]) if first_dense_vector is None: first_dense_vector = dense_vector @@ -483,6 +446,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, ) # ── Standard dense-only bulk INSERT ─────────────────────────────── @@ -495,9 +459,12 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: vector = embedder.embed(vals["text"]) if first_vector is None: first_vector = vector - point_id, payload = self._extract_point_id_and_payload(vals) - point_vector = self._build_dense_point_vector( - topology, vector, node.dense_vector + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), ) points.append( PointStruct(id=point_id, vector=point_vector, payload=payload) @@ -520,6 +487,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, ) def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: @@ -889,7 +857,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: ) search_params = self._build_search_params(node.with_clause) - self._validate_search_mmr_usage(node) + validate_search_mmr_usage(node) # When reranking is requested, fetch more candidates so the reranker has # enough material to reorder; only `node.limit` results are returned. @@ -919,21 +887,15 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: try: response = self._client.query_points( collection_name=node.collection, - prefetch=[ - Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), - using=topology.dense_using(node.dense_vector), - limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, - params=search_params, - ), - Prefetch( - query=sparse_vector, - using=topology.sparse_using(node.sparse_vector), - limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, - params=search_params, - ), - ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + prefetch=build_hybrid_prefetches( + topology, + node, + dense_vector, + sparse_vector, + search_params, + _HYBRID_PREFETCH_MULTIPLIER, + ), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=fetch_limit, offset=node.offset or None, query_filter=qdrant_filter, @@ -1015,7 +977,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: query_using = topology.dense_using(node.dense_vector) response = self._client.query_points( collection_name=node.collection, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=fetch_limit, offset=node.offset or None, @@ -1066,15 +1028,6 @@ def _build_hybrid_vectors( ) return dense_vector, sparse_vector - def _resolve_hybrid_fusion(self, fusion: str | None) -> Fusion: - if fusion is None or fusion == "rrf": - return Fusion.RRF - if fusion == "dbsf": - return Fusion.DBSF - raise QQLRuntimeError( - f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" - ) - def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: if not self._client.collection_exists(node.collection): raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") @@ -1084,7 +1037,7 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: qdrant_filter = self._wrap_as_filter( self._build_qdrant_filter(node.query_filter) ) - qdrant_filter = self._exclude_ids_from_filter( + qdrant_filter = exclude_ids_from_filter( qdrant_filter, [*node.positive_ids, *node.negative_ids], ) @@ -1092,11 +1045,11 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: recommend_input = RecommendInput( positive=list(node.positive_ids), negative=list(node.negative_ids) or None, - strategy=self._parse_recommend_strategy(node.strategy), + strategy=parse_recommend_strategy(node.strategy), ) search_params = self._build_search_params(node.with_clause) - if self._has_mmr(node.with_clause): + if has_mmr(node.with_clause): raise QQLRuntimeError("MMR is supported only for SEARCH statements") lookup_from: LookupLocation | None = None @@ -1501,107 +1454,6 @@ def _describe_quantization_update( return f", quantization={quantization.config.type.value}" return "" - def _has_mmr(self, with_clause: SearchWith | None) -> bool: - return with_clause is not None and ( - with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None - ) - - def _validate_search_mmr_usage(self, node: SearchStmt) -> None: - if not self._has_mmr(node.with_clause): - return - if node.sparse_only: - raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") - - def _build_dense_query( - self, - vector: list[float], - with_clause: SearchWith | None, - ) -> list[float] | NearestQuery: - if not self._has_mmr(with_clause): - return vector - return NearestQuery( - nearest=vector, - mmr=Mmr( - diversity=with_clause.mmr_diversity, - candidates_limit=with_clause.mmr_candidates, - ), - ) - - def _parse_recommend_strategy( - self, strategy: str | None - ) -> RecommendStrategy | None: - if strategy is None: - return None - try: - return RecommendStrategy(strategy) - except ValueError as e: - raise QQLRuntimeError( - "Unknown recommend strategy " - f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" - ) from e - - def _exclude_ids_from_filter( - self, - query_filter: Filter | None, - point_ids: list[str | int], - ) -> Filter | None: - if not point_ids: - return query_filter - - exclude_condition = HasIdCondition(has_id=point_ids) - if query_filter is None: - return Filter(must_not=[exclude_condition]) - - return Filter( - must=list(query_filter.must or []), - should=list(query_filter.should or []), - must_not=[*(query_filter.must_not or []), exclude_condition], - min_should=query_filter.min_should, - ) - - def _extract_point_id_and_payload( - self, values: dict[str, Any] - ) -> tuple[str | int, dict[str, Any]]: - payload = dict(values) - if "id" not in payload: - return str(uuid.uuid4()), payload - - point_id = payload.pop("id") - if isinstance(point_id, bool): - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - if isinstance(point_id, int): - if point_id < 0: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - return point_id, payload - if isinstance(point_id, str): - try: - uuid.UUID(point_id) - except ValueError as e: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) from e - return point_id, payload - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - - def _build_dense_point_vector( - self, - topology: CollectionTopology, - vector: list[float], - explicit_vector: str | None, - ) -> list[float] | dict[str, list[float]]: - if not topology.exists: - return {explicit_vector or self._default_dense_vector_name(): vector} - vector_name = topology.dense_payload_name(explicit_vector) - if vector_name is None: - return vector - return {vector_name: vector} - def _apply_reranking( self, query: str, @@ -1680,21 +1532,15 @@ def _execute_search_groups( response = self._client.query_points_groups( collection_name=node.collection, group_by=node.group_by, - prefetch=[ - Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), - using=topology.dense_using(node.dense_vector), - limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, - params=search_params, - ), - Prefetch( - query=sparse_vector, - using=topology.sparse_using(node.sparse_vector), - limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, - params=search_params, - ), - ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + prefetch=build_hybrid_prefetches( + topology, + node, + dense_vector, + sparse_vector, + search_params, + _HYBRID_PREFETCH_MULTIPLIER, + ), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=node.limit, group_size=node.group_size, query_filter=qdrant_filter, @@ -1729,7 +1575,7 @@ def _execute_search_groups( response = self._client.query_points_groups( collection_name=node.collection, group_by=node.group_by, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=node.limit, group_size=node.group_size, @@ -1823,90 +1669,11 @@ def _build_qdrant_filter(self, expr: FilterExpr) -> Any: Returns one of: Filter, FieldCondition, IsNullCondition, IsEmptyCondition. Use _wrap_as_filter() to guarantee the top-level result is a Filter. """ - # ── Logical combinators ─────────────────────────────────────────── - if isinstance(expr, AndExpr): - return Filter(must=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, OrExpr): - return Filter(should=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, NotExpr): - return Filter(must_not=[self._build_qdrant_filter(expr.operand)]) - - # ── Comparison ──────────────────────────────────────────────────── - if isinstance(expr, CompareExpr): - if expr.op == "=": - return FieldCondition( - key=expr.field, match=MatchValue(value=expr.value) - ) - if expr.op == "!=": - return Filter( - must_not=[ - FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) - ] - ) - _range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] - return FieldCondition( - key=expr.field, range=Range(**{_range_key: expr.value}) - ) - - # ── BETWEEN ─────────────────────────────────────────────────────── - if isinstance(expr, BetweenExpr): - return FieldCondition( - key=expr.field, range=Range(gte=expr.low, lte=expr.high) - ) - - # ── IN / NOT IN ─────────────────────────────────────────────────── - if isinstance(expr, InExpr): - return FieldCondition( - key=expr.field, match=MatchAny(any=list(expr.values)) - ) - - if isinstance(expr, NotInExpr): - return FieldCondition( - key=expr.field, - match=MatchExcept(**{"except": list(expr.values)}), - ) - - # ── IS NULL / IS NOT NULL ───────────────────────────────────────── - if isinstance(expr, IsNullExpr): - return IsNullCondition(is_null=PayloadField(key=expr.field)) - - if isinstance(expr, IsNotNullExpr): - return Filter( - must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))] - ) - - # ── IS EMPTY / IS NOT EMPTY ─────────────────────────────────────── - if isinstance(expr, IsEmptyExpr): - return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) - - if isinstance(expr, IsNotEmptyExpr): - return Filter( - must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))] - ) - - # ── Full-text MATCH ─────────────────────────────────────────────── - if isinstance(expr, MatchTextExpr): - return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) - - if isinstance(expr, MatchAnyExpr): - return FieldCondition( - key=expr.field, match=MatchTextAny(text_any=expr.text) - ) - - if isinstance(expr, MatchPhraseExpr): - return FieldCondition( - key=expr.field, match=MatchPhrase(phrase=expr.text) - ) - - raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + return build_qdrant_filter(expr) def _wrap_as_filter(self, qdrant_expr: Any) -> Filter: """Ensure the top-level expression is a Filter (required by query_points).""" - if isinstance(qdrant_expr, Filter): - return qdrant_expr - return Filter(must=[qdrant_expr]) + return wrap_as_filter(qdrant_expr) # ── Collection helpers ──────────────────────────────────────────────── diff --git a/src/qql/utils.py b/src/qql/utils.py new file mode 100644 index 0000000..d1a3e23 --- /dev/null +++ b/src/qql/utils.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Any + +from grpc import RpcError, StatusCode +from qdrant_client.models import ( + FieldCondition, + Filter, + Fusion, + HasIdCondition, + IsEmptyCondition, + IsNullCondition, + MatchAny, + MatchExcept, + MatchPhrase, + MatchText, + MatchTextAny, + MatchValue, + Mmr, + NearestQuery, + PayloadField, + Prefetch, + Range, + RecommendStrategy, +) + +from .ast_nodes import ( + AndExpr, + BetweenExpr, + CompareExpr, + FilterExpr, + InExpr, + IsEmptyExpr, + IsNotEmptyExpr, + IsNotNullExpr, + IsNullExpr, + MatchAnyExpr, + MatchPhraseExpr, + MatchTextExpr, + NotExpr, + NotInExpr, + OrExpr, + SearchStmt, + SearchWith, +) +from .exceptions import QQLRuntimeError, QQLSyntaxError +from .lexer import TokenKind + +_HYBRID_FUSION_VALUES = {"rrf", "dbsf"} + + +def is_grpc_not_found_error(error: BaseException) -> bool: + return isinstance(error, RpcError) and error.code() == StatusCode.NOT_FOUND + + +@dataclass(frozen=True) +class SearchUsingOptions: + model: str | None = None + hybrid: bool = False + fusion: str | None = None + sparse_only: bool = False + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + +@dataclass(frozen=True) +class SearchGroupByOptions: + group_by: str | None = None + group_size: int = 3 + + +def collection_topology_kwargs(vectors: Any, sparse_vectors: Any) -> dict[str, Any]: + if isinstance(vectors, dict): + dense_names = tuple(vectors.keys()) + dense_sizes = tuple( + (k, v.size) + for k, v in vectors.items() + if getattr(v, "size", None) is not None + ) + has_unnamed_dense = False + is_named_dense = True + elif vectors is None: + dense_names = () + dense_sizes = () + has_unnamed_dense = False + is_named_dense = False + else: + dense_names = () + unnamed_size = getattr(vectors, "size", None) + dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () + has_unnamed_dense = True + is_named_dense = False + + sparse_names = tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () + return { + "exists": True, + "is_named_dense": is_named_dense, + "has_unnamed_dense": has_unnamed_dense, + "dense_names": dense_names, + "sparse_names": sparse_names, + "dense_sizes": dense_sizes, + } + + +def build_qdrant_filter(expr: FilterExpr) -> Any: + if isinstance(expr, AndExpr): + return Filter(must=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, OrExpr): + return Filter(should=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, NotExpr): + return Filter(must_not=[build_qdrant_filter(expr.operand)]) + if isinstance(expr, CompareExpr): + if expr.op == "=": + return FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + if expr.op == "!=": + return Filter( + must_not=[ + FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + ] + ) + range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] + return FieldCondition(key=expr.field, range=Range(**{range_key: expr.value})) + if isinstance(expr, BetweenExpr): + return FieldCondition(key=expr.field, range=Range(gte=expr.low, lte=expr.high)) + if isinstance(expr, InExpr): + return FieldCondition(key=expr.field, match=MatchAny(any=list(expr.values))) + if isinstance(expr, NotInExpr): + return FieldCondition( + key=expr.field, + match=MatchExcept(**{"except": list(expr.values)}), + ) + if isinstance(expr, IsNullExpr): + return IsNullCondition(is_null=PayloadField(key=expr.field)) + if isinstance(expr, IsNotNullExpr): + return Filter(must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))]) + if isinstance(expr, IsEmptyExpr): + return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) + if isinstance(expr, IsNotEmptyExpr): + return Filter(must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))]) + if isinstance(expr, MatchTextExpr): + return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) + if isinstance(expr, MatchAnyExpr): + return FieldCondition(key=expr.field, match=MatchTextAny(text_any=expr.text)) + if isinstance(expr, MatchPhraseExpr): + return FieldCondition(key=expr.field, match=MatchPhrase(phrase=expr.text)) + raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + + +def wrap_as_filter(qdrant_expr: Any) -> Filter: + if isinstance(qdrant_expr, Filter): + return qdrant_expr + return Filter(must=[qdrant_expr]) + + +def resolve_hybrid_fusion(fusion: str | None) -> Fusion: + if fusion is None or fusion == "rrf": + return Fusion.RRF + if fusion == "dbsf": + return Fusion.DBSF + raise QQLRuntimeError( + f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" + ) + + +def has_mmr(with_clause: SearchWith | None) -> bool: + return with_clause is not None and ( + with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None + ) + + +def validate_search_mmr_usage(node: SearchStmt) -> None: + if not has_mmr(node.with_clause): + return + if node.sparse_only: + raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") + + +def build_dense_query( + vector: list[float], + with_clause: SearchWith | None, +) -> list[float] | NearestQuery: + if not has_mmr(with_clause): + return vector + return NearestQuery( + nearest=vector, + mmr=Mmr( + diversity=with_clause.mmr_diversity, + candidates_limit=with_clause.mmr_candidates, + ), + ) + + +def build_hybrid_prefetches( + topology: Any, + node: SearchStmt, + dense_vector: list[float], + sparse_vector: Any, + search_params: Any, + prefetch_multiplier: int, +) -> list[Prefetch]: + prefetch_limit = node.limit * prefetch_multiplier + return [ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=prefetch_limit, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=prefetch_limit, + params=search_params, + ), + ] + + +def parse_recommend_strategy(strategy: str | None) -> RecommendStrategy | None: + if strategy is None: + return None + try: + return RecommendStrategy(strategy) + except ValueError as e: + raise QQLRuntimeError( + "Unknown recommend strategy " + f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" + ) from e + + +def exclude_ids_from_filter( + query_filter: Filter | None, + point_ids: list[str | int], +) -> Filter | None: + if not point_ids: + return query_filter + + exclude_condition = HasIdCondition(has_id=point_ids) + if query_filter is None: + return Filter(must_not=[exclude_condition]) + + return Filter( + must=list(query_filter.must or []), + should=list(query_filter.should or []), + must_not=[*(query_filter.must_not or []), exclude_condition], + min_should=query_filter.min_should, + ) + + +def extract_point_id_and_payload( + values: dict[str, Any], +) -> tuple[str | int, dict[str, Any]]: + payload = dict(values) + if "id" not in payload: + return str(uuid.uuid4()), payload + + point_id = payload.pop("id") + if isinstance(point_id, bool): + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + if isinstance(point_id, int): + if point_id < 0: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + return point_id, payload + if isinstance(point_id, str): + try: + uuid.UUID(point_id) + except ValueError as e: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) from e + return point_id, payload + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + + +def build_dense_point_vector( + topology: Any, + vector: list[float], + explicit_vector: str | None, + default_dense_vector_name: str, +) -> list[float] | dict[str, list[float]]: + if not topology.exists: + return {explicit_vector or default_dense_vector_name: vector} + vector_name = topology.dense_payload_name(explicit_vector) + if vector_name is None: + return vector + return {vector_name: vector} + + +def merge_search_with(base: SearchWith | None, override: SearchWith) -> SearchWith: + if base is None: + return override + return SearchWith( + hnsw_ef=override.hnsw_ef or base.hnsw_ef, + exact=override.exact or base.exact, + acorn=override.acorn or base.acorn, + indexed_only=override.indexed_only or base.indexed_only, + quantization=override.quantization or base.quantization, + mmr_diversity=( + override.mmr_diversity + if override.mmr_diversity is not None + else base.mmr_diversity + ), + mmr_candidates=override.mmr_candidates or base.mmr_candidates, + ) + + +def parse_search_lookup(parser: Any) -> tuple[str, str | None] | None: + if parser._peek().kind != TokenKind.LOOKUP: + return None + parser._advance() + parser._expect(TokenKind.FROM) + lookup_collection = parser._parse_identifier() + lookup_vector: str | None = None + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + lookup_vector = parser._expect(TokenKind.STRING).value + return lookup_collection, lookup_vector + + +def parse_search_using(parser: Any) -> SearchUsingOptions: + if parser._peek().kind != TokenKind.USING: + return SearchUsingOptions() + + parser._advance() + if parser._peek().kind == TokenKind.HYBRID: + return _parse_hybrid_using(parser) + if parser._peek().kind == TokenKind.SPARSE: + return _parse_sparse_using(parser) + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + return SearchUsingOptions(dense_vector=parser._expect(TokenKind.STRING).value) + + parser._expect(TokenKind.MODEL) + return SearchUsingOptions(model=parser._expect(TokenKind.STRING).value) + + +def _parse_hybrid_using(parser: Any) -> SearchUsingOptions: + parser._advance() + model: str | None = None + fusion: str | None = None + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + while parser._peek().kind in (TokenKind.FUSION, TokenKind.DENSE, TokenKind.SPARSE): + sub = parser._advance() + if sub.kind == TokenKind.FUSION: + value_tok = parser._expect(TokenKind.STRING) + fusion = value_tok.value.lower() + if fusion not in _HYBRID_FUSION_VALUES: + raise QQLSyntaxError( + f"Unsupported hybrid fusion '{value_tok.value}'; expected 'rrf' or 'dbsf'", + value_tok.pos, + ) + continue + if parser._peek().kind == TokenKind.MODEL: + parser._advance() + parsed_model = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + model = parsed_model + else: + sparse_model = parsed_model + continue + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + name = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + dense_vector = name + else: + sparse_vector = name + continue + raise QQLSyntaxError( + "Expected MODEL or VECTOR after DENSE/SPARSE in USING HYBRID", + parser._peek().pos, + ) + + return SearchUsingOptions( + model=model, + hybrid=True, + fusion=fusion, + sparse_model=sparse_model, + dense_vector=dense_vector, + sparse_vector=sparse_vector, + ) + + +def _parse_sparse_using(parser: Any) -> SearchUsingOptions: + parser._advance() + sparse_model: str | None = None + sparse_vector: str | None = None + while parser._peek().kind in (TokenKind.MODEL, TokenKind.VECTOR): + sub = parser._advance() + if sub.kind == TokenKind.MODEL: + sparse_model = parser._expect(TokenKind.STRING).value + else: + sparse_vector = parser._expect(TokenKind.STRING).value + return SearchUsingOptions( + sparse_only=True, + sparse_model=sparse_model, + sparse_vector=sparse_vector, + ) + + +def parse_search_with(parser: Any, with_clause: SearchWith | None) -> SearchWith | None: + if parser._peek().kind == TokenKind.EXACT: + parser._advance() + with_clause = merge_search_with(with_clause, SearchWith(exact=True)) + + if parser._peek().kind == TokenKind.WITH: + parser._advance() + with_clause = merge_search_with(with_clause, parser._parse_with_clause()) + + return with_clause + + +def parse_search_group_by( + parser: Any, + offset: int, + rerank: bool, +) -> SearchGroupByOptions: + if parser._peek().kind != TokenKind.GROUP: + return SearchGroupByOptions() + + if offset > 0: + raise QQLSyntaxError("OFFSET cannot be used with GROUP BY", parser._peek().pos) + parser._advance() + parser._expect(TokenKind.BY) + group_by = parser._parse_field_path() + if rerank: + raise QQLSyntaxError( + "GROUP BY and RERANK cannot be combined in the same SEARCH statement", + parser._peek().pos, + ) + + group_size = 3 + if parser._peek().kind == TokenKind.GROUP_SIZE: + parser._advance() + group_size_tok = parser._peek() + group_size = int(parser._expect(TokenKind.INTEGER).value) + if group_size <= 0: + raise QQLSyntaxError( + f"GROUP_SIZE must be a positive integer, got {group_size}", + group_size_tok.pos, + ) + return SearchGroupByOptions(group_by=group_by, group_size=group_size) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py new file mode 100644 index 0000000..8726fb2 --- /dev/null +++ b/tests/test_async_connection.py @@ -0,0 +1,301 @@ +"""Tests for the AsyncConnection class. + +All tests mock AsyncQdrantClient so no live Qdrant instance is required. +""" +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock + +from qql import ( + AsyncConnection, + QQLConfig, + AsyncExecutor, + ExecutionResult, +) +from qql.exceptions import QQLSyntaxError + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +# ── TestAsyncConnectionInit ─────────────────────────────────────────────────── + +class TestAsyncConnectionInit: + """AsyncConnection.__init__ stores config and wires up the async executor.""" + + def test_default_url_and_no_secret(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection() + assert conn.config.url == "http://localhost:6333" + assert conn.config.secret is None + + def test_custom_url_and_secret_passed_to_async_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qql.async_connection.AsyncQdrantClient") + AsyncConnection("https://cloud.example.io", secret="s3cr3t") + mock_client_cls.assert_called_once_with( + url="https://cloud.example.io", api_key="s3cr3t", verify=True + ) + + def test_grpc_options_passed_to_async_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qql.async_connection.AsyncQdrantClient") + AsyncConnection( + "http://localhost:6333", + verify=False, + prefer_grpc=True, + grpc_port=9999, + ) + mock_client_cls.assert_called_once_with( + url="http://localhost:6333", + api_key=None, + verify=False, + prefer_grpc=True, + grpc_port=9999, + ) + + def test_custom_default_model_stored_in_config(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333", default_model="BAAI/bge-small-en-v1.5") + assert conn.config.default_model == "BAAI/bge-small-en-v1.5" + + def test_config_and_executor_properties_return_correct_types(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + assert isinstance(conn.config, QQLConfig) + assert isinstance(conn.executor, AsyncExecutor) + + +# ── TestAsyncConnectionRunQuery ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionRunQuery: + """AsyncConnection.run_query() pipes through the Lexer → Parser → AsyncExecutor.""" + + async def test_run_query_calls_executor_execute(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + mock_executor.execute.assert_called_once() + + async def test_executor_instance_reused_across_queries(self, mocker): + """AsyncExecutor() is constructed once; run_query() never re-instantiates it.""" + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + executor_cls = mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + + # AsyncExecutor constructor called exactly once, not once per query + executor_cls.assert_called_once() + # But execute() called three times + assert mock_executor.execute.call_count == 3 + + async def test_invalid_query_raises_qql_syntax_error(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + with pytest.raises(QQLSyntaxError): + await conn.run_query("TOTALLY INVALID QUERY GIBBERISH") + + async def test_run_query_returns_execution_result(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="1 collection(s) found", data=["docs"] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + result = await conn.run_query("SHOW COLLECTIONS") + assert isinstance(result, ExecutionResult) + assert result.success is True + + +# ── TestAsyncConnectionLifecycle ─────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionLifecycle: + """AsyncConnection.close() and the async context-manager protocol.""" + + async def test_close_calls_client_close(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + conn = AsyncConnection("http://localhost:6333") + await conn.close() + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exit(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + async with AsyncConnection("http://localhost:6333") as conn: + assert conn._client is mock_client + + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exception(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + with pytest.raises(RuntimeError, match="oops"): + async with AsyncConnection("http://localhost:6333"): + raise RuntimeError("oops") + + mock_client.close.assert_called_once() + + +# ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestArchitecturalGapsClosed: + """Rigorous tests covering async execution and collection creation races.""" + + async def test_async_topology_uses_single_get_collection_call(self, mocker): + """Async topology resolution should mirror sync executor and avoid a separate exists call.""" + from qdrant_client.models import Distance, VectorParams + + mock_client = AsyncMock() + mock_info = mocker.MagicMock() + mock_info.config.params.vectors = VectorParams(size=2, distance=Distance.COSINE) + mock_info.config.params.sparse_vectors = None + mock_client.get_collection.return_value = mock_info + + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + topology = await executor._resolve_topology("docs") + + assert topology.exists is True + assert topology.has_unnamed_dense is True + mock_client.get_collection.assert_called_once_with("docs") + mock_client.collection_exists.assert_not_called() + + async def test_async_insert_uses_raced_existing_unnamed_topology(self, mocker): + """If another creator wins the race with an unnamed vector, send a plain vector payload.""" + mock_client = AsyncMock() + mock_client.upsert.return_value = None + + from qql.executor import CollectionTopology + + topology_sequence = [ + CollectionTopology(exists=False, is_named_dense=False), + CollectionTopology( + exists=True, + is_named_dense=False, + has_unnamed_dense=True, + dense_sizes=(("", 2),), + ), + ] + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + mocker.patch( + "qql.async_executor.AsyncExecutor._resolve_topology", + side_effect=topology_sequence, + ) + create = mocker.patch( + "qql.async_executor.AsyncExecutor._create_collection_and_wait", + new_callable=AsyncMock, + ) + + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser( + Lexer().tokenize( + "INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 1}" + ) + ).parse() + result = await executor.execute(node) + + assert result.success is True + create.assert_not_called() + point = mock_client.upsert.call_args.kwargs["points"][0] + assert point.vector == [0.1, 0.2] + + async def test_async_search_embeds_once(self, mocker): + """AsyncExecutor keeps the hot path direct and avoids threadpool overhead for cached embeddings.""" + mock_client = AsyncMock() + mock_client.collection_exists.return_value = True + + # Mock embedders to track how they are called + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mock_embed = mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser(Lexer().tokenize("SEARCH docs SIMILAR TO 'neurology' LIMIT 5")).parse() + + result = await executor.execute(node) + assert result.success is True + mock_embed.assert_called_once_with("neurology") + + async def test_race_condition_collection_creation(self, mocker): + """Concurrent inserts into a non-existent collection serialize creation to avoid Qdrant conflicts.""" + import asyncio + mock_client = AsyncMock() + + # Mock get_collection to return a mock config with matching vector size + mock_info = mocker.MagicMock() + mock_info.config.params.vectors.size = 2 + mock_client.get_collection.return_value = mock_info + + from qql.executor import CollectionTopology + # Mock resolve_topology sequence using real CollectionTopology objects + topology_sequence = [ + CollectionTopology(exists=False, is_named_dense=False), # First insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Second insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Inside lock for first insert + CollectionTopology(exists=True, is_named_dense=False, has_unnamed_dense=True, dense_names=(), sparse_names=()), # Inside lock for second insert + ] + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + # Override _resolve_topology to yield the sequence + calls = 0 + async def mock_resolve(*args, **kwargs): + nonlocal calls + val = topology_sequence[calls] + calls += 1 + return val + + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", side_effect=mock_resolve) + mocker.patch("qql.async_executor.AsyncExecutor._create_collection_and_wait", return_value=None) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + insert_node_1 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 1}")).parse() + insert_node_2 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'b', 'id': 2}")).parse() + + # Fire both concurrently + res1, res2 = await asyncio.gather( + executor.execute(insert_node_1), + executor.execute(insert_node_2), + ) + + assert res1.success is True + assert res2.success is True + # Verify that _create_collection_and_wait was called exactly once despite concurrency! + executor._create_collection_and_wait.assert_called_once() diff --git a/tests/test_connection.py b/tests/test_connection.py index 2278002..ebcce74 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -44,6 +44,17 @@ def test_custom_ca_bundle_passed_to_qdrant_client(self, mocker): ) assert conn.config.verify == "/etc/ssl/internal-ca.pem" + def test_grpc_options_passed_to_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qdrant_client.QdrantClient") + Connection("http://localhost:6333", prefer_grpc=True, grpc_port=9999) + mock_client_cls.assert_called_once_with( + url="http://localhost:6333", + api_key=None, + verify=True, + prefer_grpc=True, + grpc_port=9999, + ) + def test_custom_default_model_stored_in_config(self, mocker): mocker.patch("qdrant_client.QdrantClient") conn = Connection("http://localhost:6333", default_model="BAAI/bge-small-en-v1.5") @@ -174,7 +185,12 @@ def test_run_query_delegates_to_connection(self, mocker): conn_cls = mocker.patch("qql.Connection", return_value=conn_instance) run_query("SHOW COLLECTIONS", url="http://localhost:6333") conn_cls.assert_called_once_with( - url="http://localhost:6333", secret=None, default_model=None, verify=True + url="http://localhost:6333", + secret=None, + default_model=None, + verify=True, + prefer_grpc=False, + grpc_port=6334, ) conn_instance.run_query.assert_called_once_with("SHOW COLLECTIONS") From 712b0bf3823b12f1f8d6e40aeb60ca0b626a2592 Mon Sep 17 00:00:00 2001 From: Srimon Date: Sun, 31 May 2026 07:30:43 +0530 Subject: [PATCH 2/3] feat: enhance error handling in Executor and improve async connection tests --- src/qql/connection.py | 8 +++++++- src/qql/executor.py | 8 ++++++-- tests/test_async_connection.py | 4 ++-- tests/test_executor.py | 14 ++++++++++++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/qql/connection.py b/src/qql/connection.py index 64e86c8..12e8691 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from .config import DEFAULT_MODEL, QQLConfig from .executor import Executor, ExecutionResult from .lexer import Lexer @@ -77,7 +79,11 @@ def __init__( default_model=default_model or DEFAULT_MODEL, verify=verify, ) - client_kwargs = {"url": url, "api_key": secret, "verify": verify} + client_kwargs: dict[str, Any] = { + "url": url, + "api_key": secret, + "verify": verify, + } if prefer_grpc: client_kwargs["prefer_grpc"] = True client_kwargs["grpc_port"] = grpc_port diff --git a/src/qql/executor.py b/src/qql/executor.py index f767b1a..20a206a 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -241,11 +241,15 @@ def _fetch_collection_info(self, name: str): except ValueError as e: if f"Collection {name} not found" in str(e): return None - raise + raise QQLRuntimeError( + f"Qdrant error fetching collection '{name}': {e}" + ) from e except Exception as e: if is_grpc_not_found_error(e): return None - raise + raise QQLRuntimeError( + f"Qdrant error fetching collection '{name}': {e}" + ) from e def _topology_from_collection_info(self, info: Any) -> CollectionTopology: """Parse a CollectionInfo object into a :class:`CollectionTopology`. diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 8726fb2..c739364 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -83,7 +83,7 @@ async def test_run_query_calls_executor_execute(self, mocker): conn = AsyncConnection("http://localhost:6333") await conn.run_query("SHOW COLLECTIONS") - mock_executor.execute.assert_called_once() + mock_executor.execute.assert_awaited_once() async def test_executor_instance_reused_across_queries(self, mocker): """AsyncExecutor() is constructed once; run_query() never re-instantiates it.""" @@ -102,7 +102,7 @@ async def test_executor_instance_reused_across_queries(self, mocker): # AsyncExecutor constructor called exactly once, not once per query executor_cls.assert_called_once() # But execute() called three times - assert mock_executor.execute.call_count == 3 + assert mock_executor.execute.await_count == 3 async def test_invalid_query_raises_qql_syntax_error(self, mocker): mocker.patch("qql.async_connection.AsyncQdrantClient") diff --git a/tests/test_executor.py b/tests/test_executor.py index 0a2a74f..85adc6f 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -86,6 +86,20 @@ def mock_embedder(mocker): return mock_embed +class TestFetchCollectionInfo: + def test_non_not_found_value_error_is_wrapped(self, executor, mock_client): + mock_client.get_collection.side_effect = ValueError("transport failed") + + with pytest.raises(QQLRuntimeError, match="Qdrant error fetching collection"): + executor._fetch_collection_info("docs") + + def test_non_not_found_exception_is_wrapped(self, executor, mock_client): + mock_client.get_collection.side_effect = RuntimeError("transport failed") + + with pytest.raises(QQLRuntimeError, match="Qdrant error fetching collection"): + executor._fetch_collection_info("docs") + + class TestInsert: def test_insert_creates_collection_when_missing(self, executor, mock_client): node = InsertStmt(collection="notes", values={"text": "hello"}, model=None) From 76a8c6c6b0b1d3a85a1f151f14d04c9357f5edb3 Mon Sep 17 00:00:00 2001 From: Srimon Date: Sun, 31 May 2026 07:39:47 +0530 Subject: [PATCH 3/3] test: update async connection tests to assert awaited close calls and improve error handling in executor tests --- tests/test_async_connection.py | 13 ++++++++----- tests/test_executor.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index c739364..ea5d9ba 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -135,7 +135,7 @@ async def test_close_calls_client_close(self, mocker): mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) conn = AsyncConnection("http://localhost:6333") await conn.close() - mock_client.close.assert_called_once() + mock_client.close.assert_awaited_once() async def test_context_manager_closes_on_exit(self, mocker): mock_client = AsyncMock() @@ -144,7 +144,7 @@ async def test_context_manager_closes_on_exit(self, mocker): async with AsyncConnection("http://localhost:6333") as conn: assert conn._client is mock_client - mock_client.close.assert_called_once() + mock_client.close.assert_awaited_once() async def test_context_manager_closes_on_exception(self, mocker): mock_client = AsyncMock() @@ -154,7 +154,7 @@ async def test_context_manager_closes_on_exception(self, mocker): async with AsyncConnection("http://localhost:6333"): raise RuntimeError("oops") - mock_client.close.assert_called_once() + mock_client.close.assert_awaited_once() # ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── @@ -278,7 +278,10 @@ async def mock_resolve(*args, **kwargs): return val mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", side_effect=mock_resolve) - mocker.patch("qql.async_executor.AsyncExecutor._create_collection_and_wait", return_value=None) + create_collection = mocker.patch( + "qql.async_executor.AsyncExecutor._create_collection_and_wait", + new_callable=AsyncMock, + ) from qql import QQLConfig executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) @@ -298,4 +301,4 @@ async def mock_resolve(*args, **kwargs): assert res1.success is True assert res2.success is True # Verify that _create_collection_and_wait was called exactly once despite concurrency! - executor._create_collection_and_wait.assert_called_once() + create_collection.assert_awaited_once() diff --git a/tests/test_executor.py b/tests/test_executor.py index 85adc6f..0ed3eb8 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -87,13 +87,13 @@ def mock_embedder(mocker): class TestFetchCollectionInfo: - def test_non_not_found_value_error_is_wrapped(self, executor, mock_client): + def test_value_error_wrapped_as_qql_runtime_error(self, executor, mock_client): mock_client.get_collection.side_effect = ValueError("transport failed") with pytest.raises(QQLRuntimeError, match="Qdrant error fetching collection"): executor._fetch_collection_info("docs") - def test_non_not_found_exception_is_wrapped(self, executor, mock_client): + def test_runtime_error_wrapped_as_qql_runtime_error(self, executor, mock_client): mock_client.get_collection.side_effect = RuntimeError("transport failed") with pytest.raises(QQLRuntimeError, match="Qdrant error fetching collection"):