diff --git a/README.md b/README.md index efea9cf..fbd26a9 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,106 @@ for namespace in namespaces: ns.delete([1, 2]) ``` +Async API +------- + +The turbopuffer client also provides a fully-featured async API for use with asyncio-based applications. The async API follows the same patterns as the synchronous API but with async/await syntax. + +```py +import asyncio +import turbopuffer as tpuf + +async def main(): + # Set API key and base URL + tpuf.api_key = 'your-token' + tpuf.api_base_url = "https://gcp-us-east4.turbopuffer.com" + + # Create an AsyncNamespace instance + async_ns = tpuf.AsyncNamespace('hello_world') + + # Check if namespace exists + if await async_ns.aexists(): + print(f'Namespace {async_ns.name} exists with {await async_ns.adimensions()} dimensions') + print(f'and approximately {await async_ns.aapprox_count()} vectors.') + + # Upsert data asynchronously + await async_ns.aupsert( + ids=[1, 2, 3], + vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], + attributes={'name': ['foo', 'bar', 'baz']}, + distance_metric='cosine_distance', + ) + + # Upsert using row iterator or generator + await async_ns.aupsert( + { + 'id': id, + 'vector': [id/10, id/10], + 'attributes': {'name': f'item_{id}', 'value': id*10} + } for id in range(4, 10) + ) + + # Query vectors asynchronously + result = await async_ns.aquery( + vector=[0.2, 0.3], + distance_metric='cosine_distance', + top_k=5, + include_vectors=True, + include_attributes=True, + ) + + # AsyncVectorResult can be used with async for + async for row in result: + print(f"ID: {row.id}, Distance: {row.dist}") + + # Or loaded completely into memory + all_results = await result.load() + print(f"Found {len(all_results)} results") + + # List all vectors in the namespace + all_vectors = await async_ns.avectors() + # Load all vectors into memory + vectors_list = await all_vectors.load() + print(f"Namespace contains {len(vectors_list)} vectors") + + # Delete vectors asynchronously + await async_ns.adelete([1, 2]) + + # List all namespaces asynchronously + namespaces_iterator = await tpuf.anamespaces() + # Use async for to iterate through namespaces + async for namespace in namespaces_iterator: + print(f"Namespace: {namespace.name}") + + # Or load all namespaces at once + all_namespaces = await namespaces_iterator.load() + print(f"Total namespaces: {len(all_namespaces)}") + +# Run the async main function +asyncio.run(main()) +``` + +### Context Manager Support + +AsyncNamespace instances can be used as async context managers to ensure proper resource cleanup: + +```py +async def process_data(): + async with tpuf.AsyncNamespace('my_data') as ns: + # Perform operations + await ns.aupsert(ids=[1, 2], vectors=[[0.1, 0.2], [0.3, 0.4]]) + results = await ns.aquery(vector=[0.15, 0.25], top_k=5) + # Resources will be cleaned up when exiting the context +``` + +### Converting Between Sync and Async APIs + +The synchronous Namespace methods internally use the async methods by running them in an event loop. If you're mixing sync and async code, be aware of these considerations: + +- Synchronous methods create an event loop if needed +- For best performance in async applications, use the async API directly +- In async contexts, avoid calling synchronous methods as they may cause event loop issues + Endpoint Documentation ---------------------- diff --git a/pyproject.toml b/pyproject.toml index 67be1d9..f25dff8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "turbopuffer" -version = "0.1.32" +version = "0.2.0" description = "Python Client for accessing the turbopuffer API" authors = ["turbopuffer Inc. "] homepage = "https://turbopuffer.com" @@ -30,6 +30,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" requests = "^2.31" +aiohttp = "^3.9.1" iso8601 = "^2.1.0" orjson = { version = ">=3.9, <=3.10.3", optional = true } # 3.10.4 errors on install numpy = { version = ">=1.24.0", optional = true } diff --git a/tests/test_backend.py b/tests/test_backend.py index 7927b0f..e035fd8 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,9 +1,15 @@ +import asyncio +import time +import sys +from unittest import mock +from unittest.mock import AsyncMock, MagicMock, patch + import pytest import requests -import time +import aiohttp import turbopuffer as tpuf from turbopuffer import backend as tpuf_backend -from unittest import mock +from turbopuffer.error import APIError, raise_api_error def mock_response_returning(status_code, reason): @@ -20,44 +26,337 @@ def mock_response_returning(status_code, reason): return response +def setup_failing_async_response(status_code, reason): + """Helper function to configure backend with failing async response""" + # Create backend with test API key + backend = tpuf_backend.Backend("fake_api_key") + + # Create a proper mock response with spec for better compatibility + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = status_code + mock_response.ok = False + mock_response.headers = {"Content-Type": "application/json"} + mock_response.text = AsyncMock(return_value=reason) + # This method will be called when we need to parse json response + mock_response.json = AsyncMock(return_value={"status": "error", "error": reason}) + + # Configure raise_for_status to raise the appropriate exception + error = aiohttp.ClientResponseError( + request_info=AsyncMock(), + history=tuple(), + status=status_code, + message=reason, + headers={} + ) + mock_response.raise_for_status.side_effect = error + + # Create a proper context manager class + class MockContextManager: + async def __aenter__(self): + return mock_response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + # Create a session mock that returns our context manager + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=MockContextManager()) + mock_session.closed = False + + # Patch the _get_async_session method to return our mock session + with patch.object(backend, '_get_async_session', AsyncMock(return_value=mock_session)): + pass # This applies the patch for the whole function scope + + return backend, mock_session.request + + +async def _test_500_retried_async(): + """Async implementation for the 500 error retry test""" + # Create backend with test API key + backend = tpuf_backend.Backend("fake_api_key") + + # Create mock response with internal server error + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = 500 + mock_response.ok = False + mock_response.headers = {"Content-Type": "application/json"} + mock_response.text = AsyncMock(return_value="Internal Server Error") + mock_response.json = AsyncMock(return_value={"status": "error", "error": "Internal Server Error"}) + + # Configure raise_for_status to raise an appropriate exception + error = aiohttp.ClientResponseError( + request_info=AsyncMock(), + history=tuple(), + status=500, + message="Internal Server Error", + headers={} + ) + mock_response.raise_for_status.side_effect = error + + # Create context manager + class MockContextManager: + async def __aenter__(self): + return mock_response + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + # Create session + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=MockContextManager()) + mock_session.closed = False + + # Patching + with patch.object(backend, '_get_async_session', AsyncMock(return_value=mock_session)): + # Directly patch asyncio.sleep to count retries + with patch.object(asyncio, 'sleep', AsyncMock()) as sleep: + # We expect an API error when we make the request + with pytest.raises(APIError): + await backend.amake_api_request("vectors") + + # Since the 500 error should trigger retries, verify the sleep and request counts + assert sleep.call_count == tpuf.max_retries - 1 + assert mock_session.request.call_count == tpuf.max_retries + +def test_500_retried_async(): + """Synchronous wrapper for the async retry test""" + asyncio.run(_test_500_retried_async()) + +# Helper function removed as we're using asyncio.run directly + def test_500_retried(): + """Test that 500 errors are properly retried in the sync API""" + # Since the sync API now just calls the async API, we can reuse the async test + asyncio.run(_test_500_retried_sync()) + +async def _test_500_retried_sync(): + """Async version of the sync retry test""" + # Create backend with test API key backend = tpuf_backend.Backend("fake_api_key") - backend.session.send = mock.MagicMock() - backend.session.send.return_value = mock_response_returning(500, 'Internal Error') + + # Create mock response with internal server error + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = 500 + mock_response.ok = False + mock_response.headers = {"Content-Type": "application/json"} + mock_response.text = AsyncMock(return_value="Internal Server Error") + mock_response.json = AsyncMock(return_value={"status": "error", "error": "Internal Server Error"}) + + # Configure raise_for_status to raise an appropriate exception + error = aiohttp.ClientResponseError( + request_info=AsyncMock(), + history=tuple(), + status=500, + message="Internal Server Error", + headers={} + ) + mock_response.raise_for_status.side_effect = error + + # Create context manager + class MockContextManager: + async def __aenter__(self): + return mock_response + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + # Create session + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=MockContextManager()) + mock_session.closed = False + + # Patching + with patch.object(backend, '_get_async_session', AsyncMock(return_value=mock_session)): + # Directly patch asyncio.sleep to count retries + with patch.object(asyncio, 'sleep', AsyncMock()) as sleep: + # Use run_in_executor to call the sync API from async context + loop = asyncio.get_event_loop() + + # We expect an API error when calling the sync version + with pytest.raises(APIError): + # Call the sync method which internally calls the async one + await loop.run_in_executor(None, lambda: backend.make_api_request("vectors")) + + # Since we're using the same underlying amake_api_request, count should be the same + assert sleep.call_count == tpuf.max_retries - 1 + assert mock_session.request.call_count == tpuf.max_retries + - with mock.patch.object(time, 'sleep', return_value=None) as sleep: - with pytest.raises(tpuf.error.APIError): - backend.make_api_request('namespaces', payload={}) - assert sleep.call_count == tpuf.max_retries - 1 +async def _test_429_retried_async(): + """Async implementation for 429 error retry test""" + + # Setup backend with failing response + backend = tpuf_backend.Backend("fake_api_key") + + # Create a response mock with the appropriate async context manager methods + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = 429 + mock_response.ok = False + mock_response.headers = {"Content-Type": "application/json"} + mock_response.text = AsyncMock(return_value="Too Many Requests") + mock_response.json = AsyncMock(return_value={"status": "error", "error": "Rate limit exceeded"}) + + # Set up the raise_for_status method to raise an appropriate exception + error = aiohttp.ClientResponseError( + request_info=AsyncMock(), + history=tuple(), + status=429, + message="Too Many Requests", + headers={} + ) + mock_response.raise_for_status.side_effect = error + + # Create a proper context manager class + class MockContextManager: + async def __aenter__(self): + return mock_response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + # Create a session mock that returns our context manager + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=MockContextManager()) + mock_session.closed = False + + # Patch the _get_async_session method to return our mock + with patch.object(backend, '_get_async_session', AsyncMock(return_value=mock_session)): + # Patch asyncio.sleep to avoid waiting + with patch.object(asyncio, 'sleep', AsyncMock(return_value=None)) as sleep: + # Expect an APIError when we call make_api_request + with pytest.raises(APIError): + await backend.amake_api_request("namespaces", payload={}) + + # Verify retry behavior + assert sleep.call_count == tpuf.max_retries - 1 + assert mock_session.request.call_count == tpuf.max_retries + +def test_429_retried_async(): + """Synchronous wrapper for the async 429 retry test""" + asyncio.run(_test_429_retried_async()) def test_429_retried(): + """Test that 429 errors are properly retried in sync API which calls async""" + # For the sync API, we can use the make_api_request method with the same setup + asyncio.run(_test_429_retried_sync()) + +async def _test_429_retried_sync(): + """Test async function for the sync API that internally calls async""" + # Same setup as the async test backend = tpuf_backend.Backend("fake_api_key") - backend.session.send = mock.MagicMock() - backend.session.send.return_value = mock_response_returning(429, 'Too Many Requests') - - with mock.patch.object(time, 'sleep', return_value=None) as sleep: - with pytest.raises(tpuf.error.APIError): - backend.make_api_request('namespaces', payload={}) - assert sleep.call_count == tpuf.max_retries - 1 + + # Create mock response + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.status = 429 + mock_response.ok = False + mock_response.headers = {"Content-Type": "application/json"} + mock_response.text = AsyncMock(return_value="Too Many Requests") + mock_response.json = AsyncMock(return_value={"status": "error", "error": "Rate limit exceeded"}) + + # Set up exception + error = aiohttp.ClientResponseError( + request_info=AsyncMock(), + history=tuple(), + status=429, + message="Too Many Requests", + headers={} + ) + mock_response.raise_for_status.side_effect = error + + # Context manager + class MockContextManager: + async def __aenter__(self): + return mock_response + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + # Session + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=MockContextManager()) + mock_session.closed = False + + # Patch + with patch.object(backend, '_get_async_session', AsyncMock(return_value=mock_session)): + with patch.object(asyncio, 'sleep', AsyncMock(return_value=None)) as sleep: + # Use the sync API which calls the async API internally + with pytest.raises(APIError): + # Use run_in_executor to call sync from async + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: backend.make_api_request("namespaces", payload={})) + + # Verify + assert sleep.call_count == tpuf.max_retries - 1 + assert mock_session.request.call_count == tpuf.max_retries def test_custom_headers(): - backend = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) - assert backend.session.headers["foo"] == "bar" + """Test that custom headers are properly passed to sync sessions""" + # Since we mock the find_api_key function, no API key errors will occur + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key'): + backend = tpuf_backend.Backend(headers={"foo": "bar"}) + # Test sync session headers + assert backend.session.headers["foo"] == "bar" + + # Test that namespace correctly passes headers to backend + ns = tpuf.Namespace("fake_namespace", headers={"foo": "bar"}) + assert ns.backend.session.headers["foo"] == "bar" + + # Test headers are stored in the backend for later async session creation + assert backend.headers == {"foo": "bar"} - ns = tpuf.Namespace('fake_namespace', headers = {"foo": "bar"}) - assert ns.backend.session.headers["foo"] == "bar" +async def _test_custom_headers_async(): + """Async implementation for testing custom headers in async contexts""" + # Mock the find_api_key to avoid auth errors + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key'): + # Create backend with custom headers + backend = tpuf_backend.Backend(headers={"foo": "bar"}) + assert backend.headers == {"foo": "bar"} + + # Create a mock session to be returned by ClientSession + mock_session = AsyncMock() + mock_session.closed = False + + # Mock the ClientSession constructor + with patch('aiohttp.ClientSession', return_value=mock_session) as mock_client_session: + # Call the method that creates the session + await backend._get_async_session() + + # Verify ClientSession was called with headers + mock_client_session.assert_called_once() + call_kwargs = mock_client_session.call_args.kwargs + assert 'headers' in call_kwargs + assert 'foo' in call_kwargs['headers'] + assert call_kwargs['headers']['foo'] == 'bar' + + # Test with AsyncNamespace + with patch.object(tpuf, 'api_key', 'fake_api_key'): + ns = tpuf.AsyncNamespace("fake_namespace", headers={"foo": "bar"}) + assert ns.backend.headers == {"foo": "bar"} + +def test_custom_headers_async(): + """Synchronous wrapper for async headers test""" + asyncio.run(_test_custom_headers_async()) -def test_backend_eq(): - backend = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) +# Remove redundant test that uses run_async_test helper - backend2 = tpuf_backend.Backend("fake_api_key", headers = {"foo": "notbar"}) - assert backend != backend2 - backend2 = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) - assert backend == backend2 - - backend2 = tpuf_backend.Backend("fake_api_key2", headers = {"foo": "bar"}) - assert backend != backend2 +def test_backend_eq(): + """Test that backend equality checks work correctly""" + # Mock find_api_key to avoid auth errors + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key'): + backend = tpuf_backend.Backend(headers={"foo": "bar"}) + + # Different headers should not be equal + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key'): + backend2 = tpuf_backend.Backend(headers={"foo": "notbar"}) + assert backend != backend2 + + # Same api key and headers should be equal + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key'): + backend2 = tpuf_backend.Backend(headers={"foo": "bar"}) + assert backend == backend2 + + # Different api key should not be equal + with patch.object(tpuf_backend, 'find_api_key', return_value='fake_api_key2'): + backend2 = tpuf_backend.Backend("fake_api_key2", headers={"foo": "bar"}) + assert backend != backend2 diff --git a/turbopuffer/__init__.py b/turbopuffer/__init__.py index 5e7e161..85fb992 100644 --- a/turbopuffer/__init__.py +++ b/turbopuffer/__init__.py @@ -27,7 +27,7 @@ def default(self, obj): def dump_json_bytes(obj): return json.dumps(obj, cls=NumpyEncoder).encode() from turbopuffer.version import VERSION -from turbopuffer.namespace import Namespace, namespaces, AttributeSchema, FullTextSearchParams -from turbopuffer.vectors import VectorColumns, VectorRow, VectorResult +from turbopuffer.namespace import Namespace, AsyncNamespace, namespaces, anamespaces, AttributeSchema, FullTextSearchParams +from turbopuffer.vectors import VectorColumns, VectorRow, VectorResult, AsyncVectorResult from turbopuffer.query import VectorQuery, Filters from turbopuffer.error import TurbopufferError, AuthenticationError, APIError, NotFoundError diff --git a/turbopuffer/backend.py b/turbopuffer/backend.py index b8e5ced..6be37c4 100644 --- a/turbopuffer/backend.py +++ b/turbopuffer/backend.py @@ -1,12 +1,15 @@ +import asyncio +import gzip import json import re import time import traceback +from typing import Any, Dict, List, Optional, Union + +import aiohttp import requests import turbopuffer as tpuf -import gzip from turbopuffer.error import AuthenticationError, raise_api_error -from typing import Optional, List def find_api_key(api_key: Optional[str] = None) -> str: @@ -15,13 +18,16 @@ def find_api_key(api_key: Optional[str] = None) -> str: elif tpuf.api_key is not None: return tpuf.api_key else: - raise AuthenticationError("No turbopuffer API key was provided.\n" - "Set the TURBOPUFFER_API_KEY environment variable, " - "or pass `api_key=` when creating a Namespace.") + raise AuthenticationError( + "No turbopuffer API key was provided.\n" + "Set the TURBOPUFFER_API_KEY environment variable, " + "or pass `api_key=` when creating a Namespace." + ) + def clean_api_base_url(base_url: str) -> str: - if base_url.endswith(('/v1', '/v1/', '/')): - return re.sub('(/v1|/v1/|/)$', '', base_url) + if base_url.endswith(("/v1", "/v1/", "/")): + return re.sub("(/v1|/v1/|/)$", "", base_url) else: return base_url @@ -30,133 +36,253 @@ class Backend: api_key: str api_base_url: str session: requests.Session + _async_session: Optional[aiohttp.ClientSession] = None def __init__(self, api_key: Optional[str] = None, headers: Optional[dict] = None): self.api_key = find_api_key(api_key) self.api_base_url = clean_api_base_url(tpuf.api_base_url) self.headers = headers self.session = requests.Session() - self.session.headers.update({ - 'Authorization': f'Bearer {self.api_key}', - 'User-Agent': f'tpuf-python/{tpuf.VERSION} {requests.utils.default_headers()["User-Agent"]}', - }) + self.session.headers.update( + { + "Authorization": f"Bearer {self.api_key}", + "User-Agent": f"tpuf-python/{tpuf.VERSION} {requests.utils.default_headers()['User-Agent']}", + } + ) if headers is not None: self.session.headers.update(headers) def __eq__(self, other): if isinstance(other, Backend): - return self.api_key == other.api_key and self.api_base_url == other.api_base_url and self.headers == other.headers + return ( + self.api_key == other.api_key + and self.api_base_url == other.api_base_url + and self.headers == other.headers + ) else: return False - def make_api_request(self, - *args: List[str], - method: Optional[str] = None, - query: Optional[dict] = None, - payload: Optional[dict] = None) -> dict: + async def __aenter__(self): + """Enable use as an async context manager.""" + await self._get_async_session() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Clean up resources when exiting the async context.""" + if self._async_session and not self._async_session.closed: + await self._async_session.close() + + async def _get_async_session(self) -> aiohttp.ClientSession: + """Get or create an async session.""" + if self._async_session is None or self._async_session.closed: + headers = { + "Authorization": f"Bearer {self.api_key}", + "User-Agent": f"tpuf-python/{tpuf.VERSION} aiohttp", + } + if self.headers is not None: + headers.update(self.headers) + self._async_session = aiohttp.ClientSession(headers=headers) + return self._async_session + + async def amake_api_request( + self, + *args: List[str], + method: Optional[str] = None, + query: Optional[dict] = None, + payload: Optional[Union[dict, bytes]] = None, + ) -> dict: + """ + Asynchronous version of make_api_request. + Makes API requests to the Turbopuffer API. + """ start = time.monotonic() if method is None and payload is not None: - method = 'POST' - - request = requests.Request(method or 'GET', self.api_base_url + '/v1/' + '/'.join(args)) + method = "POST" - if query is not None: - request.params = query + url = self.api_base_url + "/v1/" + "/".join(args) performance = dict() + headers = {} + + # Prepare payload data + data = None if payload is not None: payload_start = time.monotonic() if isinstance(payload, dict): json_payload = tpuf.dump_json_bytes(payload) - performance['json_time'] = time.monotonic() - payload_start + performance["json_time"] = time.monotonic() - payload_start elif isinstance(payload, bytes): json_payload = payload else: - raise ValueError(f'Unsupported POST payload type: {type(payload)}') + raise ValueError(f"Unsupported POST payload type: {type(payload)}") gzip_start = time.monotonic() gzip_payload = gzip.compress(json_payload, compresslevel=1) - performance['gzip_time'] = time.monotonic() - gzip_start + performance["gzip_time"] = time.monotonic() - gzip_start if len(gzip_payload) > 0: - performance['gzip_ratio'] = len(json_payload) / len(gzip_payload) + performance["gzip_ratio"] = len(json_payload) / len(gzip_payload) - request.headers.update({ - 'Content-Type': 'application/json', - 'Content-Encoding': 'gzip', - }) - request.data = gzip_payload + headers.update( + { + "Content-Type": "application/json", + "Content-Encoding": "gzip", + } + ) + data = gzip_payload - prepared = self.session.prepare_request(request) + session = await self._get_async_session() retry_attempt = 0 - timeouts = (tpuf.connect_timeout, tpuf.read_timeout) + timeout = aiohttp.ClientTimeout( + connect=tpuf.connect_timeout, total=tpuf.read_timeout + ) + while retry_attempt < tpuf.max_retries: request_start = time.monotonic() try: - # print(f'Sending request:', prepared.path_url, prepared.headers) - response = self.session.send(prepared, allow_redirects=False, timeout=timeouts) - performance['request_time'] = time.monotonic() - request_start - # print(f'Request time (HTTP {response.status_code}):', performance['request_time']) - - if response.status_code >= 500 or response.status_code == 408 or response.status_code == 429: - response.raise_for_status() - - server_timing_str = response.headers.get('Server-Timing', '') - if len(server_timing_str) > 0: - match_cache_hit_ratio = re.match(r'.*cache_hit_ratio;ratio=([\d\.]+)', server_timing_str) - if match_cache_hit_ratio: - try: - performance['cache_hit_ratio'] = float(match_cache_hit_ratio.group(1)) - except ValueError: - pass - match_processing = re.match(r'.*processing_time;dur=([\d\.]+)', server_timing_str) - if match_processing: - try: - performance['server_time'] = float(match_processing.group(1)) / 1000.0 - except ValueError: - pass - match_exhaustive = re.match(r'.*exhaustive_search_count;count=([\d\.]+)', server_timing_str) - if match_exhaustive: + async with session.request( + method or "GET", + url, + params=query, + data=data, + headers=headers, + allow_redirects=False, + timeout=timeout, + ) as response: + performance["request_time"] = time.monotonic() - request_start + + # Handle server errors with retries + if ( + response.status >= 500 + or response.status == 408 + or response.status == 429 + ): + response.raise_for_status() + + # Extract server timing information + server_timing_str = response.headers.get("Server-Timing", "") + if server_timing_str: + match_cache_hit_ratio = re.match( + r".*cache_hit_ratio;ratio=([\d\.]+)", server_timing_str + ) + if match_cache_hit_ratio: + try: + performance["cache_hit_ratio"] = float( + match_cache_hit_ratio.group(1) + ) + except ValueError: + pass + match_processing = re.match( + r".*processing_time;dur=([\d\.]+)", server_timing_str + ) + if match_processing: + try: + performance["server_time"] = ( + float(match_processing.group(1)) / 1000.0 + ) + except ValueError: + pass + match_exhaustive = re.match( + r".*exhaustive_search_count;count=([\d\.]+)", + server_timing_str, + ) + if match_exhaustive: + try: + performance["exhaustive_search_count"] = int( + match_exhaustive.group(1) + ) + except ValueError: + pass + + # Handle HEAD request + if method == "HEAD": + return { + "status_code": response.status, + "headers": dict(response.headers), + "performance": performance, + } + + # Process response content + content_type = response.headers.get("Content-Type", "text/plain") + if content_type == "application/json": try: - performance['exhaustive_search_count'] = int(match_exhaustive.group(1)) - except ValueError: - pass - - if method == 'HEAD': - return dict(response.__dict__, **{ - 'performance': performance, - }) - - content_type = response.headers.get('Content-Type', 'text/plain') - if content_type == 'application/json': - try: - content = response.json() - except json.JSONDecodeError as err: - raise_api_error(response.status_code, traceback.format_exception_only(err), response.text) - - if response.ok: - performance['total_time'] = time.monotonic() - start - return dict(response.__dict__, **{ - 'content': content, - 'performance': performance, - }) + content = await response.json() + except json.JSONDecodeError as err: + text = await response.text() + raise_api_error( + response.status, + traceback.format_exception_only(err), + text, + ) + + if response.ok: + performance["total_time"] = time.monotonic() - start + return { + "status_code": response.status, + "ok": response.ok, + "headers": dict(response.headers), + "content": content, + "performance": performance, + } + else: + raise_api_error( + response.status, + content.get("status", "error"), + content.get("error", ""), + ) else: - raise_api_error(response.status_code, content.get('status', 'error'), content.get('error', '')) - else: - raise_api_error(response.status_code, 'Server returned non-JSON response', response.text) - except (requests.HTTPError, requests.ConnectionError, requests.Timeout) as err: + text = await response.text() + raise_api_error( + response.status, "Server returned non-JSON response", text + ) + + except ( + aiohttp.ClientResponseError, + aiohttp.ClientConnectionError, + asyncio.TimeoutError, + ) as err: retry_attempt += 1 - # print(traceback.format_exc()) + if retry_attempt < tpuf.max_retries: - # print(f'Retrying request in {2 ** retry_attempt}s...') - time.sleep(2 ** retry_attempt) # exponential falloff up to 64 seconds for 6 retries. + await asyncio.sleep( + 2**retry_attempt + ) # exponential falloff up to 64 seconds for 6 retries else: - print(f'Request failed after {retry_attempt} attempts...') + print(f"Request failed after {retry_attempt} attempts...") - if isinstance(err, requests.HTTPError): - raise_api_error(err.response.status_code, - f'Request to {err.request.url} failed after {retry_attempt} attempts', - str(err)) + if isinstance(err, aiohttp.ClientResponseError): + raise_api_error( + err.status, + f"Request to {url} failed after {retry_attempt} attempts", + str(err), + ) else: raise + + def make_api_request( + self, + *args: List[str], + method: Optional[str] = None, + query: Optional[dict] = None, + payload: Optional[dict] = None, + ) -> dict: + """ + Makes synchronous API requests to the Turbopuffer API by calling the async version. + If this call succeeds, data is guaranteed to be durably written to object storage. + """ + # Create an event loop or use the existing one + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No event loop exists in this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Call the async version and wait for the result + return loop.run_until_complete( + self.amake_api_request(*args, method=method, query=query, payload=payload) + ) diff --git a/turbopuffer/namespace.py b/turbopuffer/namespace.py index 59c4a1e..142ac80 100644 --- a/turbopuffer/namespace.py +++ b/turbopuffer/namespace.py @@ -1,17 +1,22 @@ +from dataclasses import dataclass, asdict import sys import iso8601 import json +import asyncio from datetime import datetime from turbopuffer.error import APIError -from turbopuffer.vectors import Cursor, VectorResult, VectorColumns, VectorRow, batch_iter +from turbopuffer.vectors import Cursor, VectorResult, AsyncVectorResult, VectorColumns, VectorRow, batch_iter, abatch_iter from turbopuffer.backend import Backend from turbopuffer.query import VectorQuery, Filters, RankInput, ConsistencyDict -from typing import Dict, List, Literal, Optional, Iterable, Union, overload +from typing import Dict, List, Literal, Optional, Iterable, Union, overload, AsyncIterable, Type, TypeVar, Generic, AsyncIterator import turbopuffer as tpuf +T = TypeVar('T') + CmekDict = Dict[Literal['key_name'], str] EncryptionDict = Dict[Literal['cmek'], CmekDict] +@dataclass(frozen=True) class FullTextSearchParams: """ Used for configuring BM25 full-text indexing for a given attribute. @@ -22,19 +27,8 @@ class FullTextSearchParams: remove_stopwords: bool case_sensitive: bool - def __init__(self, language: str, stemming: bool, remove_stopwords: bool, case_sensitive: bool): - self.language = language - self.stemming = stemming - self.remove_stopwords = remove_stopwords - self.case_sensitive = case_sensitive - def as_dict(self) -> dict: - return { - "language": self.language, - "stemming": self.stemming, - "remove_stopwords": self.remove_stopwords, - "case_sensitive": self.case_sensitive, - } + return asdict(self) class AttributeSchema: """ @@ -114,67 +108,74 @@ def __eq__(self, other): else: return False + def _run_async(self, coroutine): + """Helper method to run async code from sync methods""" + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No event loop exists in this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(coroutine) + def refresh_metadata(self): - response = self.backend.make_api_request('namespaces', self.name, method='HEAD') - status_code = response.get('status_code') - if status_code == 200: - headers = response.get('headers', dict()) - dimensions = int(headers.get('x-turbopuffer-dimensions', '0')) - approx_count = int(headers.get('x-turbopuffer-approx-num-vectors', '0')) - self.metadata = { - 'exists': dimensions != 0, - 'dimensions': dimensions, - 'approx_count': approx_count, - 'created_at': iso8601.parse_date(headers.get('x-turbopuffer-created-at')), - } - elif status_code == 404: - self.metadata = { - 'exists': False, - 'dimensions': 0, - 'approx_count': 0, - 'created_at': None, - } - else: - raise APIError(response.status_code, 'Unexpected status code', response.get('content')) + """ + Refreshes the namespace metadata. + """ + # Create an async version if it doesn't exist yet + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.arefresh_metadata()) def exists(self) -> bool: """ Returns True if the namespace exists, and False if the namespace is missing or empty. """ - # Always refresh the exists check since metadata from namespaces() might be delayed. - self.refresh_metadata() - return self.metadata['exists'] + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.aexists()) def dimensions(self) -> int: """ Returns the number of vector dimensions stored in this namespace. """ - if self.metadata is None or 'dimensions' not in self.metadata: - self.refresh_metadata() - return self.metadata.pop('dimensions', 0) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.adimensions()) def approx_count(self) -> int: """ Returns the approximate number of vectors stored in this namespace. """ - if self.metadata is None or 'approx_count' not in self.metadata: - self.refresh_metadata() - return self.metadata.pop('approx_count', 0) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.aapprox_count()) def created_at(self) -> Optional[datetime]: """ Returns the creation date of this namespace. """ - if self.metadata is None or 'created_at' not in self.metadata: - self.refresh_metadata() - return self.metadata.pop('created_at', None) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.acreated_at()) def schema(self) -> NamespaceSchema: """ Returns the current schema for the namespace. """ - response = self.backend.make_api_request('namespaces', self.name, 'schema', method='GET') - return parse_namespace_schema(response["content"]) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.aschema()) def update_schema(self, schema_updates: NamespaceSchema): """ @@ -183,9 +184,10 @@ def update_schema(self, schema_updates: NamespaceSchema): See https://turbopuffer.com/docs/schema for specifics on allowed updates. """ - request_payload = json.dumps({key: value.as_dict() for key, value in schema_updates.items()}).encode() - response = self.backend.make_api_request('namespaces', self.name, 'schema', method='POST', payload=request_payload) - return parse_namespace_schema(response["content"]) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.aupdate_schema(schema_updates)) def copy_from_namespace(self, source_namespace: str): """ @@ -194,11 +196,10 @@ def copy_from_namespace(self, source_namespace: str): See: https://turbopuffer.com/docs/upsert#parameters `copy_from_namespace` for specifics on how this works. """ - payload = { - "copy_from_namespace": source_namespace - } - response = self.backend.make_api_request('namespaces', self.name, payload=payload) - assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid copy_from_namespace() response: {response}' + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.acopy_from_namespace(source_namespace)) @overload def upsert(self, @@ -258,127 +259,32 @@ def upsert(self, """ ... - def upsert(self, - data=None, - ids=None, - vectors=None, - attributes=None, - schema=None, - distance_metric=None, - encryption= None) -> None: - if data is None: - if ids is not None and vectors is not None: - return self.upsert(VectorColumns(ids=ids, vectors=vectors, attributes=attributes), schema=schema, distance_metric=distance_metric, encryption=encryption) - else: - raise ValueError('upsert() requires both ids= and vectors= be set.') - elif (ids is not None and attributes is None) or (attributes is not None and schema is None): - # Offset arguments to handle positional arguments case with no data field. - return self.upsert(VectorColumns(ids=data, vectors=ids, attributes=vectors), schema=attributes, distance_metric=distance_metric, encryption=encryption) - elif isinstance(data, VectorColumns): - payload = {**data.__dict__} - - if distance_metric is not None: - payload["distance_metric"] = distance_metric - - if schema is not None: - payload["schema"] = schema + def upsert(self, *args, **kwargs) -> None: + """ + Creates or updates vectors. + """ + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) - if encryption is not None: - payload["encryption"] = encryption - - response = self.backend.make_api_request('namespaces', self.name, payload=payload) - - assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid upsert() response: {response}' - self.metadata = None # Invalidate cached metadata - elif isinstance(data, VectorRow): - raise ValueError('upsert() should be called on a list of vectors, got single vector.') - elif isinstance(data, list): - if isinstance(data[0], dict): - return self.upsert(VectorColumns.from_rows(data), schema=schema, distance_metric=distance_metric, encryption=encryption) - elif isinstance(data[0], VectorRow): - return self.upsert(VectorColumns.from_rows(data), schema=schema, distance_metric=distance_metric, encryption=encryption) - elif isinstance(data[0], VectorColumns): - for columns in data: - self.upsert(columns, schema=schema, distance_metric=distance_metric, encryption=encryption) - return - else: - raise ValueError(f'Unsupported list data type: {type(data[0])}') - elif isinstance(data, dict): - if 'id' in data: - raise ValueError('upsert() should be called on a list of vectors, got single vector.') - elif 'ids' in data: - return self.upsert(VectorColumns.from_dict(data), schema=data.get('schema', None), distance_metric=distance_metric, encryption=encryption) - else: - raise ValueError('Provided dict is missing ids.') - elif 'pandas' in sys.modules and isinstance(data, sys.modules['pandas'].DataFrame): - if 'id' not in data.keys(): - raise ValueError('Provided pd.DataFrame is missing an id column.') - if 'vector' not in data.keys(): - raise ValueError('Provided pd.DataFrame is missing a vector column.') - # start = time.monotonic() - for i in range(0, len(data), tpuf.upsert_batch_size): - batch = data[i:i+tpuf.upsert_batch_size] - attributes = dict() - for key, values in batch.items(): - if key != 'id' and key != 'vector': - attributes[key] = values.tolist() - columns = tpuf.VectorColumns( - ids=batch['id'].tolist(), - vectors=batch['vector'].transform(lambda x: x.tolist()).tolist(), - attributes=attributes - ) - # time_diff = time.monotonic() - start - # print(f"Batch {columns.ids[0]}..{columns.ids[-1]} begin:", time_diff, '/', len(batch), '=', len(batch)/time_diff) - # before = time.monotonic() - # print(columns) - self.upsert(columns, schema=schema, distance_metric=distance_metric, encryption=encryption) - # time_diff = time.monotonic() - before - # print(f"Batch {columns.ids[0]}..{columns.ids[-1]} time:", time_diff, '/', len(batch), '=', len(batch)/time_diff) - # start = time.monotonic() - return - elif isinstance(data, Iterable): - # start = time.monotonic() - for batch in batch_iter(data, tpuf.upsert_batch_size): - # time_diff = time.monotonic() - start - # print('Batch begin:', time_diff, '/', len(batch), '=', len(batch)/time_diff) - # before = time.monotonic() - self.upsert(batch, schema=schema, distance_metric=distance_metric, encryption=encryption) - # time_diff = time.monotonic() - before - # print('Batch time:', time_diff, '/', len(batch), '=', len(batch)/time_diff) - # start = time.monotonic() - return - else: - raise ValueError(f'Unsupported data type: {type(data)}') + return self._run_async(self._async_namespace.aupsert(*args, **kwargs)) def delete(self, ids: Union[int, str, List[int], List[str]]) -> None: """ Deletes vectors by id. """ - - if isinstance(ids, int) or isinstance(ids, str): - response = self.backend.make_api_request('namespaces', self.name, payload={ - 'ids': [ids], - 'vectors': [None], - }) - elif isinstance(ids, list): - response = self.backend.make_api_request('namespaces', self.name, payload={ - 'ids': ids, - 'vectors': [None] * len(ids), - }) - else: - raise ValueError(f'Unsupported ids type: {type(ids)}') - - assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid delete() response: {response}' - self.metadata = None # Invalidate cached metadata + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.adelete(ids)) def delete_by_filter(self, filters: Filters) -> int: - response = self.backend.make_api_request('namespaces', self.name, payload={ - 'delete_by_filter': filters - }) - response_content = response.get('content', dict()) - assert response_content.get('status', '') == 'OK', f'Invalid delete_by_filter() response: {response}' - self.metadata = None # Invalidate cached metadata - return response_content.get('rows_affected') + """ + Deletes vectors by filter. + """ + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.adelete_by_filter(filters)) @overload def query(self, @@ -401,43 +307,18 @@ def query(self, query_data: VectorQuery) -> VectorResult: def query(self, query_data: dict) -> VectorResult: ... - def query(self, - query_data=None, - vector=None, - distance_metric=None, - top_k=None, - include_vectors=None, - include_attributes=None, - filters=None, - rank_by=None, - consistency=None) -> VectorResult: + def query(self, *args, **kwargs) -> VectorResult: """ Searches vectors matching the search query. - + See https://turbopuffer.com/docs/reference/query for query filter parameters. """ - - if query_data is None: - return self.query(VectorQuery( - vector=vector, - distance_metric=distance_metric, - top_k=top_k, - include_vectors=include_vectors, - include_attributes=include_attributes, - filters=filters, - rank_by=rank_by, - consistency=consistency - )) - if not isinstance(query_data, VectorQuery): - if isinstance(query_data, dict): - query_data = VectorQuery.from_dict(query_data) - else: - raise ValueError(f'query() input type must be compatible with turbopuffer.VectorQuery: {type(query_data)}') - - response = self.backend.make_api_request('namespaces', self.name, 'query', payload=query_data.__dict__) - result = VectorResult(response.get('content', dict()), namespace=self) - result.performance = response.get('performance') - return result + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + async_result = self._run_async(self._async_namespace.aquery(*args, **kwargs)) + # Convert AsyncVectorResult to VectorResult for synchronous API + return VectorResult(async_result.data, namespace=self, next_cursor=async_result.next_cursor) def vectors(self, cursor: Optional[Cursor] = None) -> VectorResult: """ @@ -446,31 +327,31 @@ def vectors(self, cursor: Optional[Cursor] = None) -> VectorResult: If you want to look up vectors by ID, use the query function with an id filter. """ - - response = self.backend.make_api_request('namespaces', self.name, query={'cursor': cursor}) - content = response.get('content', dict()) - next_cursor = content.pop('next_cursor', None) - result = VectorResult(content, namespace=self, next_cursor=next_cursor) - result.performance = response.get('performance') - return result + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + async_result = self._run_async(self._async_namespace.avectors(cursor)) + # Convert AsyncVectorResult to VectorResult for synchronous API + return VectorResult(async_result.data, namespace=self, next_cursor=async_result.next_cursor) def delete_all_indexes(self) -> None: """ Deletes all indexes in a namespace. """ - - response = self.backend.make_api_request('namespaces', self.name, 'index', method='DELETE') - assert response.get('content', dict()).get('status', '') == 'ok', f'Invalid delete_all_indexes() response: {response}' + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.adelete_all_indexes()) def delete_all(self) -> None: """ Deletes all data as well as all indexes. """ - - response = self.backend.make_api_request('namespaces', self.name, method='DELETE') - assert response.get('content', dict()).get('status', '') == 'ok', f'Invalid delete_all() response: {response}' - self.metadata = None # Invalidate cached metadata - + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.adelete_all()) + def recall(self, num=20, top_k=10) -> float: """ This function evaluates the recall performance of ANN queries in this namespace. @@ -480,18 +361,17 @@ def recall(self, num=20, top_k=10) -> float: Recall is calculated as the ratio of matching vectors between the two search results. """ - - response = self.backend.make_api_request('namespaces', self.name, '_debug', 'recall', query={'num': num, 'top_k': top_k}) - content = response.get('content', dict()) - assert 'avg_recall' in content, f'Invalid recall() response: {response}' - return float(content.get('avg_recall')) + if not hasattr(self, '_async_namespace'): + self._async_namespace = AsyncNamespace(self.name, self.backend.api_key) + + return self._run_async(self._async_namespace.arecall(num, top_k)) class NamespaceIterator: """ - The VectorResult type represents a set of vectors that are the result of a query. + The NamespaceIterator type represents a set of namespaces that can be iterated over. - A VectorResult can be treated as either a lazy iterator or a list by the user. + A NamespaceIterator can be treated as either a lazy iterator or a list by the user. Reading the length of the result will internally buffer the full result. """ @@ -513,6 +393,7 @@ def __init__(self, backend: Backend, initial_set: Union[List[Namespace], List[di else: self.namespaces = NamespaceIterator.load_namespaces(backend.api_key, initial_set) + @staticmethod def load_namespaces(api_key: Optional[str], initial_set: List[dict]) -> List[Namespace]: output = [] for input in initial_set: @@ -547,7 +428,7 @@ def __len__(self) -> int: self.next_cursor = None return len(self.namespaces) - def __getitem__(self, index) -> VectorRow: + def __getitem__(self, index) -> Namespace: if index >= len(self.namespaces) and self.next_cursor: it = iter(self) self.namespaces = [next for next in it] @@ -579,13 +460,562 @@ def __next__(self): return self.__next__() -def namespaces(api_key: Optional[str] = None) -> Iterable[Namespace]: +class AsyncNamespaceIterator: """ - Lists all turbopuffer namespaces for a given api_key. + The AsyncNamespaceIterator type represents a set of namespaces that can be asynchronously iterated over. + + An AsyncNamespaceIterator implements an async iterator interface, allowing you to use it with + async for loops. For full buffer access, await the load() method to retrieve all results. + """ + + backend: Backend + namespaces: List['AsyncNamespace'] = [] + index: int = -1 + offset: int = 0 + next_cursor: Optional[Cursor] = None + + def __init__(self, backend: Backend, initial_set: Union[List['AsyncNamespace'], List[dict]] = [], next_cursor: Optional[Cursor] = None): + self.backend = backend + self.index = -1 + self.offset = 0 + self.next_cursor = next_cursor + + if len(initial_set): + if isinstance(initial_set[0], dict): + self.namespaces = AsyncNamespaceIterator.load_namespaces(backend.api_key, initial_set) + else: + self.namespaces = initial_set + + @staticmethod + def load_namespaces(api_key: Optional[str], initial_set: List[dict]) -> List['AsyncNamespace']: + output = [] + for input in initial_set: + ns = AsyncNamespace(input['id'], api_key=api_key) + ns.metadata = { + 'exists': True, + } + output.append(ns) + + return output + + def __str__(self) -> str: + str_list = [ns.name for ns in self.namespaces] + if not self.next_cursor and self.offset == 0: + return str(str_list) + else: + return ("AsyncNamespaceIterator(" + f"offset={self.offset}, " + f"next_cursor='{self.next_cursor}', " + f"namespaces={str_list})") + + async def load(self) -> List['AsyncNamespace']: + """ + Loads and returns all namespaces, fetching additional pages as needed. + + This method buffers all data in memory. For large result sets, consider + using the async iterator interface instead. + """ + if not self.next_cursor: + return self.namespaces + + # Create new iterator and exhaust it to load all results + result = [] + async for item in self.__aiter__(): + result.append(item) + + # Update our state with the fully loaded data + self.namespaces = result + self.offset = 0 + self.index = -1 + self.next_cursor = None + + return result + + def __aiter__(self) -> 'AsyncNamespaceIterator': + # Reset state to start fresh iteration + return AsyncNamespaceIterator(self.backend, self.namespaces, self.next_cursor) + + async def __anext__(self) -> 'AsyncNamespace': + # Handle the case where we have data in memory + if self.index + 1 < len(self.namespaces): + self.index += 1 + return self.namespaces[self.index] + # Handle the case where we're at the end of our data + elif self.next_cursor is None: + raise StopAsyncIteration + # Handle the case where we need to fetch more data + else: + response = await self.backend.amake_api_request( + 'namespaces', + query={'cursor': self.next_cursor} + ) + content = response.get('content', dict()) + + self.offset += len(self.namespaces) + self.index = -1 + self.next_cursor = content.pop('next_cursor', None) + self.namespaces = AsyncNamespaceIterator.load_namespaces( + self.backend.api_key, + content.pop('namespaces', list()) + ) + + # Recursively call __anext__ to handle the case where data is empty + return await self.__anext__() + + +class AsyncNamespace: + """ + The AsyncNamespace type represents a set of vectors stored in turbopuffer, + with asynchronous access. + + Within a namespace, vectors are uniquely referred to by their ID. + All vectors in a namespace must have the same dimensions. + + This class provides async versions of all Namespace methods. + """ + + name: str + backend: Backend + + metadata: Optional[dict] = None + + def __init__(self, name: str, api_key: Optional[str] = None, headers: Optional[dict] = None): + """ + Creates a new turbopuffer.AsyncNamespace object for querying the turbopuffer API asynchronously. + + This function does not make any API calls on its own. + + Specifying an api_key here will override the global configuration for API calls to this namespace. + """ + self.name = name + self.backend = Backend(api_key, headers) + + def __str__(self) -> str: + return f'tpuf-async-namespace:{self.name}' + + def __eq__(self, other): + if isinstance(other, AsyncNamespace): + return self.name == other.name and self.backend == other.backend + else: + return False + + async def __aenter__(self): + """Enable use as an async context manager.""" + await self.backend._get_async_session() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Clean up resources when exiting the async context.""" + if self.backend._async_session and not self.backend._async_session.closed: + await self.backend._async_session.close() + + async def arefresh_metadata(self): + """ + Asynchronously refreshes the namespace metadata. + """ + response = await self.backend.amake_api_request('namespaces', self.name, method='HEAD') + status_code = response.get('status_code') + if status_code == 200: + headers = response.get('headers', dict()) + dimensions = int(headers.get('x-turbopuffer-dimensions', '0')) + approx_count = int(headers.get('x-turbopuffer-approx-num-vectors', '0')) + self.metadata = { + 'exists': dimensions != 0, + 'dimensions': dimensions, + 'approx_count': approx_count, + 'created_at': iso8601.parse_date(headers.get('x-turbopuffer-created-at')), + } + elif status_code == 404: + self.metadata = { + 'exists': False, + 'dimensions': 0, + 'approx_count': 0, + 'created_at': None, + } + else: + raise APIError(response.get('status_code'), 'Unexpected status code', response.get('content')) + + async def aexists(self) -> bool: + """ + Asynchronously checks if the namespace exists. + Returns True if the namespace exists, and False if the namespace is missing or empty. + """ + # Always refresh the exists check since metadata from namespaces() might be delayed. + await self.arefresh_metadata() + return self.metadata['exists'] + + async def adimensions(self) -> int: + """ + Asynchronously returns the number of vector dimensions stored in this namespace. + """ + if self.metadata is None or 'dimensions' not in self.metadata: + await self.arefresh_metadata() + return self.metadata.pop('dimensions', 0) + + async def aapprox_count(self) -> int: + """ + Asynchronously returns the approximate number of vectors stored in this namespace. + """ + if self.metadata is None or 'approx_count' not in self.metadata: + await self.arefresh_metadata() + return self.metadata.pop('approx_count', 0) + + async def acreated_at(self) -> Optional[datetime]: + """ + Asynchronously returns the creation date of this namespace. + """ + if self.metadata is None or 'created_at' not in self.metadata: + await self.arefresh_metadata() + return self.metadata.pop('created_at', None) + + async def aschema(self) -> NamespaceSchema: + """ + Asynchronously returns the current schema for the namespace. + """ + response = await self.backend.amake_api_request('namespaces', self.name, 'schema', method='GET') + return parse_namespace_schema(response["content"]) + + async def aupdate_schema(self, schema_updates: NamespaceSchema): + """ + Asynchronously writes updates to the schema for a namespace. + Returns the final schema after updates are done. + + See https://turbopuffer.com/docs/schema for specifics on allowed updates. + """ + request_payload = json.dumps({key: value.as_dict() for key, value in schema_updates.items()}).encode() + response = await self.backend.amake_api_request('namespaces', self.name, 'schema', method='POST', payload=request_payload) + return parse_namespace_schema(response["content"]) + + async def acopy_from_namespace(self, source_namespace: str): + """ + Asynchronously copies all documents from another namespace to this namespace. + + See: https://turbopuffer.com/docs/upsert#parameters `copy_from_namespace` + for specifics on how this works. + """ + payload = { + "copy_from_namespace": source_namespace + } + response = await self.backend.amake_api_request('namespaces', self.name, payload=payload) + assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid copy_from_namespace() response: {response}' + + @overload + async def aupsert(self, + ids: Union[List[int], List[str]], + vectors: List[List[float]], + attributes: Optional[Dict[str, List[Optional[Union[str, int]]]]] = None, + schema: Optional[Dict] = None, + distance_metric: Optional[str] = None, + encryption: Optional[EncryptionDict] = None) -> None: + """ + Asynchronously creates or updates multiple vectors provided in a column-oriented layout. + If this call succeeds, data is guaranteed to be durably written to object storage. + + Upserting a vector will overwrite any existing vector with the same ID. + """ + ... + + @overload + async def aupsert(self, + data: Union[dict, VectorColumns], + distance_metric: Optional[str] = None, + schema: Optional[Dict] = None, + encryption: Optional[EncryptionDict] = None) -> None: + """ + Asynchronously creates or updates multiple vectors provided in a column-oriented layout. + If this call succeeds, data is guaranteed to be durably written to object storage. + + Upserting a vector will overwrite any existing vector with the same ID. + """ + ... + + @overload + async def aupsert(self, + data: Union[Iterable[dict], Iterable[VectorRow]], + distance_metric: Optional[str] = None, + schema: Optional[Dict] = None, + encryption: Optional[EncryptionDict] = None) -> None: + """ + Asynchronously creates or updates a multiple vectors provided as a list or iterator. + If this call succeeds, data is guaranteed to be durably written to object storage. + + Upserting a vector will overwrite any existing vector with the same ID. + """ + ... + + @overload + async def aupsert(self, + data: VectorResult, + distance_metric: Optional[str] = None, + schema: Optional[Dict] = None, + encryption: Optional[EncryptionDict] = None) -> None: + """ + Asynchronously creates or updates multiple vectors. + If this call succeeds, data is guaranteed to be durably written to object storage. + + Upserting a vector will overwrite any existing vector with the same ID. + """ + ... + + async def aupsert(self, + data=None, + ids=None, + vectors=None, + attributes=None, + schema=None, + distance_metric=None, + encryption=None) -> None: + """ + Asynchronously creates or updates vectors. + """ + if data is None: + if ids is not None and vectors is not None: + return await self.aupsert(VectorColumns(ids=ids, vectors=vectors, attributes=attributes), schema=schema, distance_metric=distance_metric, encryption=encryption) + else: + raise ValueError('upsert() requires both ids= and vectors= be set.') + elif (ids is not None and attributes is None) or (attributes is not None and schema is None): + # Offset arguments to handle positional arguments case with no data field. + return await self.aupsert(VectorColumns(ids=data, vectors=ids, attributes=vectors), schema=attributes, distance_metric=distance_metric, encryption=encryption) + elif isinstance(data, VectorColumns): + payload = {**data.__dict__} + + if distance_metric is not None: + payload["distance_metric"] = distance_metric + + if schema is not None: + payload["schema"] = schema + + if encryption is not None: + payload["encryption"] = encryption + + response = await self.backend.amake_api_request('namespaces', self.name, payload=payload) + + assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid upsert() response: {response}' + self.metadata = None # Invalidate cached metadata + elif isinstance(data, VectorRow): + raise ValueError('upsert() should be called on a list of vectors, got single vector.') + elif isinstance(data, list): + if len(data) == 0: + return + if isinstance(data[0], dict): + return await self.aupsert(VectorColumns.from_rows(data), schema=schema, distance_metric=distance_metric, encryption=encryption) + elif isinstance(data[0], VectorRow): + return await self.aupsert(VectorColumns.from_rows(data), schema=schema, distance_metric=distance_metric, encryption=encryption) + elif isinstance(data[0], VectorColumns): + for columns in data: + await self.aupsert(columns, schema=schema, distance_metric=distance_metric, encryption=encryption) + return + else: + raise ValueError(f'Unsupported list data type: {type(data[0])}') + elif isinstance(data, dict): + if 'id' in data: + raise ValueError('upsert() should be called on a list of vectors, got single vector.') + elif 'ids' in data: + return await self.aupsert(VectorColumns.from_dict(data), schema=data.get('schema', None), distance_metric=distance_metric, encryption=encryption) + else: + raise ValueError('Provided dict is missing ids.') + elif 'pandas' in sys.modules and isinstance(data, sys.modules['pandas'].DataFrame): + if 'id' not in data.keys(): + raise ValueError('Provided pd.DataFrame is missing an id column.') + if 'vector' not in data.keys(): + raise ValueError('Provided pd.DataFrame is missing a vector column.') + + for i in range(0, len(data), tpuf.upsert_batch_size): + batch = data[i:i+tpuf.upsert_batch_size] + attributes = dict() + for key, values in batch.items(): + if key != 'id' and key != 'vector': + attributes[key] = values.tolist() + columns = tpuf.VectorColumns( + ids=batch['id'].tolist(), + vectors=batch['vector'].transform(lambda x: x.tolist()).tolist(), + attributes=attributes + ) + await self.aupsert(columns, schema=schema, distance_metric=distance_metric, encryption=encryption) + return + elif isinstance(data, Iterable): + async for batch in abatch_iter(data, tpuf.upsert_batch_size): + await self.aupsert(batch, schema=schema, distance_metric=distance_metric, encryption=encryption) + return + else: + raise ValueError(f'Unsupported data type: {type(data)}') + + async def adelete(self, ids: Union[int, str, List[int], List[str]]) -> None: + """ + Asynchronously deletes vectors by id. + """ + if isinstance(ids, int) or isinstance(ids, str): + response = await self.backend.amake_api_request('namespaces', self.name, payload={ + 'ids': [ids], + 'vectors': [None], + }) + elif isinstance(ids, list): + response = await self.backend.amake_api_request('namespaces', self.name, payload={ + 'ids': ids, + 'vectors': [None] * len(ids), + }) + else: + raise ValueError(f'Unsupported ids type: {type(ids)}') + + assert response.get('content', dict()).get('status', '') == 'OK', f'Invalid delete() response: {response}' + self.metadata = None # Invalidate cached metadata + + async def adelete_by_filter(self, filters: Filters) -> int: + """ + Asynchronously deletes vectors by filter. + """ + response = await self.backend.amake_api_request('namespaces', self.name, payload={ + 'delete_by_filter': filters + }) + response_content = response.get('content', dict()) + assert response_content.get('status', '') == 'OK', f'Invalid delete_by_filter() response: {response}' + self.metadata = None # Invalidate cached metadata + return response_content.get('rows_affected') + + @overload + async def aquery(self, + vector: Optional[List[float]] = None, + distance_metric: Optional[str] = None, + top_k: int = 10, + include_vectors: bool = False, + include_attributes: Optional[Union[List[str], bool]] = None, + filters: Optional[Filters] = None, + rank_by: Optional[RankInput] = None, + consistency: Optional[ConsistencyDict] = None + ) -> AsyncVectorResult: + ... + + @overload + async def aquery(self, query_data: VectorQuery) -> AsyncVectorResult: + ... + + @overload + async def aquery(self, query_data: dict) -> AsyncVectorResult: + ... + + async def aquery(self, + query_data=None, + vector=None, + distance_metric=None, + top_k=None, + include_vectors=None, + include_attributes=None, + filters=None, + rank_by=None, + consistency=None) -> AsyncVectorResult: + """ + Asynchronously searches vectors matching the search query. + + See https://turbopuffer.com/docs/reference/query for query filter parameters. + """ + if query_data is None: + return await self.aquery(VectorQuery( + vector=vector, + distance_metric=distance_metric, + top_k=top_k, + include_vectors=include_vectors, + include_attributes=include_attributes, + filters=filters, + rank_by=rank_by, + consistency=consistency + )) + if not isinstance(query_data, VectorQuery): + if isinstance(query_data, dict): + query_data = VectorQuery.from_dict(query_data) + else: + raise ValueError(f'query() input type must be compatible with turbopuffer.VectorQuery: {type(query_data)}') + + response = await self.backend.amake_api_request('namespaces', self.name, 'query', payload=query_data.__dict__) + result = AsyncVectorResult(response.get('content', dict()), namespace=self) + result.performance = response.get('performance') + return result + + async def avectors(self, cursor: Optional[Cursor] = None) -> AsyncVectorResult: + """ + Asynchronously exports the entire dataset at full precision. + An AsyncVectorResult is returned that will lazily load batches of vectors if treated as an async Iterator. + + If you want to look up vectors by ID, use the query function with an id filter. + """ + # Handle None cursor by passing empty string or omitting parameter entirely + query_params = {} + if cursor is not None: + query_params['cursor'] = cursor + + response = await self.backend.amake_api_request('namespaces', self.name, query=query_params) + content = response.get('content', dict()) + next_cursor = content.pop('next_cursor', None) + result = AsyncVectorResult(content, namespace=self, next_cursor=next_cursor) + result.performance = response.get('performance') + return result + + async def adelete_all_indexes(self) -> None: + """ + Asynchronously deletes all indexes in a namespace. + """ + response = await self.backend.amake_api_request('namespaces', self.name, 'index', method='DELETE') + assert response.get('content', dict()).get('status', '') == 'ok', f'Invalid delete_all_indexes() response: {response}' + + async def adelete_all(self) -> None: + """ + Asynchronously deletes all data as well as all indexes. + """ + response = await self.backend.amake_api_request('namespaces', self.name, method='DELETE') + assert response.get('content', dict()).get('status', '') == 'ok', f'Invalid delete_all() response: {response}' + self.metadata = None # Invalidate cached metadata + + async def arecall(self, num=20, top_k=10) -> float: + """ + Asynchronously evaluates the recall performance of ANN queries in this namespace. + + When you call this function, it selects 'num' random vectors that were previously inserted. + For each of these vectors, it performs an ANN index search as well as a ground truth exhaustive search. + + Recall is calculated as the ratio of matching vectors between the two search results. + """ + response = await self.backend.amake_api_request('namespaces', self.name, '_debug', 'recall', query={'num': num, 'top_k': top_k}) + content = response.get('content', dict()) + assert 'avg_recall' in content, f'Invalid recall() response: {response}' + return float(content.get('avg_recall')) + + # All sync methods in AsyncNamespace have been replaced by _run_async() helper + # that calls their respective async counterparts in AsyncNamespace + + +def namespaces(api_key: Optional[str] = None) -> Iterable[Namespace]: + """ + Lists all turbopuffer namespaces for a given api_key. + If no api_key is provided, the globally configured API key will be used. + """ + # Use the async version with run_until_complete + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No event loop exists in this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Call the async version and convert results to synchronous iterators + async_iterator = loop.run_until_complete(anamespaces(api_key)) + async_namespaces = loop.run_until_complete(async_iterator.load()) + + # Convert AsyncNamespace instances to Namespace instances + sync_namespaces = [Namespace(ns.name, api_key=api_key) for ns in async_namespaces] + return sync_namespaces + + +async def anamespaces(api_key: Optional[str] = None) -> AsyncIterator[AsyncNamespace]: + """ + Asynchronously lists all turbopuffer namespaces for a given api_key. If no api_key is provided, the globally configured API key will be used. + + Returns an async iterator that can be used with 'async for'. """ backend = Backend(api_key) - response = backend.make_api_request('namespaces') + response = await backend.amake_api_request('namespaces') content = response.get('content', dict()) next_cursor = content.pop('next_cursor', None) - return NamespaceIterator(backend, content.pop('namespaces', list()), next_cursor) + return AsyncNamespaceIterator(backend, content.pop('namespaces', list()), next_cursor) diff --git a/turbopuffer/vectors.py b/turbopuffer/vectors.py index 3a29760..3bc9116 100644 --- a/turbopuffer/vectors.py +++ b/turbopuffer/vectors.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import sys -from typing import Optional, Union, List, Iterable, Dict, overload +import asyncio +from typing import Optional, Union, List, Iterable, Dict, overload, AsyncIterator, TypeVar, Generic from itertools import islice @@ -12,6 +13,27 @@ def batch_iter(iterable, n): return yield batch +async def abatch_iter(async_iterable, n): + """Async version of batch_iter that works with async iterables.""" + batch = [] + # Handle both async iterables and regular iterables + if hasattr(async_iterable, '__aiter__'): + async for item in async_iterable: + batch.append(item) + if len(batch) >= n: + yield batch + batch = [] + if batch: + yield batch + else: + # Handle regular iterables by converting to list and using batch_iter + try: + for items in batch_iter(async_iterable, n): + yield items + except TypeError: + # If we can't iterate, treat as a single item + yield [async_iterable] + class Cursor(str): pass @@ -274,9 +296,12 @@ def __init__(self, initial_data: Optional[DATA] = None, namespace: Optional['Nam self.data = VectorResult.load_data(initial_data) + @staticmethod def load_data(initial_data: DATA) -> SET_DATA: if initial_data: if isinstance(initial_data, list): + if len(initial_data) == 0: + return [] if isinstance(initial_data[0], dict): return [VectorRow.from_dict(row) for row in initial_data] elif isinstance(initial_data[0], VectorRow): @@ -291,6 +316,7 @@ def load_data(initial_data: DATA) -> SET_DATA: raise ValueError('VectorResult from Iterable not yet supported.') else: raise ValueError(f'Unsupported data type: {type(initial_data)}') + return [] def __str__(self) -> str: if not self.next_cursor and self.offset == 0: @@ -348,3 +374,101 @@ def __next__(self): self.next_cursor = content.pop('next_cursor', None) self.data = VectorResult.load_data(content) return self.__next__() + + +class AsyncVectorResult: + """ + The AsyncVectorResult type represents a set of vectors that are the result of an async query. + + AsyncVectorResult implements an async iterator interface, allowing you to use it with + async for loops. For full buffer access, await the load() method to retrieve all results. + """ + + namespace: Optional['AsyncNamespace'] = None + data: Optional[SET_DATA] = None + index: int = -1 + offset: int = 0 + next_cursor: Optional[Cursor] = None + + performance: Optional[dict] = None + + def __init__(self, initial_data: Optional[DATA] = None, namespace: Optional['AsyncNamespace'] = None, next_cursor: Optional[Cursor] = None): + self.namespace = namespace + self.index = -1 + self.offset = 0 + self.next_cursor = next_cursor + + self.data = VectorResult.load_data(initial_data) + + def __str__(self) -> str: + if not self.next_cursor and self.offset == 0: + return str(self.data) + else: + return ("AsyncVectorResult(" + f"namespace='{self.namespace.name}', " + f"offset={self.offset}, " + f"next_cursor='{self.next_cursor}', " + f"data={self.data})") + + async def load(self) -> List[VectorRow]: + """ + Loads and returns all results, fetching additional pages as needed. + + This method buffers all data in memory. For large result sets, consider + using the async iterator interface instead. + """ + if not self.next_cursor: + if isinstance(self.data, list): + return self.data + elif isinstance(self.data, VectorColumns): + return [self.data[i] for i in range(len(self.data))] + return [] + + # Create new iterator and exhaust it to load all results + result = [] + async for item in self.__aiter__(): + result.append(item) + + # Update our state with the fully loaded data + self.data = result + self.offset = 0 + self.index = -1 + self.next_cursor = None + + return result + + def __aiter__(self) -> 'AsyncVectorResult': + # Reset state to start fresh iteration + return AsyncVectorResult(self.data, self.namespace, self.next_cursor) + + async def __anext__(self) -> VectorRow: + # Handle the case where we have data in memory + if self.data is not None and self.index + 1 < len(self.data): + self.index += 1 + if isinstance(self.data, list): + return self.data[self.index] + elif isinstance(self.data, VectorColumns): + return self.data[self.index] + else: + raise ValueError(f'Unsupported data type: {type(self.data)}') + # Handle the case where we're at the end of our data + elif self.next_cursor is None: + raise StopAsyncIteration + # Handle the case where we need to fetch more data + else: + response = await self.namespace.backend.amake_api_request( + 'vectors', + self.namespace.name, + query={'cursor': self.next_cursor} + ) + content = response.get('content', dict()) + + # Update our state + if self.data is not None: + self.offset += len(self.data) + self.index = -1 + self.next_cursor = content.pop('next_cursor', None) + self.data = VectorResult.load_data(content) + + # Recursively call __anext__ to handle the case where data is empty + return await self.__anext__() diff --git a/turbopuffer/version.py b/turbopuffer/version.py index 09829e4..6c5007c 100644 --- a/turbopuffer/version.py +++ b/turbopuffer/version.py @@ -1 +1 @@ -VERSION = "0.1.32" +VERSION = "0.2.0"